113 lines
3.1 KiB
Go
113 lines
3.1 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type kiroTokenProviderRepo struct {
|
|
mockAccountRepoForGemini
|
|
setErrorCalls int
|
|
setErrorID int64
|
|
setErrorMsg string
|
|
}
|
|
|
|
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
|
r.setErrorCalls++
|
|
r.setErrorID = id
|
|
r.setErrorMsg = errorMsg
|
|
return nil
|
|
}
|
|
|
|
type kiroTokenProviderSequenceRepo struct {
|
|
kiroTokenProviderRepo
|
|
accounts []*Account
|
|
reads int
|
|
}
|
|
|
|
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
|
if len(r.accounts) == 0 {
|
|
return nil, errors.New("account not found")
|
|
}
|
|
idx := r.reads
|
|
if idx >= len(r.accounts) {
|
|
idx = len(r.accounts) - 1
|
|
}
|
|
r.reads++
|
|
return r.accounts[idx], nil
|
|
}
|
|
|
|
type stubKiroAccountTokenRefresher struct {
|
|
tokenInfo *KiroTokenInfo
|
|
err error
|
|
}
|
|
|
|
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
|
|
if s.err != nil {
|
|
return nil, s.err
|
|
}
|
|
return s.tokenInfo, nil
|
|
}
|
|
|
|
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
|
if tokenInfo == nil {
|
|
return nil
|
|
}
|
|
return map[string]any{
|
|
"access_token": tokenInfo.AccessToken,
|
|
"expires_at": tokenInfo.ExpiresAt,
|
|
}
|
|
}
|
|
|
|
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
|
|
account := &Account{
|
|
ID: 42,
|
|
Platform: PlatformKiro,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
|
}
|
|
repo := &kiroTokenProviderRepo{
|
|
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
|
accountsByID: map[int64]*Account{account.ID: account},
|
|
},
|
|
}
|
|
provider := NewKiroTokenProvider(repo, nil, nil)
|
|
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
|
|
|
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
|
|
require.Error(t, err)
|
|
require.Empty(t, token)
|
|
require.Equal(t, 1, repo.setErrorCalls)
|
|
require.Equal(t, account.ID, repo.setErrorID)
|
|
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
|
|
require.Contains(t, repo.setErrorMsg, "invalid_grant")
|
|
}
|
|
|
|
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
|
|
usedAccount := &Account{
|
|
ID: 42,
|
|
Platform: PlatformKiro,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
|
}
|
|
latestAccount := &Account{
|
|
ID: 42,
|
|
Platform: PlatformKiro,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
|
|
}
|
|
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
|
|
provider := NewKiroTokenProvider(repo, nil, nil)
|
|
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
|
|
|
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "fresh-access", token)
|
|
require.Equal(t, 0, repo.setErrorCalls)
|
|
}
|