Merge remote-tracking branch 'pr/2131' into release/v0.1.133

# Conflicts:
#	backend/cmd/server/wire_gen.go
#	backend/internal/config/config.go
#	backend/internal/service/gateway_service.go
#	backend/internal/service/pricing_service.go
#	backend/internal/service/wire.go
#	deploy/config.example.yaml
#	frontend/src/views/admin/AccountsView.vue
This commit is contained in:
kone
2026-05-16 01:55:39 +08:00
111 changed files with 16343 additions and 439 deletions
+5
View File
@@ -93,6 +93,7 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
kiroOAuth *service.KiroOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
@@ -216,6 +217,10 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
{"KiroOAuthService", func() error {
kiroOAuth.Stop()
return nil
}},
{"OpenAIWSPool", func() error {
if openAIGateway != nil {
openAIGateway.CloseOpenAIWSPool()
+14 -5
View File
@@ -146,13 +146,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -166,6 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
kiroOAuthHandler := admin.NewKiroOAuthHandler(kiroOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
@@ -179,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
kiroCooldownStore := service.ProvideKiroCooldownStore(redisClient)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
@@ -236,7 +240,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
contentModerationHandler := admin.NewContentModerationHandler(contentModerationService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, kiroOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -260,13 +264,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -316,6 +320,7 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
kiroOAuth *service.KiroOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
@@ -438,6 +443,10 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
{"KiroOAuthService", func() error {
kiroOAuth.Stop()
return nil
}},
{"OpenAIWSPool", func() error {
if openAIGateway != nil {
openAIGateway.CloseOpenAIWSPool()
+2
View File
@@ -36,6 +36,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
antigravityOAuthSvc,
nil,
nil,
nil,
cfg,
nil,
)
@@ -72,6 +73,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
openAIOAuthSvc,
geminiOAuthSvc,
antigravityOAuthSvc,
nil, // kiroOAuth
nil, // openAIGateway
nil, // scheduledTestRunner
nil, // backupSvc
+10
View File
@@ -682,6 +682,8 @@ type GatewayConfig struct {
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
// KiroStreamKeepaliveInterval: Kiro 流式 keepalive 间隔(秒),0使用默认 25 秒
KiroStreamKeepaliveInterval int `mapstructure:"kiro_stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
MaxLineSize int `mapstructure:"max_line_size"`
@@ -1752,6 +1754,7 @@ func setDefaults() {
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
viper.SetDefault("gateway.kiro_stream_keepalive_interval", 25)
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
@@ -2369,6 +2372,13 @@ func (c *Config) Validate() error {
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
}
if c.Gateway.KiroStreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be non-negative")
}
if c.Gateway.KiroStreamKeepaliveInterval != 0 &&
(c.Gateway.KiroStreamKeepaliveInterval < 5 || c.Gateway.KiroStreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be 0 or between 5-30 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
+16
View File
@@ -23,6 +23,7 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformKiro = "kiro"
)
// Account type constants
@@ -117,6 +118,21 @@ var DefaultAntigravityModelMapping = map[string]string{
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
// DefaultKiroModelMapping 是 Kiro 平台的默认模型映射。
// 键为对外暴露/允许请求的模型名,值为实际发送到 Kiro 上游的模型名。
var DefaultKiroModelMapping = map[string]string{
"claude-opus-4-6": "claude-opus-4.6",
"claude-opus-4-6-thinking": "claude-opus-4.6",
"claude-sonnet-4-6": "claude-sonnet-4.6",
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
"claude-opus-4-5-20251101": "claude-opus-4.5",
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
}
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
+57 -1
View File
@@ -1,6 +1,9 @@
package domain
import "testing"
import (
"strings"
"testing"
)
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
t.Parallel()
@@ -24,3 +27,56 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
}
}
}
func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
t.Parallel()
expected := map[string]string{
"claude-opus-4-6": "claude-opus-4.6",
"claude-opus-4-6-thinking": "claude-opus-4.6",
"claude-sonnet-4-6": "claude-sonnet-4.6",
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
"claude-opus-4-5-20251101": "claude-opus-4.5",
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
}
if len(DefaultKiroModelMapping) != len(expected) {
t.Fatalf("expected %d Kiro mappings, got %d", len(expected), len(DefaultKiroModelMapping))
}
for model, want := range expected {
if got := DefaultKiroModelMapping[model]; got != want {
t.Fatalf("unexpected Kiro mapping for %q: got %q want %q", model, got, want)
}
}
for _, model := range []string{
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-sonnet-4",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"gpt-4o",
"gpt-4",
"deepseek-3-2",
"minimax-m2-1",
"qwen3-coder-next",
"claude-opus-4-7",
"claude-sonnet-4-6-chat",
} {
if _, ok := DefaultKiroModelMapping[model]; ok {
t.Fatalf("did not expect %q to remain in DefaultKiroModelMapping", model)
}
}
for model := range DefaultKiroModelMapping {
if strings.HasSuffix(model, "-agentic") {
t.Fatalf("did not expect agentic Kiro mapping %q", model)
}
if strings.HasSuffix(model, "-chat") {
t.Fatalf("did not expect chat-only Kiro mapping %q", model)
}
}
}
@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -179,6 +180,9 @@ type AccountWithConcurrency struct {
const accountListGroupUngroupedQueryValue = "ungrouped"
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
if h.accountUsageService != nil {
h.accountUsageService.EnrichAccountWithKiroRuntimeState(ctx, account)
}
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
@@ -351,6 +355,9 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
acc := &accounts[i]
if h.accountUsageService != nil {
h.accountUsageService.EnrichAccountWithKiroRuntimeState(c.Request.Context(), acc)
}
item := AccountWithConcurrency{
Account: dto.AccountFromService(acc),
CurrentConcurrency: concurrencyCounts[acc.ID],
@@ -1953,6 +1960,18 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle Kiro accounts
if account.Platform == service.PlatformKiro {
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, kiropkg.DefaultModels)
return
}
response.Success(c, buildMappedKiroModels(mapping))
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
@@ -1994,6 +2013,28 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models)
}
func buildMappedKiroModels(mapping map[string]string) []kiropkg.Model {
models := make([]kiropkg.Model, 0, len(mapping))
for requestedModel := range mapping {
var found bool
for _, dm := range kiropkg.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, kiropkg.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
})
}
}
return models
}
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
@@ -2206,6 +2247,12 @@ func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
// GetKiroDefaultModelMapping 获取 Kiro 平台的默认模型映射
// GET /api/v1/admin/accounts/kiro/default-model-mapping
func (h *AccountHandler) GetKiroDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultKiroModelMapping)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) {
@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"slices"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 44,
Name: "kiro-oauth",
Platform: service.PlatformKiro,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
}
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 47,
Name: "kiro-oauth-mapped",
Platform: service.PlatformKiro,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4.6",
"custom-model": "custom-upstream-model",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 2)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
require.True(t, slices.Contains(ids, "custom-model"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
}
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 45,
Name: "kiro-apikey",
Platform: service.PlatformKiro,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4.6",
"custom-model": "custom-upstream-model",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 2)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
require.True(t, slices.Contains(ids, "custom-model"))
}
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 46,
Name: "kiro-apikey-defaults",
Platform: service.PlatformKiro,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
}
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -123,7 +123,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -0,0 +1,16 @@
package admin
import (
"testing"
"github.com/gin-gonic/gin/binding"
"github.com/stretchr/testify/require"
)
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
require.NoError(t, binding.Validator.ValidateStruct(createReq))
updateReq := UpdateGroupRequest{Platform: "kiro"}
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
}
@@ -0,0 +1,149 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type KiroOAuthHandler struct {
kiroOAuthService *service.KiroOAuthService
}
func NewKiroOAuthHandler(kiroOAuthService *service.KiroOAuthService) *KiroOAuthHandler {
return &KiroOAuthHandler{kiroOAuthService: kiroOAuthService}
}
type KiroGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
Provider string `json:"provider"`
}
func (h *KiroOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req KiroGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
result, err := h.kiroOAuthService.GenerateAuthURL(c.Request.Context(), &service.KiroGenerateAuthURLInput{
ProxyID: req.ProxyID,
Provider: req.Provider,
})
if err != nil {
response.BadRequest(c, "生成授权链接失败: "+err.Error())
return
}
response.Success(c, result)
}
type KiroGenerateIDCAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
StartURL string `json:"start_url"`
Region string `json:"region"`
}
func (h *KiroOAuthHandler) GenerateIDCAuthURL(c *gin.Context) {
var req KiroGenerateIDCAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
result, err := h.kiroOAuthService.GenerateIDCAuthURL(c.Request.Context(), &service.KiroGenerateIDCAuthURLInput{
ProxyID: req.ProxyID,
StartURL: req.StartURL,
Region: req.Region,
})
if err != nil {
response.BadRequest(c, "生成 IDC 授权链接失败: "+err.Error())
return
}
response.Success(c, result)
}
type KiroExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
CallbackPath string `json:"callback_path"`
LoginOption string `json:"login_option"`
ProxyID *int64 `json:"proxy_id"`
}
func (h *KiroOAuthHandler) ExchangeCode(c *gin.Context) {
var req KiroExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.kiroOAuthService.ExchangeCode(c.Request.Context(), &service.KiroExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
CallbackPath: req.CallbackPath,
LoginOption: req.LoginOption,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Token 交换失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
type KiroRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
AuthMethod string `json:"auth_method"`
Provider string `json:"provider"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
StartURL string `json:"start_url"`
Region string `json:"region"`
ProfileArn string `json:"profile_arn"`
ProxyID *int64 `json:"proxy_id"`
}
func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
var req KiroRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.kiroOAuthService.RefreshToken(c.Request.Context(), &service.KiroRefreshTokenInput{
RefreshToken: req.RefreshToken,
AuthMethod: req.AuthMethod,
Provider: req.Provider,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
StartURL: req.StartURL,
Region: req.Region,
ProfileArn: req.ProfileArn,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "刷新 Kiro Token 失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
type KiroImportTokenRequest struct {
TokenJSON string `json:"token_json" binding:"required"`
DeviceRegistrationJSON string `json:"device_registration_json"`
}
func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
var req KiroImportTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
TokenJSON: req.TokenJSON,
DeviceRegistrationJSON: req.DeviceRegistrationJSON,
})
if err != nil {
response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
+6
View File
@@ -230,6 +230,12 @@ func AccountFromServiceShallow(a *service.Account) *Account {
OverloadUntil: a.OverloadUntil,
TempUnschedulableUntil: a.TempUnschedulableUntil,
TempUnschedulableReason: a.TempUnschedulableReason,
KiroQuotaState: a.KiroQuotaState,
KiroQuotaReason: a.KiroQuotaReason,
KiroQuotaResetAt: a.KiroQuotaResetAt,
KiroRuntimeState: a.KiroRuntimeState,
KiroRuntimeReason: a.KiroRuntimeReason,
KiroRuntimeResetAt: a.KiroRuntimeResetAt,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
+6
View File
@@ -183,6 +183,12 @@ type Account struct {
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
+1
View File
@@ -17,6 +17,7 @@ type AdminHandlers struct {
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
AntigravityOAuth *admin.AntigravityOAuthHandler
KiroOAuth *admin.KiroOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
+3
View File
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
kiroOAuthHandler *admin.KiroOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
@@ -52,6 +53,7 @@ func ProvideAdminHandlers(
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
AntigravityOAuth: antigravityOAuthHandler,
KiroOAuth: kiroOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Promo: promoHandler,
@@ -156,6 +158,7 @@ var ProviderSet = wire.NewSet(
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
admin.NewAntigravityOAuthHandler,
admin.NewKiroOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewPromoHandler,
@@ -36,11 +36,12 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformKiro = "kiro"
)
// AllPlatforms 返回所有支持的平台列表
func AllPlatforms() []string {
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
}
// Validate 验证规则配置的有效性
+258
View File
@@ -0,0 +1,258 @@
package kiro
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"runtime"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type RuntimeFingerprint struct {
OIDCSDKVersion string
RuntimeSDKVersion string
StreamingSDKVersion string
OSType string
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string
}
type runtimeFingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*RuntimeFingerprint
}
var (
globalRuntimeFingerprintManager *runtimeFingerprintManager
globalRuntimeFingerprintManagerOnce sync.Once
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
runtimeSDKVersions = []string{"1.0.0"}
streamingSDKVersions = []string{"1.0.34"}
osTypes = []string{"darwin", "win32"}
osVersions = map[string][]string{
"darwin": {"24.6.0"},
"win32": {"10.0.22631"},
}
nodeVersions = []string{"22.22.0"}
kiroVersions = []string{
"0.11.132", "0.11.131", "0.11.130",
}
)
func globalRuntimeFingerprints() *runtimeFingerprintManager {
globalRuntimeFingerprintManagerOnce.Do(func() {
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
fingerprints: make(map[string]*RuntimeFingerprint),
}
})
return globalRuntimeFingerprintManager
}
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
lookupKey := fingerprintLookupKey(accountKey, "runtime")
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
m.mu.RLock()
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
m.mu.RUnlock()
return fp
}
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
return fp
}
fp := generateRuntimeFingerprint(lookupKey, machineID)
m.fingerprints[lookupKey] = fp
return fp
}
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
hash := sha256.Sum256([]byte(accountKey))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
osType := goOSToNodePlatform(runtime.GOOS)
if !containsString(osTypes, osType) {
osType = osTypes[rng.Intn(len(osTypes))]
}
osVersionPool := osVersions[osType]
if len(osVersionPool) == 0 {
osVersionPool = osVersions["darwin"]
}
return &RuntimeFingerprint{
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
OSType: osType,
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
KiroHash: machineID,
}
}
func goOSToNodePlatform(goos string) string {
switch strings.TrimSpace(goos) {
case "windows":
return "win32"
default:
return strings.TrimSpace(goos)
}
}
func containsString(items []string, target string) bool {
for _, item := range items {
if item == target {
return true
}
}
return false
}
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
switch {
case strings.TrimSpace(clientIDHash) != "":
return clientIDHash
case strings.TrimSpace(clientID) != "":
return shortSHA(clientID)
case strings.TrimSpace(refreshToken) != "":
return shortSHA(refreshToken)
case strings.TrimSpace(profileArn) != "":
return shortSHA(profileArn)
case accountID > 0:
return shortSHA(fmt.Sprintf("account:%d", accountID))
default:
return shortSHA(uuid.NewString())
}
}
func NormalizeMachineID(machineID string) (string, bool) {
trimmed := strings.TrimSpace(machineID)
if len(trimmed) == 64 && isHexString(trimmed) {
return strings.ToLower(trimmed), true
}
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
normalized := strings.ToLower(withoutDashes)
return normalized + normalized, true
}
return "", false
}
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
return sha256Hex("KotlinNativeAPI/" + refreshToken)
}
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
return sha256Hex("KiroAPIKey/" + apiKey)
}
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
return sha256Hex("KiroFallback/" + fallbackKey)
}
return sha256Hex("KiroFallback/default")
}
func shortSHA(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
func sha256Hex(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:])
}
func isHexString(value string) bool {
for _, c := range value {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
if normalized, ok := NormalizeMachineID(machineID); ok {
return normalized
}
return BuildMachineID("", "", fallbackKey)
}
func fingerprintLookupKey(accountKey, fallback string) string {
key := strings.TrimSpace(accountKey)
if key != "" {
return key
}
return fallback
}
func BuildRuntimeUserAgent(accountKey, machineID string) string {
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
return map[string]string{
"Content-Type": "application/json",
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
"amz-sdk-invocation-id": uuid.NewString(),
"amz-sdk-request": "attempt=1; max=4",
}
}
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
return map[string]string{
"Content-Type": "application/json",
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
"Accept": "application/json, text/plain, */*",
}
}
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := baseDelay << attempt
if delay > maxDelay {
delay = maxDelay
}
const jitterFactor = 0.3
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
return time.Duration(float64(delay) * jitter)
}
@@ -0,0 +1,91 @@
package kiro
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildLoginHeadersStable(t *testing.T) {
headers1 := BuildLoginHeaders("", "")
headers2 := BuildLoginHeaders("", "")
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
require.Equal(t, "application/json", headers1["Content-Type"])
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
}
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
machineIDA := BuildMachineID("refresh-a", "", "")
machineIDB := BuildMachineID("refresh-b", "", "")
headers1 := BuildLoginHeaders("account-a", machineIDA)
headers2 := BuildLoginHeaders("account-a", machineIDA)
headers3 := BuildLoginHeaders("account-a", machineIDB)
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
require.Contains(t, headers1["User-Agent"], machineIDA)
}
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
machineID := BuildMachineID("", "", "oidc-machine")
headers1 := BuildOIDCHeaders("account-a", machineID)
headers2 := BuildOIDCHeaders("account-a", machineID)
headers3 := BuildOIDCHeaders("account-b", machineID)
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
}
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
key1 := BuildAccountKey("", "", "", "", 42)
key2 := BuildAccountKey("", "", "", "", 42)
key3 := BuildAccountKey("", "", "", "", 43)
require.Equal(t, key1, key2)
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
require.NotEqual(t, key1, key3)
}
func TestBuildMachineID(t *testing.T) {
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
fallback1 := BuildMachineID("", "", "account:1")
fallback2 := BuildMachineID("", "", "account:1")
fallback3 := BuildMachineID("", "", "account:2")
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
require.Equal(t, fallback1, fallback2)
require.NotEqual(t, fallback1, fallback3)
require.Len(t, fallback1, 64)
}
func TestNormalizeMachineID(t *testing.T) {
hex64 := strings.Repeat("A", 64)
normalized, ok := NormalizeMachineID(hex64)
require.True(t, ok)
require.Equal(t, strings.ToLower(hex64), normalized)
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
require.True(t, ok)
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
_, ok = NormalizeMachineID("not-a-machine-id")
require.False(t, ok)
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
require.False(t, ok)
}
func expectedKiroMachineID(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:])
}
+21
View File
@@ -0,0 +1,21 @@
package kiro
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
var DefaultModels = []Model{
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
}
+43
View File
@@ -0,0 +1,43 @@
package kiro
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
ids := make([]string, 0, len(DefaultModels))
for _, model := range DefaultModels {
ids = append(ids, model.ID)
}
require.Equal(t, []string{
"claude-opus-4-6",
"claude-opus-4-6-thinking",
"claude-sonnet-4-6",
"claude-sonnet-4-6-thinking",
"claude-opus-4-5-20251101",
"claude-opus-4-5-20251101-thinking",
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-5-20250929-thinking",
"claude-haiku-4-5-20251001",
"claude-haiku-4-5-20251001-thinking",
}, ids)
require.Contains(t, ids, "claude-sonnet-4-6")
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
require.NotContains(t, ids, "auto")
require.NotContains(t, ids, "claude-sonnet-4")
require.NotContains(t, ids, "gpt-4o")
require.NotContains(t, ids, "deepseek-3-2")
require.NotContains(t, ids, "minimax-m2-1")
require.NotContains(t, ids, "qwen3-coder-next")
require.NotContains(t, ids, "claude-opus-4-7")
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
for _, id := range ids {
require.NotContains(t, id, "kiro-")
require.NotContains(t, id, "-agentic")
require.NotContains(t, id, "-chat")
}
}
+511
View File
@@ -0,0 +1,511 @@
package kiro
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/google/uuid"
)
const (
socialAuthPortalURL = "https://app.kiro.dev"
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
defaultIDCRegion = "us-east-1"
BuilderIDStartURL = "https://view.awsapps.com/start"
sessionTTL = 10 * time.Minute
sessionCleanupEvery = 32
sessionCleanupMin = 32
)
var (
socialAuthEndpointURL = socialAuthEndpoint
oidcEndpointOverride = ""
)
type SocialProvider string
const (
SocialProviderGoogle SocialProvider = "Google"
SocialProviderGitHub SocialProvider = "Github"
)
type AuthSession struct {
State string
CodeVerifier string
ProxyURL string
CreatedAt time.Time
AuthType string
Provider string
RedirectURI string
ClientID string
ClientSecret string
Region string
StartURL string
}
type SessionStore struct {
mu sync.RWMutex
data map[string]*AuthSession
setCount uint64
}
func NewSessionStore() *SessionStore {
return &SessionStore{data: make(map[string]*AuthSession)}
}
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.data[id]
if ok && sessionExpired(session, now) {
delete(s.data, id)
return nil, false
}
return session, ok
}
func (s *SessionStore) Set(id string, session *AuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.setCount++
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
s.pruneExpiredLocked(time.Now())
}
s.data[id] = session
}
func (s *SessionStore) Delete(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, id)
}
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
for id, session := range s.data {
if sessionExpired(session, now) {
delete(s.data, id)
}
}
}
func sessionExpired(session *AuthSession, now time.Time) bool {
if session == nil {
return true
}
if session.CreatedAt.IsZero() {
return true
}
return now.After(session.CreatedAt.Add(sessionTTL))
}
type TokenData struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn,omitempty"`
ExpiresAt string `json:"expiresAt,omitempty"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
ClientIDHash string `json:"clientIdHash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"startUrl,omitempty"`
Region string `json:"region,omitempty"`
}
type socialTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type registerClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type createTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type userInfoResponse struct {
Email string `json:"email"`
}
type deviceRegistration struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type RefreshTokenInvalidError struct {
StatusCode int
Body string
}
func (e *RefreshTokenInvalidError) Error() string {
if e == nil {
return ""
}
body := strings.TrimSpace(e.Body)
if body == "" {
return "kiro refresh token invalid (invalid_grant)"
}
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
}
func GenerateSessionID() string {
return uuid.NewString()
}
func GenerateState() (string, error) {
return randomURLSafe(16)
}
func GenerateCodeVerifier() (string, error) {
return randomURLSafe(32)
}
func randomURLSafe(n int) (string, error) {
buf := make([]byte, n)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func GenerateCodeChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
params := url.Values{}
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("redirect_uri", redirectURI)
params.Set("redirect_from", "KiroIDE")
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
}
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
if redirectURI == "" {
return ""
}
path := strings.TrimSpace(callbackPath)
if path == "" {
path = "/oauth/callback"
} else if !strings.HasPrefix(path, "/") {
path = "/" + path
}
fullRedirectURI := redirectURI + path
if option := strings.TrimSpace(loginOption); option != "" {
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
}
return fullRedirectURI
}
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
payload := map[string]string{
"code": code,
"code_verifier": codeVerifier,
"redirect_uri": redirectURI,
}
var resp socialTokenResponse
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Region: defaultIDCRegion,
}, nil
}
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
var resp socialTokenResponse
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Provider: provider,
Region: defaultIDCRegion,
}, nil
}
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]any{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"authorization_code", "refresh_token"},
"redirectUris": []string{redirectURI},
"issuerUrl": issuerURL,
}
var resp registerClientResponse
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
return nil, err
}
return &resp, nil
}
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
if region == "" {
region = defaultIDCRegion
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scopes", strings.Join([]string{
"codewhisperer:completions",
"codewhisperer:analysis",
"codewhisperer:conversations",
"codewhisperer:transformations",
"codewhisperer:taskassist",
}, " "))
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
}
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"code": code,
"codeVerifier": codeVerifier,
"redirectUri": redirectURI,
"grantType": "authorization_code",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", "", "", 0)
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
if strings.TrimSpace(accessToken) == "" {
return ""
}
var resp userInfoResponse
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
}
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
return ""
}
return strings.TrimSpace(resp.Email)
}
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
var token TokenData
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
}
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
if strings.TrimSpace(token.AccessToken) == "" {
return nil, fmt.Errorf("access token is empty")
}
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
var reg deviceRegistration
if err := json.Unmarshal([]byte(deviceRegistrationJSON), &reg); err != nil {
return nil, fmt.Errorf("failed to parse device registration: %w", err)
}
if reg.ClientID != "" {
token.ClientID = reg.ClientID
}
if reg.ClientSecret != "" {
token.ClientSecret = reg.ClientSecret
}
}
return &token, nil
}
func getOIDCEndpoint(region string) string {
if strings.TrimSpace(oidcEndpointOverride) != "" {
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
}
if region == "" {
region = defaultIDCRegion
}
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
}
func oidcHeaders(accountKey, machineID string) map[string]string {
headers := BuildOIDCHeaders(accountKey, machineID)
if headers["amz-sdk-invocation-id"] == "" {
headers["amz-sdk-invocation-id"] = uuid.NewString()
}
if headers["amz-sdk-request"] == "" {
headers["amz-sdk-request"] = "attempt=1; max=4"
}
return headers
}
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
client, err := newHTTPClient(proxyURL)
if err != nil {
return err
}
var body io.Reader
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(encoded)
}
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
if err != nil {
return err
}
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
for key, value := range extraHeaders {
req.Header.Set(key, value)
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyText := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
}
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
}
if out == nil || len(respBody) == 0 {
return nil
}
return json.Unmarshal(respBody, out)
}
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
_, parsed, err := proxyurl.Parse(rawProxyURL)
if err != nil {
return nil, err
}
transport := &http.Transport{}
if parsed != nil {
transport.Proxy = http.ProxyURL(parsed)
}
return &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}, nil
}
@@ -0,0 +1,105 @@
package kiro
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/refreshToken", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
}))
defer server.Close()
previous := socialAuthEndpointURL
socialAuthEndpointURL = server.URL
t.Cleanup(func() { socialAuthEndpointURL = previous })
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
require.Error(t, err)
var invalid *RefreshTokenInvalidError
require.True(t, errors.As(err, &invalid))
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
require.Contains(t, invalid.Body, "invalid_grant")
}
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/token", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
require.Error(t, err)
var invalid *RefreshTokenInvalidError
require.True(t, errors.As(err, &invalid))
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
require.Contains(t, invalid.Body, "invalid_grant")
}
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
require.NoError(t, err)
require.Equal(t, profileArn, token.ProfileArn)
require.Equal(t, "kiro@example.com", token.Email)
}
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
require.NoError(t, err)
require.Equal(t, profileArn, token.ProfileArn)
require.Equal(t, "kiro@example.com", token.Email)
}
+56
View File
@@ -0,0 +1,56 @@
//go:build unit
package kiro
import (
"fmt"
"testing"
"time"
)
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
if got != want {
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
}
}
func TestBuildSocialTokenRedirectURI(t *testing.T) {
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
want := "http://localhost:49153/oauth/callback?login_option=github"
if got != want {
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
}
}
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
store := NewSessionStore()
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
session, ok := store.Get("expired")
if ok || session != nil {
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
}
if _, exists := store.data["expired"]; exists {
t.Fatalf("expired session should be deleted from the store")
}
}
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
store := NewSessionStore()
now := time.Now()
for i := 0; i < sessionCleanupMin; i++ {
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
}
store.setCount = sessionCleanupEvery - 1
store.Set("fresh", &AuthSession{CreatedAt: now})
if len(store.data) != 1 {
t.Fatalf("store size = %d, want 1", len(store.data))
}
if _, ok := store.data["fresh"]; !ok {
t.Fatalf("fresh session should remain after pruning")
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+368
View File
@@ -0,0 +1,368 @@
package kiro
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
var cachedWebSearchDescription atomic.Value // stores string
type MCPRequest struct {
ID string `json:"id"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}
type MCPResponse struct {
Result *struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
} `json:"result,omitempty"`
Error *struct {
Code *int `json:"code,omitempty"`
Message *string `json:"message,omitempty"`
} `json:"error,omitempty"`
}
type WebSearchResults struct {
Results []WebSearchResult `json:"results"`
}
type WebSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Snippet *string `json:"snippet,omitempty"`
PublishedDate *int64 `json:"publishedDate,omitempty"`
ID *string `json:"id,omitempty"`
Domain *string `json:"domain,omitempty"`
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
PublicDomain *bool `json:"publicDomain,omitempty"`
}
type SearchIndicator struct {
ToolUseID string
Query string
Results *WebSearchResults
}
func GetCachedWebSearchDescription() string {
if v := cachedWebSearchDescription.Load(); v != nil {
return strings.TrimSpace(v.(string))
}
return ""
}
func SetCachedWebSearchDescription(desc string) {
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
}
func BuildMcpEndpoint(region string) string {
if strings.TrimSpace(region) == "" {
region = "us-east-1"
}
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
}
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
return nil
}
for _, item := range resp.Result.Content {
if item.Type != "" && item.Type != "text" {
continue
}
var results WebSearchResults
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
return &results
}
}
return nil
}
func ExtractSearchQuery(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
for i := len(arr) - 1; i >= 0; i-- {
msg := arr[i]
if msg.Get("role").String() != "user" {
continue
}
text := extractSearchText(msg.Get("content"))
const prefix = "Perform a web search for the query: "
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
if text != "" {
return text
}
}
return ""
}
func extractSearchText(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if !content.IsArray() {
return ""
}
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
return text
}
}
}
return ""
}
func GenerateToolUseID() string {
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
}
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err != nil {
return body, err
}
rawTools, ok := payload["tools"].([]interface{})
if !ok {
return body, nil
}
replaced := make([]interface{}, 0, len(rawTools))
for _, rawTool := range rawTools {
tool, ok := rawTool.(map[string]interface{})
if !ok {
replaced = append(replaced, rawTool)
continue
}
name := getInterfaceString(tool["name"])
toolType := getInterfaceString(tool["type"])
if !isWebSearchToolName(name, toolType) {
replaced = append(replaced, rawTool)
continue
}
replaced = append(replaced, map[string]interface{}{
"name": "web_search",
"description": minimalWebSearchDescription,
"input_schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "The search query to execute",
},
},
"required": []string{"query"},
"additionalProperties": false,
},
})
}
payload["tools"] = replaced
updated, err := json.Marshal(payload)
if err != nil {
return body, err
}
return updated, nil
}
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(claudePayload, &payload); err != nil {
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
}
rawMessages, ok := payload["messages"].([]interface{})
if !ok {
return claudePayload, fmt.Errorf("claude payload missing messages array")
}
assistantMsg := map[string]interface{}{
"role": "assistant",
"content": []interface{}{
map[string]interface{}{
"type": "tool_use",
"id": toolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": query},
},
},
}
userContent := []interface{}{
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolUseID,
"content": formatToolResultText(results),
},
}
if guidance := searchGuidanceText(); guidance != "" {
userContent = append(userContent, map[string]interface{}{
"type": "text",
"text": guidance,
})
}
userMsg := map[string]interface{}{
"role": "user",
"content": userContent,
}
rawMessages = append(rawMessages, assistantMsg, userMsg)
payload["messages"] = rawMessages
updated, err := json.Marshal(payload)
if err != nil {
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
}
return updated, nil
}
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
if len(searches) == 0 {
return responsePayload, nil
}
var response map[string]interface{}
if err := json.Unmarshal(responsePayload, &response); err != nil {
return responsePayload, err
}
content, _ := response["content"].([]interface{})
updated := make([]interface{}, 0, len(searches)*2+len(content))
for _, search := range searches {
updated = append(updated, map[string]interface{}{
"type": "server_tool_use",
"id": search.ToolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": search.Query},
})
updated = append(updated, map[string]interface{}{
"type": "web_search_tool_result",
"content": buildSearchResultContent(search.Results),
})
}
updated = append(updated, content...)
response["content"] = updated
encoded, err := json.Marshal(response)
if err != nil {
return responsePayload, err
}
return encoded, nil
}
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
content := make([]map[string]interface{}, 0)
if results == nil {
return content
}
for _, result := range results.Results {
snippet := ""
if result.Snippet != nil {
snippet = strings.TrimSpace(*result.Snippet)
}
content = append(content, map[string]interface{}{
"type": "web_search_result",
"title": result.Title,
"url": result.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
return content
}
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
content := gjson.GetBytes(responsePayload, "content")
if !content.IsArray() {
return "", "", false
}
for _, block := range content.Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if !isWebSearchToolName(name, "") {
continue
}
query = strings.TrimSpace(block.Get("input.query").String())
if query == "" {
continue
}
return block.Get("id").String(), query, true
}
return "", "", false
}
func isWebSearchToolName(name, toolType string) bool {
name = strings.ToLower(strings.TrimSpace(name))
toolType = strings.ToLower(strings.TrimSpace(toolType))
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
return true
}
switch name {
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
return true
default:
return false
}
}
func getInterfaceString(v interface{}) string {
if v == nil {
return ""
}
switch val := v.(type) {
case string:
return strings.TrimSpace(val)
default:
return strings.TrimSpace(fmt.Sprint(val))
}
}
func formatToolResultText(results *WebSearchResults) string {
if results == nil || len(results.Results) == 0 {
return "No search results found."
}
payload, err := json.MarshalIndent(results.Results, "", " ")
if err != nil {
return "Found search results, but failed to format them."
}
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
}
func searchGuidanceText() string {
now := time.Now()
return fmt.Sprintf(`<search_guidance>
Current date: %s (%s)
IMPORTANT: Evaluate the search results above carefully. If the results are:
- Mostly spam, SEO junk, or unrelated websites
- Missing actual information about the query topic
- Outdated or not matching the requested time frame
Then you MUST use the web_search tool again with a refined query. Try:
- Rephrasing in English for better coverage
- Using more specific keywords
- Adding date context
Do NOT apologize for bad results without first attempting a re-search.
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
}
@@ -0,0 +1,297 @@
package kiro
import (
"encoding/json"
"strings"
)
type BufferedStreamResult struct {
StopReason string
WebSearchQuery string
WebSearchToolUseID string
HasWebSearchToolUse bool
WebSearchToolUseIndex int
}
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
searchContent := make([]map[string]interface{}, 0)
if results != nil {
for _, result := range results.Results {
snippet := ""
if result.Snippet != nil {
snippet = strings.TrimSpace(*result.Snippet)
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": result.Title,
"url": result.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
inputJSON, _ := json.Marshal(map[string]string{"query": query})
events := []map[string]interface{}{
{
"type": "content_block_start",
"index": startIndex,
"content_block": map[string]interface{}{
"type": "server_tool_use",
"id": toolUseID,
"name": "web_search",
"input": map[string]interface{}{},
},
},
{
"type": "content_block_delta",
"index": startIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
},
{
"type": "content_block_stop",
"index": startIndex,
},
{
"type": "content_block_start",
"index": startIndex + 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"content": searchContent,
},
},
{
"type": "content_block_stop",
"index": startIndex + 1,
},
}
result := make([][]byte, 0, len(events))
for _, event := range events {
eventType, _ := event["type"].(string)
payload, _ := json.Marshal(event)
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
}
return result
}
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
var currentToolName string
currentToolIndex := -1
var toolInputBuilder strings.Builder
for _, chunk := range chunks {
lines := strings.Split(string(chunk), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
switch eventType, _ := event["type"].(string); eventType {
case "message_delta":
if delta, ok := event["delta"].(map[string]interface{}); ok {
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
result.StopReason = stopReason
}
}
case "content_block_start":
contentBlock, ok := event["content_block"].(map[string]interface{})
if !ok {
continue
}
blockType, _ := contentBlock["type"].(string)
if blockType != "tool_use" {
continue
}
currentToolName, _ = contentBlock["name"].(string)
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
if idx, ok := event["index"].(float64); ok {
currentToolIndex = int(idx)
}
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
}
toolInputBuilder.Reset()
case "content_block_delta":
if currentToolName == "" {
continue
}
delta, ok := event["delta"].(map[string]interface{})
if !ok {
continue
}
deltaType, _ := delta["type"].(string)
if deltaType != "input_json_delta" {
continue
}
if partialJSON, ok := delta["partial_json"].(string); ok {
toolInputBuilder.WriteString(partialJSON)
}
case "content_block_stop":
if !isWebSearchToolName(currentToolName, "") {
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
continue
}
result.HasWebSearchToolUse = true
result.WebSearchToolUseIndex = currentToolIndex
var input map[string]string
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
result.WebSearchQuery = strings.TrimSpace(input["query"])
}
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
}
}
}
return result
}
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
filtered := make([][]byte, 0, len(chunks))
for _, chunk := range chunks {
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
if shouldForward {
filtered = append(filtered, adjusted)
}
}
return filtered
}
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
return filterSSEChunk(chunk, -1, offset)
}
func MaxContentBlockIndex(chunks [][]byte) int {
maxIndex := -1
for _, chunk := range chunks {
lines := strings.Split(string(chunk), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
switch eventType, _ := event["type"].(string); eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
maxIndex = int(idx)
}
}
}
}
return maxIndex
}
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
lines := strings.Split(string(chunk), "\n")
var builder strings.Builder
hasContent := false
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "event: ") {
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
i++
continue
}
}
builder.WriteString(line + "\n")
hasContent = true
continue
}
if strings.HasPrefix(line, "data: ") {
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "[DONE]" {
continue
}
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
continue
}
adjusted := adjustEventPayload(payload, indexOffset)
if adjusted == "" {
continue
}
builder.WriteString("data: " + adjusted + "\n")
hasContent = true
continue
}
builder.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
if !hasContent {
return nil, false
}
return []byte(builder.String()), true
}
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
if payload == "" {
return false
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
return false
}
eventType, _ := event["type"].(string)
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
return true
}
if webSearchToolUseIndex < 0 {
return false
}
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
return true
}
return false
}
func adjustEventPayload(payload string, indexOffset int) string {
if payload == "" || indexOffset == 0 {
return payload
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
return payload
}
switch eventType, _ := event["type"].(string); eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + indexOffset
if adjusted, err := json.Marshal(event); err == nil {
return string(adjusted)
}
}
}
return payload
}
@@ -0,0 +1,73 @@
package kiro
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
snippet := "result snippet"
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
},
}, 0)
require.Len(t, events, 5)
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
require.Contains(t, string(events[0]), `"input":{}`)
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
require.NotContains(t, string(events[3]), `"tool_use_id"`)
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
}
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
chunks := [][]byte{
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
}
result := AnalyzeBufferedStream(chunks)
require.True(t, result.HasWebSearchToolUse)
require.Equal(t, "golang concurrency", result.WebSearchQuery)
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
require.Equal(t, 1, result.WebSearchToolUseIndex)
require.Equal(t, "tool_use", result.StopReason)
}
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
chunks := [][]byte{
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
}
filtered := FilterChunksForClient(chunks, 1, 2)
require.NotEmpty(t, filtered)
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
require.NotContains(t, joined, `"type":"message_start"`)
require.NotContains(t, joined, `"type":"message_delta"`)
require.NotContains(t, joined, `"name":"web_search"`)
require.Contains(t, joined, `"index":2`)
require.Equal(t, 2, MaxContentBlockIndex(filtered))
}
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
require.False(t, shouldForward)
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
require.True(t, shouldForward)
require.Contains(t, string(adjusted), `"index":3`)
}
+138
View File
@@ -0,0 +1,138 @@
package kiro
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
body := []byte(`{
"tools":[{"type":"web_search_20250305","description":"old"}],
"messages":[{"role":"user","content":"golang"}]
}`)
updated, err := ReplaceWebSearchToolDescription(body)
require.NoError(t, err)
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
}
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
body := []byte(`{
"messages":[{"role":"user","content":"what is golang"}]
}`)
results := &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev"},
},
}
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
require.NoError(t, err)
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
}
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
response := []byte(`{
"content":[
{"type":"text","text":"let me search"},
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
]
}`)
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
require.True(t, ok)
require.Equal(t, "srvtoolu_next", toolUseID)
require.Equal(t, "golang concurrency", query)
}
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
response := []byte(`{
"id":"msg_1",
"type":"message",
"role":"assistant",
"model":"kiro",
"content":[{"type":"text","text":"final"}],
"stop_reason":"end_turn",
"usage":{"input_tokens":1,"output_tokens":1}
}`)
snippet := "result snippet"
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
{
ToolUseID: "srvtoolu_test",
Query: "golang",
Results: &WebSearchResults{
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
},
},
})
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(updated, &decoded))
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
}
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
resp := &MCPResponse{
Result: &struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
}{
Content: []struct {
Type string `json:"type"`
Text string `json:"text"`
}{
{
Type: "text",
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
},
},
},
}
results := ParseSearchResults(resp)
require.NotNil(t, results)
require.Len(t, results.Results, 1)
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
require.Equal(t, "doc-1", *results.Results[0].ID)
require.Equal(t, "go.dev", *results.Results[0].Domain)
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
require.True(t, *results.Results[0].PublicDomain)
}
func TestSearchGuidanceText_IsStructured(t *testing.T) {
guidance := searchGuidanceText()
require.Contains(t, guidance, "<search_guidance>")
require.Contains(t, guidance, "Current date:")
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
require.Contains(t, guidance, "Rephrasing in English for better coverage")
}
+479
View File
@@ -0,0 +1,479 @@
package kirocooldown
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
const (
MinRequestInterval = time.Second
MaxRequestInterval = 2 * time.Second
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
ShortCooldown = time.Minute
MaxCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
redisTimeout = 3 * time.Second
activeTTL = 10 * time.Second
stateTTL = 25 * time.Hour
keyPrefix = "kiro:cooldown:"
)
var (
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
reserveRequestScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
local interval_ms = tonumber(ARGV[1])
local active_ttl_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
if cooldown_until_ms > now_ms then
return {1, cooldown_until_ms - now_ms, cooldown_reason}
end
if cooldown_until_ms > 0 then
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
end
local next_slot_ms = now_ms
if last_request_ms > 0 then
local candidate_ms = last_request_ms + interval_ms
if candidate_ms > now_ms then
next_slot_ms = candidate_ms
end
end
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
if fail_count > 0 or cooldown_until_ms > now_ms then
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
else
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
end
return {0, next_slot_ms - now_ms, ''}
`)
mark429Script = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
local short_cooldown_ms = tonumber(ARGV[1])
local max_cooldown_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
if cooldown_ms > max_cooldown_ms then
cooldown_ms = max_cooldown_ms
end
redis.call('HSET', KEYS[1],
'fail_count', fail_count,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[4]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
markSuccessScript = redis.NewScript(`
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', 0,
'cooldown_reason', ''
)
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
return 1
`)
markSuspendedScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local cooldown_ms = tonumber(ARGV[1])
local state_ttl_ms = tonumber(ARGV[2])
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[3]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
)
type Error struct {
remaining time.Duration
reason string
}
type State struct {
Active bool
Reason string
CooldownUntil time.Time
Remaining time.Duration
FailCount int
}
func NewError(remaining time.Duration, reason string) error {
return &Error{remaining: remaining, reason: reason}
}
func (e *Error) Error() string {
if e == nil {
return ""
}
if e.reason == "" {
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
}
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
}
func Calculate429Cooldown(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
cooldown := ShortCooldown * time.Duration(1<<retryCount)
if cooldown > MaxCooldown {
return MaxCooldown
}
return cooldown
}
type Store struct {
client *redis.Client
rngMu sync.Mutex
rng *rand.Rand
}
func NewStore(client *redis.Client) *Store {
return &Store{
client: client,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := reserveRequestScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
s.nextInterval().Milliseconds(),
activeTTL.Milliseconds(),
stateTTL.Milliseconds(),
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
}
parts, ok := values.([]interface{})
if !ok || len(parts) != 3 {
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
}
state, err := luaInt64(parts[0])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
}
waitMS, err := luaInt64(parts[1])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
}
reason, err := luaString(parts[2])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
}
if state == 1 {
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
}
if waitMS <= 0 {
return 0, nil
}
return time.Duration(waitMS) * time.Millisecond, nil
}
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
if err := s.validate(); err != nil {
return err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
if err := markSuccessScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
activeTTL.Milliseconds(),
).Err(); err != nil {
return fmt.Errorf("kiro cooldown mark success: %w", err)
}
return nil
}
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := mark429Script.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
ShortCooldown.Milliseconds(),
MaxCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReason429,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := markSuspendedScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
LongCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReasonSuspended,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
if err := s.validate(); err != nil {
return nil, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := s.client.HMGet(
cacheCtx,
RedisKey(tokenKey),
"cooldown_until_ms",
"cooldown_reason",
"fail_count",
).Result()
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
}
if len(values) != 3 {
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
}
if cooldownUntilMS <= 0 {
return nil, nil
}
cooldownUntil := time.UnixMilli(cooldownUntilMS)
remaining := time.Until(cooldownUntil)
if remaining <= 0 {
return nil, nil
}
return &State{
Active: true,
Reason: reason,
CooldownUntil: cooldownUntil,
Remaining: remaining,
FailCount: int(failCount),
}, nil
}
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
if err := s.validate(); err != nil {
return false, err
}
uniqueKeys := make([]string, 0, len(tokenKeys))
seen := make(map[string]struct{}, len(tokenKeys))
for _, tokenKey := range tokenKeys {
tokenKey = strings.TrimSpace(tokenKey)
if tokenKey == "" {
continue
}
redisKey := RedisKey(tokenKey)
if _, ok := seen[redisKey]; ok {
continue
}
seen[redisKey] = struct{}{}
uniqueKeys = append(uniqueKeys, redisKey)
}
if len(uniqueKeys) == 0 {
return false, nil
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
type candidate struct {
redisKey string
cooldownUntilMS int64
failCount int64
}
now := time.Now().UnixMilli()
var best *candidate
pipe := s.client.Pipeline()
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
for _, redisKey := range uniqueKeys {
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
}
if _, err := pipe.Exec(cacheCtx); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
}
for i, cmd := range cmds {
values, err := cmd.Result()
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
}
if len(values) != 3 {
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
}
if cooldownUntilMS <= now || reason != CooldownReason429 {
continue
}
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
if best == nil ||
current.cooldownUntilMS < best.cooldownUntilMS ||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
best = current
}
}
if best == nil {
return false, nil
}
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
}
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
}
return true, nil
}
func RedisKey(tokenKey string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
digest := hex.EncodeToString(sum[:])
return keyPrefix + "{" + digest + "}"
}
func ActiveTTL() time.Duration {
return activeTTL
}
func StateTTL() time.Duration {
return stateTTL
}
func (s *Store) validate() error {
if s == nil || s.client == nil {
return ErrStoreUnavailable
}
return nil
}
func (s *Store) nextInterval() time.Duration {
s.rngMu.Lock()
defer s.rngMu.Unlock()
if MaxRequestInterval <= MinRequestInterval {
return MinRequestInterval
}
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
}
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, redisTimeout)
}
func luaInt64(v any) (int64, error) {
switch n := v.(type) {
case int64:
return n, nil
case int:
return int64(n), nil
case string:
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
case []byte:
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
default:
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
}
}
func luaString(v any) (string, error) {
switch s := v.(type) {
case string:
return s, nil
case []byte:
return string(s), nil
case nil:
return "", nil
default:
return "", fmt.Errorf("unsupported lua string type %T", v)
}
}
@@ -0,0 +1,32 @@
package kirocooldown
import (
"context"
"testing"
"github.com/redis/go-redis/v9"
)
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
if err != nil {
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
}
if cleared {
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
}
}
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
store := NewStore(nil)
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
if err == nil {
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
}
if cleared {
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
}
}
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
service.PlatformOpenAI: 1,
service.PlatformGemini: 1,
service.PlatformAntigravity: 2,
service.PlatformKiro: 1,
}
for platform, minCount := range requiredByPlatform {
+15
View File
@@ -41,6 +41,9 @@ func RegisterAdminRoutes(
// Antigravity OAuth
registerAntigravityOAuthRoutes(admin, h)
// Kiro OAuth / IDC
registerKiroOAuthRoutes(admin, h)
// 代理管理
registerProxyRoutes(admin, h)
@@ -315,6 +318,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Antigravity 默认模型映射
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
accounts.GET("/kiro/default-model-mapping", h.Admin.Account.GetKiroDefaultModelMapping)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
@@ -367,6 +371,17 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
}
}
func registerKiroOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
kiro := admin.Group("/kiro")
{
kiro.POST("/oauth/auth-url", h.Admin.KiroOAuth.GenerateAuthURL)
kiro.POST("/oauth/idc-auth-url", h.Admin.KiroOAuth.GenerateIDCAuthURL)
kiro.POST("/oauth/exchange-code", h.Admin.KiroOAuth.ExchangeCode)
kiro.POST("/oauth/refresh-token", h.Admin.KiroOAuth.RefreshToken)
kiro.POST("/oauth/import-token", h.Admin.KiroOAuth.ImportToken)
}
}
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies := admin.Group("/proxies")
{
+37 -13
View File
@@ -48,6 +48,13 @@ type Account struct {
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
KiroQuotaState string
KiroQuotaReason string
KiroQuotaResetAt *time.Time
KiroRuntimeState string
KiroRuntimeReason string
KiroRuntimeResetAt *time.Time
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
@@ -164,6 +171,10 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) IsKiro() bool {
return a.Platform == PlatformKiro
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
@@ -478,17 +489,17 @@ func (a *Account) GetModelMapping() map[string]string {
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
// 部分平台在未显式配置 model_mapping 时仍应使用默认映射
// 以限制可调度/可转发的模型集合。
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
return defaults
}
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil
}
if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
return defaults
}
return nil
}
@@ -510,13 +521,23 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
return result
}
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
return defaults
}
return nil
}
func defaultModelMappingForPlatform(platform string) map[string]string {
switch platform {
case domain.PlatformAntigravity:
return domain.DefaultAntigravityModelMapping
case domain.PlatformKiro:
return domain.DefaultKiroModelMapping
default:
return nil
}
}
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
@@ -608,8 +629,8 @@ func resolveRequestedModelInMapping(mapping map[string]string, requestedModel st
return matchWildcardMappingResult(mapping, requestedModel)
}
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping,返回 true(允许所有模型)
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时也会先回退到默认映射。
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
@@ -622,8 +643,8 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
}
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时返回默认映射结果。
func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel
@@ -725,6 +746,9 @@ func (a *Account) GetBaseURL() string {
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
if a.Platform == PlatformKiro {
return ""
}
return "https://api.anthropic.com"
}
if a.Platform == PlatformAntigravity {
+2 -2
View File
@@ -180,7 +180,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
if err != nil {
return nil, err
}
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
}
}
@@ -296,7 +296,7 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if err != nil {
return nil, err
}
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
}
}
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
@@ -66,6 +67,7 @@ type AccountTestService struct {
accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider
claudeTokenProvider *ClaudeTokenProvider
kiroTokenProvider *KiroTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
@@ -77,6 +79,7 @@ func NewAccountTestService(
accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider,
claudeTokenProvider *ClaudeTokenProvider,
kiroTokenProvider *KiroTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
@@ -86,6 +89,7 @@ func NewAccountTestService(
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
claudeTokenProvider: claudeTokenProvider,
kiroTokenProvider: kiroTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
@@ -192,6 +196,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt)
}
if account.IsKiro() && account.Type == AccountTypeOAuth {
return s.testKiroAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -240,6 +248,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
}
baseURL := account.GetBaseURL()
if baseURL == "" && account.Platform == PlatformKiro {
return s.sendErrorAndEnd(c, "Kiro API Key accounts require a Base URL")
}
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
@@ -388,6 +399,149 @@ func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Con
return s.processClaudeStream(c, resp.Body)
}
func (s *AccountTestService) testKiroAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
testModelID := strings.TrimSpace(modelID)
if testModelID == "" {
testModelID = "claude-sonnet-4-6"
}
if mappedModel := account.GetMappedModel(testModelID); strings.TrimSpace(mappedModel) != "" {
testModelID = mappedModel
}
if account.Type != AccountTypeOAuth {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Kiro account type: %s", account.Type))
}
if s.kiroTokenProvider == nil {
return s.sendErrorAndEnd(c, "Kiro token provider not configured")
}
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get Kiro access token: %s", err.Error()))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
resp, err := s.executeKiroTestUpstream(ctx, account, payloadBytes, testModelID, accessToken)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, formatKiroTestError(resp.StatusCode, body, testModelID, account))
}
pr, pw := io.Pipe()
go func() {
defer func() { _ = resp.Body.Close() }()
_, streamErr := kiropkg.StreamEventStreamAsAnthropic(ctx, resp.Body, pw, testModelID, estimateKiroInputTokens(payloadBytes))
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
}
_ = pw.Close()
}()
return s.processClaudeStream(c, pr)
}
func formatKiroTestError(statusCode int, body []byte, requestedModel string, account *Account) string {
return fmt.Sprintf("API returned %d: %s", statusCode, string(body))
}
func (s *AccountTestService) executeKiroTestUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string) (*http.Response, error) {
modelID := kiropkg.MapModel(mappedModel)
currentToken := token
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
if err != nil {
return nil, err
}
payload := buildResult.Payload
endpoints := buildKiroEndpoints(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
accountKey := buildKiroAccountKey(account)
maxRetries := 2
for idx, endpoint := range endpoints {
for attempt := 0; attempt <= maxRetries; attempt++ {
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
if err != nil {
return nil, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
if idx+1 < len(endpoints) {
_ = resp.Body.Close()
break
}
return resp, nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, readErr
}
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
if err != nil {
return nil, err
}
payload = buildResult.Payload
continue
}
}
resetHTTPResponseBody(resp, respBody)
return resp, nil
}
if resp.StatusCode == http.StatusBadRequest {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, readErr
}
resetHTTPResponseBody(resp, respBody)
return resp, nil
}
return resp, nil
}
}
return nil, fmt.Errorf("kiro upstream endpoints exhausted")
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
@@ -0,0 +1,84 @@
//go:build unit
package service
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
)
func TestAccountTestService_KiroAPIKeyUsesGenericAnthropicCompatiblePath(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 19,
Name: "kiro-apikey-test",
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"base_url": "https://kiro-upstream.example.com",
"api_key": "kiro-api-key",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4-6",
},
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"invalid api key"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
req := upstream.requests[0]
require.Equal(t, "kiro-upstream.example.com", req.URL.Host)
require.Equal(t, "/v1/messages", req.URL.Path)
require.Equal(t, "kiro-api-key", req.Header.Get("x-api-key"))
require.Empty(t, req.Header.Get("Authorization"))
require.Equal(t, claude.APIKeyBetaHeader, req.Header.Get("anthropic-beta"))
}
func TestAccountTestService_KiroAPIKeyWithoutBaseURLErrors(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 20,
Name: "kiro-apikey-missing-base-url",
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "kiro-api-key",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: &queuedHTTPUpstream{},
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Contains(t, err.Error(), "Base URL")
}
@@ -0,0 +1,317 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountTestService_KiroUsesKiroUpstreamInsteadOfAnthropic(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 1,
Name: "kiro-test",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/TESTSOCIAL",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{1: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"Invalid bearer token"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "gpt-4o", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
req := upstream.requests[0]
require.Equal(t, "q.us-east-1.amazonaws.com", req.URL.Host)
require.Equal(t, "/generateAssistantResponse", req.URL.Path)
require.Equal(t, "Bearer kiro-access-token", req.Header.Get("Authorization"))
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
require.Empty(t, req.Header.Get("anthropic-version"))
require.NotContains(t, req.URL.Host, "api.anthropic.com")
}
func TestAccountTestService_Kiro429DoesNotFallbackToCodeWhispererEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 2,
Name: "kiro-fallback",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"api_region": "us-west-2",
"region": "us-west-2",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTFALLBACK",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{2: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusTooManyRequests, `{"message":"slow down"}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
require.Contains(t, err.Error(), "API returned 429")
}
func TestAccountTestService_KiroIDCWithoutProfileArnOmitsProfileArnAndUsesDefaultRuntimeRegion(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 5,
Name: "kiro-idc-default-region",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"auth_method": "idc",
"provider": "AWS",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{5: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
require.Equal(t, "q.us-east-1.amazonaws.com", upstream.requests[0].URL.Host)
body, readErr := io.ReadAll(upstream.requests[0].Body)
require.NoError(t, readErr)
require.NotContains(t, string(body), `"profileArn":`)
}
func TestAccountTestService_KiroInvalidModelErrorPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 6,
Name: "kiro-invalid-model",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTINVALIDMODEL",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Equal(t, `API returned 400: {"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`, err.Error())
}
func TestAccountTestService_KiroInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 7,
Name: "kiro-invalid-model-refresh",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{7: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Contains(t, err.Error(), "API returned 400")
require.Len(t, upstream.requests, 1)
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
require.NoError(t, readErr)
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
}
func TestAccountTestService_KiroPreferredEndpointIsIgnored(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 6,
Name: "kiro-preferred-endpoint",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"api_region": "us-west-2",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/PREFERRED",
"preferred_endpoint": "codewhisperer",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
}
func TestBuildKiroPayloadForAccount_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
account := &Account{
ID: 3,
Name: "kiro-builder-id",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"auth_method": "idc",
"provider": "BuilderId",
"region": "us-east-1",
"client_id": "builder-client-id",
},
}
testPayload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(testPayload)
require.NoError(t, err)
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotContains(t, string(kiroPayload), `"profileArn":`)
}
func TestBuildKiroPayloadForAccount_KiroBuilderIDUsesCredentialProfileArn(t *testing.T) {
account := &Account{
ID: 33,
Name: "kiro-builder-id-cached",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"auth_method": "builder-id",
"provider": "BuilderId",
"region": "us-east-1",
"client_id": "builder-client-id",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED",
},
}
testPayload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(testPayload)
require.NoError(t, err)
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.Contains(t, string(kiroPayload), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED"`)
}
func TestBuildKiroPayloadForAccount_KiroEnterpriseIDCOmitsMissingProfileArn(t *testing.T) {
account := &Account{
ID: 4,
Name: "kiro-enterprise-idc",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"auth_method": "idc",
"provider": "AWS",
"region": "us-east-1",
"client_id": "enterprise-client-id",
"start_url": "https://d-example.awsapps.com/start",
},
}
testPayload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(testPayload)
require.NoError(t, err)
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotContains(t, string(kiroPayload), `"profileArn":`)
}
@@ -103,10 +103,17 @@ type antigravityUsageCache struct {
timestamp time.Time
}
// kiroUsageCache 缓存 Kiro 额度快照
type kiroUsageCache struct {
usageInfo *UsageInfo
timestamp time.Time
}
const (
apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
kiroUsageErrorTTL = 1 * time.Minute // Kiro 错误缓存 TTL(可恢复错误)
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
@@ -118,8 +125,10 @@ type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
kiroUsageCache sync.Map // accountID -> *kiroUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
kiroUsageFlight singleflight.Group // 防止同一 Kiro 账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
}
@@ -176,6 +185,23 @@ type AICredit struct {
MinimumBalance float64 `json:"minimum_balance,omitempty"`
}
// KiroCreditProgress 表示 Kiro 主额度或 Bonus 的用量进度。
type KiroCreditProgress struct {
CurrentUsage float64 `json:"current_usage"`
UsageLimit float64 `json:"usage_limit"`
PercentageUsed float64 `json:"percentage_used"`
DaysRemaining int `json:"days_remaining,omitempty"`
ExpiryDate *time.Time `json:"expiry_date,omitempty"`
}
// KiroOverageInfo 表示 Kiro 账号的 overage 状态。
type KiroOverageInfo struct {
CurrentOverages float64 `json:"current_overages"`
OverageCharges float64 `json:"overage_charges"`
CurrencyCode string `json:"currency_code,omitempty"`
CurrencySymbol string `json:"currency_symbol,omitempty"`
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
Source string `json:"source,omitempty"` // "passive" or "active"
@@ -203,6 +229,21 @@ type UsageInfo struct {
// Antigravity AI Credits 余额
AICredits []AICredit `json:"ai_credits,omitempty"`
// Kiro Credits 额度与 overage 信息
KiroSubscriptionName string `json:"kiro_subscription_name,omitempty"`
KiroSubscriptionType string `json:"kiro_subscription_type,omitempty"`
KiroResetAt *time.Time `json:"kiro_reset_at,omitempty"`
KiroOveragesEnabled bool `json:"kiro_overages_enabled,omitempty"`
KiroCredit *KiroCreditProgress `json:"kiro_credit,omitempty"`
KiroBonus *KiroCreditProgress `json:"kiro_bonus,omitempty"`
KiroOverage *KiroOverageInfo `json:"kiro_overage,omitempty"`
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
@@ -266,6 +307,7 @@ type AccountUsageService struct {
cache *UsageCache
identityCache IdentityCache
tlsFPProfileService *TLSFingerprintProfileService
kiroCooldownStore KiroCooldownStore
}
// NewAccountUsageService 创建AccountUsageService实例
@@ -291,6 +333,13 @@ func NewAccountUsageService(
}
}
func (s *AccountUsageService) SetKiroCooldownStore(store KiroCooldownStore) *AccountUsageService {
if s != nil {
s.kiroCooldownStore = store
}
return s
}
// GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope
@@ -317,6 +366,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return usage, err
}
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
return s.getKiroUsage(ctx, account, "active", false)
}
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity {
usage, err := s.getAntigravityUsage(ctx, account)
@@ -425,6 +478,13 @@ func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformKiro {
if account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("passive usage only supported for Kiro OAuth accounts")
}
return s.getKiroUsage(ctx, account, "passive", false)
}
if !account.IsAnthropicOAuthOrSetupToken() {
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
}
@@ -0,0 +1,40 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccountUsageService_GetUsage_KiroAPIKeyUnsupported(t *testing.T) {
account := &Account{
ID: 9101,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
usage, err := svc.GetUsage(context.Background(), account.ID)
require.Nil(t, usage)
require.Error(t, err)
require.Contains(t, err.Error(), "does not support usage query")
}
func TestAccountUsageService_GetPassiveUsage_KiroAPIKeyUnsupported(t *testing.T) {
account := &Account{
ID: 9102,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
usage, err := svc.GetPassiveUsage(context.Background(), account.ID)
require.Nil(t, usage)
require.Error(t, err)
require.Contains(t, err.Error(), "Kiro OAuth")
}
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
requestedModel: "any-model",
expected: true,
},
{
name: "kiro no mapping falls back to default whitelist",
platform: PlatformKiro,
credentials: nil,
requestedModel: "claude-sonnet-4-6",
expected: true,
},
{
name: "kiro no mapping rejects model outside default whitelist",
platform: PlatformKiro,
credentials: nil,
requestedModel: "auto",
expected: false,
},
// 精确匹配
{
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
{
name: "kiro no mapping uses default upstream mapping",
platform: PlatformKiro,
credentials: nil,
requestedModel: "claude-sonnet-4-6",
expected: "claude-sonnet-4.6",
},
// 精确匹配
{
+2 -2
View File
@@ -1716,7 +1716,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
// require_oauth_only: 过滤掉 apikey 类型账号
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
@@ -2008,7 +2008,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
// require_oauth_only: 过滤掉 apikey 类型账号
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
+17 -2
View File
@@ -198,6 +198,13 @@ func (s *BillingService) initFallbackPricing() {
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
@@ -278,13 +285,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
return s.fallbackPrices["claude-3-opus"]
}
if strings.Contains(modelLower, "sonnet") {
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
switch {
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
return s.fallbackPrices["claude-sonnet-4.6"]
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
return s.fallbackPrices["claude-sonnet-4.5"]
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
return s.fallbackPrices["claude-sonnet-4"]
}
return s.fallbackPrices["claude-3-5-sonnet"]
}
if strings.Contains(modelLower, "haiku") {
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
switch {
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
return s.fallbackPrices["claude-haiku-4.5"]
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
return s.fallbackPrices["claude-3-5-haiku"]
}
return s.fallbackPrices["claude-3-haiku"]
@@ -39,6 +39,7 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
PlatformKiro = domain.PlatformKiro
)
// Account type constants
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
@@ -105,6 +109,24 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
var resp *http.Response
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
} else {
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -126,7 +148,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
@@ -144,6 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
@@ -102,6 +106,24 @@ func (s *GatewayService) ForwardAsResponses(
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
var resp *http.Response
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
} else {
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -123,7 +145,7 @@ func (s *GatewayService) ForwardAsResponses(
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
@@ -141,6 +163,7 @@ func (s *GatewayService) ForwardAsResponses(
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
nil,
)
}
+194 -13
View File
@@ -27,6 +27,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
@@ -56,6 +57,7 @@ const (
defaultModelsListCacheTTL = 15 * time.Second
postUsageBillingTimeout = 15 * time.Second
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
defaultKiroStreamKeepalive = 25 * time.Second
)
const (
@@ -70,6 +72,7 @@ const (
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
type kiroCooldownRecoveryAttemptedKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
@@ -78,6 +81,7 @@ type accountWithLoad struct {
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
var kiroCooldownRecoveryAttemptedKey = kiroCooldownRecoveryAttemptedKeyType{}
var (
windowCostPrefetchCacheHitTotal atomic.Int64
@@ -554,6 +558,8 @@ type GatewayService struct {
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
kiroTokenProvider *KiroTokenProvider
kiroCooldownStore KiroCooldownStore
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver
@@ -592,6 +598,8 @@ func NewGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
kiroTokenProvider *KiroTokenProvider,
kiroCooldownStore KiroCooldownStore,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
digestStore *DigestSessionStore,
@@ -624,6 +632,8 @@ func NewGatewayService(
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
kiroTokenProvider: kiroTokenProvider,
kiroCooldownStore: kiroCooldownStore,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
@@ -902,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
injectMetadata bool
metadataUserID string
stripSystemCacheControl bool
preserveToolChoice bool
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
@@ -1116,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
modified = true
}
}
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
if !gjson.GetBytes(out, "max_tokens").Exists() {
@@ -1967,6 +1984,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if len(candidates) == 0 {
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, useMixed) {
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
return s.SelectAccountWithLoadAwareness(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID)
}
return nil, ErrNoAvailableAccounts
}
@@ -2346,14 +2367,91 @@ func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool
if account == nil {
return false
}
return account.IsSchedulable()
if !account.IsSchedulable() {
return false
}
return s.isKiroRuntimeSchedulable(context.Background(), account)
}
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
if account == nil {
return false
}
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false
}
return s.isKiroRuntimeSchedulable(ctx, account)
}
func (s *GatewayService) isKiroRuntimeSchedulable(ctx context.Context, account *Account) bool {
if account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth || s == nil || s.kiroCooldownStore == nil {
return true
}
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(account))
if err != nil {
return true
}
return state == nil || !state.Active
}
func (s *GatewayService) tryRecoverKiroCooldownPool(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) bool {
if s == nil || s.kiroCooldownStore == nil || ctx.Value(kiroCooldownRecoveryAttemptedKey) == true {
return false
}
tokenKeys := s.kiroTransientCooldownRecoveryKeys(ctx, accounts, requestedModel, excludedIDs, allowMixedScheduling)
if len(tokenKeys) == 0 {
return false
}
cleared, err := s.kiroCooldownStore.ClearEarliestTransientCooldown(ctx, tokenKeys)
if err != nil {
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery failed: %v", err)
return false
}
if cleared {
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery cleared one transient cooldown")
}
return cleared
}
func (s *GatewayService) kiroTransientCooldownRecoveryKeys(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) []string {
tokenKeys := make([]string, 0, len(accounts))
eligible := 0
for i := range accounts {
acc := &accounts[i]
if acc == nil || acc.Platform != PlatformKiro || acc.Type != AccountTypeOAuth {
if allowMixedScheduling {
continue
}
return nil
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) ||
!s.isAccountSchedulableForWindowCost(ctx, acc, false) ||
!s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
eligible++
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc))
if err != nil || state == nil || !state.Active {
return nil
}
if state.Reason != kirocooldown.CooldownReason429 {
return nil
}
tokenKeys = append(tokenKeys, buildKiroAccountKey(acc))
}
if eligible == 0 || len(tokenKeys) != eligible {
return nil
}
return tokenKeys
}
// isAccountInGroup checks if the account belongs to the specified group.
@@ -3232,6 +3330,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if selected == nil {
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, false) {
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
return s.selectAccountForModelWithPlatform(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
if requestedModel != "" {
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
}
@@ -3611,6 +3713,17 @@ func (s *GatewayService) diagnoseSelectionFailure(
if _, excluded := excludedIDs[acc.ID]; excluded {
return selectionFailureDiagnosis{Category: "excluded"}
}
if !acc.IsSchedulable() {
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
if acc.Platform == PlatformKiro && acc.Type == AccountTypeOAuth {
if state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc)); err == nil && state != nil && state.Active {
return selectionFailureDiagnosis{
Category: "unschedulable",
Detail: fmt.Sprintf("kiro_runtime_%s remaining=%s", state.Reason, state.Remaining.Truncate(time.Second)),
}
}
}
if !s.isAccountSchedulableForSelection(acc) {
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
@@ -3774,6 +3887,13 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
}
return accessToken, "oauth", nil
}
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth && s.kiroTokenProvider != nil {
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 其他情况(Gemini 有自己的 TokenProvidersetup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token")
@@ -4343,11 +4463,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model
@@ -4371,6 +4486,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return s.forwardBedrock(ctx, c, account, parsed, startTime)
}
if account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
return s.forwardKiroMessages(ctx, c, account, parsed, startTime)
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
@@ -4425,7 +4549,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
normalizeOpts := claudeOAuthNormalizeOptions{
stripSystemCacheControl: !systemRewritten,
preserveToolChoice: account.Platform == PlatformKiro,
}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
@@ -4462,7 +4589,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(reqModel); next != "" && next != reqModel {
mappedModel = next
mappingSource = "account"
}
} else if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
mappingSource = "account"
@@ -5967,6 +6099,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
if baseURL == "" && account.Platform == PlatformKiro {
return nil, fmt.Errorf("kiro api key account requires base_url")
}
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
@@ -7228,10 +7363,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
keepaliveInterval := s.streamKeepaliveIntervalForAccount(account)
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
@@ -8277,6 +8409,9 @@ type recordUsageOpts struct {
// 长上下文计费(仅 Gemini 路径需要)
LongContextThreshold int
LongContextMultiplier float64
// Kiro 账号在上游返回 auto 等无法定价模型时使用保守计费兜底。
IsKiroAccount bool
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -8414,6 +8549,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
}
// 计算费用
opts.IsKiroAccount = account != nil && account.Platform == PlatformKiro
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式
@@ -8492,6 +8628,28 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
const kiroConservativeFallbackBillingModel = "claude-opus-4-6"
func shouldUseKiroConservativeBillingFallback(result *ForwardResult, billingModel string, opts *recordUsageOpts) bool {
if result == nil {
return false
}
return opts != nil && opts.IsKiroAccount
}
func (s *GatewayService) calculateKiroConservativeTokenCost(tokens UsageTokens, multiplier float64) *CostBreakdown {
if s == nil || s.billingService == nil {
return nil
}
cost, err := s.billingService.CalculateCost(kiroConservativeFallbackBillingModel, tokens, multiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate conservative Kiro fallback cost failed: %v", err)
return nil
}
return cost
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -8596,6 +8754,12 @@ func (s *GatewayService) calculateTokenCost(
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
if shouldUseKiroConservativeBillingFallback(result, billingModel, opts) {
if fallback := s.calculateKiroConservativeTokenCost(tokens, multiplier); fallback != nil {
logger.LegacyPrintf("service.gateway", "Using conservative Kiro fallback pricing for model=%s", billingModel)
return fallback
}
}
return &CostBreakdown{ActualCost: 0}
}
return cost
@@ -8856,6 +9020,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
return nil
}
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
return nil
}
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
@@ -9486,6 +9654,19 @@ func reconcileCachedTokens(usage map[string]any) bool {
return true
}
func (s *GatewayService) streamKeepaliveIntervalForAccount(account *Account) time.Duration {
if account != nil && account.Platform == PlatformKiro {
if s != nil && s.cfg != nil && s.cfg.Gateway.KiroStreamKeepaliveInterval > 0 {
return time.Duration(s.cfg.Gateway.KiroStreamKeepaliveInterval) * time.Second
}
return defaultKiroStreamKeepalive
}
if s != nil && s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
return time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
return 0
}
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
// initDebugGatewayBodyFile 初始化网关调试日志文件。
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
// Anthropic API Key keeps the existing account-level override:
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
// Kiro OAuth uses channel-level switch only.
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
@@ -62,13 +64,29 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
return false
}
if account == nil {
return false
}
switch {
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default: // "default" → follow channel config
default:
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
}
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
default:
return false
}
}
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
if groupID == nil || s.channelService == nil {
return false
}
@@ -76,8 +94,7 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
return ch.IsWebSearchEmulationEnabled(platform)
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
"usage": map[string]int{
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
}
return flushSSEJSON(w, "message_start", evt)
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
"name": toolNameWebSearch, "input": map[string]any{},
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
inputJSON, err := json.Marshal(map[string]string{"query": query})
if err != nil {
return fmt.Errorf("marshal query: %w", err)
}
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
}); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
blocks := make([]map[string]any, 0, len(results))
for _, r := range results {
block := map[string]string{
block := map[string]any{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
"encrypted_content": r.Snippet,
"page_age": nil,
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
@@ -13,6 +14,31 @@ import (
"github.com/stretchr/testify/require"
)
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
rec := httptest.NewRecorder()
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, `"cache_creation_input_tokens":0`)
require.Contains(t, body, `"cache_read_input_tokens":0`)
}
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
rec := httptest.NewRecorder()
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, `event: content_block_start`)
require.Contains(t, body, `"type":"server_tool_use"`)
require.Contains(t, body, `"input":{}`)
require.Contains(t, body, `event: content_block_delta`)
require.Contains(t, body, `"type":"input_json_delta"`)
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
require.Contains(t, body, `event: content_block_stop`)
}
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
require.Nil(t, blocks[1]["page_age"])
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
require.Equal(t, "", blocks[0]["encrypted_content"])
require.Nil(t, blocks[0]["page_age"])
}
// --- buildTextSummary ---
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
}
}
func newKiroOAuthAccount() *Account {
return &Account{
ID: 2,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
// nil channelService + default mode → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 11,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
},
}
channelSvc := newChannelServiceWithCache(77, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newKiroOAuthAccount()
groupID := int64(77)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 12,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
},
}
channelSvc := newChannelServiceWithCache(78, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newKiroOAuthAccount()
groupID := int64(78)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newKiroOAuthAccount()
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
@@ -0,0 +1,623 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/stretchr/testify/require"
)
type kiroUsageCooldownStore struct {
state *kirocooldown.State
err error
}
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
return nil
}
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
return s.state, s.err
}
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
return false, nil
}
func kiroFloatPtr(v float64) *float64 {
return &v
}
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
}
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.billingService = NewBillingService(svc.cfg, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-sonnet-4-6": {
InputCostPerToken: 2.5e-6,
OutputCostPerToken: 10e-6,
},
},
})
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_kiro_billing_normalized",
Model: "claude-sonnet-4-6",
UpstreamModel: "claude-sonnet-4.6",
Usage: OpenAIUsage{
InputTokens: 20,
OutputTokens: 10,
},
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
account := Account{
ID: 701,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
require.Equal(t, "*/*", r.Header.Get("Accept"))
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
"overageConfiguration": {"overageStatus":"ENABLED"},
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
"usageBreakdownList": [{
"currency":"USD",
"currentOveragesWithPrecision":2,
"currentUsageWithPrecision":125,
"freeTrialInfo":{
"currentUsageWithPrecision":25,
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
"freeTrialStatus":"ACTIVE",
"usageLimitWithPrecision":500
},
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
"overageCharges":0.08,
"resourceType":"CREDIT",
"usageLimitWithPrecision":2000
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, "active", usage.Source)
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
require.True(t, usage.KiroOveragesEnabled)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
require.NotNil(t, usage.KiroBonus)
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
require.NotNil(t, usage.KiroOverage)
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
require.NotNil(t, usage.KiroResetAt)
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
require.NotNil(t, usage.KiroQuotaResetAt)
}
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
account := Account{
ID: 702,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":300,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer successServer.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, firstUsage)
require.NotNil(t, firstUsage.KiroCredit)
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
}))
defer failingServer.Close()
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
require.Empty(t, usage.Error)
require.Empty(t, usage.ErrorCode)
}
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
account := Account{
ID: 703,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "BuilderId",
"auth_method": "idc",
"region": "us-east-1",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Empty(t, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":42,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
}
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
account := Account{
ID: 707,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"region": "us-east-1",
"start_url": "https://d-example.awsapps.com/start",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":64,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
}
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
account := Account{
ID: 709,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"api_region": "eu-west-1",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
gotRegions := make([]string, 0, 2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":11,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(region string) string {
gotRegions = append(gotRegions, region)
return server.URL
}
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, []string{"eu-west-1"}, gotRegions)
}
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
account := Account{
ID: 710,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
gotRegions := make([]string, 0, 2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Empty(t, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":7,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(region string) string {
gotRegions = append(gotRegions, region)
return server.URL
}
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
}
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
account := Account{
ID: 704,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
SetKiroCooldownStore(&kiroUsageCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(90 * time.Second),
Remaining: 90 * time.Second,
},
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":42,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.Equal(t, "cooldown", usage.KiroRuntimeState)
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
require.NotNil(t, usage.KiroRuntimeResetAt)
}
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
StatusCode: http.StatusBadRequest,
Body: `{"message":"profileArn is required for this request."}`,
})
require.Equal(t, errorCodeForbidden, info.ErrorCode)
require.False(t, info.NeedsReauth)
}
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
StatusCode: http.StatusTooManyRequests,
Body: `{"message":"overage exhausted for this billing window"}`,
})
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
}
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
account := Account{
ID: 708,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, firstUsage)
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, secondUsage)
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
require.Equal(t, 1, requestCount)
}
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
NextDateReset: "2099-03-13T12:00:00Z",
OverageConfiguration: kiroOverageConfiguration{
OverageStatus: "DISABLED",
},
UsageBreakdownList: []kiroUsageBreakdown{
{
ResourceType: "CREDIT",
CurrentUsageWithPrecision: kiroFloatPtr(2000),
UsageLimitWithPrecision: kiroFloatPtr(2000),
CurrentOveragesWithPrecision: kiroFloatPtr(0),
},
},
})
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
require.NotNil(t, info.KiroQuotaResetAt)
}
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
SetKiroCooldownStore(&kiroUsageCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(2 * time.Minute),
Remaining: 2 * time.Minute,
},
})
account := &Account{
ID: 705,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "kiro-access-token"},
}
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
require.Equal(t, "cooldown", account.KiroRuntimeState)
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
require.NotNil(t, account.KiroRuntimeResetAt)
}
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
account := Account{
ID: 706,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"nextDateReset":"2099-03-13T12:00:00Z",
"overageConfiguration":{"overageStatus":"ENABLED"},
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":2000,
"currentOveragesWithPrecision":4,
"overageCharges":0.2,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
_, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
target := &Account{
ID: account.ID,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "kiro-access-token"},
}
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
require.NotNil(t, target.KiroQuotaResetAt)
}
@@ -0,0 +1,163 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
account := Account{
Type: AccountTypeAPIKey,
Platform: PlatformKiro,
Credentials: map[string]any{},
}
require.Empty(t, account.GetBaseURL())
}
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
svc := &GatewayService{}
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
require.Equal(t, 25*time.Second, got)
}
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamKeepaliveInterval: 10,
KiroStreamKeepaliveInterval: 25,
},
},
}
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
}
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
svc := NewBillingService(&config.Config{}, nil)
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
}
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
tests := []struct {
name string
requestedModel string
upstreamModel string
want string
}{
{
name: "kiro claude sonnet 4.6 uses pricing key format",
requestedModel: "claude-sonnet-4-6",
upstreamModel: "claude-sonnet-4.6",
want: "claude-sonnet-4-6",
},
{
name: "falls back to upstream when requested model empty",
requestedModel: "",
upstreamModel: "claude-haiku-4-5",
want: "claude-haiku-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
})
}
}
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.billingService = NewBillingService(svc.cfg, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-sonnet-4-6": {
InputCostPerToken: 2.5e-6,
OutputCostPerToken: 10e-6,
},
},
})
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_kiro_billing_normalized",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
Model: "claude-sonnet-4-6",
UpstreamModel: "claude-sonnet-4.6",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_kiro_auto_fallback_cost",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
Model: "auto",
UpstreamModel: "auto",
Duration: time.Second,
},
APIKey: &APIKey{ID: 601, Quota: 100},
User: &User{ID: 701},
Account: &Account{ID: 801, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
@@ -0,0 +1,222 @@
package service
import (
"errors"
"net"
"net/http"
"strings"
"github.com/tidwall/gjson"
)
const (
kiroErrorAuthError = "auth_error"
kiroErrorMonthlyRequest = "monthly_request_count"
kiroErrorProfileError = "profile_error"
kiroErrorQuotaExhausted = "quota_exhausted"
kiroErrorOverageExhausted = "overage_exhausted"
kiroErrorRateLimited = "rate_limited"
kiroErrorSuspended = "suspended"
kiroErrorUsageForbidden = "usage_forbidden"
kiroErrorUpstreamTransient = "upstream_transient"
kiroErrorBadRequestSchema = "bad_request_schema"
kiroErrorBadRequestToolPairing = "bad_request_tool_pairing"
kiroErrorBadRequestInvalidModel = "bad_request_invalid_model"
kiroErrorBadRequestAuth = "bad_request_auth"
kiroErrorBadRequestQuota = "bad_request_quota"
kiroErrorBadRequestUnknown = "bad_request_unknown"
kiroErrorRefreshTokenInvalid = "refresh_token_invalid"
kiroQuotaStateNormal = "normal"
kiroQuotaStateOverageActive = "overage_active"
kiroQuotaStateCreditsExhausted = "credits_exhausted"
kiroQuotaStateOverageExhausted = "overage_exhausted"
)
type kiroErrorClassification struct {
Category string
StatusCode int
Message string
}
func classifyKiroHTTPError(statusCode int, body string) kiroErrorClassification {
trimmed := strings.TrimSpace(body)
lower := strings.ToLower(trimmed)
switch {
case statusCode == http.StatusUnauthorized:
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusPaymentRequired && looksLikeKiroMonthlyRequestCountError(trimmed):
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusForbidden && isKiroSuspendedBody([]byte(trimmed)):
return kiroErrorClassification{Category: kiroErrorSuspended, StatusCode: statusCode, Message: trimmed}
case looksLikeKiroProfileError(lower):
return kiroErrorClassification{Category: kiroErrorProfileError, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusBadRequest:
return classifyKiroBadRequest(trimmed, lower)
case statusCode == http.StatusForbidden && isKiroTokenErrorBody([]byte(trimmed)):
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
case looksLikeKiroOverageExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorOverageExhausted, StatusCode: statusCode, Message: trimmed}
case looksLikeKiroQuotaExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusTooManyRequests:
return kiroErrorClassification{Category: kiroErrorRateLimited, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusForbidden:
return kiroErrorClassification{Category: kiroErrorUsageForbidden, StatusCode: statusCode, Message: trimmed}
case statusCode >= 500:
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
default:
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
}
}
func classifyKiroError(err error) kiroErrorClassification {
if err == nil {
return kiroErrorClassification{}
}
var httpErr *kiroUsageHTTPError
if errors.As(err, &httpErr) && httpErr != nil {
return classifyKiroHTTPError(httpErr.StatusCode, httpErr.Body)
}
errStr := strings.TrimSpace(err.Error())
lower := strings.ToLower(errStr)
switch {
case looksLikeKiroInvalidGrantError(lower):
return kiroErrorClassification{Category: kiroErrorRefreshTokenInvalid, Message: errStr}
case looksLikeKiroMonthlyRequestCountError(errStr):
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, Message: errStr}
case looksLikeKiroProfileError(lower):
return kiroErrorClassification{Category: kiroErrorProfileError, Message: errStr}
case looksLikeKiroOverageExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorOverageExhausted, Message: errStr}
case looksLikeKiroQuotaExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, Message: errStr}
case strings.Contains(lower, "context deadline exceeded"),
strings.Contains(lower, "timeout"),
isNetErr(err):
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
default:
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
}
}
func classifyKiroBadRequest(trimmed, lower string) kiroErrorClassification {
switch {
case looksLikeKiroBadRequestSchemaError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestSchema, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroBadRequestToolPairingError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestToolPairing, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroBadRequestInvalidModelError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestInvalidModel, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroInvalidGrantError(lower) || looksLikeKiroBadRequestAuthError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestAuth, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroQuotaExhaustedError(lower) || looksLikeKiroMonthlyRequestCountError(trimmed):
return kiroErrorClassification{Category: kiroErrorBadRequestQuota, StatusCode: http.StatusBadRequest, Message: trimmed}
default:
return kiroErrorClassification{Category: kiroErrorBadRequestUnknown, StatusCode: http.StatusBadRequest, Message: trimmed}
}
}
func looksLikeKiroBadRequestSchemaError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "schema") ||
strings.Contains(lower, "inputschema") ||
strings.Contains(lower, "improperly formed request") ||
strings.Contains(lower, "additionalproperties") ||
(strings.Contains(lower, "properties") && strings.Contains(lower, "required"))
}
func looksLikeKiroBadRequestToolPairingError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "tool_use") ||
strings.Contains(lower, "tool_result") ||
strings.Contains(lower, "tooluseid") ||
strings.Contains(lower, "toolresults") ||
strings.Contains(lower, "must be paired") ||
strings.Contains(lower, "missing tool result")
}
func looksLikeKiroBadRequestInvalidModelError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "invalid model") ||
strings.Contains(lower, "invalid_model_id") ||
strings.Contains(lower, "model not supported") ||
strings.Contains(lower, "unsupportedmodel") ||
strings.Contains(lower, "modelid")
}
func looksLikeKiroBadRequestAuthError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "invalid token") ||
strings.Contains(lower, "expired token") ||
strings.Contains(lower, "access token") ||
strings.Contains(lower, "refresh token")
}
func looksLikeKiroInvalidGrantError(lower string) bool {
return strings.Contains(lower, "invalid_grant")
}
func looksLikeKiroMonthlyRequestCountError(body string) bool {
trimmed := strings.TrimSpace(body)
if trimmed == "" {
return false
}
if strings.Contains(trimmed, "MONTHLY_REQUEST_COUNT") {
return true
}
if !gjson.Valid(trimmed) {
return false
}
return gjson.Get(trimmed, "reason").String() == "MONTHLY_REQUEST_COUNT" ||
gjson.Get(trimmed, "error.reason").String() == "MONTHLY_REQUEST_COUNT"
}
func looksLikeKiroProfileError(lower string) bool {
if lower == "" {
return false
}
return (strings.Contains(lower, "profilearn") && strings.Contains(lower, "required")) ||
(strings.Contains(lower, "profile arn") && strings.Contains(lower, "required")) ||
(strings.Contains(lower, "profile") && strings.Contains(lower, "not found")) ||
(strings.Contains(lower, "invalid profile")) ||
(strings.Contains(lower, "listavailableprofiles"))
}
func looksLikeKiroQuotaExhaustedError(lower string) bool {
if lower == "" {
return false
}
return (strings.Contains(lower, "credit") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "depleted"))) ||
(strings.Contains(lower, "quota") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "exceeded") || strings.Contains(lower, "depleted"))) ||
(strings.Contains(lower, "usage limit") && (strings.Contains(lower, "reached") || strings.Contains(lower, "exceeded"))) ||
(strings.Contains(lower, "resource has been exhausted"))
}
func looksLikeKiroOverageExhaustedError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "overage") &&
(strings.Contains(lower, "exhaust") ||
strings.Contains(lower, "disabled") ||
strings.Contains(lower, "not enabled") ||
strings.Contains(lower, "not allowed") ||
strings.Contains(lower, "limit"))
}
func isNetErr(err error) bool {
var netErr net.Error
return errors.As(err, &netErr)
}
@@ -0,0 +1,66 @@
package service
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestClassifyKiroHTTPErrorBadRequestCategories(t *testing.T) {
tests := []struct {
name string
body string
want string
}{
{
name: "schema",
body: `{"message":"Improperly formed request: inputSchema.properties must be an object"}`,
want: kiroErrorBadRequestSchema,
},
{
name: "tool pairing",
body: `{"message":"tool_use must be paired with a matching tool_result"}`,
want: kiroErrorBadRequestToolPairing,
},
{
name: "invalid model id",
body: `{"message":"invalid modelId: model not supported"}`,
want: kiroErrorBadRequestInvalidModel,
},
{
name: "invalid model upstream",
body: `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
want: kiroErrorBadRequestInvalidModel,
},
{
name: "invalid model reason",
body: `{"message":"model route unavailable","reason":"INVALID_MODEL_ID"}`,
want: kiroErrorBadRequestInvalidModel,
},
{
name: "auth",
body: `{"error":"invalid_grant","message":"Invalid refresh token provided"}`,
want: kiroErrorBadRequestAuth,
},
{
name: "quota",
body: `{"message":"resource has been exhausted"}`,
want: kiroErrorBadRequestQuota,
},
{
name: "unknown",
body: `{"message":"bad request"}`,
want: kiroErrorBadRequestUnknown,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
classification := classifyKiroHTTPError(http.StatusBadRequest, tt.body)
require.Equal(t, tt.want, classification.Category)
require.Equal(t, http.StatusBadRequest, classification.StatusCode)
require.Equal(t, tt.body, classification.Message)
})
}
}
@@ -0,0 +1,180 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/google/uuid"
)
func buildKiroAccountKey(account *Account) string {
if account == nil {
return ""
}
return kiropkg.BuildAccountKey(
account.GetCredential("client_id"),
account.GetCredential("client_id_hash"),
account.GetCredential("refresh_token"),
account.GetCredential("profile_arn"),
account.ID,
)
}
func buildKiroMachineID(account *Account) string {
if account == nil {
return kiropkg.BuildMachineID("", "", "account:nil")
}
for _, key := range []string{"machine_id", "machineId"} {
if machineID, ok := kiropkg.NormalizeMachineID(account.GetCredential(key)); ok {
return machineID
}
}
fallbackKey := buildKiroMachineIDFallbackKey(account)
if account.Type == AccountTypeAPIKey {
return kiropkg.BuildMachineID("", firstKiroCredential(account, "kiro_api_key", "kiroApiKey", "api_key"), fallbackKey)
}
return kiropkg.BuildMachineID(account.GetCredential("refresh_token"), "", fallbackKey)
}
func firstKiroCredential(account *Account, keys ...string) string {
if account == nil {
return ""
}
for _, key := range keys {
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
return value
}
}
return ""
}
func buildKiroMachineIDFallbackKey(account *Account) string {
if account == nil {
return "account:nil"
}
if account.ID > 0 {
return fmt.Sprintf("account:%d", account.ID)
}
for _, key := range []string{"client_id", "profile_arn"} {
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
return key + ":" + value
}
}
if name := strings.TrimSpace(account.Name); name != "" {
return "name:" + name
}
return "account:unknown"
}
func buildKiroRequestID(resp *http.Response) string {
if resp == nil {
return ""
}
if requestID := strings.TrimSpace(resp.Header.Get("x-request-id")); requestID != "" {
return requestID
}
if requestID := strings.TrimSpace(resp.Header.Get("x-amzn-requestid")); requestID != "" {
return requestID
}
return strings.TrimSpace(resp.Header.Get("x-amz-request-id"))
}
func isKiroInvalidModelIDBody(respBody []byte) bool {
var payload struct {
Reason string `json:"reason"`
Message string `json:"message"`
Error struct {
Reason string `json:"reason"`
Message string `json:"message"`
} `json:"error"`
}
if json.Unmarshal(respBody, &payload) != nil {
return looksLikeKiroBadRequestInvalidModelError(strings.ToLower(string(respBody)))
}
return strings.EqualFold(strings.TrimSpace(payload.Reason), "INVALID_MODEL_ID") ||
strings.EqualFold(strings.TrimSpace(payload.Error.Reason), "INVALID_MODEL_ID") ||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Message)) ||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Error.Message))
}
func isKiroSuspendedBody(respBody []byte) bool {
body := string(respBody)
return strings.Contains(body, "SUSPENDED") || strings.Contains(body, "TEMPORARILY_SUSPENDED")
}
func isKiroTokenErrorBody(respBody []byte) bool {
lower := strings.ToLower(string(respBody))
return strings.Contains(lower, "token") ||
strings.Contains(lower, "expired") ||
strings.Contains(lower, "invalid") ||
strings.Contains(lower, "unauthorized")
}
func kiroProxyURL(account *Account) string {
if account != nil && account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
func kiroAPIRegion(account *Account) string {
if account == nil {
return kiroDefaultRegion
}
region := strings.TrimSpace(account.GetCredential("api_region"))
if region == "" {
region = kiroDefaultRegion
}
return region
}
func applyKiroConditionalHeaders(req *http.Request, account *Account) {
if req == nil || account == nil {
return
}
if strings.EqualFold(strings.TrimSpace(account.GetCredential("auth_method")), "external_idp") {
req.Header.Set("TokenType", "EXTERNAL_IDP")
}
if strings.EqualFold(strings.TrimSpace(account.GetCredential("provider")), "Internal") {
req.Header.Set("redirect-for-internal", "true")
}
}
func resolveKiroPayloadProfileArn(account *Account) string {
if account == nil {
return ""
}
return strings.TrimSpace(account.GetCredential("profile_arn"))
}
func newKiroJSONRequest(ctx context.Context, endpointURL string, payload []byte, token, accountKey, machineID, amzTarget string, account *Account) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
if amzTarget != "" {
req.Header.Set("X-Amz-Target", amzTarget)
}
if account != nil {
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
if profileArn != "" {
req.Header.Set("x-amzn-kiro-profile-arn", profileArn)
}
}
applyKiroConditionalHeaders(req, account)
return req, nil
}
@@ -0,0 +1,263 @@
//go:build unit
package service
import (
"context"
"net/http"
"strings"
"testing"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
accountA := &Account{
ID: 99,
Credentials: map[string]any{
"access_token": "token-a",
},
}
accountB := &Account{
ID: 99,
Credentials: map[string]any{
"access_token": "token-b",
},
}
require.Equal(t, buildKiroAccountKey(accountA), buildKiroAccountKey(accountB))
}
func TestBuildKiroMachineIDPrefersExplicitCredential(t *testing.T) {
account := &Account{
ID: 101,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"machineId": "2582956e-cc88-4669-b546-07adbffcb894",
"refresh_token": "refresh-token",
},
}
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", buildKiroMachineID(account))
}
func TestBuildKiroMachineIDDerivesFromRefreshToken(t *testing.T) {
account := &Account{
ID: 102,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token",
},
}
require.Equal(t, kiropkg.BuildMachineID("refresh-token", "", "account:102"), buildKiroMachineID(account))
}
func TestBuildKiroMachineIDDerivesFromAPIKeyAccount(t *testing.T) {
account := &Account{
ID: 103,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"kiroApiKey": "kiro-api-key",
},
}
require.Equal(t, kiropkg.BuildMachineID("", "kiro-api-key", "account:103"), buildKiroMachineID(account))
}
func TestNewKiroJSONRequestAddsConditionalHeaders(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"auth_method": "external_idp",
"provider": "Internal",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER",
},
}
req, err := newKiroJSONRequest(
context.Background(),
"https://q.us-east-1.amazonaws.com/generateAssistantResponse",
[]byte(`{"ok":true}`),
"access-token",
"account-key",
buildKiroMachineID(account),
"",
account,
)
require.NoError(t, err)
require.Equal(t, "EXTERNAL_IDP", req.Header.Get("TokenType"))
require.Equal(t, "true", req.Header.Get("redirect-for-internal"))
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER", req.Header.Get("x-amzn-kiro-profile-arn"))
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
require.Equal(t, "true", req.Header.Get("x-amzn-codewhisperer-optout"))
require.Contains(t, req.Header.Get("User-Agent"), "aws-sdk-js/1.0.34")
require.Contains(t, req.Header.Get("User-Agent"), "md/nodejs#22.22.0")
require.Contains(t, req.Header.Get("User-Agent"), buildKiroMachineID(account))
require.Contains(t, req.Header.Get("X-Amz-User-Agent"), buildKiroMachineID(account))
require.True(t, strings.Contains(req.Header.Get("User-Agent"), "api/codewhispererstreaming#1.0.34"))
require.Empty(t, req.Header.Get("Anthropic-Beta"))
}
func TestIsKiroInvalidModelIDBodyRecognizesKnownForms(t *testing.T) {
tests := []string{
`{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`,
`{"message":"Invalid model. Please select a different model to continue."}`,
`API Error: 400 {"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
}
for _, body := range tests {
require.True(t, isKiroInvalidModelIDBody([]byte(body)), body)
}
}
func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
account := &Account{
ID: 7,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/test",
},
}
body := []byte(`{
"model":"claude-sonnet-4-6",
"messages":[{"role":"user","content":"hello"}]
}`)
headers := http.Header{}
headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-sonnet-4.6",
"kiro-access-token",
"claude-sonnet-4-6",
headers,
)
require.NoError(t, err)
require.NotContains(t, string(payload), "CHUNKED WRITE PROTOCOL")
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
}
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
account := &Account{
ID: 8,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
body := []byte(`{
"model":"claude-opus-4.6",
"messages":[{"role":"user","content":"hello"}]
}`)
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-opus-4.6",
"kiro-access-token",
"claude-opus-4-6-thinking",
nil,
)
require.NoError(t, err)
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
}
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
account := &Account{
ID: 9,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
body := []byte(`{
"model":"claude-opus-4.6",
"messages":[{"role":"user","content":"hello"}]
}`)
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-opus-4.6",
"kiro-access-token",
"claude-opus-4-6",
nil,
)
require.NoError(t, err)
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
require.NotContains(t, systemContent, "<thinking_mode>")
}
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"api_region": "eu-west-1",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
"region": "ap-northeast-1",
},
}
require.Equal(t, "eu-west-1", kiroAPIRegion(account))
}
func TestKiroAPIRegionIgnoresProfileARNRegionFallback(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
},
}
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
}
func TestKiroAPIRegionIgnoresOIDCRegionFallback(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"region": "ap-northeast-2",
},
}
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
}
func TestBuildKiroEndpointsUsesOnlyAmazonQEndpoint(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"api_region": "us-west-2",
"preferred_endpoint": "cw",
},
}
endpoints := buildKiroEndpoints(account)
require.Len(t, endpoints, 1)
require.Equal(t, "AmazonQ", endpoints[0].Name)
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
require.Empty(t, endpoints[0].AmzTarget)
}
func TestBuildKiroEndpointsIgnoresPreferredEndpoint(t *testing.T) {
for _, preferred := range []string{"codewhisperer", "cw", "unknown"} {
account := &Account{
Credentials: map[string]any{
"api_region": "us-west-2",
"preferred_endpoint": preferred,
},
}
endpoints := buildKiroEndpoints(account)
require.Len(t, endpoints, 1)
require.Equal(t, "AmazonQ", endpoints[0].Name)
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
}
}
@@ -0,0 +1,74 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestAccountKiroDefaultMappingRestrictsUnsupportedModels(t *testing.T) {
account := &Account{Platform: PlatformKiro}
require.False(t, account.IsModelSupported("gpt-4o"))
require.False(t, account.IsModelSupported("kiro-gpt-4o"))
require.False(t, account.IsModelSupported("auto"))
require.Equal(t, "claude-sonnet-4.6", account.GetMappedModel("claude-sonnet-4-6"))
}
func TestGatewayServiceCalculateTokenCost_KiroAutoUsesConservativeFallback(t *testing.T) {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
svc := NewGatewayService(
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
result := &ForwardResult{
Model: "auto",
UpstreamModel: "auto",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
}
expected, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
cost := svc.calculateTokenCost(context.Background(), result, &APIKey{}, "auto", 1.1, &recordUsageOpts{IsKiroAccount: true})
require.NotNil(t, cost)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-12)
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-12)
}
@@ -0,0 +1,369 @@
package service
import (
"context"
"fmt"
"strings"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
const (
// Kiro desktop social auth uses localhost loopback callbacks from a fixed
// allowlist. Use one of the bundled ports from the official client.
kiroSocialRedirectURI = "http://localhost:49153"
// AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
)
type KiroOAuthService struct {
sessionStore *kiropkg.SessionStore
proxyRepo ProxyRepository
}
func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
return &KiroOAuthService{
sessionStore: kiropkg.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
func (s *KiroOAuthService) Stop() {}
type KiroAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
type KiroIDCAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
ClientID string `json:"client_id"`
Region string `json:"region"`
StartURL string `json:"start_url"`
}
type KiroTokenInfo struct {
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ProfileArn string `json:"profile_arn,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
AuthMethod string `json:"auth_method,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ClientIDHash string `json:"client_id_hash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"start_url,omitempty"`
Region string `json:"region,omitempty"`
}
type KiroGenerateAuthURLInput struct {
ProxyID *int64
Provider string
}
type KiroExchangeCodeInput struct {
SessionID string
State string
Code string
CallbackPath string
LoginOption string
ProxyID *int64
}
type KiroGenerateIDCAuthURLInput struct {
ProxyID *int64
StartURL string
Region string
}
type KiroRefreshTokenInput struct {
RefreshToken string
AuthMethod string
Provider string
ClientID string
ClientSecret string
StartURL string
Region string
ProfileArn string
ProxyID *int64
}
type KiroImportTokenInput struct {
TokenJSON string
DeviceRegistrationJSON string
}
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
provider := strings.TrimSpace(input.Provider)
if provider == "" {
provider = string(kiropkg.SocialProviderGoogle)
}
if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
}
state, err := kiropkg.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
codeVerifier, err := kiropkg.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
sessionID := kiropkg.GenerateSessionID()
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
AuthType: "social",
Provider: provider,
RedirectURI: kiroSocialRedirectURI,
})
return &KiroAuthURLResult{
AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
SessionID: sessionID,
State: state,
}, nil
}
func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state invalid")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
}
switch session.AuthType {
case "social":
token, err := kiropkg.CreateSocialToken(
ctx,
proxyURL,
input.Code,
session.CodeVerifier,
buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
)
if err != nil {
return nil, err
}
token.Provider = session.Provider
s.sessionStore.Delete(input.SessionID)
return toKiroTokenInfo(token), nil
case "idc":
token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
if err != nil {
return nil, err
}
s.sessionStore.Delete(input.SessionID)
return toKiroTokenInfo(token), nil
default:
return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
}
}
func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
option := strings.ToLower(strings.TrimSpace(loginOption))
if option == "" {
switch provider {
case string(kiropkg.SocialProviderGitHub):
option = "github"
case string(kiropkg.SocialProviderGoogle):
option = "google"
}
}
return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
}
func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
startURL := strings.TrimSpace(input.StartURL)
if startURL == "" {
startURL = kiropkg.BuilderIDStartURL
}
region := strings.TrimSpace(input.Region)
if region == "" {
region = "us-east-1"
}
state, err := kiropkg.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
codeVerifier, err := kiropkg.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
if err != nil {
return nil, err
}
sessionID := kiropkg.GenerateSessionID()
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
AuthType: "idc",
RedirectURI: kiroIDCRedirectURI,
ClientID: reg.ClientID,
ClientSecret: reg.ClientSecret,
Region: region,
StartURL: startURL,
})
return &KiroIDCAuthURLResult{
AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
SessionID: sessionID,
State: state,
ClientID: reg.ClientID,
Region: region,
StartURL: startURL,
}, nil
}
func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
if authMethod == "" {
authMethod = "social"
}
var token *kiropkg.TokenData
var err error
switch authMethod {
case "idc":
token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
default:
token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
}
if err != nil {
return nil, err
}
if token.ProfileArn == "" {
token.ProfileArn = input.ProfileArn
}
if token.ClientID == "" {
token.ClientID = input.ClientID
}
if token.ClientSecret == "" {
token.ClientSecret = input.ClientSecret
}
if token.StartURL == "" {
token.StartURL = input.StartURL
}
if token.Region == "" {
token.Region = input.Region
}
return toKiroTokenInfo(token), nil
}
func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("not a kiro oauth account")
}
return s.RefreshToken(ctx, &KiroRefreshTokenInput{
RefreshToken: account.GetCredential("refresh_token"),
AuthMethod: account.GetCredential("auth_method"),
Provider: account.GetCredential("provider"),
ClientID: account.GetCredential("client_id"),
ClientSecret: account.GetCredential("client_secret"),
StartURL: account.GetCredential("start_url"),
Region: account.GetCredential("region"),
ProfileArn: account.GetCredential("profile_arn"),
ProxyID: account.ProxyID,
})
}
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
if err != nil {
return nil, err
}
return toKiroTokenInfo(token), nil
}
func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
if tokenInfo == nil {
return map[string]any{}
}
creds := map[string]any{}
if tokenInfo.AccessToken != "" {
creds["access_token"] = tokenInfo.AccessToken
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.ProfileArn != "" {
creds["profile_arn"] = tokenInfo.ProfileArn
}
if tokenInfo.ExpiresAt != "" {
creds["expires_at"] = tokenInfo.ExpiresAt
}
if tokenInfo.AuthMethod != "" {
creds["auth_method"] = tokenInfo.AuthMethod
}
if tokenInfo.Provider != "" {
creds["provider"] = tokenInfo.Provider
}
if tokenInfo.ClientID != "" {
creds["client_id"] = tokenInfo.ClientID
}
if tokenInfo.ClientSecret != "" {
creds["client_secret"] = tokenInfo.ClientSecret
}
if tokenInfo.ClientIDHash != "" {
creds["client_id_hash"] = tokenInfo.ClientIDHash
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.StartURL != "" {
creds["start_url"] = tokenInfo.StartURL
}
if tokenInfo.Region != "" {
creds["region"] = tokenInfo.Region
}
return creds
}
func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
if token == nil {
return nil
}
return &KiroTokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ProfileArn: token.ProfileArn,
ExpiresAt: token.ExpiresAt,
AuthMethod: token.AuthMethod,
Provider: token.Provider,
ClientID: token.ClientID,
ClientSecret: token.ClientSecret,
ClientIDHash: token.ClientIDHash,
Email: token.Email,
StartURL: token.StartURL,
Region: token.Region,
}
}
func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil || s.proxyRepo == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil || proxy == nil {
return "", err
}
return proxy.URL(), nil
}
@@ -0,0 +1,51 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/stretchr/testify/require"
)
func TestKiroIDCAuthRedirectURIUsesLoopbackIP(t *testing.T) {
require.Equal(t, "http://127.0.0.1:9876/oauth/callback", kiroIDCRedirectURI)
}
func TestKiroSocialAuthRedirectURIUsesLoopbackIP(t *testing.T) {
require.Equal(t, "http://localhost:49153", kiroSocialRedirectURI)
}
func TestBuildKiroSocialExchangeRedirectURIUsesProviderDefault(t *testing.T) {
require.Equal(
t,
"http://localhost:49153/oauth/callback?login_option=github",
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "", ""),
)
}
func TestBuildKiroSocialExchangeRedirectURIPreservesParsedCallbackData(t *testing.T) {
require.Equal(
t,
"http://localhost:49153/signin/callback?login_option=google",
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "/signin/callback", "google"),
)
}
func TestKiroOAuthService_ExchangeCodeRejectsExpiredSession(t *testing.T) {
svc := NewKiroOAuthService(nil)
svc.sessionStore.Set("expired-session", &kiropkg.AuthSession{
State: "expected-state",
CreatedAt: time.Now().Add(-11 * time.Minute),
})
_, err := svc.ExchangeCode(context.Background(), &KiroExchangeCodeInput{
SessionID: "expired-session",
State: "expected-state",
Code: "auth-code",
})
require.EqualError(t, err, "session not found or expired")
}
+740
View File
@@ -0,0 +1,740 @@
package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
mathrand "math/rand"
"net/http"
"strings"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
type kiroEndpointConfig struct {
URL string
AmzTarget string
Name string
}
const kiroInvalidModelTempUnschedDuration = time.Minute
const (
kiroRetryBaseDelay = 200 * time.Millisecond
kiroRetryMaxDelay = 2 * time.Second
)
var kiroRetrySleep = sleepWithContext
func kiroRetryBackoffDelay(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := kiroRetryBaseDelay * time.Duration(1<<attempt)
if delay > kiroRetryMaxDelay {
delay = kiroRetryMaxDelay
}
jitterMax := delay / 4
if jitterMax <= 0 {
return delay
}
return delay + time.Duration(mathrand.Int63n(int64(jitterMax)+1))
}
func sleepKiroRetry(ctx context.Context, attempt int) error {
return kiroRetrySleep(ctx, kiroRetryBackoffDelay(attempt))
}
func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) {
if account == nil || parsed == nil {
return nil, fmt.Errorf("kiro forward: missing account or request")
}
originalModel := parsed.Model
mappedModel := originalModel
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
body := parsed.Body
if mappedModel != originalModel {
body = s.replaceModelInBody(body, mappedModel)
}
logger.L().Debug("gateway forward_kiro_messages: request prepared",
zap.Int64("account_id", account.ID),
zap.String("auth_method", strings.TrimSpace(account.GetCredential("auth_method"))),
zap.String("requested_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("has_profile_arn", strings.TrimSpace(account.GetCredential("profile_arn")) != ""),
)
if s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, body) {
parsedForEmulation := *parsed
parsedForEmulation.Body = body
return s.handleWebSearchEmulation(ctx, c, account, &parsedForEmulation)
}
if parsed.Stream {
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: failoverErr.StatusCode,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(err.Error()),
})
return nil, failoverErr
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
}
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel, false)
if err != nil {
return nil, err
}
if streamResult.usage == nil {
streamResult.usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *streamResult.usage,
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: streamResult.firstTokenMs,
ClientDisconnect: streamResult.clientDisconnect,
}, nil
}
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
if tokenType != "oauth" {
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
}
if isOnlyWebSearchToolInBody(body) {
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
switch {
case errors.Is(webSearchErr, errKiroWebSearchFallback):
case webSearchErr == nil:
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
c.Header("Content-Type", "application/json")
if webSearchResult.RequestID != "" {
c.Header("x-request-id", webSearchResult.RequestID)
}
c.Data(http.StatusOK, "application/json", webSearchResult.ResponseBody)
return &ForwardResult{
RequestID: webSearchResult.RequestID,
Usage: webSearchResult.Usage,
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
default:
var httpErr *kiroWebSearchHTTPError
if errors.As(webSearchErr, &httpErr) && httpErr.Response != nil {
return nil, s.handleKiroHTTPError(ctx, httpErr.Response, c, account, mappedModel, body)
}
var failoverErr *UpstreamFailoverError
if errors.As(webSearchErr, &failoverErr) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: failoverErr.StatusCode,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(webSearchErr.Error()),
})
return nil, failoverErr
}
safeErr := sanitizeUpstreamErrorMessage(webSearchErr.Error())
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
}
}
inputTokens := estimateKiroInputTokens(body)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: failoverErr.StatusCode,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(err.Error()),
})
return nil, failoverErr
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
}
parseResult, err := kiropkg.ParseNonStreamingEventStreamWithContext(resp.Body, mappedModel, requestCtx)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Failed to parse Kiro upstream response",
},
})
return nil, err
}
c.Header("Content-Type", "application/json")
if requestID := resp.Header.Get("x-request-id"); requestID != "" {
c.Header("x-request-id", requestID)
}
c.Data(http.StatusOK, "application/json", parseResult.ResponseBody)
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, 0, err
}
if tokenType != "oauth" {
return nil, 0, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
}
inputTokens := estimateKiroInputTokens(anthropicBody)
if isOnlyWebSearchToolInBody(anthropicBody) {
pr, pw := io.Pipe()
headers := make(http.Header)
headers.Set("Content-Type", "text/event-stream")
go func() {
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
}
_ = pw.Close()
}()
return &http.Response{
StatusCode: http.StatusOK,
Header: headers,
Body: pr,
}, inputTokens, nil
}
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
return nil, inputTokens, err
}
return nil, inputTokens, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return resp, inputTokens, nil
}
pr, pw := io.Pipe()
wrappedHeaders := resp.Header.Clone()
wrappedHeaders.Set("Content-Type", "text/event-stream")
if requestID := buildKiroRequestID(resp); requestID != "" {
wrappedHeaders.Set("x-request-id", requestID)
}
go func() {
defer func() { _ = resp.Body.Close() }()
_, streamErr := kiropkg.StreamEventStreamAsAnthropicWithContext(ctx, resp.Body, pw, mappedModel, inputTokens, requestCtx)
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
}
_ = pw.Close()
}()
return &http.Response{
StatusCode: resp.StatusCode,
Header: wrappedHeaders,
Body: pr,
}, inputTokens, nil
}
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
var requestCtx kiropkg.KiroRequestContext
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
return nil, requestCtx, failoverErr
}
return nil, requestCtx, err
}
modelID := kiropkg.MapModel(mappedModel)
currentToken := token
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
if err != nil {
return nil, requestCtx, err
}
payload := buildResult.Payload
requestCtx = buildResult.Context
endpoints := buildKiroEndpoints(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
accountKey := buildKiroAccountKey(account)
maxRetries := 2
for idx, endpoint := range endpoints {
for attempt := 0; attempt <= maxRetries; attempt++ {
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
if err != nil {
return nil, requestCtx, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
if attempt < maxRetries {
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
continue
}
return nil, requestCtx, err
}
if resp.StatusCode == http.StatusTooManyRequests {
cooldown, err := s.markKiro429(ctx, accountKey)
if err != nil {
_ = resp.Body.Close()
return nil, requestCtx, err
}
if idx+1 < len(endpoints) {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
break
}
resp.Header.Set("x-kiro-cooldown", cooldown.String())
return resp, requestCtx, nil
}
if resp.StatusCode == http.StatusRequestTimeout || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
if attempt < maxRetries {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
continue
}
if idx+1 < len(endpoints) {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
break
}
return resp, requestCtx, nil
}
if resp.StatusCode == http.StatusPaymentRequired {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, requestCtx, readErr
}
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
if classification.Category == kiroErrorMonthlyRequest {
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
}
return nil, requestCtx, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, requestCtx, readErr
}
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
return nil, requestCtx, err
}
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
if err != nil {
return nil, requestCtx, err
}
payload = buildResult.Payload
requestCtx = buildResult.Context
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
continue
}
if refreshErr != nil && isNonRetryableRefreshError(refreshErr) {
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
}
if classifyKiroHTTPError(resp.StatusCode, string(respBody)).Category == kiroErrorAuthError {
s.markKiroAuthTemporarilyUnavailable(ctx, account, resp.StatusCode, string(respBody))
}
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
if resp.StatusCode == http.StatusBadRequest {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, requestCtx, readErr
}
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
logKiroBadRequestClassification(classification, account, mappedModel, resp.Header, respBody)
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, requestCtx, err
}
}
return resp, requestCtx, nil
}
}
return nil, requestCtx, fmt.Errorf("kiro upstream endpoints exhausted")
}
func buildKiroEndpoints(account *Account) []kiroEndpointConfig {
region := kiroAPIRegion(account)
return []kiroEndpointConfig{
{
URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
Name: "AmazonQ",
},
}
}
func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) ([]byte, error) {
result, err := buildKiroPayloadForAccountWithRepo(ctx, nil, account, anthropicBody, modelID, token, requestModel, headers)
if err != nil {
return nil, err
}
return result.Payload, nil
}
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
profileArn := resolveKiroPayloadProfileArn(account)
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
}
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
requestModel = strings.TrimSpace(requestModel)
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
return anthropicBody
}
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
return anthropicBody
}
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
return next
}
return anthropicBody
}
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
until := time.Now().Add(10 * time.Minute)
reason := fmt.Sprintf("kiro auth failure (%d): %s", statusCode, strings.TrimSpace(body))
_ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason)
}
func (s *GatewayService) markKiroMonthlyRequestCountRateLimited(ctx context.Context, account *Account, body string) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
resetAt := nextKiroMonthlyResetUTC(time.Now())
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
logger.L().Warn("kiro monthly request count rate-limit failed",
zap.Int64("account_id", account.ID),
zap.Time("reset_at", resetAt),
zap.Error(err),
)
return
}
reason := "kiro monthly request count exhausted (402): MONTHLY_REQUEST_COUNT"
if trimmed := strings.TrimSpace(body); trimmed != "" {
reason = fmt.Sprintf("%s body=%s", reason, truncateForLog([]byte(trimmed), 512))
}
logger.L().Warn("kiro monthly request count rate-limited",
zap.Int64("account_id", account.ID),
zap.Time("reset_at", resetAt),
zap.String("reason", reason),
)
}
func nextKiroMonthlyResetUTC(now time.Time) time.Time {
utc := now.UTC()
year, month, _ := utc.Date()
return time.Date(year, month+1, 1, 0, 0, 0, 0, time.UTC)
}
func resetHTTPResponseBody(resp *http.Response, body []byte) {
if resp == nil {
return
}
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))
}
func estimateKiroInputTokens(body []byte) int {
if len(body) == 0 {
return 0
}
if tokens := gjson.GetBytes(body, "metadata.input_tokens").Int(); tokens > 0 {
return int(tokens)
}
tokens := len(body) / 4
if tokens == 0 {
return 1
}
return tokens
}
func kiroUsageToClaude(usage kiropkg.Usage, fallbackInput int) ClaudeUsage {
inputTokens := usage.InputTokens
if inputTokens == 0 {
inputTokens = fallbackInput
}
return ClaudeUsage{
InputTokens: inputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
func (s *GatewayService) markKiroInvalidModelRateLimited(ctx context.Context, account *Account, mappedModel string) {
if s == nil || s.accountRepo == nil || account == nil || account.Type != AccountTypeOAuth {
return
}
resetAt := time.Now().Add(kiroInvalidModelTempUnschedDuration)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
logger.L().Warn("kiro invalid model rate-limit failed",
zap.Int64("account_id", account.ID),
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
zap.Time("reset_at", resetAt),
zap.Error(err),
)
return
}
logger.L().Warn("kiro invalid model rate-limited",
zap.Int64("account_id", account.ID),
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
zap.Time("reset_at", resetAt),
)
}
func (s *GatewayService) handleKiroHTTPError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, mappedModel string, requestBody []byte) error {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg == "" {
upstreamMsg = strings.TrimSpace(string(respBody))
}
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
if resp.StatusCode == http.StatusBadRequest {
logKiroBadRequestClassification(classification, account, "", resp.Header, respBody)
}
if classification.Category == kiroErrorMonthlyRequest {
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
}
if classification.Category == kiroErrorBadRequestInvalidModel && account != nil && account.Type == AccountTypeOAuth {
s.markKiroInvalidModelRateLimited(ctx, account, mappedModel)
event := s.buildKiroInvalidModelUpstreamEvent(account, resp, upstreamMsg, mappedModel, requestBody, c)
appendOpsUpstreamError(c, event)
return &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
if resp.StatusCode == http.StatusPaymentRequired || s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
})
c.JSON(mapUpstreamStatusCode(resp.StatusCode), gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": coalesceKiroErrorMessage(resp.StatusCode, upstreamMsg),
},
})
return fmt.Errorf("kiro upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
func (s *GatewayService) buildKiroInvalidModelUpstreamEvent(account *Account, resp *http.Response, upstreamMsg, mappedModel string, requestBody []byte, c *gin.Context) OpsUpstreamErrorEvent {
_ = s
requestedModel := strings.TrimSpace(gjson.GetBytes(requestBody, "model").String())
hasTools := gjson.GetBytes(requestBody, "tools").Exists()
hasAdaptiveThinking := strings.EqualFold(strings.TrimSpace(gjson.GetBytes(requestBody, "thinking.type").String()), "adaptive")
hasContext1MBeta := false
if c != nil {
hasContext1MBeta = strings.Contains(c.GetHeader("Anthropic-Beta"), "context-1m")
}
return OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
RequestedModel: requestedModel,
MappedModel: strings.TrimSpace(mappedModel),
KiroModelID: kiropkg.MapModel(mappedModel),
HasTools: hasTools,
HasAdaptiveThinking: hasAdaptiveThinking,
HasContext1MBeta: hasContext1MBeta,
}
}
func logKiroBadRequestClassification(classification kiroErrorClassification, account *Account, model string, headers http.Header, body []byte) {
if classification.StatusCode != http.StatusBadRequest {
return
}
var accountID int64
if account != nil {
accountID = account.ID
}
logger.L().Warn("kiro upstream bad request classified",
zap.String("category", classification.Category),
zap.Int("status", classification.StatusCode),
zap.Int64("account_id", accountID),
zap.String("model", strings.TrimSpace(model)),
zap.String("request_id", headers.Get("x-request-id")),
zap.String("body_excerpt", truncateForLog(body, 512)),
)
}
func coalesceKiroErrorMessage(statusCode int, upstreamMsg string) string {
if upstreamMsg != "" {
return upstreamMsg
}
switch statusCode {
case http.StatusTooManyRequests:
return "Rate limit exceeded"
case http.StatusForbidden:
return "Access denied"
case http.StatusUnauthorized:
return "Authentication failed"
default:
return "Upstream request failed"
}
}
@@ -0,0 +1,99 @@
package service
import (
"context"
"errors"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
)
var errKiroCooldownStoreUnavailable = errors.New("kiro cooldown store unavailable")
type KiroCooldownStore interface {
ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error)
MarkSuccess(ctx context.Context, tokenKey string) error
Mark429(ctx context.Context, tokenKey string) (time.Duration, error)
MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error)
GetState(ctx context.Context, tokenKey string) (*kirocooldown.State, error)
ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error)
}
func asKiroCooldownFailoverError(err error) *UpstreamFailoverError {
if err == nil {
return nil
}
var cooldownErr *kirocooldown.Error
if !errors.As(err, &cooldownErr) {
return nil
}
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: []byte(cooldownErr.Error()),
}
}
func (s *GatewayService) checkAndWaitKiroCooldown(ctx context.Context, tokenKey string) error {
if s == nil || s.kiroCooldownStore == nil {
return errKiroCooldownStoreUnavailable
}
waitFor, err := s.kiroCooldownStore.ReserveRequest(ctx, tokenKey)
if err != nil {
return err
}
if waitFor <= 0 {
return nil
}
timer := time.NewTimer(waitFor)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ctx.Err()
case <-timer.C:
return nil
}
}
func (s *GatewayService) markKiroSuccess(ctx context.Context, tokenKey string) error {
if s == nil || s.kiroCooldownStore == nil {
return errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.MarkSuccess(ctx, tokenKey)
}
func (s *GatewayService) markKiro429(ctx context.Context, tokenKey string) (time.Duration, error) {
if s == nil || s.kiroCooldownStore == nil {
return 0, errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.Mark429(ctx, tokenKey)
}
func (s *GatewayService) markKiroSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
if s == nil || s.kiroCooldownStore == nil {
return 0, errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.MarkSuspended(ctx, tokenKey)
}
func (s *GatewayService) getKiroCooldownState(ctx context.Context, tokenKey string) (*kirocooldown.State, error) {
if s == nil || s.kiroCooldownStore == nil {
return nil, errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.GetState(ctx, tokenKey)
}
func kiroRuntimeStateSnapshot(state *kirocooldown.State) (string, string, *time.Time) {
if state == nil || !state.Active {
return "", "", nil
}
resetAt := state.CooldownUntil
switch state.Reason {
case kirocooldown.CooldownReasonSuspended:
return "suspended", state.Reason, &resetAt
default:
return "cooldown", state.Reason, &resetAt
}
}
@@ -0,0 +1,192 @@
//go:build integration
package service
import (
"context"
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const kiroCooldownRedisImageTag = "redis:8.4-alpine"
func TestRedisKiroCooldownStoreSharesCooldownAcrossInstances(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
storeA := kirocooldown.NewStore(rdb)
storeB := kirocooldown.NewStore(rdb)
cooldown, err := storeA.Mark429(ctx, "token-shared")
require.NoError(t, err)
require.Equal(t, time.Minute, cooldown)
wait, err := storeB.ReserveRequest(ctx, "token-shared")
require.Zero(t, wait)
require.Error(t, err)
require.Contains(t, err.Error(), kirocooldown.CooldownReason429)
require.NoError(t, storeB.MarkSuccess(ctx, "token-shared"))
wait, err = storeA.ReserveRequest(ctx, "token-shared")
require.NoError(t, err)
require.GreaterOrEqual(t, wait, 0*time.Second)
}
func TestRedisKiroCooldownStoreSharesReservationAcrossInstances(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
storeA := kirocooldown.NewStore(rdb)
storeB := kirocooldown.NewStore(rdb)
wait, err := storeA.ReserveRequest(ctx, "token-rate")
require.NoError(t, err)
require.Zero(t, wait)
wait, err = storeB.ReserveRequest(ctx, "token-rate")
require.NoError(t, err)
require.Greater(t, wait, 0*time.Millisecond)
require.LessOrEqual(t, wait, kirocooldown.MaxRequestInterval)
}
func TestRedisKiroCooldownStoreSharesSuspendedStateAcrossInstances(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
storeA := kirocooldown.NewStore(rdb)
storeB := kirocooldown.NewStore(rdb)
cooldown, err := storeA.MarkSuspended(ctx, "token-suspended")
require.NoError(t, err)
require.Equal(t, kirocooldown.LongCooldown, cooldown)
wait, err := storeB.ReserveRequest(ctx, "token-suspended")
require.Zero(t, wait)
require.Error(t, err)
require.Contains(t, err.Error(), kirocooldown.CooldownReasonSuspended)
}
func TestRedisKiroCooldownStoreSuspendedResetsFailCount(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
store := kirocooldown.NewStore(rdb)
_, err := store.Mark429(ctx, "token-reset")
require.NoError(t, err)
_, err = store.Mark429(ctx, "token-reset")
require.NoError(t, err)
cooldown, err := store.MarkSuspended(ctx, "token-reset")
require.NoError(t, err)
require.Equal(t, kirocooldown.LongCooldown, cooldown)
cooldown, err = store.Mark429(ctx, "token-reset")
require.NoError(t, err)
require.Equal(t, time.Minute, cooldown)
}
func TestRedisKiroCooldownStoreReserveDifferentTokenIgnoresOldCooldown(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
store := kirocooldown.NewStore(rdb)
_, err := store.Mark429(ctx, "token-old")
require.NoError(t, err)
wait, err := store.ReserveRequest(ctx, "token-new")
require.NoError(t, err)
require.Zero(t, wait)
}
func TestRedisKiroCooldownStoreUsesExpectedTTLs(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
store := kirocooldown.NewStore(rdb)
_, err := store.ReserveRequest(ctx, "token-ttl-active")
require.NoError(t, err)
activeTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-active")).Result()
require.NoError(t, err)
require.Greater(t, activeTTL, 0*time.Second)
require.LessOrEqual(t, activeTTL, kirocooldown.ActiveTTL())
_, err = store.MarkSuspended(ctx, "token-ttl-state")
require.NoError(t, err)
stateTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-state")).Result()
require.NoError(t, err)
require.Greater(t, stateTTL, 24*time.Hour)
require.LessOrEqual(t, stateTTL, kirocooldown.StateTTL())
}
func startKiroCooldownRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
ensureKiroCooldownDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, kiroCooldownRedisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
host, err := redisContainer.Host(ctx)
require.NoError(t, err)
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", host, port.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}
func ensureKiroCooldownDockerAvailable(t *testing.T) {
t.Helper()
if kiroCooldownDockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过依赖 testcontainers 的 Kiro cooldown 集成测试")
}
func kiroCooldownDockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func kiroCooldownUserHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}
@@ -0,0 +1,583 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type stubKiroCooldownStore struct {
reserveWait time.Duration
reserveErr error
successErr error
mark429TTL time.Duration
mark429Err error
suspendedTTL time.Duration
suspendedErr error
state *kirocooldown.State
stateErr error
clearCalled bool
clearKeys []string
clearResult bool
clearErr error
}
type recordingKiroTempUnschedRepo struct {
mockAccountRepoForGemini
called bool
id int64
until time.Time
reason string
rateCalled bool
rateID int64
rateLimitReset time.Time
rateLimitedCall int
}
func (r *recordingKiroTempUnschedRepo) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
r.called = true
r.id = id
r.until = until
r.reason = reason
return nil
}
func (r *recordingKiroTempUnschedRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
r.rateCalled = true
r.rateID = id
r.rateLimitReset = resetAt
r.rateLimitedCall++
return nil
}
type recordingKiroErrorRepo struct {
recordingKiroTempUnschedRepo
setErrorCalls int
errorID int64
errorMsg string
}
func (r *recordingKiroErrorRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.errorID = id
r.errorMsg = errorMsg
return nil
}
func (s *stubKiroCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
return s.reserveWait, s.reserveErr
}
func (s *stubKiroCooldownStore) MarkSuccess(context.Context, string) error {
return s.successErr
}
func (s *stubKiroCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
return s.mark429TTL, s.mark429Err
}
func (s *stubKiroCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
return s.suspendedTTL, s.suspendedErr
}
func (s *stubKiroCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
if s.clearCalled && s.clearResult {
return nil, nil
}
return s.state, s.stateErr
}
func (s *stubKiroCooldownStore) ClearEarliestTransientCooldown(_ context.Context, tokenKeys []string) (bool, error) {
s.clearCalled = true
s.clearKeys = append([]string(nil), tokenKeys...)
return s.clearResult, s.clearErr
}
func TestCalculateKiro429Cooldown(t *testing.T) {
require.Equal(t, time.Minute, kirocooldown.Calculate429Cooldown(0))
require.Equal(t, 2*time.Minute, kirocooldown.Calculate429Cooldown(1))
require.Equal(t, 4*time.Minute, kirocooldown.Calculate429Cooldown(2))
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(3))
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(10))
}
func TestGatewayServiceCheckAndWaitKiroCooldownReturnsNilWithoutWait(t *testing.T) {
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{},
}
require.NoError(t, svc.checkAndWaitKiroCooldown(context.Background(), "token1"))
}
func TestGatewayServiceCheckAndWaitKiroCooldownPropagatesReserveError(t *testing.T) {
expected := errors.New("redis unavailable")
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{reserveErr: expected},
}
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
require.ErrorIs(t, err, expected)
}
func TestGatewayServiceCheckAndWaitKiroCooldownRequiresStore(t *testing.T) {
svc := &GatewayService{}
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
require.ErrorIs(t, err, errKiroCooldownStoreUnavailable)
}
func TestGatewayServiceCheckAndWaitKiroCooldownWaitsAndHonorsContext(t *testing.T) {
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{reserveWait: 200 * time.Millisecond},
}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
err := svc.checkAndWaitKiroCooldown(ctx, "token1")
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestAsKiroCooldownFailoverError(t *testing.T) {
err := kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429)
var cooldownErr *kirocooldown.Error
require.ErrorAs(t, err, &cooldownErr)
failoverErr := asKiroCooldownFailoverError(err)
require.NotNil(t, failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
require.False(t, failoverErr.RetryableOnSameAccount)
}
func TestAsKiroCooldownFailoverErrorIgnoresNonCooldownErrors(t *testing.T) {
require.Nil(t, asKiroCooldownFailoverError(errors.New("redis unavailable")))
}
func TestGatewayServiceTryRecoverKiroCooldownPoolClearsOnlyTransientCooldown(t *testing.T) {
store := &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(time.Minute),
Remaining: time.Minute,
},
clearResult: true,
}
svc := &GatewayService{kiroCooldownStore: store}
accounts := []Account{
{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
},
}
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
require.True(t, recovered)
require.True(t, store.clearCalled)
require.Len(t, store.clearKeys, 1)
require.Equal(t, buildKiroAccountKey(&accounts[0]), store.clearKeys[0])
}
func TestGatewayServiceTryRecoverKiroCooldownPoolSkipsSuspended(t *testing.T) {
store := &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReasonSuspended,
CooldownUntil: time.Now().Add(time.Hour),
Remaining: time.Hour,
},
clearResult: true,
}
svc := &GatewayService{kiroCooldownStore: store}
accounts := []Account{
{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
},
}
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
require.False(t, recovered)
require.False(t, store.clearCalled)
}
func TestSelectAccountWithLoadAwarenessRecoversKiroCooldownPool(t *testing.T) {
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
account := Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
store := &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(time.Minute),
Remaining: time.Minute,
},
clearResult: true,
}
svc := &GatewayService{
accountRepo: &mockAccountRepoForGemini{accounts: []Account{account}},
concurrencyService: NewConcurrencyService(&mockConcurrencyCache{}),
cfg: cfg,
kiroCooldownStore: store,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformKiro)
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "", nil, "", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, account.ID, result.Account.ID)
require.True(t, store.clearCalled)
require.Equal(t, []string{buildKiroAccountKey(&account)}, store.clearKeys)
}
func TestClassifyKiroHTTPErrorMonthlyRequestCount(t *testing.T) {
tests := []string{
`{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
`{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}`,
`API returned 402: {"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
}
for _, body := range tests {
classification := classifyKiroHTTPError(http.StatusPaymentRequired, body)
require.Equal(t, kiroErrorMonthlyRequest, classification.Category)
}
}
func TestClassifyKiroHTTPErrorPlain402IsTransient(t *testing.T) {
classification := classifyKiroHTTPError(http.StatusPaymentRequired, `{"message":"payment required"}`)
require.Equal(t, kiroErrorUpstreamTransient, classification.Category)
}
func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{
reserveErr: kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429),
},
}
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
require.False(t, failoverErr.RetryableOnSameAccount)
}
func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
},
}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
payload, err := createTestPayload("claude-opus-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
require.Len(t, upstream.requests, 1)
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
require.NoError(t, readErr)
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
}
func TestHandleKiroHTTPErrorOAuthInvalidModelRateLimitsAndFailovers(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Anthropic-Beta", "context-1m-2025-08-07")
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Name: "kiro-oauth",
}
repo := &recordingKiroTempUnschedRepo{}
svc := &GatewayService{accountRepo: repo}
requestBody := []byte(`{"model":"claude-opus-4-7","tools":[{"name":"search"}],"thinking":{"type":"adaptive"}}`)
resp := newJSONResponse(http.StatusBadRequest, `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`)
resp.Header.Set("x-request-id", "req-invalid-model")
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", requestBody)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "Invalid model")
require.False(t, failoverErr.RetryableOnSameAccount)
require.False(t, repo.called)
require.True(t, repo.rateCalled)
require.Equal(t, account.ID, repo.rateID)
require.WithinDuration(t, time.Now().Add(kiroInvalidModelTempUnschedDuration), repo.rateLimitReset, 5*time.Second)
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, PlatformKiro, events[0].Platform)
require.Equal(t, account.ID, events[0].AccountID)
require.Equal(t, account.Name, events[0].AccountName)
require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
require.Equal(t, "req-invalid-model", events[0].UpstreamRequestID)
require.Equal(t, "failover", events[0].Kind)
require.Equal(t, "claude-opus-4-7", events[0].RequestedModel)
require.Equal(t, "claude-opus-4.6", events[0].MappedModel)
require.Equal(t, "claude-opus-4.6", events[0].KiroModelID)
require.True(t, events[0].HasTools)
require.True(t, events[0].HasAdaptiveThinking)
require.True(t, events[0].HasContext1MBeta)
}
func TestHandleKiroHTTPErrorAPIKeyInvalidModelDoesNotFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
account := &Account{
ID: 43,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
}
repo := &recordingKiroTempUnschedRepo{}
svc := &GatewayService{accountRepo: repo}
resp := newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model. Please select a different model to continue."}`)
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", []byte(`{"model":"claude-opus-4-7"}`))
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.NotErrorAs(t, err, &failoverErr)
require.False(t, repo.called)
require.False(t, repo.rateCalled)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestNextKiroMonthlyResetUTC(t *testing.T) {
tests := []struct {
name string
now time.Time
want time.Time
}{
{
name: "middle of month",
now: time.Date(2026, time.April, 27, 10, 30, 45, 123, time.FixedZone("CST", 8*3600)),
want: time.Date(2026, time.May, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "december rolls year",
now: time.Date(2026, time.December, 31, 23, 59, 59, 0, time.UTC),
want: time.Date(2027, time.January, 1, 0, 0, 0, 0, time.UTC),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, nextKiroMonthlyResetUTC(tt.now))
})
}
}
func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailovers(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
repo := &recordingKiroTempUnschedRepo{}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusPaymentRequired, `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`),
},
}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
payload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "MONTHLY_REQUEST_COUNT")
require.False(t, repo.called)
require.True(t, repo.rateCalled)
require.Equal(t, account.ID, repo.rateID)
require.Equal(t, nextKiroMonthlyResetUTC(time.Now()), repo.rateLimitReset)
}
func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
repo := &recordingKiroTempUnschedRepo{}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusPaymentRequired, `{"message":"payment required"}`),
},
}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
payload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
require.False(t, repo.called)
require.False(t, repo.rateCalled)
}
func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedule(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"refresh_token": "old-refresh",
},
}
repo := &recordingKiroErrorRepo{
recordingKiroTempUnschedRepo: recordingKiroTempUnschedRepo{
mockAccountRepoForGemini: mockAccountRepoForGemini{
accountsByID: map[int64]*Account{account.ID: account},
},
},
}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"message":"token expired"}`),
},
}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
kiroTokenProvider: provider,
}
payload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, account.ID, repo.errorID)
require.Contains(t, repo.errorMsg, "invalid_grant")
require.False(t, repo.called, "non-retryable refresh errors should not mark temporary unschedulable")
}
func TestGatewayServiceIsAccountSchedulableForSelectionSkipsActiveKiroCooldown(t *testing.T) {
now := time.Now().Add(2 * time.Minute)
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: now,
Remaining: 2 * time.Minute,
},
},
}
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
}
require.False(t, svc.isAccountSchedulableForSelection(account))
}
@@ -0,0 +1,221 @@
package service
import (
"context"
"errors"
"strconv"
"strings"
"time"
)
const (
kiroTokenRefreshSkew = 3 * time.Minute
kiroTokenCacheSkew = 5 * time.Minute
)
type KiroTokenCache = GeminiTokenCache
type kiroAccountTokenRefresher interface {
RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error)
BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any
}
type KiroTokenProvider struct {
accountRepo AccountRepository
tokenCache KiroTokenCache
kiroOAuthService kiroAccountTokenRefresher
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewKiroTokenProvider(
accountRepo AccountRepository,
tokenCache KiroTokenCache,
kiroOAuthService *KiroOAuthService,
) *KiroTokenProvider {
return &KiroTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
kiroOAuthService: kiroOAuthService,
refreshPolicy: GeminiProviderRefreshPolicy(),
}
}
func (p *KiroTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
func (p *KiroTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *KiroTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return "", errors.New("not a kiro oauth account")
}
cacheKey := KiroTokenCacheKey(account)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= kiroTokenRefreshSkew
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, kiroTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > kiroTokenCacheSkew:
ttl = until - kiroTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
func KiroTokenCacheKey(account *Account) string {
if account == nil {
return "kiro:account:0"
}
if clientIDHash := strings.TrimSpace(account.GetCredential("client_id_hash")); clientIDHash != "" {
return "kiro:" + clientIDHash
}
if clientID := strings.TrimSpace(account.GetCredential("client_id")); clientID != "" {
return "kiro:client:" + clientID
}
return "kiro:account:" + strconv.FormatInt(account.ID, 10)
}
func (p *KiroTokenProvider) ForceRefreshAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return "", errors.New("not a kiro oauth account")
}
if p.kiroOAuthService == nil {
return "", errors.New("kiro oauth service is nil")
}
cacheKey := KiroTokenCacheKey(account)
lockHeld := false
if p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
lockHeld = true
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
if p.accountRepo != nil {
if latestAccount, err := p.accountRepo.GetByID(ctx, account.ID); err == nil && latestAccount != nil {
account = latestAccount
}
}
tokenInfo, err := p.kiroOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
if !lockHeld {
if latestAccount, stale := CheckTokenVersion(ctx, account, p.accountRepo); stale && latestAccount != nil {
account = latestAccount
if accessToken := strings.TrimSpace(account.GetCredential("access_token")); accessToken != "" {
_ = p.cacheAccessToken(ctx, account, accessToken)
return accessToken, nil
}
}
}
if isNonRetryableRefreshError(err) && p.accountRepo != nil {
errorMsg := "Token refresh failed (non-retryable): " + err.Error()
_ = p.accountRepo.SetError(ctx, account.ID, errorMsg)
}
return "", err
}
newCredentials := MergeCredentials(account.Credentials, p.kiroOAuthService.BuildAccountCredentials(tokenInfo))
newCredentials["_token_version"] = time.Now().UnixMilli()
if err := persistAccountCredentials(ctx, p.accountRepo, account, newCredentials); err != nil {
return "", err
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken == "" {
accessToken = strings.TrimSpace(tokenInfo.AccessToken)
}
if accessToken == "" {
return "", errors.New("access_token not found after kiro refresh")
}
if err := p.cacheAccessToken(ctx, account, accessToken); err != nil {
return "", err
}
return accessToken, nil
}
func (p *KiroTokenProvider) cacheAccessToken(ctx context.Context, account *Account, accessToken string) error {
if p.tokenCache == nil || account == nil || strings.TrimSpace(accessToken) == "" {
return nil
}
ttl := 30 * time.Minute
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > kiroTokenCacheSkew:
ttl = until - kiroTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
return p.tokenCache.SetAccessToken(ctx, KiroTokenCacheKey(account), accessToken, ttl)
}
@@ -0,0 +1,112 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type kiroTokenProviderRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setErrorID int64
setErrorMsg string
}
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
type kiroTokenProviderSequenceRepo struct {
kiroTokenProviderRepo
accounts []*Account
reads int
}
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
if len(r.accounts) == 0 {
return nil, errors.New("account not found")
}
idx := r.reads
if idx >= len(r.accounts) {
idx = len(r.accounts) - 1
}
r.reads++
return r.accounts[idx], nil
}
type stubKiroAccountTokenRefresher struct {
tokenInfo *KiroTokenInfo
err error
}
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
if s.err != nil {
return nil, s.err
}
return s.tokenInfo, nil
}
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
if tokenInfo == nil {
return nil
}
return map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": tokenInfo.ExpiresAt,
}
}
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "old-refresh"},
}
repo := &kiroTokenProviderRepo{
mockAccountRepoForGemini: mockAccountRepoForGemini{
accountsByID: map[int64]*Account{account.ID: account},
},
}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
require.Error(t, err)
require.Empty(t, token)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
require.Contains(t, repo.setErrorMsg, "invalid_grant")
}
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
usedAccount := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "old-refresh"},
}
latestAccount := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
}
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
require.NoError(t, err)
require.Equal(t, "fresh-access", token)
require.Equal(t, 0, repo.setErrorCalls)
}
@@ -0,0 +1,47 @@
package service
import (
"context"
"time"
)
const kiroRefreshWindow = 15 * time.Minute
type KiroTokenRefresher struct {
kiroOAuthService *KiroOAuthService
}
func NewKiroTokenRefresher(kiroOAuthService *KiroOAuthService) *KiroTokenRefresher {
return &KiroTokenRefresher{
kiroOAuthService: kiroOAuthService,
}
}
func (r *KiroTokenRefresher) CacheKey(account *Account) string {
return KiroTokenCacheKey(account)
}
func (r *KiroTokenRefresher) CanRefresh(account *Account) bool {
return account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth
}
func (r *KiroTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
return time.Until(*expiresAt) <= kiroRefreshWindow
}
func (r *KiroTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.kiroOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
newCredentials := r.kiroOAuthService.BuildAccountCredentials(tokenInfo)
return MergeCredentials(account.Credentials, newCredentials), nil
}
@@ -0,0 +1,608 @@
package service
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/google/uuid"
)
const (
kiroUsageOrigin = "AI_EDITOR"
kiroUsageResourceType = "AGENTIC_REQUEST"
kiroDefaultRegion = "us-east-1"
)
var resolveKiroRuntimeEndpoint = kiroRuntimeEndpoint
type kiroUsageLimitsResponse struct {
NextDateReset any `json:"nextDateReset"`
OverageConfiguration kiroOverageConfiguration `json:"overageConfiguration"`
SubscriptionInfo kiroSubscriptionInfo `json:"subscriptionInfo"`
UsageBreakdownList []kiroUsageBreakdown `json:"usageBreakdownList"`
}
type kiroOverageConfiguration struct {
OverageStatus string `json:"overageStatus"`
}
type kiroSubscriptionInfo struct {
SubscriptionTitle string `json:"subscriptionTitle"`
Type string `json:"type"`
}
type kiroUsageBreakdown struct {
Currency string `json:"currency"`
CurrentOverages *float64 `json:"currentOverages"`
CurrentOveragesWithPrecision *float64 `json:"currentOveragesWithPrecision"`
CurrentUsage *float64 `json:"currentUsage"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
DisplayName string `json:"displayName"`
DisplayNamePlural string `json:"displayNamePlural"`
FreeTrialInfo *kiroFreeTrialInfo `json:"freeTrialInfo"`
NextDateReset any `json:"nextDateReset"`
OverageCharges *float64 `json:"overageCharges"`
ResourceType string `json:"resourceType"`
UsageLimit *float64 `json:"usageLimit"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
}
type kiroFreeTrialInfo struct {
CurrentUsage *float64 `json:"currentUsage"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
FreeTrialExpiry any `json:"freeTrialExpiry"`
FreeTrialStatus string `json:"freeTrialStatus"`
UsageLimit *float64 `json:"usageLimit"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
}
type kiroUsageHTTPError struct {
StatusCode int
Body string
}
func (e *kiroUsageHTTPError) Error() string {
if e == nil {
return "kiro usage request failed"
}
if strings.TrimSpace(e.Body) == "" {
return fmt.Sprintf("kiro usage request failed (status %d)", e.StatusCode)
}
return fmt.Sprintf("kiro usage request failed (status %d): %s", e.StatusCode, e.Body)
}
func (s *AccountUsageService) getKiroUsage(ctx context.Context, account *Account, source string, forceRefresh bool) (*UsageInfo, error) {
now := time.Now()
if account == nil {
return &UsageInfo{
Source: source,
UpdatedAt: &now,
Error: "account is nil",
ErrorCode: errorCodeNetworkError,
}, nil
}
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return &UsageInfo{
Source: source,
UpdatedAt: &now,
}, nil
}
cached, hasCached := s.getCachedKiroUsage(account.ID)
if hasCached && (cached.ErrorCode != "" || cached.Error != "") {
cached.Source = source
s.attachKiroRuntimeState(ctx, account, cached)
return cached, nil
}
if !forceRefresh && hasCached {
cached.Source = source
s.attachKiroRuntimeState(ctx, account, cached)
return cached, nil
}
flightKey := fmt.Sprintf("kiro-usage:%d", account.ID)
result, fetchErr, _ := s.cache.kiroUsageFlight.Do(flightKey, func() (any, error) {
if !forceRefresh {
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
return usage, nil
}
}
usage, err := s.fetchAndCacheKiroUsage(ctx, account, source)
if err != nil {
return nil, err
}
return usage, nil
})
if fetchErr == nil {
if usage, ok := result.(*UsageInfo); ok && usage != nil {
usage.Source = source
s.attachKiroRuntimeState(ctx, account, usage)
if source == "active" {
s.tryClearRecoverableAccountError(ctx, account)
}
return usage, nil
}
}
degraded := buildKiroDegradedUsage(fetchErr)
degraded.Source = source
if hasCached {
cached.Error = degraded.Error
cached.ErrorCode = degraded.ErrorCode
cached.NeedsReauth = degraded.NeedsReauth
cached.KiroQuotaState = degraded.KiroQuotaState
cached.KiroQuotaReason = degraded.KiroQuotaReason
cached.KiroQuotaResetAt = degraded.KiroQuotaResetAt
cached.Source = source
s.attachKiroRuntimeState(ctx, account, cached)
return cached, nil
}
s.storeKiroUsageSnapshot(account.ID, degraded)
s.attachKiroRuntimeState(ctx, account, degraded)
return degraded, nil
}
func (s *AccountUsageService) fetchAndCacheKiroUsage(ctx context.Context, account *Account, source string) (*UsageInfo, error) {
token := strings.TrimSpace(account.GetCredential("access_token"))
if token == "" {
return nil, fmt.Errorf("no access token available")
}
region := kiroAPIRegion(account)
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
resp, err := s.requestKiroUsageLimits(ctx, account, region, profileArn, token)
if err != nil {
return nil, err
}
usage := mapKiroUsageToInfo(resp)
usage.Source = source
s.storeKiroUsageSnapshot(account.ID, usage)
return usage, nil
}
func (s *AccountUsageService) storeKiroUsageSnapshot(accountID int64, usage *UsageInfo) {
if s == nil || s.cache == nil || accountID <= 0 || usage == nil {
return
}
now := time.Now()
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
s.cache.kiroUsageCache.Store(accountID, &kiroUsageCache{
usageInfo: cloneUsageInfo(usage),
timestamp: now,
})
}
func (s *AccountUsageService) getCachedKiroUsage(accountID int64) (*UsageInfo, bool) {
if s == nil || s.cache == nil || accountID <= 0 {
return nil, false
}
cached, ok := s.cache.kiroUsageCache.Load(accountID)
if !ok {
return nil, false
}
cache, ok := cached.(*kiroUsageCache)
if !ok || cache == nil || cache.usageInfo == nil {
return nil, false
}
if time.Since(cache.timestamp) >= kiroCacheTTL(cache.usageInfo) {
return nil, false
}
return cloneUsageInfo(cache.usageInfo), true
}
func kiroCacheTTL(info *UsageInfo) time.Duration {
if info == nil {
return kiroUsageErrorTTL
}
if info.ErrorCode != "" || info.Error != "" {
return kiroUsageErrorTTL
}
return apiCacheTTL
}
func cloneUsageInfo(info *UsageInfo) *UsageInfo {
if info == nil {
return nil
}
cloned := *info
return &cloned
}
func (s *AccountUsageService) requestKiroUsageLimits(ctx context.Context, account *Account, region, profileArn, token string) (*kiroUsageLimitsResponse, error) {
endpoint := resolveKiroRuntimeEndpoint(region)
reqURL, err := url.Parse(endpoint + "/getUsageLimits")
if err != nil {
return nil, fmt.Errorf("build kiro usage url failed: %w", err)
}
q := reqURL.Query()
q.Set("origin", kiroUsageOrigin)
if profileArn = strings.TrimSpace(profileArn); profileArn != "" {
q.Set("profileArn", profileArn)
}
q.Set("resourceType", kiroUsageResourceType)
reqURL.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("create kiro usage request failed: %w", err)
}
s.applyKiroRuntimeHeaders(req, account, token)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: accountProxyURL(account),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
})
if err != nil {
return nil, fmt.Errorf("create kiro usage client failed: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("kiro usage request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("read kiro usage response failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
}
var parsed kiroUsageLimitsResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("decode kiro usage response failed: %w", err)
}
return &parsed, nil
}
func (s *AccountUsageService) applyKiroRuntimeHeaders(req *http.Request, account *Account, token string) {
if req == nil {
return
}
accountKey := buildKiroAccountKey(account)
machineID := buildKiroMachineID(account)
req.Header.Set("Accept", "*/*")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
if account == nil {
return
}
applyKiroConditionalHeaders(req, account)
}
func accountProxyURL(account *Account) string {
if account == nil || account.ProxyID == nil || account.Proxy == nil {
return ""
}
return account.Proxy.URL()
}
func kiroRuntimeEndpoint(region string) string {
region = strings.TrimSpace(region)
if region == "" {
region = kiroDefaultRegion
}
switch region {
case "us-east-1":
return "https://q.us-east-1.amazonaws.com"
case "eu-central-1":
return "https://q.eu-central-1.amazonaws.com"
case "us-gov-east-1":
return "https://q-fips.us-gov-east-1.amazonaws.com"
case "us-gov-west-1":
return "https://q-fips.us-gov-west-1.amazonaws.com"
case "us-iso-east-1":
return "https://q.us-iso-east-1.c2s.ic.gov"
case "us-isob-east-1":
return "https://q.us-isob-east-1.sc2s.sgov.gov"
case "us-isof-south-1":
return "https://q.us-isof-south-1.csp.hci.ic.gov"
case "us-isof-east-1":
return "https://q.us-isof-east-1.csp.hci.ic.gov"
default:
if strings.HasPrefix(region, "us-gov-") {
return "https://q-fips." + region + ".amazonaws.com"
}
return "https://q." + region + ".amazonaws.com"
}
}
func isLoopbackEndpoint(raw string) bool {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return false
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func mapKiroUsageToInfo(resp *kiroUsageLimitsResponse) *UsageInfo {
now := time.Now()
if resp == nil {
return &UsageInfo{UpdatedAt: &now}
}
info := &UsageInfo{
UpdatedAt: &now,
KiroSubscriptionName: strings.TrimSpace(resp.SubscriptionInfo.SubscriptionTitle),
KiroSubscriptionType: strings.TrimSpace(resp.SubscriptionInfo.Type),
KiroOveragesEnabled: strings.EqualFold(strings.TrimSpace(resp.OverageConfiguration.OverageStatus), "ENABLED"),
}
resetAt := parseKiroTimestamp(resp.NextDateReset)
if credit := selectKiroCreditBreakdown(resp.UsageBreakdownList); credit != nil {
if breakdownReset := parseKiroTimestamp(credit.NextDateReset); breakdownReset != nil {
resetAt = breakdownReset
}
info.KiroCredit = &KiroCreditProgress{
CurrentUsage: selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage),
UsageLimit: selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit),
PercentageUsed: percentageOrZero(selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage), selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit)),
}
info.KiroOverage = &KiroOverageInfo{
CurrentOverages: selectKiroFloat(credit.CurrentOveragesWithPrecision, credit.CurrentOverages),
OverageCharges: selectKiroFloat(credit.OverageCharges, nil),
CurrencyCode: strings.TrimSpace(credit.Currency),
CurrencySymbol: kiroCurrencySymbol(strings.TrimSpace(credit.Currency)),
}
if ft := credit.FreeTrialInfo; ft != nil && strings.EqualFold(strings.TrimSpace(ft.FreeTrialStatus), "ACTIVE") {
expiry := parseKiroTimestamp(ft.FreeTrialExpiry)
daysRemaining := 0
if expiry != nil {
daysRemaining = int(time.Until(*expiry).Hours() / 24)
if time.Until(*expiry)%(24*time.Hour) != 0 {
daysRemaining++
}
if daysRemaining < 0 {
daysRemaining = 0
}
}
current := selectKiroFloat(ft.CurrentUsageWithPrecision, ft.CurrentUsage)
limit := selectKiroFloat(ft.UsageLimitWithPrecision, ft.UsageLimit)
info.KiroBonus = &KiroCreditProgress{
CurrentUsage: current,
UsageLimit: limit,
PercentageUsed: percentageOrZero(current, limit),
DaysRemaining: daysRemaining,
ExpiryDate: expiry,
}
}
}
info.KiroResetAt = resetAt
setKiroQuotaStateFromUsage(info)
return info
}
func selectKiroCreditBreakdown(items []kiroUsageBreakdown) *kiroUsageBreakdown {
for i := range items {
if strings.EqualFold(strings.TrimSpace(items[i].ResourceType), "CREDIT") {
return &items[i]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func selectKiroFloat(preferred *float64, fallback *float64) float64 {
switch {
case preferred != nil:
return *preferred
case fallback != nil:
return *fallback
default:
return 0
}
}
func percentageOrZero(current, limit float64) float64 {
if limit <= 0 {
return 0
}
return current / limit * 100
}
func parseKiroTimestamp(raw any) *time.Time {
if raw == nil {
return nil
}
switch v := raw.(type) {
case string:
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, trimmed); err == nil {
return &parsed
}
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
return unixishToTime(i)
}
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
return unixishFloatToTime(f)
}
case float64:
return unixishFloatToTime(v)
case int64:
return unixishToTime(v)
case int:
return unixishToTime(int64(v))
case json.Number:
if i, err := v.Int64(); err == nil {
return unixishToTime(i)
}
if f, err := v.Float64(); err == nil {
return unixishFloatToTime(f)
}
}
return nil
}
func unixishFloatToTime(v float64) *time.Time {
if v <= 0 {
return nil
}
if v >= 1e12 {
t := time.UnixMilli(int64(v))
return &t
}
t := time.Unix(int64(v), 0)
return &t
}
func unixishToTime(v int64) *time.Time {
if v <= 0 {
return nil
}
if v >= 1e12 {
t := time.UnixMilli(v)
return &t
}
t := time.Unix(v, 0)
return &t
}
func kiroCurrencySymbol(code string) string {
switch strings.ToUpper(strings.TrimSpace(code)) {
case "USD":
return "$"
default:
return ""
}
}
func buildKiroDegradedUsage(err error) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
Error: "usage API error",
ErrorCode: errorCodeNetworkError,
}
if err == nil {
return info
}
info.Error = fmt.Sprintf("usage API error: %v", err)
classification := classifyKiroError(err)
switch classification.Category {
case kiroErrorAuthError:
info.ErrorCode = errorCodeUnauthenticated
info.NeedsReauth = true
case kiroErrorRateLimited:
info.ErrorCode = errorCodeRateLimited
case kiroErrorQuotaExhausted:
info.ErrorCode = errorCodeNetworkError
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
info.KiroQuotaReason = classification.Message
case kiroErrorOverageExhausted:
info.ErrorCode = errorCodeNetworkError
info.KiroQuotaState = kiroQuotaStateOverageExhausted
info.KiroQuotaReason = classification.Message
case kiroErrorSuspended, kiroErrorUsageForbidden, kiroErrorProfileError:
info.ErrorCode = errorCodeForbidden
default:
info.ErrorCode = errorCodeNetworkError
}
return info
}
func (s *AccountUsageService) attachKiroRuntimeState(ctx context.Context, account *Account, usage *UsageInfo) {
if s == nil || usage == nil || account == nil || account.Platform != PlatformKiro || s.kiroCooldownStore == nil {
return
}
usage.KiroRuntimeState = ""
usage.KiroRuntimeReason = ""
usage.KiroRuntimeResetAt = nil
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
if err != nil || state == nil {
return
}
usage.KiroRuntimeState, usage.KiroRuntimeReason, usage.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
}
func (s *AccountUsageService) EnrichAccountWithKiroRuntimeState(ctx context.Context, account *Account) {
if s == nil || account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return
}
account.KiroQuotaState = ""
account.KiroQuotaReason = ""
account.KiroQuotaResetAt = nil
account.KiroRuntimeState = ""
account.KiroRuntimeReason = ""
account.KiroRuntimeResetAt = nil
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
account.KiroQuotaState = usage.KiroQuotaState
account.KiroQuotaReason = usage.KiroQuotaReason
account.KiroQuotaResetAt = usage.KiroQuotaResetAt
}
if s.kiroCooldownStore == nil {
return
}
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
if err != nil || state == nil {
return
}
account.KiroRuntimeState, account.KiroRuntimeReason, account.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
}
func setKiroQuotaStateFromUsage(info *UsageInfo) {
if info == nil {
return
}
info.KiroQuotaState = ""
info.KiroQuotaReason = ""
info.KiroQuotaResetAt = nil
creditExhausted := false
if info.KiroCredit != nil && info.KiroCredit.UsageLimit > 0 {
creditExhausted = info.KiroCredit.CurrentUsage >= info.KiroCredit.UsageLimit
}
overageActive := info.KiroOverage != nil &&
(info.KiroOverage.CurrentOverages > 0 || info.KiroOverage.OverageCharges > 0)
switch {
case info.KiroOveragesEnabled && (overageActive || creditExhausted):
info.KiroQuotaState = kiroQuotaStateOverageActive
info.KiroQuotaReason = "overages_enabled"
info.KiroQuotaResetAt = info.KiroResetAt
case creditExhausted:
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
info.KiroQuotaReason = "credits_exhausted"
info.KiroQuotaResetAt = info.KiroResetAt
default:
info.KiroQuotaState = kiroQuotaStateNormal
}
}
+458
View File
@@ -0,0 +1,458 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
const kiroMaxWebSearchIterations = 5
var (
errKiroWebSearchFallback = errors.New("kiro web search fallback")
kiroWebSearchDescCache sync.Map
)
type kiroWebSearchExecution struct {
ResponseBody []byte
Usage ClaudeUsage
RequestID string
}
type kiroWebSearchHTTPError struct {
Response *http.Response
}
type kiroStreamChunkCollector struct {
chunks [][]byte
}
func (e *kiroWebSearchHTTPError) Error() string {
if e == nil || e.Response == nil {
return "kiro web search http error"
}
return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
}
func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
if len(p) > 0 {
w.chunks = append(w.chunks, append([]byte(nil), p...))
}
return len(p), nil
}
func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
collector := &kiroStreamChunkCollector{}
result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
if err != nil {
return nil, nil, err
}
return collector.chunks, result, nil
}
func writeSSEChunks(w io.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if len(chunk) == 0 {
continue
}
if _, err := w.Write(chunk); err != nil {
return err
}
}
return nil
}
func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
if strings.TrimSpace(msgID) == "" {
msgID = "msg_" + kiropkg.GenerateToolUseID()
}
if strings.TrimSpace(model) == "" {
model = "kiro"
}
payload, err := json.Marshal(map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": inputTokens,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
})
if err != nil {
return err
}
_, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
return err
}
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
) error {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return errKiroWebSearchFallback
}
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
if err != nil {
currentBody = anthropicBody
}
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
nextContentBlockIndex := 0
if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
return err
}
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
s.prefetchKiroWebSearchDescription(ctx, account, token)
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
if strings.TrimSpace(nextToken) != "" {
token = nextToken
}
if mcpErr != nil {
results = nil
}
if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
return err
}
nextContentBlockIndex += 2
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
if err != nil {
return errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return &kiroWebSearchHTTPError{Response: resp}
}
chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
defer func() { _ = resp.Body.Close() }()
return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
}()
if streamErr != nil {
return streamErr
}
analysis := kiropkg.AnalyzeBufferedStream(chunks)
if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
if err := writeSSEChunks(w, filtered); err != nil {
return err
}
if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
nextContentBlockIndex = maxIndex + 1
}
query = analysis.WebSearchQuery
if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
} else {
currentToolUseID = analysis.WebSearchToolUseID
}
continue
}
for _, chunk := range chunks {
adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
if !shouldForward {
continue
}
if _, err := w.Write(adjusted); err != nil {
return err
}
}
return nil
}
return fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return nil, errKiroWebSearchFallback
}
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
if err != nil {
currentBody = anthropicBody
}
inputTokens := estimateKiroInputTokens(anthropicBody)
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
searches := make([]kiropkg.SearchIndicator, 0, 2)
requestID := ""
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
s.prefetchKiroWebSearchDescription(ctx, account, token)
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
if strings.TrimSpace(nextToken) != "" {
token = nextToken
}
if mcpErr != nil {
results = nil
}
searches = append(searches, kiropkg.SearchIndicator{
ToolUseID: currentToolUseID,
Query: query,
Results: results,
})
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
if err != nil {
return nil, errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &kiroWebSearchHTTPError{Response: resp}
}
parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
defer func() { _ = resp.Body.Close() }()
return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
}()
if parseErr != nil {
return nil, parseErr
}
if requestID == "" {
requestID = buildKiroRequestID(resp)
}
nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
if injectErr == nil {
parseResult.ResponseBody = finalBody
}
return &kiroWebSearchExecution{
ResponseBody: parseResult.ResponseBody,
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
RequestID: requestID,
}, nil
}
query = nextQuery
if strings.TrimSpace(nextToolUseID) == "" {
nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
}
currentToolUseID = nextToolUseID
}
return nil, fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
kiropkg.SetCachedWebSearchDescription(desc)
}
return
}
reqBody, _ := json.Marshal(kiropkg.MCPRequest{
ID: "tools_list",
JSONRPC: "2.0",
Method: "tools/list",
})
resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
if err != nil || resp == nil {
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return
}
var result kiropkg.MCPResponse
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
return
}
for _, tool := range result.Result.Tools {
if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
kiroWebSearchDescCache.Store(endpoint, tool.Description)
kiropkg.SetCachedWebSearchDescription(tool.Description)
return
}
}
}
func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
if err != nil {
return nil, token, err
}
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
if err != nil {
return nil, nextToken, err
}
if resp == nil {
return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nextToken, err
}
if resp.StatusCode != http.StatusOK {
return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed kiropkg.MCPResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, nextToken, err
}
if parsed.Error != nil {
msg := "unknown error"
if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
msg = strings.TrimSpace(*parsed.Error.Message)
}
code := 0
if parsed.Error.Code != nil {
code = *parsed.Error.Code
}
return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
}
return kiropkg.ParseSearchResults(&parsed), nextToken, nil
}
func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
return kiropkg.MCPRequest{
ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
JSONRPC: "2.0",
Method: "tools/call",
Params: map[string]interface{}{
"name": "web_search",
"arguments": map[string]interface{}{
"query": query,
"_meta": map[string]interface{}{
"_isValid": true,
"_activePath": []string{"query"},
"_completedPaths": [][]string{{"query"}},
},
},
},
}
}
func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
currentToken := token
accountKey := buildKiroAccountKey(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
for attempt := 0; attempt < 3; attempt++ {
if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
return nil, currentToken, failoverErr
}
return nil, currentToken, err
}
req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
if err != nil {
return nil, currentToken, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return nil, currentToken, err
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, currentToken, readErr
}
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
return nil, currentToken, err
}
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
if s.kiroTokenProvider == nil {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr != nil {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, currentToken, sleepErr
}
continue
}
if resp.StatusCode == http.StatusTooManyRequests {
if _, err := s.markKiro429(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, currentToken, err
}
}
if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
if attempt < 2 {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, currentToken, sleepErr
}
continue
}
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, currentToken, err
}
}
return resp, currentToken, nil
}
return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
}
@@ -0,0 +1,28 @@
//go:build unit
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildKiroWebSearchMCPRequest_UsesUnderscoredMetaKeys(t *testing.T) {
req := buildKiroWebSearchMCPRequest("golang concurrency")
body, err := json.Marshal(req)
require.NoError(t, err)
require.Equal(t, "tools/call", gjson.GetBytes(body, "method").String())
require.Equal(t, "web_search", gjson.GetBytes(body, "params.name").String())
require.Equal(t, "golang concurrency", gjson.GetBytes(body, "params.arguments.query").String())
require.True(t, gjson.GetBytes(body, "params.arguments._meta._isValid").Bool())
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._activePath.0").String())
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._completedPaths.0.0").String())
require.False(t, gjson.GetBytes(body, "params.arguments._meta.isValid").Exists())
require.False(t, gjson.GetBytes(body, "params.arguments._meta.activePath").Exists())
require.False(t, gjson.GetBytes(body, "params.arguments._meta.completedPaths").Exists())
}
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
t.Run(mode, func(t *testing.T) {
t.Parallel()
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
privacyCalls := 0
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
privacyCalls++
+3
View File
@@ -289,6 +289,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
@@ -89,6 +89,14 @@ type OpsUpstreamErrorEvent struct {
AccountID int64 `json:"account_id,omitempty"`
AccountName string `json:"account_name,omitempty"`
// Model diagnostics.
RequestedModel string `json:"requested_model,omitempty"`
MappedModel string `json:"mapped_model,omitempty"`
KiroModelID string `json:"kiro_model_id,omitempty"`
HasTools bool `json:"has_tools,omitempty"`
HasAdaptiveThinking bool `json:"has_adaptive_thinking,omitempty"`
HasContext1MBeta bool `json:"has_context_1m_beta,omitempty"`
// Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
+31 -1
View File
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-2.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model)
model = canonicalModelNameForPricing(model)
model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/")
model = strings.TrimPrefix(model, "publishers/google/models/")
@@ -628,7 +628,31 @@ func normalizeModelNameForPricing(model string) string {
if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" {
return canonical
}
return canonicalModelNameForPricing(model)
}
func canonicalModelNameForPricing(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {
return ""
}
switch model {
case "claude-opus-4.5":
return "claude-opus-4-5"
case "claude-opus-4.6":
return "claude-opus-4-6"
case "claude-opus-4.7":
return "claude-opus-4-7"
case "claude-sonnet-4.5":
return "claude-sonnet-4-5"
case "claude-sonnet-4.6":
return "claude-sonnet-4-6"
case "claude-haiku-4.5":
return "claude-haiku-4-5"
default:
return model
}
}
func lastSegment(model string) string {
@@ -674,8 +698,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
@@ -713,6 +739,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
}
case strings.Contains(model, "sonnet"):
switch {
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
fallbackName = "sonnet-4.6"
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
fallbackName = "sonnet-4.5"
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
@@ -722,6 +750,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
}
case strings.Contains(model, "haiku"):
switch {
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
fallbackName = "haiku-4.5"
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
fallbackName = "haiku-3.5"
default:
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
if len(groupIDs) == 0 {
return nil
}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
var firstErr error
for _, platform := range platforms {
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
buckets := make([]SchedulerBucket, 0)
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
for _, platform := range platforms {
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
@@ -42,6 +42,9 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformKiro:
keysToDelete = append(keysToDelete, KiroTokenCacheKey(account))
keysToDelete = append(keysToDelete, "kiro:"+accountIDKey)
case PlatformOpenAI:
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic:
@@ -10,6 +10,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
@@ -44,6 +45,7 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
kiroOAuthService *KiroOAuthService,
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
@@ -64,6 +66,7 @@ func NewTokenRefreshService(
claudeRefresher := NewClaudeTokenRefresher(oauthService)
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
kiroRefresher := NewKiroTokenRefresher(kiroOAuthService)
// 注册平台特定的刷新器(TokenRefresher 接口)
s.refreshers = []TokenRefresher{
@@ -71,6 +74,7 @@ func NewTokenRefreshService(
openAIRefresher,
geminiRefresher,
agRefresher,
kiroRefresher,
}
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
@@ -79,6 +83,7 @@ func NewTokenRefreshService(
openAIRefresher,
geminiRefresher,
agRefresher,
kiroRefresher,
}
return s
@@ -415,6 +420,10 @@ func isNonRetryableRefreshError(err error) bool {
if err == nil {
return false
}
var kiroInvalidGrant *kiropkg.RefreshTokenInvalidError
if errors.As(err, &kiroInvalidGrant) {
return true
}
msg := strings.ToLower(err.Error())
nonRetryable := []string{
"invalid_grant", // refresh_token 已失效
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 5,
Platform: PlatformGemini,
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 6,
Platform: PlatformGemini,
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 7,
Platform: PlatformGemini,
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 9,
Platform: PlatformGemini,
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 11,
Platform: PlatformGemini,
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 12,
Platform: PlatformGemini,
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
until := time.Now().Add(10 * time.Minute)
account := &Account{
ID: 15,
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 16,
Platform: tt.platform,
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 18,
Platform: PlatformOpenAI,
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
@@ -22,9 +22,9 @@ func optionalNonEqualStringPtr(value, compare string) *string {
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
return trimmed
return normalizeModelNameForPricing(trimmed)
}
return strings.TrimSpace(upstreamModel)
return normalizeModelNameForPricing(upstreamModel)
}
func optionalInt64Ptr(v int64) *int64 {
+23 -1
View File
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
@@ -52,6 +53,7 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
kiroOAuthService *KiroOAuthService,
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
@@ -60,7 +62,7 @@ func ProvideTokenRefreshService(
proxyRepo ProxyRepository,
refreshAPI *OAuthRefreshAPI,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 OpenAI privacy opt-out 依赖
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
@@ -129,6 +131,23 @@ func ProvideAntigravityTokenProvider(
return p
}
func ProvideKiroTokenProvider(
accountRepo AccountRepository,
tokenCache GeminiTokenCache,
kiroOAuthService *KiroOAuthService,
refreshAPI *OAuthRefreshAPI,
) *KiroTokenProvider {
p := NewKiroTokenProvider(accountRepo, tokenCache, kiroOAuthService)
executor := NewKiroTokenRefresher(kiroOAuthService)
p.SetRefreshAPI(refreshAPI, executor)
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
return p
}
func ProvideKiroCooldownStore(redisClient *redis.Client) KiroCooldownStore {
return kirocooldown.NewStore(redisClient)
}
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
@@ -457,8 +476,11 @@ var ProviderSet = wire.NewSet(
NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService,
NewKiroOAuthService,
ProvideOAuthRefreshAPI,
ProvideGeminiTokenProvider,
ProvideKiroTokenProvider,
ProvideKiroCooldownStore,
NewGeminiMessagesCompatService,
ProvideAntigravityTokenProvider,
ProvideOpenAITokenProvider,
+12
View File
@@ -89,6 +89,10 @@ security:
enabled: false
# Allowed upstream hosts for API proxying
# 允许代理的上游 API 主机列表
# If you enable Kiro OAuth / IDC, also allow Kiro auth and AWS SSO OIDC hosts.
# 如果启用 Kiro OAuth / IDC,请同时放行 Kiro 鉴权域名和 AWS SSO OIDC 域名。
# If you enable Kiro runtime forwarding, also allow the corresponding AWS Q API endpoint.
# 如果启用 Kiro 运行时转发,还需要放行对应的 AWS Q API 域名。
upstream_hosts:
- "api.openai.com"
- "api.anthropic.com"
@@ -97,6 +101,11 @@ security:
- "api.minimaxi.com"
- "generativelanguage.googleapis.com"
- "cloudcode-pa.googleapis.com"
- "prod.us-east-1.auth.desktop.kiro.dev"
- "oidc.us-east-1.amazonaws.com"
- "oidc.*.amazonaws.com"
- "device.sso.*.amazonaws.com"
- "q.*.amazonaws.com"
- "*.openai.azure.com"
# Allowed hosts for pricing data download
# 允许下载定价数据的主机列表
@@ -372,6 +381,9 @@ gateway:
# Max image requests waiting in this process when overflow_mode=wait, 0=unlimited
# wait 模式当前进程允许排队等待的图片请求数,0=不限制
max_waiting_requests: 100
# Kiro stream keepalive interval (seconds), 0=use default 25
# Kiro 流式 keepalive 间隔(秒),0=使用默认 25
kiro_stream_keepalive_interval: 25
# SSE max line size in bytes (default: 40MB)
# SSE 单行最大字节数(默认 40MB)
max_line_size: 41943040
+7
View File
@@ -565,6 +565,13 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
return data
}
export async function getKiroDefaultModelMapping(): Promise<Record<string, string>> {
const { data } = await apiClient.get<Record<string, string>>(
'/admin/accounts/kiro/default-model-mapping'
)
return data
}
/**
* Refresh OpenAI token using refresh token
* @param refreshToken - The refresh token
+3
View File
@@ -17,6 +17,7 @@ import subscriptionsAPI from './subscriptions'
import usageAPI from './usage'
import geminiAPI from './gemini'
import antigravityAPI from './antigravity'
import kiroAPI from './kiro'
import userAttributesAPI from './userAttributes'
import opsAPI from './ops'
import errorPassthroughAPI from './errorPassthrough'
@@ -50,6 +51,7 @@ export const adminAPI = {
usage: usageAPI,
gemini: geminiAPI,
antigravity: antigravityAPI,
kiro: kiroAPI,
userAttributes: userAttributesAPI,
ops: opsAPI,
errorPassthrough: errorPassthroughAPI,
@@ -81,6 +83,7 @@ export {
usageAPI,
geminiAPI,
antigravityAPI,
kiroAPI,
userAttributesAPI,
opsAPI,
errorPassthroughAPI,
+89
View File
@@ -0,0 +1,89 @@
import { apiClient } from '../client'
export interface KiroAuthUrlResponse {
auth_url: string
session_id: string
state: string
}
export interface KiroIDCAuthUrlResponse extends KiroAuthUrlResponse {
client_id?: string
region?: string
start_url?: string
}
export interface KiroTokenInfo {
access_token?: string
refresh_token?: string
profile_arn?: string
expires_at?: string
auth_method?: string
provider?: string
client_id?: string
client_secret?: string
client_id_hash?: string
email?: string
start_url?: string
region?: string
[key: string]: unknown
}
export async function generateAuthUrl(payload: {
proxy_id?: number
provider?: string
}): Promise<KiroAuthUrlResponse> {
const { data } = await apiClient.post<KiroAuthUrlResponse>('/admin/kiro/oauth/auth-url', payload)
return data
}
export async function generateIDCAuthUrl(payload: {
proxy_id?: number
start_url?: string
region?: string
}): Promise<KiroIDCAuthUrlResponse> {
const { data } = await apiClient.post<KiroIDCAuthUrlResponse>('/admin/kiro/oauth/idc-auth-url', payload)
return data
}
export async function exchangeCode(payload: {
session_id: string
state: string
code: string
callback_path?: string
login_option?: string
proxy_id?: number
}): Promise<KiroTokenInfo> {
const { data } = await apiClient.post<KiroTokenInfo>('/admin/kiro/oauth/exchange-code', payload)
return data
}
export async function refreshToken(payload: {
refresh_token: string
auth_method?: string
provider?: string
client_id?: string
client_secret?: string
start_url?: string
region?: string
profile_arn?: string
proxy_id?: number
}): Promise<KiroTokenInfo> {
const { data } = await apiClient.post<KiroTokenInfo>('/admin/kiro/oauth/refresh-token', payload)
return data
}
export async function importToken(payload: {
token_json: string
device_registration_json?: string
}): Promise<KiroTokenInfo> {
const { data } = await apiClient.post<KiroTokenInfo>('/admin/kiro/oauth/import-token', payload)
return data
}
export default {
generateAuthUrl,
generateIDCAuthUrl,
exchangeCode,
refreshToken,
importToken
}
@@ -12,6 +12,11 @@
<span class="text-[11px] text-gray-400 dark:text-gray-500">{{ overloadCountdown }}</span>
</div>
<div v-else-if="kiroQuotaBadgeLabel" class="flex flex-col items-center gap-1">
<span :class="['badge text-xs', kiroQuotaBadgeClass]">{{ kiroQuotaBadgeLabel }}</span>
<span v-if="kiroQuotaHint" class="text-[11px] text-gray-400 dark:text-gray-500">{{ kiroQuotaHint }}</span>
</div>
<!-- Main Status Badge (shown when not rate limited/overloaded) -->
<template v-else>
<button
@@ -69,7 +74,7 @@
<div
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 w-56 -translate-x-1/2 whitespace-normal rounded bg-gray-900 px-3 py-2 text-center text-xs leading-relaxed text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
>
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatDateTime(account.rate_limit_reset_at) }) }}
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatDateTime(activeKiroRuntimeResetAt || account.rate_limit_reset_at) }) }}
<div
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700"
></div>
@@ -172,11 +177,37 @@ const emit = defineEmits<{
}>()
// Computed: is rate limited (429)
const activeKiroRuntimeResetAt = computed(() => {
if (props.account.platform !== 'kiro') return null
if (props.account.kiro_runtime_state !== 'cooldown') return null
if (!props.account.kiro_runtime_reset_at) return null
const resetAt = new Date(props.account.kiro_runtime_reset_at)
if (Number.isNaN(resetAt.getTime()) || resetAt <= new Date()) return null
return props.account.kiro_runtime_reset_at
})
const isRateLimited = computed(() => {
if (activeKiroRuntimeResetAt.value) return true
if (!props.account.rate_limit_reset_at) return false
return new Date(props.account.rate_limit_reset_at) > new Date()
})
const isKiroRuntimeSuspended = computed(() => {
if (props.account.platform !== 'kiro') return false
if (props.account.kiro_runtime_state !== 'suspended') return false
if (!props.account.kiro_runtime_reset_at) return true
const resetAt = new Date(props.account.kiro_runtime_reset_at)
return Number.isNaN(resetAt.getTime()) || resetAt > new Date()
})
const activeKiroQuotaResetAt = computed(() => {
if (props.account.platform !== 'kiro') return null
if (!props.account.kiro_quota_reset_at) return null
const resetAt = new Date(props.account.kiro_quota_reset_at)
if (Number.isNaN(resetAt.getTime()) || resetAt <= new Date()) return null
return props.account.kiro_quota_reset_at
})
type AccountModelStatusItem = {
kind: 'rate_limit' | 'credits_exhausted' | 'credits_active'
model: string
@@ -281,7 +312,7 @@ const isTempUnschedulable = computed(() => {
// Computed: has error status
const hasError = computed(() => {
return props.account.status === 'error'
return props.account.status === 'error' || isKiroRuntimeSuspended.value
})
const isQuotaExceeded = computed(() => {
@@ -296,7 +327,7 @@ const isQuotaExceeded = computed(() => {
// Computed: countdown text for rate limit (429)
const rateLimitCountdown = computed(() => {
return formatCountdown(props.account.rate_limit_reset_at)
return formatCountdown(activeKiroRuntimeResetAt.value || props.account.rate_limit_reset_at)
})
const rateLimitResumeText = computed(() => {
@@ -309,8 +340,45 @@ const overloadCountdown = computed(() => {
return formatCountdownWithSuffix(props.account.overload_until)
})
const kiroQuotaBadgeLabel = computed(() => {
if (props.account.platform !== 'kiro') return ''
switch (props.account.kiro_quota_state) {
case 'credits_exhausted':
return t('admin.accounts.status.creditsExhausted')
case 'overage_exhausted':
return t('admin.accounts.status.overageExhausted')
default:
return ''
}
})
const kiroQuotaBadgeClass = computed(() => {
switch (props.account.kiro_quota_state) {
case 'credits_exhausted':
case 'overage_exhausted':
return 'badge-danger'
default:
return 'badge-gray'
}
})
const kiroQuotaHint = computed(() => {
if (!activeKiroQuotaResetAt.value) return ''
switch (props.account.kiro_quota_state) {
case 'credits_exhausted':
return t('admin.accounts.status.creditsExhaustedUntil', { time: formatDateTime(activeKiroQuotaResetAt.value) })
case 'overage_exhausted':
return t('admin.accounts.status.overageExhaustedUntil', { time: formatDateTime(activeKiroQuotaResetAt.value) })
default:
return ''
}
})
// Computed: status badge class
const statusClass = computed(() => {
if (isKiroRuntimeSuspended.value) {
return 'badge-danger'
}
if (hasError.value) {
return 'badge-danger'
}
@@ -331,6 +399,9 @@ const statusClass = computed(() => {
// Computed: status text
const statusText = computed(() => {
if (isKiroRuntimeSuspended.value) {
return t('admin.accounts.forbidden')
}
if (hasError.value) {
return t('admin.accounts.status.error')
}
@@ -395,6 +395,72 @@
</div>
</template>
<!-- Kiro platform: show credits + bonus + overage summary -->
<template v-else-if="account.platform === 'kiro' && account.type === 'oauth'">
<div v-if="loading" class="space-y-1.5">
<div class="h-4 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div class="space-y-1">
<div class="h-1.5 w-32 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
<div class="h-1.5 w-28 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
<div v-else-if="error" class="text-xs text-red-500">
{{ error }}
</div>
<div v-else-if="kiroUsageAvailable || kiroStatusBadgeLabel" class="space-y-2">
<div v-if="kiroStatusBadgeLabel" class="flex flex-wrap items-center gap-x-2 gap-y-1">
<span
:class="[
'inline-flex items-center gap-1 text-[10px] font-medium',
kiroStatusToneClass
]"
:title="usageInfo?.error || undefined"
>
<span class="h-1.5 w-1.5 rounded-full bg-current opacity-80"></span>
{{ kiroStatusBadgeLabel }}
</span>
</div>
<div v-if="kiroStatusHint" class="text-[9px] leading-tight text-gray-500 dark:text-gray-400">
{{ kiroStatusHint }}
</div>
<div v-if="usageInfo?.kiro_credit" class="space-y-1">
<div class="flex items-baseline justify-between gap-2 text-[11px] text-gray-600 dark:text-gray-300">
<span class="font-medium tracking-[0.01em]">{{ t('admin.accounts.usageWindow.kiroCredits') }}</span>
<span class="font-semibold tabular-nums text-gray-700 dark:text-gray-200">{{ formatKiroAmount(usageInfo.kiro_credit.current_usage) }} / {{ formatKiroAmount(usageInfo.kiro_credit.usage_limit) }}</span>
</div>
<div class="h-1.5 overflow-hidden rounded-full bg-gray-200 dark:bg-gray-700">
<div class="h-full rounded-full bg-amber-500 transition-all" :style="{ width: `${kiroCreditPercent}%` }"></div>
</div>
</div>
<div v-if="usageInfo?.kiro_bonus" class="space-y-1">
<div class="flex items-baseline justify-between gap-2 text-[11px] text-gray-600 dark:text-gray-300">
<span class="font-medium tracking-[0.01em]">{{ t('admin.accounts.usageWindow.kiroBonus') }}</span>
<span class="font-semibold tabular-nums text-gray-700 dark:text-gray-200">{{ formatKiroAmount(usageInfo.kiro_bonus.current_usage) }} / {{ formatKiroAmount(usageInfo.kiro_bonus.usage_limit) }}</span>
</div>
<div class="h-1.5 overflow-hidden rounded-full bg-gray-200 dark:bg-gray-700">
<div class="h-full rounded-full bg-emerald-500 transition-all" :style="{ width: `${kiroBonusPercent}%` }"></div>
</div>
<div v-if="kiroBonusMeta" class="text-[9px] leading-tight text-gray-500 dark:text-gray-400">
{{ kiroBonusMeta }}
</div>
</div>
<div class="flex flex-wrap items-center gap-x-3 gap-y-1 text-[10px] text-gray-500 dark:text-gray-400">
<span v-if="kiroResetDisplay" class="inline-flex items-center gap-1">
<span class="text-gray-400 dark:text-gray-500">{{ t('admin.accounts.usageWindow.kiroReset') }}</span>
<span class="font-medium tabular-nums text-gray-600 dark:text-gray-300">{{ kiroResetDisplay }}</span>
</span>
<span v-if="kiroOverageSummary" class="inline-flex items-center gap-1 font-medium">
{{ kiroOverageSummary }}
</span>
</div>
</div>
<div v-else class="text-xs text-gray-400">-</div>
</template>
<!-- Other accounts: no usage window -->
<template v-else>
<div class="text-xs text-gray-400">-</div>
@@ -498,6 +564,10 @@ const props = withDefaults(
}
)
const emit = defineEmits<{
kiroUsageMeta: [meta: { plan_type?: string; kiro_overages_enabled: boolean }]
}>()
const { t } = useI18n()
const desktopViewportQuery = '(min-width: 768px)'
@@ -510,7 +580,9 @@ const error = ref<string | null>(null)
const usageInfo = ref<AccountUsageInfo | null>(null)
const rootRef = ref<HTMLElement | null>(null)
const isDesktopViewport = ref(
typeof window === 'undefined' ? true : window.matchMedia(desktopViewportQuery).matches
typeof window === 'undefined' || typeof window.matchMedia !== 'function'
? true
: window.matchMedia(desktopViewportQuery).matches
)
const hasEnteredViewport = ref(false)
const pendingAutoLoad = ref(false)
@@ -531,6 +603,9 @@ const shouldFetchUsage = computed(() => {
if (props.account.platform === 'anthropic') {
return props.account.type === 'oauth' || props.account.type === 'setup-token'
}
if (props.account.platform === 'kiro') {
return props.account.type === 'oauth'
}
if (props.account.platform === 'gemini') {
return true
}
@@ -984,6 +1059,179 @@ const isAnthropicOAuthOrSetupToken = computed(() => {
return props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')
})
const isKiroOAuth = computed(() => {
return props.account.platform === 'kiro' && props.account.type === 'oauth'
})
const defaultUsageSource = computed<'passive' | 'active' | undefined>(() => {
if (isAnthropicOAuthOrSetupToken.value || isKiroOAuth.value) {
return 'passive'
}
return undefined
})
const manualRefreshUsageSource = computed<'passive' | 'active' | undefined>(() => {
if (isKiroOAuth.value) {
return 'active'
}
return defaultUsageSource.value
})
const kiroUsageAvailable = computed(() => {
return !!(
usageInfo.value?.kiro_credit ||
usageInfo.value?.kiro_bonus ||
usageInfo.value?.kiro_overage ||
usageInfo.value?.kiro_reset_at ||
usageInfo.value?.kiro_overages_enabled
)
})
const syncKiroUsageMeta = (info?: AccountUsageInfo | null) => {
if (!isKiroOAuth.value) return
const planType = (
info?.kiro_subscription_name ||
info?.kiro_subscription_type ||
''
).trim()
emit('kiroUsageMeta', {
...(planType ? { plan_type: planType } : {}),
kiro_overages_enabled: info?.kiro_overages_enabled === true
})
}
const clampPercent = (value?: number | null) => {
if (value == null || !Number.isFinite(value)) return 0
return Math.max(0, Math.min(100, value))
}
const kiroCreditPercent = computed(() => clampPercent(usageInfo.value?.kiro_credit?.percentage_used))
const kiroBonusPercent = computed(() => clampPercent(usageInfo.value?.kiro_bonus?.percentage_used))
const formatKiroAmount = (value?: number | null) => {
if (value == null || !Number.isFinite(value)) return '0'
if (Math.abs(value) >= 1000 || Number.isInteger(value)) {
return formatCompactNumber(value, { allowBillions: false })
}
return value.toFixed(2).replace(/\.?0+$/, '')
}
const kiroResetDisplay = computed(() => {
const raw = usageInfo.value?.kiro_reset_at
if (!raw) return ''
const parsed = new Date(raw)
if (Number.isNaN(parsed.getTime())) return ''
return parsed.toLocaleDateString()
})
const kiroBonusMeta = computed(() => {
const bonus = usageInfo.value?.kiro_bonus
if (!bonus) return ''
if ((bonus.days_remaining ?? 0) > 0) {
return t('admin.accounts.usageWindow.kiroDaysLeft', { days: bonus.days_remaining })
}
if (bonus.expiry_date) {
const parsed = new Date(bonus.expiry_date)
if (!Number.isNaN(parsed.getTime())) {
return `${t('admin.accounts.usageWindow.kiroExpires')} ${parsed.toLocaleDateString()}`
}
}
return ''
})
const kiroRuntimeResetDisplay = computed(() => {
const raw = usageInfo.value?.kiro_runtime_reset_at
if (!raw) return ''
const parsed = new Date(raw)
if (Number.isNaN(parsed.getTime())) return ''
return parsed.toLocaleString()
})
const kiroQuotaResetDisplay = computed(() => {
const raw = usageInfo.value?.kiro_quota_reset_at
if (!raw) return ''
const parsed = new Date(raw)
if (Number.isNaN(parsed.getTime())) return ''
return parsed.toLocaleString()
})
const isKiroProfileError = computed(() => {
if (!isKiroOAuth.value) return false
const err = (usageInfo.value?.error || '').toLowerCase()
return err.includes('profilearn is required') ||
(err.includes('profile arn') && err.includes('required')) ||
err.includes('profilearn') ||
err.includes('listavailableprofiles')
})
const isKiroUsageForbidden = computed(() => {
if (!isKiroOAuth.value) return false
return usageInfo.value?.error_code === 'forbidden' && !usageInfo.value?.needs_reauth && !isKiroProfileError.value
})
const kiroQuotaState = computed(() => usageInfo.value?.kiro_quota_state || '')
const kiroStatusBadgeLabel = computed(() => {
const runtimeState = usageInfo.value?.kiro_runtime_state
if (runtimeState === 'suspended') return t('admin.accounts.forbidden')
if (runtimeState === 'cooldown') return t('admin.accounts.status.rateLimited')
if (usageInfo.value?.needs_reauth) return t('admin.accounts.needsReauth')
if (isKiroProfileError.value) return t('admin.accounts.usageError')
if (isKiroUsageForbidden.value) return t('admin.accounts.forbidden')
if (kiroQuotaState.value === 'overage_active') return t('admin.accounts.status.overageActive')
if (kiroQuotaState.value === 'credits_exhausted') return t('admin.accounts.status.creditsExhausted')
if (kiroQuotaState.value === 'overage_exhausted') return t('admin.accounts.status.overageExhausted')
return ''
})
const kiroStatusToneClass = computed(() => {
const runtimeState = usageInfo.value?.kiro_runtime_state
if (runtimeState === 'suspended') return 'text-red-700 dark:text-red-300'
if (runtimeState === 'cooldown') return 'text-amber-700 dark:text-amber-300'
if (usageInfo.value?.needs_reauth) return 'text-orange-700 dark:text-orange-300'
if (isKiroProfileError.value) return 'text-yellow-700 dark:text-yellow-300'
if (isKiroUsageForbidden.value) return 'text-rose-700 dark:text-rose-300'
if (kiroQuotaState.value === 'overage_active') return 'text-amber-700 dark:text-amber-300'
if (kiroQuotaState.value === 'credits_exhausted' || kiroQuotaState.value === 'overage_exhausted') {
return 'text-red-700 dark:text-red-300'
}
return 'text-gray-600 dark:text-gray-300'
})
const kiroStatusHint = computed(() => {
const runtimeState = usageInfo.value?.kiro_runtime_state
if (runtimeState === 'cooldown' && kiroRuntimeResetDisplay.value) {
return t('admin.accounts.status.rateLimitedUntil', { time: kiroRuntimeResetDisplay.value })
}
if (kiroQuotaState.value === 'credits_exhausted' && kiroQuotaResetDisplay.value) {
return t('admin.accounts.status.creditsExhaustedUntil', { time: kiroQuotaResetDisplay.value })
}
if (kiroQuotaState.value === 'overage_exhausted' && kiroQuotaResetDisplay.value) {
return t('admin.accounts.status.overageExhaustedUntil', { time: kiroQuotaResetDisplay.value })
}
return ''
})
const kiroOverageSummary = computed(() => {
const overage = usageInfo.value?.kiro_overage
if (!overage) return ''
const hasOverageCount = (overage.current_overages ?? 0) > 0
const hasCharges = (overage.overage_charges ?? 0) > 0
if (!hasOverageCount && !hasCharges) return ''
const parts: string[] = [t('admin.accounts.usageWindow.kiroOverage')]
if (hasOverageCount) {
parts.push(formatKiroAmount(overage.current_overages))
}
if (hasCharges) {
const symbol = overage.currency_symbol || overage.currency_code || ''
parts.push(`(${symbol}${(overage.overage_charges ?? 0).toFixed(2)})`)
}
return parts.join(' ')
})
const loadUsage = async (options?: { source?: 'passive' | 'active'; bypassCache?: boolean }) => {
if (!shouldFetchUsage.value) return
@@ -992,6 +1240,7 @@ const loadUsage = async (options?: { source?: 'passive' | 'active'; bypassCache?
const cached = _usageCache.get(props.account.id)
if (cached && Date.now() - cached.ts < USAGE_CACHE_TTL) {
usageInfo.value = cached.data
syncKiroUsageMeta(cached.data)
loading.value = false
return
}
@@ -1001,10 +1250,13 @@ const loadUsage = async (options?: { source?: 'passive' | 'active'; bypassCache?
error.value = null
try {
const fetchFn = () => adminAPI.accounts.getUsage(props.account.id, options?.source)
const fetchFn = () => options?.source
? adminAPI.accounts.getUsage(props.account.id, options.source)
: adminAPI.accounts.getUsage(props.account.id)
const result = await enqueueUsageRequest(props.account, fetchFn)
if (!unmounted.value) {
usageInfo.value = result
syncKiroUsageMeta(result)
_usageCache.set(props.account.id, { data: result, ts: Date.now() })
}
} catch (e: any) {
@@ -1070,7 +1322,10 @@ const attachVisibilityObserver = () => {
const loadActiveUsage = async () => {
activeQueryLoading.value = true
try {
usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active')
const result = await adminAPI.accounts.getUsage(props.account.id, 'active')
usageInfo.value = result
syncKiroUsageMeta(result)
_usageCache.set(props.account.id, { data: result, ts: Date.now() })
} catch (e: any) {
console.error('Failed to load active usage:', e)
} finally {
@@ -1166,7 +1421,7 @@ const formatKeyUserCost = computed(() => {
})
onMounted(() => {
if (typeof window !== 'undefined') {
if (typeof window !== 'undefined' && typeof window.matchMedia === 'function') {
desktopViewportMediaQuery = window.matchMedia(desktopViewportQuery)
isDesktopViewport.value = desktopViewportMediaQuery.matches
desktopViewportListener = (event: MediaQueryListEvent) => {
@@ -1180,15 +1435,17 @@ onMounted(() => {
}
if (!shouldAutoLoadUsageOnMount.value) return
const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined
requestAutoLoad(source)
requestAutoLoad(defaultUsageSource.value)
})
watch(openAIUsageRefreshKey, (nextKey, prevKey) => {
if (!prevKey || nextKey === prevKey) return
if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return
requestAutoLoad()
_usageCache.delete(props.account.id)
loadUsage({ bypassCache: true }).catch((e) => {
console.error('Failed to reload OpenAI usage after row refresh:', e)
})
})
watch(
@@ -1197,9 +1454,8 @@ watch(
if (nextToken === prevToken) return
if (!shouldFetchUsage.value) return
const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined
_usageCache.delete(props.account.id)
loadUsage({ source, bypassCache: true }).catch((e) => {
loadUsage({ source: manualRefreshUsageSource.value, bypassCache: true }).catch((e) => {
console.error('Failed to refresh usage after manual refresh:', e)
})
}
@@ -147,6 +147,19 @@
<Icon name="cloud" size="sm" />
Antigravity
</button>
<button
type="button"
@click="form.platform = 'kiro'"
:class="[
'flex flex-1 items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
form.platform === 'kiro'
? 'bg-white text-amber-700 shadow-sm dark:bg-dark-600 dark:text-amber-300'
: 'text-gray-600 hover:text-gray-900 dark:text-gray-400 dark:hover:text-gray-200'
]"
>
<Icon name="sparkles" size="sm" />
Kiro
</button>
</div>
</div>
@@ -774,6 +787,457 @@
</div>
</div>
<!-- Kiro account type selection -->
<div v-if="form.platform === 'kiro'">
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
<div class="mt-2 grid grid-cols-2 gap-3">
<button
type="button"
@click="accountCategory = 'oauth-based'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'oauth-based'
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]"
>
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', accountCategory === 'oauth-based' ? 'bg-amber-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
<Icon name="key" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.types.oauth') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.types.kiroOauth') }}
</span>
</div>
</button>
<button
type="button"
@click="accountCategory = 'apikey'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'apikey'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]"
>
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', accountCategory === 'apikey' ? 'bg-purple-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
API Key
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.types.kiroApikey') }}
</span>
</div>
</button>
</div>
</div>
<!-- Kiro OAuth auth mode selection -->
<div v-if="form.platform === 'kiro' && accountCategory === 'oauth-based'">
<label class="input-label">{{ t('admin.accounts.oauth.kiro.authModeTitle') }}</label>
<div class="mt-2 grid grid-cols-1 gap-3 md:grid-cols-3">
<button
type="button"
@click="kiroAccountType = 'oauth'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroAccountType === 'oauth'
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]"
>
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', kiroAccountType === 'oauth' ? 'bg-amber-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
<Icon name="key" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.oauthTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.oauthSubtitle') }}
</span>
</div>
</button>
<button
type="button"
@click="kiroAccountType = 'idc'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroAccountType === 'idc'
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
: 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
]"
>
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', kiroAccountType === 'idc' ? 'bg-blue-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.idcTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.idcSubtitle') }}
</span>
</div>
</button>
<button
type="button"
@click="kiroAccountType = 'import'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroAccountType === 'import'
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
: 'border-gray-200 hover:border-slate-300 dark:border-dark-600 dark:hover:border-slate-700'
]"
>
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', kiroAccountType === 'import' ? 'bg-slate-700 text-white dark:bg-slate-500' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
<Icon name="download" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.importTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.importSubtitle') }}
</span>
</div>
</button>
</div>
</div>
<div v-if="form.platform === 'kiro' && accountCategory === 'oauth-based' && kiroAccountType === 'oauth'" class="mt-4 space-y-3">
<div class="flex items-center justify-between">
<label class="input-label">{{ t('admin.accounts.oauth.kiro.oauthProviderTitle') }}</label>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.oauth.kiro.socialSubtitle') }}</span>
</div>
<div class="grid grid-cols-2 gap-3">
<button
type="button"
@click="kiroOAuthProvider = 'google'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroOAuthProvider === 'google'
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroOAuthProvider === 'google'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="user" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.googleTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.googleDesc') }}
</span>
</div>
</button>
<button
type="button"
@click="kiroOAuthProvider = 'github'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroOAuthProvider === 'github'
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
: 'border-gray-200 hover:border-slate-300 dark:border-dark-600 dark:hover:border-slate-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroOAuthProvider === 'github'
? 'bg-slate-700 text-white dark:bg-slate-500'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="terminal" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.githubTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.githubDesc') }}
</span>
</div>
</button>
</div>
</div>
<div v-if="form.platform === 'kiro' && accountCategory === 'oauth-based' && kiroAccountType === 'idc'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.startUrlLabel') }}</label>
<input
v-model="kiroIDCStartUrl"
type="text"
class="input"
:placeholder="t('admin.accounts.oauth.kiro.startUrlPlaceholder')"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.regionLabel') }}</label>
<input
v-model="kiroIDCRegion"
type="text"
class="input"
:placeholder="t('admin.accounts.oauth.kiro.regionPlaceholder')"
/>
</div>
</div>
<div v-if="form.platform === 'kiro' && accountCategory === 'apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
<input
v-model="apiKeyBaseUrl"
type="text"
required
class="input"
placeholder="https://your-kiro-upstream.example.com"
/>
<p class="input-hint">{{ baseUrlHint }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.apiKeyRequired') }}</label>
<input
v-model="apiKeyValue"
type="password"
required
class="input font-mono"
placeholder="sk-..."
/>
<p class="input-hint">{{ apiKeyHint }}</p>
</div>
</div>
<div v-if="form.platform === 'kiro' && accountCategory === 'apikey'" class="space-y-4">
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.poolModeHint') }}
</p>
</div>
<button
type="button"
@click="poolModeEnabled = !poolModeEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
<p class="text-xs text-blue-700 dark:text-blue-400">
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
{{ t('admin.accounts.poolModeInfo') }}
</p>
</div>
<div v-if="poolModeEnabled" class="mt-3">
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
<input
v-model.number="poolModeRetryCount"
type="number"
min="0"
:max="MAX_POOL_MODE_RETRY_COUNT"
step="1"
class="input"
/>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{
t('admin.accounts.poolModeRetryCountHint', {
default: DEFAULT_POOL_MODE_RETRY_COUNT,
max: MAX_POOL_MODE_RETRY_COUNT
})
}}
</p>
</div>
</div>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.customErrorCodesHint') }}
</p>
</div>
<button
type="button"
@click="customErrorCodesEnabled = !customErrorCodesEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customErrorCodesEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customErrorCodesEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="customErrorCodesEnabled" class="space-y-3">
<div class="rounded-lg bg-amber-50 p-3 dark:bg-amber-900/20">
<p class="text-xs text-amber-700 dark:text-amber-400">
<Icon name="exclamationTriangle" size="sm" class="mr-1 inline" :stroke-width="2" />
{{ t('admin.accounts.customErrorCodesWarning') }}
</p>
</div>
<div class="flex flex-wrap gap-2">
<button
v-for="code in commonErrorCodes"
:key="code.value"
type="button"
@click="toggleErrorCode(code.value)"
:class="[
'rounded-lg px-3 py-1.5 text-sm font-medium transition-colors',
selectedErrorCodes.includes(code.value)
? 'bg-red-100 text-red-700 ring-1 ring-red-500 dark:bg-red-900/30 dark:text-red-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ code.value }} {{ code.label }}
</button>
</div>
<div class="flex items-center gap-2">
<input
v-model.number="customErrorCodeInput"
type="number"
min="100"
max="599"
class="input flex-1"
:placeholder="t('admin.accounts.enterErrorCode')"
@keyup.enter="addCustomErrorCode"
/>
<button
type="button"
@click="addCustomErrorCode"
class="btn btn-secondary shrink-0"
>
{{ t('admin.accounts.add') }}
</button>
</div>
<div class="flex flex-wrap gap-1.5">
<span
v-for="code in selectedErrorCodes.sort((a, b) => a - b)"
:key="code"
class="inline-flex items-center gap-1 rounded-full bg-red-100 px-2.5 py-0.5 text-sm font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400"
>
{{ code }}
<button
type="button"
@click="removeErrorCode(code)"
class="hover:text-red-900 dark:hover:text-red-300"
>
<Icon name="x" size="sm" :stroke-width="2" />
</button>
</span>
<span v-if="selectedErrorCodes.length === 0" class="text-xs text-gray-400">
{{ t('admin.accounts.noneSelectedUsesDefault') }}
</span>
</div>
</div>
</div>
</div>
<!-- Kiro 只支持模型映射模式不支持白名单模式 -->
<div v-if="form.platform === 'kiro'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div>
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
<p class="text-xs text-purple-700 dark:text-purple-400">
{{ t('admin.accounts.mapRequestModels') }}
</p>
</div>
<div v-if="kiroModelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in kiroModelMappings"
:key="getKiroModelMappingKey(mapping)"
class="space-y-1"
>
<div class="flex items-center gap-2">
<input
v-model="mapping.from"
type="text"
:class="[
'input flex-1',
!isValidWildcardPattern(mapping.from) ? 'border-red-500 dark:border-red-500' : ''
]"
:placeholder="t('admin.accounts.requestModel')"
/>
<svg class="h-4 w-4 flex-shrink-0 text-gray-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3" />
</svg>
<input
v-model="mapping.to"
type="text"
:class="[
'input flex-1',
mapping.to.includes('*') ? 'border-red-500 dark:border-red-500' : ''
]"
:placeholder="t('admin.accounts.actualModel')"
/>
<button
type="button"
@click="removeKiroModelMapping(index)"
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
>
<Icon name="x" size="sm" />
</button>
</div>
</div>
</div>
<button type="button" @click="addKiroModelMapping" class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300">
<svg class="mr-1 inline h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4" />
</svg>
{{ t('admin.accounts.addMapping') }}
</button>
<div class="flex flex-wrap gap-2">
<button
v-for="preset in kiroPresetMappings"
:key="preset.label"
type="button"
@click="addKiroPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
<!-- Upstream config (only for Antigravity upstream type) -->
<div v-if="form.platform === 'antigravity' && antigravityAccountType === 'upstream'" class="space-y-4">
<div>
@@ -1009,7 +1473,7 @@
</div>
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && form.platform !== 'kiro'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
<input
@@ -2750,7 +3214,23 @@
<!-- Step 2: OAuth Authorization -->
<div v-else class="space-y-5">
<div v-if="isKiroImportMode" class="space-y-4 rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-700 dark:bg-amber-900/20">
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.tokenJsonLabel') }}</label>
<textarea v-model="kiroTokenJson" rows="8" class="input font-mono text-xs" placeholder='{"accessToken":"...","refreshToken":"..."}'></textarea>
<p class="input-hint">{{ t('admin.accounts.oauth.kiro.tokenJsonHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.deviceRegistrationLabel') }}</label>
<textarea v-model="kiroDeviceRegistrationJson" rows="6" class="input font-mono text-xs" placeholder='{"clientId":"...","clientSecret":"..."}'></textarea>
<p class="input-hint">{{ t('admin.accounts.oauth.kiro.deviceRegistrationHint') }}</p>
</div>
<div v-if="currentOAuthError" class="rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30">
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">{{ currentOAuthError }}</p>
</div>
</div>
<OAuthAuthorizationFlow
v-else
ref="oauthFlowRef"
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
:auth-url="currentAuthUrl"
@@ -2824,7 +3304,16 @@
{{ t('common.back') }}
</button>
<button
v-if="isManualInputMethod"
v-if="isKiroImportMode"
type="button"
:disabled="currentOAuthLoading || !kiroTokenJson.trim()"
class="btn btn-primary"
@click="handleKiroImport"
>
{{ currentOAuthLoading ? t('admin.accounts.creating') : t('common.create') }}
</button>
<button
v-else-if="isManualInputMethod"
type="button"
:disabled="!canExchangeCode"
class="btn btn-primary"
@@ -3101,6 +3590,7 @@ import {
commonErrorCodes,
buildModelMappingObject,
fetchAntigravityDefaultMappings,
fetchKiroDefaultMappings,
isValidWildcardPattern
} from '@/composables/useModelWhitelist'
import { useAuthStore } from '@/stores/auth'
@@ -3114,6 +3604,7 @@ import {
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
import { useKiroOAuth } from '@/composables/useKiroOAuth'
import type {
Proxy,
AdminGroup,
@@ -3151,6 +3642,8 @@ import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
interface OAuthFlowExposed {
authCode: string
oauthState: string
oauthCallbackPath: string
oauthLoginOption: string
projectId: string
sessionKey: string
refreshToken: string
@@ -3167,6 +3660,11 @@ const oauthStepTitle = computed(() => {
if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title')
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
if (form.platform === 'kiro') {
return kiroAccountType.value === 'import'
? t('admin.accounts.oauth.kiro.importDialogTitle')
: t('admin.accounts.oauth.kiro.title')
}
return t('admin.accounts.oauth.title')
})
@@ -3174,12 +3672,14 @@ const oauthStepTitle = computed(() => {
const baseUrlHint = computed(() => {
if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint')
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
if (form.platform === 'kiro') return t('admin.accounts.kiro.baseUrlHint')
return t('admin.accounts.baseUrlHint')
})
const apiKeyHint = computed(() => {
if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint')
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
if (form.platform === 'kiro') return t('admin.accounts.kiro.apiKeyHint')
return t('admin.accounts.apiKeyHint')
})
@@ -3202,12 +3702,14 @@ const oauth = useAccountOAuth() // For Anthropic OAuth
const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
const kiroOAuth = useKiroOAuth() // For Kiro OAuth / IDC
// Computed: current OAuth state for template binding
const currentAuthUrl = computed(() => {
if (form.platform === 'openai') return openaiOAuth.authUrl.value
if (form.platform === 'gemini') return geminiOAuth.authUrl.value
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
if (form.platform === 'kiro') return kiroOAuth.authUrl.value
return oauth.authUrl.value
})
@@ -3215,6 +3717,7 @@ const currentSessionId = computed(() => {
if (form.platform === 'openai') return openaiOAuth.sessionId.value
if (form.platform === 'gemini') return geminiOAuth.sessionId.value
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
if (form.platform === 'kiro') return kiroOAuth.sessionId.value
return oauth.sessionId.value
})
@@ -3222,6 +3725,7 @@ const currentOAuthLoading = computed(() => {
if (form.platform === 'openai') return openaiOAuth.loading.value
if (form.platform === 'gemini') return geminiOAuth.loading.value
if (form.platform === 'antigravity') return antigravityOAuth.loading.value
if (form.platform === 'kiro') return kiroOAuth.loading.value
return oauth.loading.value
})
@@ -3229,6 +3733,7 @@ const currentOAuthError = computed(() => {
if (form.platform === 'openai') return openaiOAuth.error.value
if (form.platform === 'gemini') return geminiOAuth.error.value
if (form.platform === 'antigravity') return antigravityOAuth.error.value
if (form.platform === 'kiro') return kiroOAuth.error.value
return oauth.error.value
})
@@ -3307,6 +3812,14 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist'
const antigravityWhitelistModels = ref<string[]>([])
const antigravityModelMappings = ref<ModelMapping[]>([])
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
const kiroAccountType = ref<'oauth' | 'idc' | 'import'>('oauth')
const kiroOAuthProvider = ref<'google' | 'github'>('google')
const kiroIDCStartUrl = ref('https://view.awsapps.com/start')
const kiroIDCRegion = ref('us-east-1')
const kiroTokenJson = ref('')
const kiroDeviceRegistrationJson = ref('')
const kiroModelMappings = ref<ModelMapping[]>([])
const kiroPresetMappings = computed(() => getPresetMappingsByPlatform('kiro'))
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
// Bedrock credentials
@@ -3328,6 +3841,7 @@ const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
const getOpenAICompactModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-antigravity-model-mapping')
const getKiroModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-kiro-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('create-temp-unsched-rule')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false)
@@ -3513,6 +4027,8 @@ const isOAuthFlow = computed(() => {
return accountCategory.value === 'oauth-based'
})
const isKiroImportMode = computed(() => form.platform === 'kiro' && kiroAccountType.value === 'import')
const isManualInputMethod = computed(() => {
return oauthFlowRef.value?.inputMethod === 'manual'
})
@@ -3535,6 +4051,9 @@ const canExchangeCode = computed(() => {
if (form.platform === 'antigravity') {
return authCode.trim() && antigravityOAuth.sessionId.value && !antigravityOAuth.loading.value
}
if (form.platform === 'kiro') {
return authCode.trim() && kiroOAuth.sessionId.value && !kiroOAuth.loading.value
}
return authCode.trim() && oauth.sessionId.value && !oauth.loading.value
})
@@ -3556,10 +4075,15 @@ watch(
antigravityModelMappings.value = [...mappings]
})
antigravityWhitelistModels.value = []
} else if (form.platform === 'kiro') {
fetchKiroDefaultMappings().then(mappings => {
kiroModelMappings.value = [...mappings]
})
} else {
antigravityWhitelistModels.value = []
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
kiroModelMappings.value = []
}
} else {
resetForm()
@@ -3576,6 +4100,10 @@ watch(
form.type = 'apikey'
return
}
if (form.platform === 'kiro') {
form.type = category === 'oauth-based' ? 'oauth' : 'apikey'
return
}
// Bedrock
if (form.platform === 'anthropic' && category === 'bedrock') {
form.type = 'bedrock' as AccountType
@@ -3602,6 +4130,8 @@ watch(
? 'https://api.openai.com'
: newPlatform === 'gemini'
? 'https://generativelanguage.googleapis.com'
: newPlatform === 'kiro'
? ''
: 'https://api.anthropic.com'
// Clear model-related settings
allowedModels.value = []
@@ -3615,11 +4145,21 @@ watch(
antigravityWhitelistModels.value = []
accountCategory.value = 'oauth-based'
antigravityAccountType.value = 'oauth'
} else if (newPlatform === 'kiro') {
fetchKiroDefaultMappings().then(mappings => {
kiroModelMappings.value = [...mappings]
})
accountCategory.value = 'oauth-based'
kiroAccountType.value = 'oauth'
kiroOAuthProvider.value = 'google'
apiKeyBaseUrl.value = ''
apiKeyValue.value = ''
} else {
allowOverages.value = false
antigravityWhitelistModels.value = []
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
kiroModelMappings.value = []
}
if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
accountCategory.value = 'oauth-based'
@@ -3659,6 +4199,7 @@ watch(
geminiOAuth.resetState()
antigravityOAuth.resetState()
kiroOAuth.resetState()
}
)
@@ -3760,6 +4301,22 @@ const addAntigravityPresetMapping = (from: string, to: string) => {
antigravityModelMappings.value.push({ from, to })
}
const addKiroModelMapping = () => {
kiroModelMappings.value.push({ from: '', to: '' })
}
const removeKiroModelMapping = (index: number) => {
kiroModelMappings.value.splice(index, 1)
}
const addKiroPresetMapping = (from: string, to: string) => {
if (kiroModelMappings.value.some((m) => m.from === from)) {
appStore.showInfo(t('admin.accounts.mappingExists', { model: from }))
return
}
kiroModelMappings.value.push({ from, to })
}
// Error code toggle helper
const toggleErrorCode = (code: number) => {
const index = selectedErrorCodes.value.indexOf(code)
@@ -4033,6 +4590,15 @@ const resetForm = () => {
fetchAntigravityDefaultMappings().then(mappings => {
antigravityModelMappings.value = [...mappings]
})
kiroAccountType.value = 'oauth'
kiroOAuthProvider.value = 'google'
kiroIDCStartUrl.value = 'https://view.awsapps.com/start'
kiroIDCRegion.value = 'us-east-1'
kiroTokenJson.value = ''
kiroDeviceRegistrationJson.value = ''
fetchKiroDefaultMappings().then(mappings => {
kiroModelMappings.value = [...mappings]
})
poolModeEnabled.value = false
poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT
customErrorCodesEnabled.value = false
@@ -4084,6 +4650,7 @@ const resetForm = () => {
openaiOAuth.resetState()
geminiOAuth.resetState()
antigravityOAuth.resetState()
kiroOAuth.resetState()
oauthFlowRef.value?.reset()
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
@@ -4379,6 +4946,45 @@ const handleSubmit = async () => {
return
}
// For Kiro API key type, create directly
if (form.platform === 'kiro' && accountCategory.value === 'apikey') {
if (!form.name.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
if (!apiKeyBaseUrl.value.trim()) {
appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl'))
return
}
if (!apiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
return
}
const credentials: Record<string, unknown> = {
base_url: apiKeyBaseUrl.value.trim(),
api_key: apiKeyValue.value.trim()
}
const modelMapping = buildModelMappingObject('mapping', [], kiroModelMappings.value)
if (modelMapping) {
credentials.model_mapping = modelMapping
}
if (poolModeEnabled.value) {
credentials.pool_mode = true
credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
}
if (customErrorCodesEnabled.value) {
credentials.custom_error_codes_enabled = true
credentials.custom_error_codes = [...selectedErrorCodes.value]
}
await createAccountAndFinish('kiro', 'apikey', credentials)
return
}
// For apikey type, create directly
if (!apiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
@@ -4450,6 +5056,7 @@ const goBackToBasicInfo = () => {
openaiOAuth.resetState()
geminiOAuth.resetState()
antigravityOAuth.resetState()
kiroOAuth.resetState()
oauthFlowRef.value?.reset()
}
@@ -4465,6 +5072,19 @@ const handleGenerateUrl = async () => {
)
} else if (form.platform === 'antigravity') {
await antigravityOAuth.generateAuthUrl(form.proxy_id)
} else if (form.platform === 'kiro') {
if (kiroAccountType.value === 'idc') {
await kiroOAuth.generateIDCAuthUrl({
proxyId: form.proxy_id,
startUrl: kiroIDCStartUrl.value.trim() || undefined,
region: kiroIDCRegion.value.trim() || undefined
})
} else {
await kiroOAuth.generateAuthUrl(
form.proxy_id,
kiroOAuthProvider.value === 'github' ? 'Github' : 'Google'
)
}
} else {
await oauth.generateAuthUrl(addMethod.value, form.proxy_id)
}
@@ -5035,6 +5655,50 @@ const handleAntigravityExchange = async (authCode: string) => {
}
}
const buildKiroCredentials = (tokenInfo: Parameters<typeof kiroOAuth.buildCredentials>[0]) => {
const credentials = kiroOAuth.buildCredentials(tokenInfo)
const modelMapping = buildModelMappingObject('mapping', [], kiroModelMappings.value)
if (modelMapping) {
credentials.model_mapping = modelMapping
}
return credentials
}
const handleKiroExchange = async (authCode: string) => {
if (!authCode.trim() || !kiroOAuth.sessionId.value) return
kiroOAuth.loading.value = true
kiroOAuth.error.value = ''
try {
const stateFromInput = oauthFlowRef.value?.oauthState || ''
const stateToUse = stateFromInput || kiroOAuth.state.value
if (!stateToUse) {
kiroOAuth.error.value = t('admin.accounts.oauth.authFailed')
appStore.showError(kiroOAuth.error.value)
return
}
const tokenInfo = await kiroOAuth.exchangeAuthCode({
code: authCode.trim(),
sessionId: kiroOAuth.sessionId.value,
state: stateToUse,
callbackPath: oauthFlowRef.value?.oauthCallbackPath || '',
loginOption: oauthFlowRef.value?.oauthLoginOption || '',
proxyId: form.proxy_id
})
if (!tokenInfo) return
const credentials = buildKiroCredentials(tokenInfo)
await createAccountAndFinish('kiro', 'oauth', credentials)
} catch (error: any) {
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(kiroOAuth.error.value)
} finally {
kiroOAuth.loading.value = false
}
}
// Anthropic OAuth
const handleAnthropicExchange = async (authCode: string) => {
if (!authCode.trim() || !oauth.sessionId.value) return
@@ -5133,6 +5797,8 @@ const handleExchangeCode = async () => {
return handleOpenAIExchange(authCode)
case 'gemini':
return handleGeminiExchange(authCode)
case 'kiro':
return handleKiroExchange(authCode)
case 'antigravity':
return handleAntigravityExchange(authCode)
default:
@@ -5140,6 +5806,24 @@ const handleExchangeCode = async () => {
}
}
const handleKiroImport = async () => {
if (!isKiroImportMode.value || !kiroTokenJson.value.trim()) return
const tokenInfo = await kiroOAuth.importToken(
kiroTokenJson.value,
kiroDeviceRegistrationJson.value || undefined
)
if (!tokenInfo) return
try {
const credentials = buildKiroCredentials(tokenInfo)
await createAccountAndFinish('kiro', 'oauth', credentials)
} catch (error: any) {
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(kiroOAuth.error.value)
}
}
const handleCookieAuth = async (sessionKey: string) => {
oauth.loading.value = true
oauth.error.value = ''
@@ -39,6 +39,8 @@
? 'https://api.openai.com'
: account.platform === 'gemini'
? 'https://generativelanguage.googleapis.com'
: account.platform === 'kiro'
? 'https://your-kiro-upstream.example.com'
: account.platform === 'antigravity'
? 'https://cloudcode-pa.googleapis.com'
: 'https://api.anthropic.com'
@@ -61,6 +63,8 @@
? 'sk-proj-...'
: account.platform === 'gemini'
? 'AIza...'
: account.platform === 'kiro'
? 'sk-...'
: account.platform === 'antigravity'
? 'sk-...'
: 'sk-ant-...'
@@ -69,8 +73,93 @@
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
</div>
<!-- Model Restriction Section (不适用于 Antigravity) -->
<div v-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div v-if="account.platform === 'kiro'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
<p class="text-xs text-purple-700 dark:text-purple-400">
{{ t('admin.accounts.mapRequestModels') }}
</p>
</div>
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in modelMappings"
:key="getModelMappingKey(mapping)"
class="space-y-1"
>
<div class="flex items-center gap-2">
<input
v-model="mapping.from"
type="text"
:class="[
'input flex-1',
!isValidWildcardPattern(mapping.from) ? 'border-red-500 dark:border-red-500' : ''
]"
:placeholder="t('admin.accounts.requestModel')"
/>
<svg class="h-4 w-4 flex-shrink-0 text-gray-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3" />
</svg>
<input
v-model="mapping.to"
type="text"
:class="[
'input flex-1',
mapping.to.includes('*') ? 'border-red-500 dark:border-red-500' : ''
]"
:placeholder="t('admin.accounts.actualModel')"
/>
<button
type="button"
@click="removeModelMapping(index)"
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
/>
</svg>
</button>
</div>
<p v-if="!isValidWildcardPattern(mapping.from)" class="text-xs text-red-500">
{{ t('admin.accounts.wildcardOnlyAtEnd') }}
</p>
<p v-if="mapping.to.includes('*')" class="text-xs text-red-500">
{{ t('admin.accounts.targetNoWildcard') }}
</p>
</div>
</div>
<button
type="button"
@click="addModelMapping"
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
>
<svg class="mr-1 inline h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4" />
</svg>
{{ t('admin.accounts.addMapping') }}
</button>
<div class="flex flex-wrap gap-2">
<button
v-for="preset in presetMappings"
:key="preset.label"
type="button"
@click="addPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div>
<!-- Model Restriction Section (不适用于 Antigravity / Kiro) -->
<div v-else-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div
@@ -407,9 +496,9 @@
</div>
<!-- OpenAI OAuth Model Mapping (OAuth 类型没有 apikey 容器需要独立的模型映射区域) -->
<!-- OpenAI / Kiro OAuth Model Restriction (OAuth 类型没有 apikey 容器需要独立区域) -->
<div
v-if="account.platform === 'openai' && account.type === 'oauth'"
v-if="(account.platform === 'openai' || account.platform === 'kiro') && account.type === 'oauth'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
@@ -423,6 +512,82 @@
</p>
</div>
<template v-else-if="account.platform === 'kiro'">
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
<p class="text-xs text-purple-700 dark:text-purple-400">
{{ t('admin.accounts.mapRequestModels') }}
</p>
</div>
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in modelMappings"
:key="'oauth-' + getModelMappingKey(mapping)"
class="flex items-center gap-2"
>
<input
v-model="mapping.from"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.requestModel')"
/>
<svg
class="h-4 w-4 flex-shrink-0 text-gray-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M14 5l7 7m0 0l-7 7m7-7H3"
/>
</svg>
<input
v-model="mapping.to"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.actualModel')"
/>
<button
type="button"
@click="removeModelMapping(index)"
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
/>
</svg>
</button>
</div>
</div>
<button
type="button"
@click="addModelMapping"
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
>
+ {{ t('admin.accounts.addMapping') }}
</button>
<div class="flex flex-wrap gap-2">
<button
v-for="preset in presetMappings"
:key="'oauth-' + preset.label"
type="button"
@click="addPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</template>
<template v-else>
<!-- Mode Toggle -->
<div class="mb-4 flex gap-2">
@@ -2205,6 +2370,7 @@ import {
resolveOpenAIWSModeFromExtra
} from '@/utils/openaiWsMode'
import {
fetchKiroDefaultMappings,
getPresetMappingsByPlatform,
commonErrorCodes,
buildModelMappingObject,
@@ -2233,11 +2399,13 @@ const baseUrlHint = computed(() => {
if (!props.account) return t('admin.accounts.baseUrlHint')
if (props.account.platform === 'openai') return t('admin.accounts.openai.baseUrlHint')
if (props.account.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
if (props.account.platform === 'kiro') return t('admin.accounts.kiro.baseUrlHint')
return t('admin.accounts.baseUrlHint')
})
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
const isKiroOAuthAccount = computed(() => props.account?.platform === 'kiro' && props.account?.type === 'oauth')
// Model mapping type
interface ModelMapping {
@@ -2295,6 +2463,21 @@ const getOpenAICompactModelMappingKey = createStableObjectKeyResolver<ModelMappi
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
const applyKiroModelMappings = (entries: Array<[string, string]>) => {
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
const loadDefaultKiroModelMappings = () => {
fetchKiroDefaultMappings().then(mappings => {
if (!isKiroOAuthAccount.value) return
modelRestrictionMode.value = 'mapping'
modelMappings.value = mappings.map(({ from, to }) => ({ from, to }))
allowedModels.value = []
})
}
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
@@ -2486,6 +2669,7 @@ const tempUnschedPresets = computed(() => [
const defaultBaseUrl = computed(() => {
if (props.account?.platform === 'openai') return 'https://api.openai.com'
if (props.account?.platform === 'gemini') return 'https://generativelanguage.googleapis.com'
if (props.account?.platform === 'kiro') return ''
return 'https://api.anthropic.com'
})
@@ -2709,6 +2893,8 @@ const syncFormFromAccount = (newAccount: Account | null) => {
? 'https://api.openai.com'
: newAccount.platform === 'gemini'
? 'https://generativelanguage.googleapis.com'
: newAccount.platform === 'kiro'
? ''
: 'https://api.anthropic.com'
editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl
@@ -2717,6 +2903,11 @@ const syncFormFromAccount = (newAccount: Account | null) => {
if (existingMappings && typeof existingMappings === 'object') {
const entries = Object.entries(existingMappings)
if (newAccount.platform === 'kiro') {
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
} else {
// Detect if this is whitelist mode (all from === to) or mapping mode
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
@@ -2731,6 +2922,16 @@ const syncFormFromAccount = (newAccount: Account | null) => {
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
}
} else if (newAccount.platform === 'kiro') {
fetchKiroDefaultMappings().then(mappings => {
if (props.account?.id !== newAccount.id || props.account?.type !== 'apikey' || props.account?.platform !== 'kiro') {
return
}
modelRestrictionMode.value = 'mapping'
modelMappings.value = mappings.map(({ from, to }) => ({ from, to }))
allowedModels.value = []
})
} else {
// No mappings: default to whitelist mode with empty selection (allow all)
modelRestrictionMode.value = 'whitelist'
@@ -2785,15 +2986,18 @@ const syncFormFromAccount = (newAccount: Account | null) => {
const entries = Object.entries(existingMappings)
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
if (isWhitelistMode) {
// Whitelist mode: populate allowedModels
modelRestrictionMode.value = 'whitelist'
allowedModels.value = entries.map(([from]) => from)
modelMappings.value = []
} else {
// Mapping mode: populate modelMappings
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
} else {
// No mappings: default to whitelist mode with empty selection (allow all)
modelRestrictionMode.value = 'whitelist'
modelMappings.value = []
allowedModels.value = []
@@ -2835,8 +3039,16 @@ const syncFormFromAccount = (newAccount: Account | null) => {
: 'https://api.anthropic.com'
editBaseUrl.value = platformDefaultUrl
// Load model mappings for OpenAI OAuth accounts
if (newAccount.platform === 'openai' && newAccount.credentials) {
// Load model mappings for OpenAI/Kiro OAuth accounts
if (newAccount.platform === 'kiro' && newAccount.credentials) {
const oauthCredentials = newAccount.credentials as Record<string, unknown>
const existingMappings = oauthCredentials.model_mapping as Record<string, string> | undefined
if (existingMappings && typeof existingMappings === 'object' && Object.keys(existingMappings).length > 0) {
applyKiroModelMappings(Object.entries(existingMappings))
} else {
loadDefaultKiroModelMappings()
}
} else if (newAccount.platform === 'openai' && newAccount.credentials) {
const oauthCredentials = newAccount.credentials as Record<string, unknown>
const existingMappings = oauthCredentials.model_mapping as Record<string, string> | undefined
if (existingMappings && typeof existingMappings === 'object') {
@@ -2892,6 +3104,7 @@ watch(
{ immediate: true }
)
// Model mapping helpers
const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
@@ -3333,9 +3546,16 @@ const handleSubmit = async () => {
// For apikey type, handle credentials update
if (props.account.type === 'apikey') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newBaseUrl = editBaseUrl.value.trim() || defaultBaseUrl.value
const newBaseUrl = props.account.platform === 'kiro'
? editBaseUrl.value.trim()
: (editBaseUrl.value.trim() || defaultBaseUrl.value)
const shouldApplyModelMapping = !(props.account.platform === 'openai' && openaiPassthroughEnabled.value)
if (!newBaseUrl) {
appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl'))
return
}
// Always update credentials for apikey type to handle model mapping changes
const newCredentials: Record<string, unknown> = {
...currentCredentials,
@@ -3356,7 +3576,11 @@ const handleSubmit = async () => {
// Add model mapping if configuredOpenAI
if (shouldApplyModelMapping) {
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
const modelMapping = buildModelMappingObject(
props.account.platform === 'kiro' ? 'mapping' : modelRestrictionMode.value,
props.account.platform === 'kiro' ? [] : allowedModels.value,
modelMappings.value
)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
@@ -3494,7 +3718,7 @@ const handleSubmit = async () => {
}
// Model mapping
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
const modelMapping = buildModelMappingObject('mapping', [], modelMappings.value)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
@@ -3548,6 +3772,22 @@ const handleSubmit = async () => {
updatePayload.credentials = newCredentials
}
// Kiro OAuth: persist model mapping to credentials
if (props.account.platform === 'kiro' && props.account.type === 'oauth') {
const currentCredentials = (updatePayload.credentials as Record<string, unknown>) ||
((props.account.credentials as Record<string, unknown>) || {})
const newCredentials: Record<string, unknown> = { ...currentCredentials }
const modelMapping = buildModelMappingObject('mapping', [], modelMappings.value)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
delete newCredentials.model_mapping
}
updatePayload.credentials = newCredentials
}
// Antigravity: persist model mapping to credentials (applies to all antigravity types)
// Antigravity
if (props.account.platform === 'antigravity') {
@@ -696,6 +696,7 @@ const getOAuthKey = (key: string) => {
if (props.platform === 'openai') return `admin.accounts.oauth.openai.${key}`
if (props.platform === 'gemini') return `admin.accounts.oauth.gemini.${key}`
if (props.platform === 'antigravity') return `admin.accounts.oauth.antigravity.${key}`
if (props.platform === 'kiro') return `admin.accounts.oauth.kiro.${key}`
return `admin.accounts.oauth.${key}`
}
@@ -726,6 +727,8 @@ const sessionTokenInput = ref('')
const codexSessionInput = ref('')
const showHelpDialog = ref(false)
const oauthState = ref('')
const oauthCallbackPath = ref('')
const oauthLoginOption = ref('')
const projectId = ref('')
// Computed: show method selection when either cookie or refresh token option is enabled
@@ -765,10 +768,10 @@ watch(inputMethod, (newVal) => {
emit('update:inputMethod', newVal)
})
// Auto-extract code from callback URL (OpenAI/Gemini/Antigravity)
// Auto-extract code from callback URL (OpenAI/Gemini/Antigravity/Kiro)
// e.g., http://localhost:8085/callback?code=xxx...&state=...
watch(authCodeInput, (newVal) => {
if (props.platform !== 'openai' && props.platform !== 'gemini' && props.platform !== 'antigravity') return
if (props.platform !== 'openai' && props.platform !== 'gemini' && props.platform !== 'antigravity' && props.platform !== 'kiro') return
const trimmed = newVal.trim()
// Check if it looks like a URL with code parameter
@@ -778,7 +781,11 @@ watch(authCodeInput, (newVal) => {
const url = new URL(trimmed)
const code = url.searchParams.get('code')
const stateParam = url.searchParams.get('state')
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
if (props.platform === 'kiro') {
oauthCallbackPath.value = url.pathname || ''
oauthLoginOption.value = url.searchParams.get('login_option') || ''
}
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity' || props.platform === 'kiro') && stateParam) {
oauthState.value = stateParam
}
if (code && code !== trimmed) {
@@ -789,7 +796,13 @@ watch(authCodeInput, (newVal) => {
// If URL parsing fails, try regex extraction
const match = trimmed.match(/[?&]code=([^&]+)/)
const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
if (props.platform === 'kiro') {
const pathMatch = trimmed.match(/^https?:\/\/[^/]+(\/[^?]*)/)
oauthCallbackPath.value = pathMatch?.[1] || oauthCallbackPath.value
const loginOptionMatch = trimmed.match(/[?&]login_option=([^&]+)/)
oauthLoginOption.value = loginOptionMatch?.[1] || oauthLoginOption.value
}
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity' || props.platform === 'kiro') && stateMatch && stateMatch[1]) {
oauthState.value = stateMatch[1]
}
if (match && match[1] && match[1] !== trimmed) {
@@ -841,6 +854,8 @@ const handleImportCodexSession = () => {
defineExpose({
authCode: authCodeInput,
oauthState,
oauthCallbackPath,
oauthLoginOption,
projectId,
sessionKey: sessionKeyInput,
refreshToken: refreshTokenInput,
@@ -850,6 +865,8 @@ defineExpose({
reset: () => {
authCodeInput.value = ''
oauthState.value = ''
oauthCallbackPath.value = ''
oauthLoginOption.value = ''
projectId.value = ''
sessionKeyInput.value = ''
refreshTokenInput.value = ''
@@ -13,6 +13,15 @@ vi.mock('vue-i18n', async () => {
}
})
vi.mock('@/i18n', () => ({
i18n: {
global: {
t: (key: string) => key
}
},
getLocale: () => 'en'
}))
function makeAccount(overrides: Partial<Account>): Account {
return {
id: 1,
@@ -35,6 +44,12 @@ function makeAccount(overrides: Partial<Account>): Account {
overload_until: null,
temp_unschedulable_until: null,
temp_unschedulable_reason: null,
kiro_quota_state: null,
kiro_quota_reason: null,
kiro_quota_reset_at: null,
kiro_runtime_state: null,
kiro_runtime_reason: null,
kiro_runtime_reset_at: null,
session_window_start: null,
session_window_end: null,
session_window_status: null,
@@ -159,4 +174,96 @@ describe('AccountStatusIndicator', () => {
// AICredits 积分耗尽状态应显示
expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
it('Kiro 运行时冷却在状态列复用限流展示', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 5,
name: 'kiro-cooldown',
platform: 'kiro',
kiro_runtime_state: 'cooldown',
kiro_runtime_reason: 'rate_limit_exceeded',
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedAutoResume')
expect(wrapper.text()).toContain('429')
})
it('Kiro suspended 在状态列显示为 forbidden', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 6,
name: 'kiro-suspended',
platform: 'kiro',
kiro_runtime_state: 'suspended',
kiro_runtime_reason: 'account_suspended',
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.forbidden')
})
it('Kiro overage active 在状态列仍显示正常状态', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 7,
name: 'kiro-overage-active',
platform: 'kiro',
kiro_quota_state: 'overage_active',
kiro_quota_reason: 'overages_enabled',
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.status.active')
expect(wrapper.text()).not.toContain('admin.accounts.status.overageActive')
})
it('Kiro overage exhausted 在状态列显示危险徽章', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 8,
name: 'kiro-overage-exhausted',
platform: 'kiro',
kiro_quota_state: 'overage_exhausted',
kiro_quota_reason: 'overage disabled after quota exhaustion',
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.status.overageExhausted')
expect(wrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
})
})
@@ -25,6 +25,15 @@ vi.mock('vue-i18n', async () => {
}
})
vi.mock('@/i18n', () => ({
i18n: {
global: {
t: (key: string) => key
}
},
getLocale: () => 'en'
}))
function makeAccount(overrides: Partial<Account>): Account {
return {
id: 1,
@@ -47,6 +56,12 @@ function makeAccount(overrides: Partial<Account>): Account {
overload_until: null,
temp_unschedulable_until: null,
temp_unschedulable_reason: null,
kiro_quota_state: null,
kiro_quota_reason: null,
kiro_quota_reset_at: null,
kiro_runtime_state: null,
kiro_runtime_reason: null,
kiro_runtime_reset_at: null,
session_window_start: null,
session_window_end: null,
session_window_status: null,
@@ -530,6 +545,234 @@ describe('AccountUsageCell', () => {
expect(wrapper.text()).toContain('7d|100|106540000')
})
it('Kiro OAuth 会用 passive source 拉取并展示 credits 额度', async () => {
const account = makeAccount({
id: 3001,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
getUsage.mockResolvedValue({
source: 'passive',
kiro_subscription_name: 'KIRO PRO+',
kiro_overages_enabled: true,
kiro_credit: {
current_usage: 125,
usage_limit: 2000,
percentage_used: 6.25,
},
kiro_bonus: {
current_usage: 25,
usage_limit: 500,
percentage_used: 5,
days_remaining: 7,
},
kiro_overage: {
current_overages: 2,
overage_charges: 0.08,
currency_symbol: '$',
currency_code: 'USD',
},
kiro_reset_at: '2099-03-13T12:00:00Z',
})
const wrapper = mount(AccountUsageCell, {
props: {
account
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(getUsage).toHaveBeenCalledWith(3001, 'passive')
expect(wrapper.emitted('kiroUsageMeta')?.[0]).toEqual([
{
plan_type: 'KIRO PRO+',
kiro_overages_enabled: true
}
])
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroCredits')
expect(wrapper.text()).toContain('125 / 2.0K')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroBonus')
expect(wrapper.text()).toContain('25 / 500')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroDaysLeft')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroReset')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroOverage 2 ($0.08)')
})
it('Kiro OAuth 会展示运行时冷却状态', async () => {
getUsage.mockResolvedValue({
source: 'passive',
kiro_runtime_state: 'cooldown',
kiro_runtime_reason: 'rate_limit_exceeded',
kiro_runtime_reset_at: '2099-03-13T12:00:00Z',
kiro_credit: {
current_usage: 10,
usage_limit: 100,
percentage_used: 10,
},
})
const wrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3002,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedUntil')
})
it('Kiro OAuth 会展示 overage active 与 exhausted 状态', async () => {
getUsage.mockResolvedValueOnce({
source: 'passive',
kiro_quota_state: 'overage_active',
kiro_quota_reason: 'overages_enabled',
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
kiro_overages_enabled: true,
kiro_credit: {
current_usage: 2100,
usage_limit: 2000,
percentage_used: 100,
},
kiro_overage: {
current_overages: 3,
overage_charges: 0.12,
currency_symbol: '$',
},
})
const activeWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3005,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(activeWrapper.text()).toContain('admin.accounts.status.overageActive')
expect(activeWrapper.text()).not.toContain('admin.accounts.status.overageActiveUntil')
getUsage.mockResolvedValueOnce({
source: 'passive',
kiro_quota_state: 'overage_exhausted',
kiro_quota_reason: 'usage API error: overage exhausted',
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
error: 'usage API error: kiro usage request failed (status 429): {"message":"overage exhausted"}',
})
const exhaustedWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3006,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhausted')
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
})
it('Kiro OAuth 会展示 profile 异常和 usage forbidden 徽章', async () => {
getUsage.mockResolvedValueOnce({
source: 'passive',
error_code: 'forbidden',
error: 'usage API error: kiro usage request failed (status 400): {"message":"profileArn is required for this request."}',
})
const profileWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3003,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(profileWrapper.text()).toContain('admin.accounts.usageError')
getUsage.mockResolvedValueOnce({
source: 'passive',
error_code: 'forbidden',
error: 'usage API error: kiro usage request failed (status 403): {"message":"User is not authorized to access this feature."}',
})
const forbiddenWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3004,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(forbiddenWrapper.text()).toContain('admin.accounts.forbidden')
})
it('Key 账号会展示 today stats 徽章并带 A/U 提示', async () => {
const wrapper = mount(AccountUsageCell, {
props: {
@@ -0,0 +1,56 @@
import { describe, expect, it, vi } from 'vitest'
import { nextTick } from 'vue'
import { mount } from '@vue/test-utils'
vi.mock('@/stores/app', () => ({
useAppStore: () => ({
showSuccess: vi.fn(),
showError: vi.fn()
})
}))
vi.mock('@/composables/useClipboard', () => ({
useClipboard: () => ({
copied: { value: false },
copyToClipboard: vi.fn()
})
}))
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
import OAuthAuthorizationFlow from '../OAuthAuthorizationFlow.vue'
describe('OAuthAuthorizationFlow', () => {
it('extracts code, state, and callback metadata from a full Kiro callback URL', async () => {
const wrapper = mount(OAuthAuthorizationFlow, {
props: {
addMethod: 'oauth',
platform: 'kiro',
authUrl: 'https://example.com/authorize',
sessionId: 'session-1'
},
global: {
stubs: {
Icon: true
}
}
})
const textarea = wrapper.get('textarea')
await textarea.setValue('http://localhost:49153/oauth/callback?code=abc123&state=state456&login_option=github')
await nextTick()
expect((textarea.element as HTMLTextAreaElement).value).toBe('abc123')
expect((wrapper.vm as any).oauthState).toBe('state456')
expect((wrapper.vm as any).oauthCallbackPath).toBe('/oauth/callback')
expect((wrapper.vm as any).oauthLoginOption).toBe('github')
})
})
@@ -489,7 +489,8 @@ const platformOptions = [
{ value: 'anthropic', label: 'Anthropic' },
{ value: 'openai', label: 'OpenAI' },
{ value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' }
{ value: 'antigravity', label: 'Antigravity' },
{ value: 'kiro', label: 'Kiro' }
]
// Load rules when dialog opens
@@ -25,7 +25,7 @@ const updateType = (value: string | number | boolean | null) => { emit('update:f
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
const updatePrivacyMode = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, privacy_mode: value }) }
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }])
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'kiro', label: 'Kiro' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }, { value: 'unschedulable', label: t('admin.accounts.status.unschedulable') }])
const privacyOpts = computed(() => [
@@ -2,14 +2,11 @@
<BaseDialog
:show="show"
:title="t('admin.accounts.reAuthorizeAccount')"
width="normal"
:width="isKiro ? 'wide' : 'normal'"
@close="handleClose"
>
<div v-if="account" class="space-y-4">
<!-- Account Info -->
<div
class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700"
>
<div class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
<div class="flex items-center gap-3">
<div
:class="[
@@ -18,6 +15,8 @@
? 'from-green-500 to-green-600'
: isGemini
? 'from-blue-500 to-blue-600'
: isKiro
? 'from-amber-500 to-amber-600'
: isAntigravity
? 'from-purple-500 to-purple-600'
: 'from-orange-500 to-orange-600'
@@ -26,15 +25,15 @@
<Icon name="sparkles" size="md" class="text-white" />
</div>
<div>
<span class="block font-semibold text-gray-900 dark:text-white">{{
account.name
}}</span>
<span class="block font-semibold text-gray-900 dark:text-white">{{ account.name }}</span>
<span class="text-sm text-gray-500 dark:text-gray-400">
{{
isOpenAI
? t('admin.accounts.openaiAccount')
: isGemini
? t('admin.accounts.geminiAccount')
: isKiro
? t('admin.accounts.kiroAccount')
: isAntigravity
? t('admin.accounts.antigravityAccount')
: t('admin.accounts.claudeCodeAccount')
@@ -44,7 +43,6 @@
</div>
</div>
<!-- Add Method Selection (Claude only) -->
<fieldset v-if="isAnthropic" class="border-0 p-0">
<legend class="input-label">{{ t('admin.accounts.oauth.authMethod') }}</legend>
<div class="mt-2 flex gap-4">
@@ -55,9 +53,7 @@
value="oauth"
class="mr-2 text-primary-600 focus:ring-primary-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{
t('admin.accounts.types.oauth')
}}</span>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.types.oauth') }}</span>
</label>
<label class="flex cursor-pointer items-center">
<input
@@ -66,14 +62,11 @@
value="setup-token"
class="mr-2 text-primary-600 focus:ring-primary-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{
t('admin.accounts.setupTokenLongLived')
}}</span>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.setupTokenLongLived') }}</span>
</label>
</div>
</fieldset>
<!-- Gemini OAuth Type Display (read-only) -->
<div v-if="isGemini" class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
<div class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
@@ -116,7 +109,187 @@
</div>
</div>
<div v-if="isKiro" class="rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-700/40 dark:bg-amber-900/20">
<div class="mb-3 text-sm font-medium text-amber-900 dark:text-amber-100">
{{ t('admin.accounts.oauth.kiro.authModeTitle') }}
</div>
<div class="grid grid-cols-1 gap-3 md:grid-cols-3">
<button
type="button"
@click="kiroAccountType = 'oauth'"
:class="kiroModeClass('oauth')"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroAccountType === 'oauth'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="key" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.oauthTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.oauthSubtitle') }}
</span>
</div>
</button>
<button
type="button"
@click="kiroAccountType = 'idc'"
:class="kiroModeClass('idc')"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroAccountType === 'idc'
? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.idcTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.idcSubtitle') }}
</span>
</div>
</button>
<button
type="button"
@click="kiroAccountType = 'import'"
:class="kiroModeClass('import')"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroAccountType === 'import'
? 'bg-slate-700 text-white dark:bg-slate-500'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="download" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.importTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.importSubtitle') }}
</span>
</div>
</button>
</div>
<div v-if="kiroAccountType === 'oauth'" class="mt-3 space-y-3">
<div class="text-xs text-amber-800 dark:text-amber-200">
{{ t('admin.accounts.oauth.kiro.oauthSubtitle') }}
</div>
<div class="grid grid-cols-2 gap-3">
<button
type="button"
@click="kiroOAuthProvider = 'google'"
:class="kiroProviderClass('google')"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroOAuthProvider === 'google'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="user" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.googleTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.googleDesc') }}
</span>
</div>
</button>
<button
type="button"
@click="kiroOAuthProvider = 'github'"
:class="kiroProviderClass('github')"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
kiroOAuthProvider === 'github'
? 'bg-slate-700 text-white dark:bg-slate-500'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="terminal" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.oauth.kiro.githubTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.oauth.kiro.githubDesc') }}
</span>
</div>
</button>
</div>
</div>
<div v-if="kiroAccountType === 'idc'" class="mt-3 grid gap-3 md:grid-cols-2">
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.idcStartUrlLabel') }}</label>
<input
v-model="kiroIDCStartUrl"
type="text"
class="input"
:placeholder="t('admin.accounts.oauth.kiro.startUrlPlaceholder')"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.regionLabel') }}</label>
<input
v-model="kiroIDCRegion"
type="text"
class="input"
:placeholder="t('admin.accounts.oauth.kiro.regionPlaceholder')"
/>
</div>
</div>
<div v-if="isKiroImportMode" class="mt-3 space-y-3">
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.tokenJsonLabel') }}</label>
<textarea
v-model="kiroTokenJson"
rows="7"
class="input font-mono text-xs"
placeholder='{"accessToken":"...","refreshToken":"..."}'
/>
<p class="input-hint">{{ t('admin.accounts.oauth.kiro.tokenJsonHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.oauth.kiro.deviceRegistrationLabel') }}</label>
<textarea
v-model="kiroDeviceRegistrationJson"
rows="4"
class="input font-mono text-xs"
placeholder='{"clientId":"...","clientSecret":"..."}'
/>
</div>
</div>
</div>
<OAuthAuthorizationFlow
v-if="!isKiroImportMode"
ref="oauthFlowRef"
:add-method="addMethod"
:auth-url="currentAuthUrl"
@@ -128,12 +301,11 @@
:show-cookie-option="isAnthropic"
:allow-multiple="false"
:method-label="t('admin.accounts.inputMethod')"
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
:platform="oauthPlatform"
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
@generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth"
/>
</div>
<template #footer>
@@ -142,7 +314,16 @@
{{ t('common.cancel') }}
</button>
<button
v-if="isManualInputMethod"
v-if="isKiroImportMode"
type="button"
:disabled="currentLoading || !kiroTokenJson.trim()"
class="btn btn-primary"
@click="handleKiroImport"
>
{{ currentLoading ? t('admin.accounts.oauth.verifying') : t('admin.accounts.oauth.kiro.importAndUpdate') }}
</button>
<button
v-else-if="isManualInputMethod"
type="button"
:disabled="!canExchangeCode"
class="btn btn-primary"
@@ -161,18 +342,14 @@
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
/>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
/>
</svg>
{{
currentLoading
? t('admin.accounts.oauth.verifying')
: t('admin.accounts.oauth.completeAuth')
}}
{{ currentLoading ? t('admin.accounts.oauth.verifying') : t('admin.accounts.oauth.completeAuth') }}
</button>
</div>
</template>
@@ -180,28 +357,29 @@
</template>
<script setup lang="ts">
import { ref, computed, watch } from 'vue'
import { computed, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import {
useAccountOAuth,
type AddMethod,
type AuthInputMethod
} from '@/composables/useAccountOAuth'
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
import type { Account } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Icon from '@/components/icons/Icon.vue'
import OAuthAuthorizationFlow from '@/components/account/OAuthAuthorizationFlow.vue'
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
import {
type AddMethod,
type AuthInputMethod,
useAccountOAuth
} from '@/composables/useAccountOAuth'
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
import { useKiroOAuth } from '@/composables/useKiroOAuth'
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
import { useAppStore } from '@/stores/app'
import type { Account, AccountPlatform } from '@/types'
// Type for exposed OAuthAuthorizationFlow component
// Note: defineExpose automatically unwraps refs, so we use the unwrapped types
interface OAuthFlowExposed {
authCode: string
oauthState: string
oauthCallbackPath: string
oauthLoginOption: string
projectId: string
sessionKey: string
inputMethod: AuthInputMethod
@@ -222,77 +400,96 @@ const emit = defineEmits<{
const appStore = useAppStore()
const { t } = useI18n()
// OAuth composables
const claudeOAuth = useAccountOAuth()
const openaiOAuth = useOpenAIOAuth()
const geminiOAuth = useGeminiOAuth()
const antigravityOAuth = useAntigravityOAuth()
const kiroOAuth = useKiroOAuth()
// Refs
const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
// State
const addMethod = ref<AddMethod>('oauth')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist')
const kiroAccountType = ref<'oauth' | 'idc' | 'import'>('oauth')
const kiroOAuthProvider = ref<'google' | 'github'>('google')
const kiroIDCStartUrl = ref('https://view.awsapps.com/start')
const kiroIDCRegion = ref('us-east-1')
const kiroTokenJson = ref('')
const kiroDeviceRegistrationJson = ref('')
// Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai')
const isOpenAILike = computed(() => isOpenAI.value)
const isGemini = computed(() => props.account?.platform === 'gemini')
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
const isKiro = computed(() => props.account?.platform === 'kiro')
const oauthPlatform = computed<AccountPlatform>(() => {
if (isOpenAI.value) return 'openai'
if (isGemini.value) return 'gemini'
if (isKiro.value) return 'kiro'
if (isAntigravity.value) return 'antigravity'
return 'anthropic'
})
// Computed - current OAuth state based on platform
const currentAuthUrl = computed(() => {
if (isOpenAILike.value) return openaiOAuth.authUrl.value
if (isGemini.value) return geminiOAuth.authUrl.value
if (isKiro.value) return kiroOAuth.authUrl.value
if (isAntigravity.value) return antigravityOAuth.authUrl.value
return claudeOAuth.authUrl.value
})
const currentSessionId = computed(() => {
if (isOpenAILike.value) return openaiOAuth.sessionId.value
if (isGemini.value) return geminiOAuth.sessionId.value
if (isKiro.value) return kiroOAuth.sessionId.value
if (isAntigravity.value) return antigravityOAuth.sessionId.value
return claudeOAuth.sessionId.value
})
const currentLoading = computed(() => {
if (isOpenAILike.value) return openaiOAuth.loading.value
if (isGemini.value) return geminiOAuth.loading.value
if (isKiro.value) return kiroOAuth.loading.value
if (isAntigravity.value) return antigravityOAuth.loading.value
return claudeOAuth.loading.value
})
const currentError = computed(() => {
if (isOpenAILike.value) return openaiOAuth.error.value
if (isGemini.value) return geminiOAuth.error.value
if (isKiro.value) return kiroOAuth.error.value
if (isAntigravity.value) return antigravityOAuth.error.value
return claudeOAuth.error.value
})
// Computed
const isKiroImportMode = computed(() => isKiro.value && kiroAccountType.value === 'import')
const isManualInputMethod = computed(() => {
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option)
return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
return isOpenAILike.value || isGemini.value || isKiro.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
})
const canExchangeCode = computed(() => {
if (isKiroImportMode.value) {
return false
}
const authCode = oauthFlowRef.value?.authCode || ''
const sessionId = currentSessionId.value
const loading = currentLoading.value
return authCode.trim() && sessionId && !loading
return !!(authCode.trim() && currentSessionId.value && !currentLoading.value)
})
// Watchers
watch(
() => props.show,
(newVal) => {
if (newVal && props.account) {
// Initialize addMethod based on current account type (Claude only)
if (
isAnthropic.value &&
(props.account.type === 'oauth' || props.account.type === 'setup-token')
) {
if (!newVal || !props.account) {
resetState()
return
}
if (isAnthropic.value && (props.account.type === 'oauth' || props.account.type === 'setup-token')) {
addMethod.value = props.account.type as AddMethod
}
if (isGemini.value) {
const creds = (props.account.credentials || {}) as Record<string, unknown>
geminiOAuthType.value =
@@ -302,42 +499,123 @@ watch(
? 'ai_studio'
: 'code_assist'
}
} else {
resetState()
if (isKiro.value) {
const creds = (props.account.credentials || {}) as Record<string, unknown>
const authMethod = typeof creds.auth_method === 'string' ? creds.auth_method : ''
const provider = String(creds.provider || '').toLowerCase()
kiroIDCStartUrl.value = typeof creds.start_url === 'string' && creds.start_url ? creds.start_url : 'https://view.awsapps.com/start'
kiroIDCRegion.value = typeof creds.region === 'string' && creds.region ? creds.region : 'us-east-1'
kiroAccountType.value = authMethod === 'idc' ? 'idc' : 'oauth'
kiroOAuthProvider.value = provider === 'github' ? 'github' : 'google'
}
}
)
// Methods
const resetState = () => {
addMethod.value = 'oauth'
geminiOAuthType.value = 'code_assist'
kiroAccountType.value = 'oauth'
kiroOAuthProvider.value = 'google'
kiroIDCStartUrl.value = 'https://view.awsapps.com/start'
kiroIDCRegion.value = 'us-east-1'
kiroTokenJson.value = ''
kiroDeviceRegistrationJson.value = ''
claudeOAuth.resetState()
openaiOAuth.resetState()
geminiOAuth.resetState()
antigravityOAuth.resetState()
kiroOAuth.resetState()
oauthFlowRef.value?.reset()
}
const kiroModeClass = (mode: typeof kiroAccountType.value) => [
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroAccountType.value === mode
? mode === 'idc'
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
: mode === 'import'
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
: 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: mode === 'idc'
? 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
: mode === 'import'
? 'border-gray-200 hover:border-slate-300 dark:border-dark-600 dark:hover:border-slate-700'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]
const kiroProviderClass = (provider: typeof kiroOAuthProvider.value) => [
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
kiroOAuthProvider.value === provider
? provider === 'github'
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
: 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: provider === 'github'
? 'border-amber-200 hover:border-slate-300 dark:border-amber-700/40 dark:hover:border-slate-700'
: 'border-amber-200 hover:border-amber-300 dark:border-amber-700/40 dark:hover:border-amber-600'
]
const handleClose = () => {
emit('close')
}
const buildUpdatedCredentials = (next: Record<string, unknown>) => ({
...((props.account?.credentials || {}) as Record<string, unknown>),
...next
})
const updateAccountCredentials = async (payload: {
type: 'oauth' | 'setup-token'
credentials: Record<string, unknown>
extra?: Record<string, unknown>
}) => {
if (!props.account) return
await adminAPI.accounts.update(props.account.id, payload)
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized', updatedAccount)
handleClose()
}
const handleGenerateUrl = async () => {
if (!props.account) return
if (isOpenAILike.value) {
await openaiOAuth.generateAuthUrl(props.account.proxy_id)
} else if (isGemini.value) {
return
}
if (isGemini.value) {
const creds = (props.account.credentials || {}) as Record<string, unknown>
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
const projectId = geminiOAuthType.value === 'code_assist' ? oauthFlowRef.value?.projectId : undefined
await geminiOAuth.generateAuthUrl(props.account.proxy_id, projectId, geminiOAuthType.value, tierId)
} else if (isAntigravity.value) {
await antigravityOAuth.generateAuthUrl(props.account.proxy_id)
} else {
await claudeOAuth.generateAuthUrl(addMethod.value, props.account.proxy_id)
return
}
if (isKiro.value) {
if (kiroAccountType.value === 'idc') {
await kiroOAuth.generateIDCAuthUrl({
proxyId: props.account.proxy_id,
startUrl: kiroIDCStartUrl.value,
region: kiroIDCRegion.value
})
return
}
await kiroOAuth.generateAuthUrl(
props.account.proxy_id,
kiroOAuthProvider.value === 'github' ? 'Github' : 'Google'
)
return
}
if (isAntigravity.value) {
await antigravityOAuth.generateAuthUrl(props.account.proxy_id)
return
}
await claudeOAuth.generateAuthUrl(addMethod.value, props.account.proxy_id)
}
const handleExchangeCode = async () => {
@@ -347,53 +625,37 @@ const handleExchangeCode = async () => {
if (!authCode.trim()) return
if (isOpenAILike.value) {
// OpenAI OAuth flow
const oauthClient = openaiOAuth
const sessionId = oauthClient.sessionId.value
const sessionId = openaiOAuth.sessionId.value
if (!sessionId) return
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
const stateToUse = (oauthFlowRef.value?.oauthState || openaiOAuth.oauthState.value || '').trim()
if (!stateToUse) {
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
appStore.showError(oauthClient.error.value)
openaiOAuth.error.value = t('admin.accounts.oauth.authFailed')
appStore.showError(openaiOAuth.error.value)
return
}
const tokenInfo = await oauthClient.exchangeAuthCode(
authCode.trim(),
sessionId,
stateToUse,
props.account.proxy_id
)
const tokenInfo = await openaiOAuth.exchangeAuthCode(authCode.trim(), sessionId, stateToUse, props.account.proxy_id)
if (!tokenInfo) return
// Build credentials and extra info
const credentials = oauthClient.buildCredentials(tokenInfo)
const extra = oauthClient.buildExtraInfo(tokenInfo)
try {
// Update account with new credentials
await adminAPI.accounts.update(props.account.id, {
type: 'oauth', // OpenAI OAuth is always 'oauth' type
credentials,
extra
await updateAccountCredentials({
type: 'oauth',
credentials: buildUpdatedCredentials(openaiOAuth.buildCredentials(tokenInfo)),
extra: openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
})
// Clear error status after successful re-authorization
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized', updatedAccount)
handleClose()
} catch (error: any) {
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(oauthClient.error.value)
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(openaiOAuth.error.value)
}
} else if (isGemini.value) {
return
}
if (isGemini.value) {
const sessionId = geminiOAuth.sessionId.value
if (!sessionId) return
const stateFromInput = oauthFlowRef.value?.oauthState || ''
const stateToUse = stateFromInput || geminiOAuth.state.value
const stateToUse = oauthFlowRef.value?.oauthState || geminiOAuth.state.value
if (!stateToUse) return
const tokenInfo = await geminiOAuth.exchangeAuthCode({
@@ -402,32 +664,58 @@ const handleExchangeCode = async () => {
state: stateToUse,
proxyId: props.account.proxy_id,
oauthType: geminiOAuthType.value,
tierId: typeof (props.account.credentials as any)?.tier_id === 'string' ? ((props.account.credentials as any).tier_id as string) : undefined
tierId: typeof (props.account.credentials as any)?.tier_id === 'string'
? ((props.account.credentials as any).tier_id as string)
: undefined
})
if (!tokenInfo) return
const credentials = geminiOAuth.buildCredentials(tokenInfo)
try {
await adminAPI.accounts.update(props.account.id, {
await updateAccountCredentials({
type: 'oauth',
credentials
credentials: buildUpdatedCredentials(geminiOAuth.buildCredentials(tokenInfo))
})
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized', updatedAccount)
handleClose()
} catch (error: any) {
geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(geminiOAuth.error.value)
}
} else if (isAntigravity.value) {
// Antigravity OAuth flow
return
}
if (isKiro.value) {
const sessionId = kiroOAuth.sessionId.value
if (!sessionId) return
const stateToUse = oauthFlowRef.value?.oauthState || kiroOAuth.state.value
if (!stateToUse) return
const tokenInfo = await kiroOAuth.exchangeAuthCode({
code: authCode.trim(),
sessionId,
state: stateToUse,
callbackPath: oauthFlowRef.value?.oauthCallbackPath || '',
loginOption: oauthFlowRef.value?.oauthLoginOption || '',
proxyId: props.account.proxy_id
})
if (!tokenInfo) return
try {
await updateAccountCredentials({
type: 'oauth',
credentials: buildUpdatedCredentials(kiroOAuth.buildCredentials(tokenInfo))
})
} catch (error: any) {
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(kiroOAuth.error.value)
}
return
}
if (isAntigravity.value) {
const sessionId = antigravityOAuth.sessionId.value
if (!sessionId) return
const stateFromInput = oauthFlowRef.value?.oauthState || ''
const stateToUse = stateFromInput || antigravityOAuth.state.value
const stateToUse = oauthFlowRef.value?.oauthState || antigravityOAuth.state.value
if (!stateToUse) return
const tokenInfo = await antigravityOAuth.exchangeAuthCode({
@@ -438,23 +726,18 @@ const handleExchangeCode = async () => {
})
if (!tokenInfo) return
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
try {
await adminAPI.accounts.update(props.account.id, {
await updateAccountCredentials({
type: 'oauth',
credentials
credentials: buildUpdatedCredentials(antigravityOAuth.buildCredentials(tokenInfo))
})
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized', updatedAccount)
handleClose()
} catch (error: any) {
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(antigravityOAuth.error.value)
}
} else {
// Claude OAuth flow
return
}
const sessionId = claudeOAuth.sessionId.value
if (!sessionId) return
@@ -474,32 +757,41 @@ const handleExchangeCode = async () => {
...proxyConfig
})
const extra = claudeOAuth.buildExtraInfo(tokenInfo)
// Update account with new credentials and type
await adminAPI.accounts.update(props.account.id, {
type: addMethod.value, // Update type based on selected method
credentials: tokenInfo,
extra
await updateAccountCredentials({
type: addMethod.value,
credentials: buildUpdatedCredentials(tokenInfo),
extra: claudeOAuth.buildExtraInfo(tokenInfo)
})
// Clear error status after successful re-authorization
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized', updatedAccount)
handleClose()
} catch (error: any) {
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(claudeOAuth.error.value)
} finally {
claudeOAuth.loading.value = false
}
}
const handleKiroImport = async () => {
if (!props.account || !isKiroImportMode.value || !kiroTokenJson.value.trim()) return
const tokenInfo = await kiroOAuth.importToken(
kiroTokenJson.value,
kiroDeviceRegistrationJson.value || undefined
)
if (!tokenInfo) return
try {
await updateAccountCredentials({
type: 'oauth',
credentials: buildUpdatedCredentials(kiroOAuth.buildCredentials(tokenInfo))
})
} catch (error: any) {
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(kiroOAuth.error.value)
}
}
const handleCookieAuth = async (sessionKey: string) => {
if (!props.account || isOpenAILike.value) return
if (!props.account || isOpenAILike.value || isKiro.value) return
claudeOAuth.loading.value = true
claudeOAuth.error.value = ''
@@ -517,24 +809,13 @@ const handleCookieAuth = async (sessionKey: string) => {
...proxyConfig
})
const extra = claudeOAuth.buildExtraInfo(tokenInfo)
// Update account with new credentials and type
await adminAPI.accounts.update(props.account.id, {
type: addMethod.value, // Update type based on selected method
credentials: tokenInfo,
extra
await updateAccountCredentials({
type: addMethod.value,
credentials: buildUpdatedCredentials(tokenInfo),
extra: claudeOAuth.buildExtraInfo(tokenInfo)
})
// Clear error status after successful re-authorization
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized', updatedAccount)
handleClose()
} catch (error: any) {
claudeOAuth.error.value =
error.response?.data?.detail || t('admin.accounts.oauth.cookieAuthFailed')
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.cookieAuthFailed')
} finally {
claudeOAuth.loading.value = false
}
@@ -115,7 +115,7 @@ const labelClass = computed(() => {
}
//
if (props.platform === 'anthropic') {
if (props.platform === 'anthropic' || props.platform === 'kiro') {
return `${base} bg-orange-200/60 text-orange-800 dark:bg-orange-800/40 dark:text-orange-300`
}
if (props.platform === 'openai') {
@@ -129,7 +129,7 @@ const labelClass = computed(() => {
// Badge color based on platform and subscription type
const badgeClass = computed(() => {
if (props.platform === 'anthropic') {
if (props.platform === 'anthropic' || props.platform === 'kiro') {
// Claude: orange theme
return isSubscription.value
? 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
@@ -91,6 +91,8 @@ const ratePillClass = computed(() => {
return 'bg-green-50 text-green-700 dark:bg-green-900/20 dark:text-green-400'
case 'gemini':
return 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400'
case 'kiro':
return 'bg-amber-50 text-amber-700 dark:bg-amber-900/20 dark:text-amber-400'
default: // antigravity and others
return 'bg-violet-50 text-violet-700 dark:bg-violet-900/20 dark:text-violet-400'
}
@@ -1,6 +1,6 @@
<template>
<!-- Claude/Anthropic logo -->
<svg v-if="platform === 'anthropic'" :class="sizeClass" viewBox="0 0 16 16" fill="currentColor">
<svg v-if="platform === 'anthropic' || platform === 'kiro'" :class="sizeClass" viewBox="0 0 16 16" fill="currentColor">
<path
d="m3.127 10.604 3.135-1.76.053-.153-.053-.085H6.11l-.525-.032-1.791-.048-1.554-.065-1.505-.08-.38-.081L0 7.832l.036-.234.32-.214.455.04 1.009.069 1.513.105 1.097.064 1.626.17h.259l.036-.105-.089-.065-.068-.064-1.566-1.062-1.695-1.121-.887-.646-.48-.327-.243-.306-.104-.67.435-.48.585.04.15.04.593.456 1.267.981 1.654 1.218.242.202.097-.068.012-.049-.109-.181-.9-1.626-.96-1.655-.428-.686-.113-.411a2 2 0 0 1-.068-.484l.496-.674L4.446 0l.662.089.279.242.411.94.666 1.48 1.033 2.014.302.597.162.553.06.17h.105v-.097l.085-1.134.157-1.392.154-1.792.052-.504.25-.605.497-.327.387.186.319.456-.045.294-.19 1.23-.37 1.93-.243 1.29h.142l.161-.16.654-.868 1.097-1.372.484-.545.565-.601.363-.287h.686l.505.751-.226.775-.707.895-.585.759-.839 1.13-.524.904.048.072.125-.012 1.897-.403 1.024-.186 1.223-.21.553.258.06.263-.218.536-1.307.323-1.533.307-2.284.54-.028.02.032.04 1.029.098.44.024h1.077l2.005.15.525.346.315.424-.053.323-.807.411-3.631-.863-.872-.218h-.12v.073l.726.71 1.331 1.202 1.667 1.55.084.383-.214.302-.226-.032-1.464-1.101-.565-.497-1.28-1.077h-.084v.113l.295.432 1.557 2.34.08.718-.112.234-.404.141-.444-.08-.911-1.28-.94-1.44-.759-1.291-.093.053-.448 4.821-.21.246-.484.186-.403-.307-.214-.496.214-.98.258-1.28.21-1.016.19-1.263.112-.42-.008-.028-.092.012-.953 1.307-1.448 1.957-1.146 1.227-.274.109-.477-.247.045-.44.266-.39 1.586-2.018.956-1.25.617-.723-.004-.105h-.036l-4.212 2.736-.75.096-.324-.302.04-.496.154-.162 1.267-.871z"
/>
@@ -31,10 +31,18 @@
</span>
</div>
<!-- Row 2: Plan type + Privacy mode (only if either exists) -->
<div v-if="planLabel || privacyBadge" class="inline-flex items-center overflow-hidden rounded-md">
<div v-if="planLabel || privacyBadge || overagesBadge" class="inline-flex items-center overflow-hidden rounded-md">
<span v-if="planLabel" :class="['inline-flex items-center gap-1 px-1.5 py-1', planBadgeClass]">
<span>{{ planLabel }}</span>
</span>
<span
v-if="overagesBadge"
:class="['inline-flex items-center gap-1 px-1.5 py-1', overagesBadge.class]"
:title="overagesBadge.title"
>
<Icon name="sparkles" size="xs" />
<span>{{ overagesBadge.label }}</span>
</span>
<span
v-if="privacyBadge"
:class="['inline-flex items-center gap-1 px-1.5 py-1', privacyBadge.class]"
@@ -66,6 +74,7 @@ interface Props {
platform: AccountPlatform
type: AccountType
planType?: string
overagesEnabled?: boolean
privacyMode?: string
subscriptionExpiresAt?: string
}
@@ -76,6 +85,7 @@ const platformLabel = computed(() => {
if (props.platform === 'anthropic') return 'Anthropic'
if (props.platform === 'openai') return 'OpenAI'
if (props.platform === 'antigravity') return 'Antigravity'
if (props.platform === 'kiro') return 'Kiro'
return 'Gemini'
})
@@ -126,6 +136,9 @@ const platformClass = computed(() => {
if (props.platform === 'antigravity') {
return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
}
if (props.platform === 'kiro') {
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
}
return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
})
@@ -139,6 +152,9 @@ const typeClass = computed(() => {
if (props.platform === 'antigravity') {
return 'bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400'
}
if (props.platform === 'kiro') {
return 'bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400'
}
return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400'
})
@@ -149,6 +165,15 @@ const planBadgeClass = computed(() => {
return typeClass.value
})
const overagesBadge = computed(() => {
if (props.platform !== 'kiro' || !props.overagesEnabled) return null
return {
label: t('admin.accounts.status.overageActive'),
title: t('admin.accounts.usageWindow.kiroOverage'),
class: 'bg-amber-100 text-amber-700 dark:bg-amber-900/30 dark:text-amber-300'
}
})
// Subscription expiration label (non-free only)
const expiresLabel = computed(() => {
if (!props.subscriptionExpiresAt || !props.planType) return ''
@@ -0,0 +1,42 @@
import { mount } from '@vue/test-utils'
import { describe, expect, it, vi } from 'vitest'
import PlatformTypeBadge from '../PlatformTypeBadge.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key === 'admin.accounts.status.overageActive' ? 'Overage' : key
})
}
})
describe('PlatformTypeBadge', () => {
it('shows Kiro overages tag next to the plan tag when enabled', () => {
const wrapper = mount(PlatformTypeBadge, {
props: {
platform: 'kiro',
type: 'oauth',
planType: 'KIRO PRO+',
overagesEnabled: true
}
})
expect(wrapper.text()).toContain('KIRO PRO+')
expect(wrapper.text()).toContain('Overage')
})
it('does not show overages tag for non-Kiro accounts', () => {
const wrapper = mount(PlatformTypeBadge, {
props: {
platform: 'openai',
type: 'oauth',
planType: 'Pro',
overagesEnabled: true
}
})
expect(wrapper.text()).not.toContain('Overage')
})
})
@@ -4,7 +4,12 @@ vi.mock('@/api/admin/accounts', () => ({
getAntigravityDefaultModelMapping: vi.fn()
}))
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
import {
buildModelMappingObject,
fetchKiroDefaultMappings,
getModelsByPlatform,
getPresetMappingsByPlatform
} from '../useModelWhitelist'
describe('useModelWhitelist', () => {
it('openai 模型列表包含 GPT-5.4 官方快照', () => {
@@ -51,6 +56,58 @@ describe('useModelWhitelist', () => {
expect(models.indexOf('gemini-2.5-flash-image')).toBeLessThan(models.indexOf('gemini-2.5-flash-lite'))
})
it('kiro 模型列表不暴露旧的 -agentic / -chat 后缀', () => {
const models = getModelsByPlatform('kiro')
expect(models).toContain('claude-sonnet-4-6')
expect(models).toContain('claude-sonnet-4-6-thinking')
expect(models).not.toContain('claude-sonnet-4-6-chat')
expect(models.every((model) => !model.endsWith('-agentic') && !model.endsWith('-chat'))).toBe(true)
})
it('kiro 模型列表只保留 Claude 模型', () => {
const models = getModelsByPlatform('kiro')
expect(models).toEqual([
'claude-opus-4-6',
'claude-opus-4-6-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-6-thinking',
'claude-opus-4-5-20251101',
'claude-opus-4-5-20251101-thinking',
'claude-sonnet-4-5-20250929',
'claude-sonnet-4-5-20250929-thinking',
'claude-haiku-4-5-20251001',
'claude-haiku-4-5-20251001-thinking'
])
expect(models.every(model => model.startsWith('claude-'))).toBe(true)
expect(models.some(model => model.endsWith('-agentic'))).toBe(false)
expect(models.some(model => model.endsWith('-chat'))).toBe(false)
expect(models).not.toContain('kiro-auto')
expect(models).not.toContain('claude-opus-4-5')
expect(models).not.toContain('claude-sonnet-4-5')
expect(models).not.toContain('claude-sonnet-4')
expect(models).not.toContain('claude-3-5-sonnet-20241022')
expect(models).not.toContain('claude-3-5-haiku-20241022')
expect(models).not.toContain('claude-haiku-4-5')
expect(models).not.toContain('gpt-4o')
expect(models).not.toContain('gpt-4')
expect(models).not.toContain('gpt-4-turbo')
expect(models).not.toContain('gpt-3.5-turbo')
expect(models).not.toContain('deepseek-3-2')
expect(models).not.toContain('minimax-m2-1')
expect(models).not.toContain('qwen3-coder-next')
})
it('claude 模型列表包含 dated 和 thinking 兼容别名', () => {
const models = getModelsByPlatform('claude')
expect(models).toContain('claude-opus-4-6-thinking')
expect(models).toContain('claude-opus-4-5-20251101-thinking')
expect(models).toContain('claude-sonnet-4-20250514-thinking')
expect(models).toContain('claude-haiku-4-5-20251001-thinking')
})
it('whitelist 模式会忽略通配符条目', () => {
const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
expect(mapping).toEqual({
@@ -73,4 +130,68 @@ describe('useModelWhitelist', () => {
'gpt-5.4-mini': 'gpt-5.4-mini'
})
})
it('kiro 预设映射只暴露 Claude 入口', () => {
const mappings = getPresetMappingsByPlatform('kiro')
const mappingTargets = mappings.map(item => item.to)
expect(mappings.map(({ from, to }) => ({ from, to }))).toEqual([
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
])
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
expect(mappingTargets.every(model => model.startsWith('claude-'))).toBe(true)
expect(mappingTargets.some(model => model.endsWith('-agentic'))).toBe(false)
expect(mappingTargets.some(model => model.endsWith('-chat'))).toBe(false)
expect(mappingTargets).not.toContain('kiro-auto')
expect(mappingTargets.some(model => model.startsWith('kiro-'))).toBe(false)
expect(mappings.some(item => item.from === 'claude-opus-4-5')).toBe(false)
expect(mappings.some(item => item.from === 'claude-sonnet-4-5')).toBe(false)
expect(mappings.some(item => item.from === 'claude-sonnet-4')).toBe(false)
expect(mappings.some(item => item.from === 'claude-3-5-sonnet-20241022')).toBe(false)
expect(mappings.some(item => item.from === 'claude-3-5-haiku-20241022')).toBe(false)
expect(mappings.some(item => item.from === 'claude-haiku-4-5')).toBe(false)
expect(mappingTargets).not.toContain('gpt-4o')
expect(mappingTargets).not.toContain('gpt-4')
expect(mappingTargets).not.toContain('gpt-4-turbo')
expect(mappingTargets).not.toContain('gpt-3.5-turbo')
expect(mappingTargets).not.toContain('deepseek-3.2')
expect(mappingTargets).not.toContain('minimax-m2.1')
expect(mappingTargets).not.toContain('qwen3-coder-next')
})
it('kiro 默认映射会在前端填充所有可精确定价模型', async () => {
const mappings = await fetchKiroDefaultMappings()
expect(mappings).toEqual(expect.arrayContaining([
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
]))
expect(mappings).toHaveLength(10)
expect(mappings.every(item => !item.from.startsWith('kiro-'))).toBe(true)
expect(mappings.every(item => !item.to.startsWith('kiro-'))).toBe(true)
expect(mappings.every(item => !item.from.endsWith('-agentic'))).toBe(true)
expect(mappings.every(item => !item.to.endsWith('-agentic'))).toBe(true)
expect(mappings.every(item => !item.from.endsWith('-chat'))).toBe(true)
expect(mappings.every(item => !item.to.endsWith('-chat'))).toBe(true)
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
expect(mappings.every(item => item.to.startsWith('claude-'))).toBe(true)
expect(mappings.some(item => item.to === 'claude-opus-4-7')).toBe(false)
})
})

Some files were not shown because too many files have changed in this diff Show More