Files
sub2api/backend/internal/service/kiro_token_provider_test.go
2026-04-30 14:02:05 +08:00

113 lines
3.1 KiB
Go

//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type kiroTokenProviderRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setErrorID int64
setErrorMsg string
}
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
type kiroTokenProviderSequenceRepo struct {
kiroTokenProviderRepo
accounts []*Account
reads int
}
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
if len(r.accounts) == 0 {
return nil, errors.New("account not found")
}
idx := r.reads
if idx >= len(r.accounts) {
idx = len(r.accounts) - 1
}
r.reads++
return r.accounts[idx], nil
}
type stubKiroAccountTokenRefresher struct {
tokenInfo *KiroTokenInfo
err error
}
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
if s.err != nil {
return nil, s.err
}
return s.tokenInfo, nil
}
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
if tokenInfo == nil {
return nil
}
return map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": tokenInfo.ExpiresAt,
}
}
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "old-refresh"},
}
repo := &kiroTokenProviderRepo{
mockAccountRepoForGemini: mockAccountRepoForGemini{
accountsByID: map[int64]*Account{account.ID: account},
},
}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
require.Error(t, err)
require.Empty(t, token)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
require.Contains(t, repo.setErrorMsg, "invalid_grant")
}
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
usedAccount := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "old-refresh"},
}
latestAccount := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
}
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
require.NoError(t, err)
require.Equal(t, "fresh-access", token)
require.Equal(t, 0, repo.setErrorCalls)
}