Add Vertex service account support

This commit is contained in:
Oliver
2026-04-25 20:39:58 -04:00
parent 489a4d934e
commit 6d11f9ed77
17 changed files with 1243 additions and 36 deletions
@@ -15,7 +15,7 @@ const (
geminiTokenCacheSkew = 5 * time.Minute
)
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return "", errors.New("not a gemini oauth account")
if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not a gemini oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := GeminiTokenCacheKey(account)
@@ -168,7 +171,51 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if p.tokenCache != nil {
var lockErr error
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(200 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if p.tokenCache != nil {
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func GeminiTokenCacheKey(account *Account) string {
if account != nil && account.Type == AccountTypeServiceAccount {
if key, err := parseVertexServiceAccountKey(account); err == nil {
return vertexServiceAccountCacheKey(account, key)
}
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "gemini:" + projectID