Merge pull request #1977 from sholiverlee/vertex

feat: 支持 Vertex Service Account(Anthropic / Gemini)
This commit is contained in:
Wesley Liddick
2026-04-29 15:48:26 +08:00
committed by GitHub
19 changed files with 1330 additions and 36 deletions
+2 -2
View File
@@ -145,13 +145,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -178,7 +179,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
+6 -5
View File
@@ -26,11 +26,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI
)
// Redeem type constants
@@ -98,7 +98,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -117,7 +117,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -64,6 +64,7 @@ func isOpenAIImageModel(model string) bool {
type AccountTestService struct {
accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider
claudeTokenProvider *ClaudeTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
@@ -74,6 +75,7 @@ type AccountTestService struct {
func NewAccountTestService(
accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider,
claudeTokenProvider *ClaudeTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
@@ -82,6 +84,7 @@ func NewAccountTestService(
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
claudeTokenProvider: claudeTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
@@ -210,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
}
if account.Type == AccountTypeServiceAccount {
return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
}
// Determine authentication method and API URL
var authToken string
@@ -313,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
testModelID = mappedModel
} else {
testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload)
vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
}
if s.claudeTokenProvider == nil {
return s.sendErrorAndEnd(c, "Claude token provider not configured")
}
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
}
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusForbidden {
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, errMsg)
}
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
@@ -711,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
testModelID = geminicli.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeAPIKey {
// For static upstream credentials with model mapping, map the model
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -740,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
case AccountTypeServiceAccount:
req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -893,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
}
func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
if s.geminiTokenProvider == nil {
return nil, fmt.Errorf("gemini token provider not configured")
}
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("failed to get service account access token: %w", err)
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
return req, nil
}
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
var inner map[string]any
@@ -17,7 +17,7 @@ const (
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not an anthropic oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := ClaudeTokenCacheKey(account)
@@ -157,3 +160,42 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if p.tokenCache != nil {
var lockErr error
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if p.tokenCache != nil {
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
return "", errors.New("not an anthropic oauth or service account")
}
cacheKey := ClaudeTokenCacheKey(account)
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
+6 -5
View File
@@ -41,11 +41,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI
)
// Redeem type constants
@@ -0,0 +1,68 @@
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "Bearer inbound-token")
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
c.Request.Header.Set("Anthropic-Version", "2023-06-01")
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeServiceAccount,
Credentials: map[string]any{
"project_id": "vertex-proj",
"location": "us-east5",
},
}
body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
svc := &GatewayService{}
req, err := svc.buildUpstreamRequest(
context.Background(),
c,
account,
body,
"vertex-token",
"service_account",
"claude-sonnet-4-5@20250929",
false,
false,
)
require.NoError(t, err)
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
got := readRequestBodyForTest(t, req)
require.Equal(t, "", gjson.GetBytes(got, "model").String())
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
}
func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
t.Helper()
require.NotNil(t, req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
return body
}
+87 -1
View File
@@ -3597,7 +3597,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
if account.Type == AccountTypeServiceAccount {
requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
} else {
requestedModel = claude.NormalizeModelID(requestedModel)
}
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
@@ -3617,6 +3621,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
return apiKey, "apikey", nil
case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
case AccountTypeServiceAccount:
if account.Platform != PlatformAnthropic {
return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
}
if s.claudeTokenProvider == nil {
return "", "", errors.New("claude token provider not configured")
}
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "service_account", nil
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -4219,6 +4235,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
if candidate, matched := account.ResolveMappedModel(reqModel); matched {
mappedModel = candidate
mappingSource = "account"
} else {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
if normalized != reqModel {
mappedModel = normalized
mappingSource = "vertex"
}
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
@@ -5688,6 +5716,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
}
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -5874,6 +5906,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
return req, nil
}
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
token string,
modelID string,
reqStream bool,
) (*http.Request, error) {
vertexBody, err := buildVertexAnthropicRequestBody(body)
if err != nil {
return nil, err
}
setOpsUpstreamRequestBody(c, vertexBody)
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
if err != nil {
return nil, err
}
if c != nil && c.Request != nil {
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(strings.TrimSpace(key))
if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
continue
}
wireKey := resolveWireCasing(key)
for _, v := range values {
addHeaderRaw(req.Header, wireKey, v)
}
}
}
req.Header.Del("authorization")
req.Header.Del("x-api-key")
req.Header.Del("x-goog-api-key")
req.Header.Del("cookie")
req.Header.Del("anthropic-version")
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
setHeaderRaw(req.Header, "content-type", "application/json")
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
"url": req.URL.String(),
"token_type": "service_account",
"model": modelID,
"stream": strconv.FormatBool(reqStream),
})
return req, nil
}
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号,需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
@@ -579,7 +579,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -712,6 +712,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = "x-request-id"
case AccountTypeServiceAccount:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
if err != nil {
return nil, "", err
}
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -1094,7 +1124,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -1213,6 +1243,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = "x-request-id"
case AccountTypeServiceAccount:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
if err != nil {
return nil, "", err
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
@@ -15,7 +15,7 @@ const (
geminiTokenCacheSkew = 5 * time.Minute
)
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return "", errors.New("not a gemini oauth account")
if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not a gemini oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := GeminiTokenCacheKey(account)
@@ -168,7 +171,51 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if p.tokenCache != nil {
var lockErr error
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(200 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if p.tokenCache != nil {
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func GeminiTokenCacheKey(account *Account) string {
if account != nil && account.Type == AccountTypeServiceAccount {
if key, err := parseVertexServiceAccountKey(account); err == nil {
return vertexServiceAccountCacheKey(account, key)
}
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "gemini:" + projectID
@@ -0,0 +1,303 @@
package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
vertexDefaultLocation = "us-central1"
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
vertexServiceAccountCacheSkew = 5 * time.Minute
vertexAnthropicVersion = "vertex-2023-10-16"
)
var (
vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
)
type vertexServiceAccountKey struct {
Type string `json:"type"`
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
TokenURI string `json:"token_uri"`
}
type vertexTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
func (a *Account) IsVertexServiceAccount() bool {
return a != nil && a.Type == AccountTypeServiceAccount
}
func (a *Account) VertexProjectID() string {
if a == nil {
return ""
}
if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
return v
}
key, err := parseVertexServiceAccountKey(a)
if err == nil {
return strings.TrimSpace(key.ProjectID)
}
return ""
}
func (a *Account) VertexLocation(model string) string {
if a == nil {
return vertexDefaultLocation
}
if model != "" && a.Credentials != nil {
if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
return strings.TrimSpace(loc)
}
}
}
if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
return v
}
return vertexDefaultLocation
}
func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
if account == nil || account.Credentials == nil {
return nil, errors.New("service account credentials not configured")
}
if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
return parseVertexServiceAccountJSON([]byte(raw))
}
if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
return parseVertexServiceAccountJSON([]byte(raw))
}
if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
b, _ := json.Marshal(nested)
return parseVertexServiceAccountJSON(b)
}
if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
b, _ := json.Marshal(nested)
return parseVertexServiceAccountJSON(b)
}
return nil, errors.New("service_account_json not found in credentials")
}
func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
var key vertexServiceAccountKey
if err := json.Unmarshal(raw, &key); err != nil {
return nil, fmt.Errorf("invalid service account json: %w", err)
}
if strings.TrimSpace(key.ClientEmail) == "" {
return nil, errors.New("service account json missing client_email")
}
if strings.TrimSpace(key.PrivateKey) == "" {
return nil, errors.New("service account json missing private_key")
}
if strings.TrimSpace(key.ProjectID) == "" {
return nil, errors.New("service account json missing project_id")
}
if strings.TrimSpace(key.TokenURI) == "" {
key.TokenURI = vertexDefaultTokenURL
}
return &key, nil
}
func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
fingerprint := ""
if key != nil {
sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
fingerprint = hex.EncodeToString(sum[:8])
}
if fingerprint == "" && account != nil {
fingerprint = fmt.Sprintf("account:%d", account.ID)
}
return "vertex:service_account:" + fingerprint
}
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
now := time.Now()
claims := jwt.MapClaims{
"iss": key.ClientEmail,
"scope": vertexCloudPlatformScope,
"aud": key.TokenURI,
"iat": now.Unix(),
"exp": now.Add(time.Hour).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
if strings.TrimSpace(key.PrivateKeyID) != "" {
token.Header["kid"] = key.PrivateKeyID
}
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
if err != nil {
return "", 0, fmt.Errorf("parse service account private key: %w", err)
}
assertion, err := token.SignedString(privateKey)
if err != nil {
return "", 0, fmt.Errorf("sign service account assertion: %w", err)
}
values := url.Values{}
values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
values.Set("assertion", assertion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
if err != nil {
return "", 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", 0, fmt.Errorf("service account token request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var parsed vertexTokenResponse
_ = json.Unmarshal(body, &parsed)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := strings.TrimSpace(parsed.ErrorDesc)
if msg == "" {
msg = strings.TrimSpace(parsed.Error)
}
if msg == "" {
msg = string(bytes.TrimSpace(body))
}
return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
}
if strings.TrimSpace(parsed.AccessToken) == "" {
return "", 0, errors.New("service account token response missing access_token")
}
ttl := time.Duration(parsed.ExpiresIn) * time.Second
if ttl <= 0 {
ttl = time.Hour
}
if ttl > vertexServiceAccountCacheSkew {
ttl -= vertexServiceAccountCacheSkew
}
return parsed.AccessToken, ttl, nil
}
func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
projectID = strings.TrimSpace(projectID)
location = strings.TrimSpace(location)
model = strings.TrimSpace(model)
action = strings.TrimSpace(action)
if projectID == "" {
return "", errors.New("vertex project_id is required")
}
if location == "" {
location = vertexDefaultLocation
}
if !vertexLocationPattern.MatchString(location) {
return "", fmt.Errorf("invalid vertex location: %s", location)
}
if model == "" {
return "", errors.New("vertex model is required")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
default:
return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
}
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
if location == "global" {
host = "aiplatform.googleapis.com"
}
u := fmt.Sprintf(
"https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
host,
url.PathEscape(projectID),
url.PathEscape(location),
url.PathEscape(model),
action,
)
if stream {
u += "?alt=sse"
}
return u, nil
}
func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
projectID = strings.TrimSpace(projectID)
location = strings.TrimSpace(location)
model = strings.TrimSpace(model)
if projectID == "" {
return "", errors.New("vertex project_id is required")
}
if location == "" {
location = vertexDefaultLocation
}
if !vertexLocationPattern.MatchString(location) {
return "", fmt.Errorf("invalid vertex location: %s", location)
}
if model == "" {
return "", errors.New("vertex model is required")
}
action := "rawPredict"
if stream {
action = "streamRawPredict"
}
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
if location == "global" {
host = "aiplatform.googleapis.com"
}
escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
return fmt.Sprintf(
"https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
host,
url.PathEscape(projectID),
url.PathEscape(location),
escapedModel,
action,
), nil
}
func normalizeVertexAnthropicModelID(model string) string {
model = strings.TrimSpace(model)
if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
return model
}
if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
return m[1] + "@" + m[2]
}
return model
}
func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
}
delete(payload, "model")
payload["anthropic_version"] = vertexAnthropicVersion
return json.Marshal(payload)
}
@@ -0,0 +1,77 @@
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildVertexGeminiURL(t *testing.T) {
got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
require.NoError(t, err)
require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
}
func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
require.NoError(t, err)
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
}
func TestBuildVertexAnthropicURL(t *testing.T) {
got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
require.NoError(t, err)
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
}
func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
require.NoError(t, err)
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
}
func TestNormalizeVertexAnthropicModelID(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
}
func TestBuildVertexAnthropicRequestBody(t *testing.T) {
got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
require.NoError(t, err)
require.Equal(t, "", gjson.GetBytes(got, "model").String())
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
}
func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
_, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid vertex location")
}
func TestParseVertexServiceAccountKey(t *testing.T) {
raw := `{
"type": "service_account",
"project_id": "vertex-proj",
"private_key_id": "kid",
"private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
"client_email": "svc@vertex-proj.iam.gserviceaccount.com"
}`
account := &Account{
Type: AccountTypeServiceAccount,
Platform: PlatformGemini,
Credentials: map[string]any{
"service_account_json": raw,
},
}
key, err := parseVertexServiceAccountKey(account)
require.NoError(t, err)
require.Equal(t, "vertex-proj", key.ProjectID)
require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
}