feat(backend): add kiro account support

This commit is contained in:
nianzs
2026-04-29 16:29:21 +08:00
parent 9d801595c9
commit 05bc424c9a
60 changed files with 11916 additions and 38 deletions
+5
View File
@@ -93,6 +93,7 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
kiroOAuth *service.KiroOAuthService,
openAIGateway *service.OpenAIGatewayService, openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService, scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService, backupSvc *service.BackupService,
@@ -216,6 +217,10 @@ func provideCleanup(
antigravityOAuth.Stop() antigravityOAuth.Stop()
return nil return nil
}}, }},
{"KiroOAuthService", func() error {
kiroOAuth.Stop()
return nil
}},
{"OpenAIWSPool", func() error { {"OpenAIWSPool", func() error {
if openAIGateway != nil { if openAIGateway != nil {
openAIGateway.CloseOpenAIWSPool() openAIGateway.CloseOpenAIWSPool()
+14 -5
View File
@@ -146,13 +146,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient) internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -166,6 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
kiroOAuthHandler := admin.NewKiroOAuthHandler(kiroOAuthService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService) promoHandler := admin.NewPromoHandler(promoService)
@@ -179,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
kiroCooldownStore := service.ProvideKiroCooldownStore(redisClient)
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db) channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService) affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, kiroOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -256,13 +260,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -312,6 +316,7 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
kiroOAuth *service.KiroOAuthService,
openAIGateway *service.OpenAIGatewayService, openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService, scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService, backupSvc *service.BackupService,
@@ -434,6 +439,10 @@ func provideCleanup(
antigravityOAuth.Stop() antigravityOAuth.Stop()
return nil return nil
}}, }},
{"KiroOAuthService", func() error {
kiroOAuth.Stop()
return nil
}},
{"OpenAIWSPool", func() error { {"OpenAIWSPool", func() error {
if openAIGateway != nil { if openAIGateway != nil {
openAIGateway.CloseOpenAIWSPool() openAIGateway.CloseOpenAIWSPool()
+2
View File
@@ -36,6 +36,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
antigravityOAuthSvc, antigravityOAuthSvc,
nil, nil,
nil, nil,
nil,
cfg, cfg,
nil, nil,
) )
@@ -72,6 +73,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
openAIOAuthSvc, openAIOAuthSvc,
geminiOAuthSvc, geminiOAuthSvc,
antigravityOAuthSvc, antigravityOAuthSvc,
nil, // kiroOAuth
nil, // openAIGateway nil, // openAIGateway
nil, // scheduledTestRunner nil, // scheduledTestRunner
nil, // backupSvc nil, // backupSvc
+10
View File
@@ -635,6 +635,8 @@ type GatewayConfig struct {
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用 // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
// KiroStreamKeepaliveInterval: Kiro 流式 keepalive 间隔(秒),0使用默认 25 秒
KiroStreamKeepaliveInterval int `mapstructure:"kiro_stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值) // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
MaxLineSize int `mapstructure:"max_line_size"` MaxLineSize int `mapstructure:"max_line_size"`
@@ -1689,6 +1691,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10) viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.kiro_stream_keepalive_interval", 25)
viper.SetDefault("gateway.max_line_size", 500*1024*1024) viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
@@ -2277,6 +2280,13 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
} }
if c.Gateway.KiroStreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be non-negative")
}
if c.Gateway.KiroStreamKeepaliveInterval != 0 &&
(c.Gateway.KiroStreamKeepaliveInterval < 5 || c.Gateway.KiroStreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be 0 or between 5-30 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds // 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
+16
View File
@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai" PlatformOpenAI = "openai"
PlatformGemini = "gemini" PlatformGemini = "gemini"
PlatformAntigravity = "antigravity" PlatformAntigravity = "antigravity"
PlatformKiro = "kiro"
) )
// Account type constants // Account type constants
@@ -116,6 +117,21 @@ var DefaultAntigravityModelMapping = map[string]string{
"tab_flash_lite_preview": "tab_flash_lite_preview", "tab_flash_lite_preview": "tab_flash_lite_preview",
} }
// DefaultKiroModelMapping 是 Kiro 平台的默认模型映射。
// 键为对外暴露/允许请求的模型名,值为实际发送到 Kiro 上游的模型名。
var DefaultKiroModelMapping = map[string]string{
"claude-opus-4-6": "claude-opus-4.6",
"claude-opus-4-6-thinking": "claude-opus-4.6",
"claude-sonnet-4-6": "claude-sonnet-4.6",
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
"claude-opus-4-5-20251101": "claude-opus-4.5",
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
}
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射 // DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID // 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的 // 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
+55 -1
View File
@@ -1,6 +1,9 @@
package domain package domain
import "testing" import (
"strings"
"testing"
)
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) { func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
t.Parallel() t.Parallel()
@@ -24,3 +27,54 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
} }
} }
} }
func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
t.Parallel()
expected := map[string]string{
"claude-opus-4-6": "claude-opus-4.6",
"claude-opus-4-6-thinking": "claude-opus-4.6",
"claude-sonnet-4-6": "claude-sonnet-4.6",
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
"claude-opus-4-5-20251101": "claude-opus-4.5",
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
}
if len(DefaultKiroModelMapping) != len(expected) {
t.Fatalf("expected %d Kiro mappings, got %d", len(expected), len(DefaultKiroModelMapping))
}
for model, want := range expected {
if got := DefaultKiroModelMapping[model]; got != want {
t.Fatalf("unexpected Kiro mapping for %q: got %q want %q", model, got, want)
}
}
for _, model := range []string{
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-sonnet-4",
"gpt-4o",
"gpt-4",
"deepseek-3-2",
"minimax-m2-1",
"qwen3-coder-next",
"claude-opus-4-7",
"claude-sonnet-4-6-chat",
} {
if _, ok := DefaultKiroModelMapping[model]; ok {
t.Fatalf("did not expect %q to remain in DefaultKiroModelMapping", model)
}
}
for model := range DefaultKiroModelMapping {
if strings.HasSuffix(model, "-agentic") {
t.Fatalf("did not expect agentic Kiro mapping %q", model)
}
if strings.HasSuffix(model, "-chat") {
t.Fatalf("did not expect chat-only Kiro mapping %q", model)
}
}
}
@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -179,6 +180,9 @@ type AccountWithConcurrency struct {
const accountListGroupUngroupedQueryValue = "ungrouped" const accountListGroupUngroupedQueryValue = "ungrouped"
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
if h.accountUsageService != nil {
h.accountUsageService.EnrichAccountWithKiroRuntimeState(ctx, account)
}
item := AccountWithConcurrency{ item := AccountWithConcurrency{
Account: dto.AccountFromService(account), Account: dto.AccountFromService(account),
CurrentConcurrency: 0, CurrentConcurrency: 0,
@@ -351,6 +355,9 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts)) result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if h.accountUsageService != nil {
h.accountUsageService.EnrichAccountWithKiroRuntimeState(c.Request.Context(), acc)
}
item := AccountWithConcurrency{ item := AccountWithConcurrency{
Account: dto.AccountFromService(acc), Account: dto.AccountFromService(acc),
CurrentConcurrency: concurrencyCounts[acc.ID], CurrentConcurrency: concurrencyCounts[acc.ID],
@@ -1913,6 +1920,18 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return return
} }
// Handle Kiro accounts
if account.Platform == service.PlatformKiro {
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, kiropkg.DefaultModels)
return
}
response.Success(c, buildMappedKiroModels(mapping))
return
}
// Handle Claude/Anthropic accounts // Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models // For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() { if account.IsOAuth() {
@@ -1954,6 +1973,28 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models) response.Success(c, models)
} }
func buildMappedKiroModels(mapping map[string]string) []kiropkg.Model {
models := make([]kiropkg.Model, 0, len(mapping))
for requestedModel := range mapping {
var found bool
for _, dm := range kiropkg.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, kiropkg.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
})
}
}
return models
}
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account // SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy // POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) { func (h *AccountHandler) SetPrivacy(c *gin.Context) {
@@ -2166,6 +2207,12 @@ func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping) response.Success(c, domain.DefaultAntigravityModelMapping)
} }
// GetKiroDefaultModelMapping 获取 Kiro 平台的默认模型映射
// GET /api/v1/admin/accounts/kiro/default-model-mapping
func (h *AccountHandler) GetKiroDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultKiroModelMapping)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。 // sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。 // 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) { func sanitizeExtraBaseRPM(extra map[string]any) {
@@ -0,0 +1,149 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type KiroOAuthHandler struct {
kiroOAuthService *service.KiroOAuthService
}
func NewKiroOAuthHandler(kiroOAuthService *service.KiroOAuthService) *KiroOAuthHandler {
return &KiroOAuthHandler{kiroOAuthService: kiroOAuthService}
}
type KiroGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
Provider string `json:"provider"`
}
func (h *KiroOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req KiroGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
result, err := h.kiroOAuthService.GenerateAuthURL(c.Request.Context(), &service.KiroGenerateAuthURLInput{
ProxyID: req.ProxyID,
Provider: req.Provider,
})
if err != nil {
response.BadRequest(c, "生成授权链接失败: "+err.Error())
return
}
response.Success(c, result)
}
type KiroGenerateIDCAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
StartURL string `json:"start_url"`
Region string `json:"region"`
}
func (h *KiroOAuthHandler) GenerateIDCAuthURL(c *gin.Context) {
var req KiroGenerateIDCAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
result, err := h.kiroOAuthService.GenerateIDCAuthURL(c.Request.Context(), &service.KiroGenerateIDCAuthURLInput{
ProxyID: req.ProxyID,
StartURL: req.StartURL,
Region: req.Region,
})
if err != nil {
response.BadRequest(c, "生成 IDC 授权链接失败: "+err.Error())
return
}
response.Success(c, result)
}
type KiroExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
CallbackPath string `json:"callback_path"`
LoginOption string `json:"login_option"`
ProxyID *int64 `json:"proxy_id"`
}
func (h *KiroOAuthHandler) ExchangeCode(c *gin.Context) {
var req KiroExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.kiroOAuthService.ExchangeCode(c.Request.Context(), &service.KiroExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
CallbackPath: req.CallbackPath,
LoginOption: req.LoginOption,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Token 交换失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
type KiroRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
AuthMethod string `json:"auth_method"`
Provider string `json:"provider"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
StartURL string `json:"start_url"`
Region string `json:"region"`
ProfileArn string `json:"profile_arn"`
ProxyID *int64 `json:"proxy_id"`
}
func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
var req KiroRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.kiroOAuthService.RefreshToken(c.Request.Context(), &service.KiroRefreshTokenInput{
RefreshToken: req.RefreshToken,
AuthMethod: req.AuthMethod,
Provider: req.Provider,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
StartURL: req.StartURL,
Region: req.Region,
ProfileArn: req.ProfileArn,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "刷新 Kiro Token 失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
type KiroImportTokenRequest struct {
TokenJSON string `json:"token_json" binding:"required"`
DeviceRegistrationJSON string `json:"device_registration_json"`
}
func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
var req KiroImportTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
TokenJSON: req.TokenJSON,
DeviceRegistrationJSON: req.DeviceRegistrationJSON,
})
if err != nil {
response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
+6
View File
@@ -221,6 +221,12 @@ func AccountFromServiceShallow(a *service.Account) *Account {
OverloadUntil: a.OverloadUntil, OverloadUntil: a.OverloadUntil,
TempUnschedulableUntil: a.TempUnschedulableUntil, TempUnschedulableUntil: a.TempUnschedulableUntil,
TempUnschedulableReason: a.TempUnschedulableReason, TempUnschedulableReason: a.TempUnschedulableReason,
KiroQuotaState: a.KiroQuotaState,
KiroQuotaReason: a.KiroQuotaReason,
KiroQuotaResetAt: a.KiroQuotaResetAt,
KiroRuntimeState: a.KiroRuntimeState,
KiroRuntimeReason: a.KiroRuntimeReason,
KiroRuntimeResetAt: a.KiroRuntimeResetAt,
SessionWindowStart: a.SessionWindowStart, SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd, SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus, SessionWindowStatus: a.SessionWindowStatus,
+6
View File
@@ -174,6 +174,12 @@ type Account struct {
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"` TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"` TempUnschedulableReason string `json:"temp_unschedulable_reason"`
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
SessionWindowStart *time.Time `json:"session_window_start"` SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"` SessionWindowEnd *time.Time `json:"session_window_end"`
+1
View File
@@ -17,6 +17,7 @@ type AdminHandlers struct {
OpenAIOAuth *admin.OpenAIOAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler
AntigravityOAuth *admin.AntigravityOAuthHandler AntigravityOAuth *admin.AntigravityOAuthHandler
KiroOAuth *admin.KiroOAuthHandler
Proxy *admin.ProxyHandler Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler Redeem *admin.RedeemHandler
Promo *admin.PromoHandler Promo *admin.PromoHandler
+3
View File
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
openaiOAuthHandler *admin.OpenAIOAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler,
antigravityOAuthHandler *admin.AntigravityOAuthHandler, antigravityOAuthHandler *admin.AntigravityOAuthHandler,
kiroOAuthHandler *admin.KiroOAuthHandler,
proxyHandler *admin.ProxyHandler, proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler, redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler, promoHandler *admin.PromoHandler,
@@ -51,6 +52,7 @@ func ProvideAdminHandlers(
OpenAIOAuth: openaiOAuthHandler, OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler, GeminiOAuth: geminiOAuthHandler,
AntigravityOAuth: antigravityOAuthHandler, AntigravityOAuth: antigravityOAuthHandler,
KiroOAuth: kiroOAuthHandler,
Proxy: proxyHandler, Proxy: proxyHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Promo: promoHandler, Promo: promoHandler,
@@ -154,6 +156,7 @@ var ProviderSet = wire.NewSet(
admin.NewOpenAIOAuthHandler, admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler, admin.NewGeminiOAuthHandler,
admin.NewAntigravityOAuthHandler, admin.NewAntigravityOAuthHandler,
admin.NewKiroOAuthHandler,
admin.NewProxyHandler, admin.NewProxyHandler,
admin.NewRedeemHandler, admin.NewRedeemHandler,
admin.NewPromoHandler, admin.NewPromoHandler,
+258
View File
@@ -0,0 +1,258 @@
package kiro
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"runtime"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type RuntimeFingerprint struct {
OIDCSDKVersion string
RuntimeSDKVersion string
StreamingSDKVersion string
OSType string
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string
}
type runtimeFingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*RuntimeFingerprint
}
var (
globalRuntimeFingerprintManager *runtimeFingerprintManager
globalRuntimeFingerprintManagerOnce sync.Once
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
runtimeSDKVersions = []string{"1.0.0"}
streamingSDKVersions = []string{"1.0.34"}
osTypes = []string{"darwin", "win32"}
osVersions = map[string][]string{
"darwin": {"24.6.0"},
"win32": {"10.0.22631"},
}
nodeVersions = []string{"22.22.0"}
kiroVersions = []string{
"0.11.132", "0.11.131", "0.11.130",
}
)
func globalRuntimeFingerprints() *runtimeFingerprintManager {
globalRuntimeFingerprintManagerOnce.Do(func() {
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
fingerprints: make(map[string]*RuntimeFingerprint),
}
})
return globalRuntimeFingerprintManager
}
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
lookupKey := fingerprintLookupKey(accountKey, "runtime")
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
m.mu.RLock()
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
m.mu.RUnlock()
return fp
}
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
return fp
}
fp := generateRuntimeFingerprint(lookupKey, machineID)
m.fingerprints[lookupKey] = fp
return fp
}
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
hash := sha256.Sum256([]byte(accountKey))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
osType := goOSToNodePlatform(runtime.GOOS)
if !containsString(osTypes, osType) {
osType = osTypes[rng.Intn(len(osTypes))]
}
osVersionPool := osVersions[osType]
if len(osVersionPool) == 0 {
osVersionPool = osVersions["darwin"]
}
return &RuntimeFingerprint{
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
OSType: osType,
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
KiroHash: machineID,
}
}
func goOSToNodePlatform(goos string) string {
switch strings.TrimSpace(goos) {
case "windows":
return "win32"
default:
return strings.TrimSpace(goos)
}
}
func containsString(items []string, target string) bool {
for _, item := range items {
if item == target {
return true
}
}
return false
}
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
switch {
case strings.TrimSpace(clientIDHash) != "":
return clientIDHash
case strings.TrimSpace(clientID) != "":
return shortSHA(clientID)
case strings.TrimSpace(refreshToken) != "":
return shortSHA(refreshToken)
case strings.TrimSpace(profileArn) != "":
return shortSHA(profileArn)
case accountID > 0:
return shortSHA(fmt.Sprintf("account:%d", accountID))
default:
return shortSHA(uuid.NewString())
}
}
func NormalizeMachineID(machineID string) (string, bool) {
trimmed := strings.TrimSpace(machineID)
if len(trimmed) == 64 && isHexString(trimmed) {
return strings.ToLower(trimmed), true
}
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
normalized := strings.ToLower(withoutDashes)
return normalized + normalized, true
}
return "", false
}
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
return sha256Hex("KotlinNativeAPI/" + refreshToken)
}
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
return sha256Hex("KiroAPIKey/" + apiKey)
}
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
return sha256Hex("KiroFallback/" + fallbackKey)
}
return sha256Hex("KiroFallback/default")
}
func shortSHA(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
func sha256Hex(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:])
}
func isHexString(value string) bool {
for _, c := range value {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
if normalized, ok := NormalizeMachineID(machineID); ok {
return normalized
}
return BuildMachineID("", "", fallbackKey)
}
func fingerprintLookupKey(accountKey, fallback string) string {
key := strings.TrimSpace(accountKey)
if key != "" {
return key
}
return fallback
}
func BuildRuntimeUserAgent(accountKey, machineID string) string {
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
return map[string]string{
"Content-Type": "application/json",
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
"amz-sdk-invocation-id": uuid.NewString(),
"amz-sdk-request": "attempt=1; max=4",
}
}
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
return map[string]string{
"Content-Type": "application/json",
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
"Accept": "application/json, text/plain, */*",
}
}
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := baseDelay << attempt
if delay > maxDelay {
delay = maxDelay
}
const jitterFactor = 0.3
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
return time.Duration(float64(delay) * jitter)
}
@@ -0,0 +1,91 @@
package kiro
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildLoginHeadersStable(t *testing.T) {
headers1 := BuildLoginHeaders("", "")
headers2 := BuildLoginHeaders("", "")
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
require.Equal(t, "application/json", headers1["Content-Type"])
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
}
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
machineIDA := BuildMachineID("refresh-a", "", "")
machineIDB := BuildMachineID("refresh-b", "", "")
headers1 := BuildLoginHeaders("account-a", machineIDA)
headers2 := BuildLoginHeaders("account-a", machineIDA)
headers3 := BuildLoginHeaders("account-a", machineIDB)
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
require.Contains(t, headers1["User-Agent"], machineIDA)
}
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
machineID := BuildMachineID("", "", "oidc-machine")
headers1 := BuildOIDCHeaders("account-a", machineID)
headers2 := BuildOIDCHeaders("account-a", machineID)
headers3 := BuildOIDCHeaders("account-b", machineID)
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
}
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
key1 := BuildAccountKey("", "", "", "", 42)
key2 := BuildAccountKey("", "", "", "", 42)
key3 := BuildAccountKey("", "", "", "", 43)
require.Equal(t, key1, key2)
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
require.NotEqual(t, key1, key3)
}
func TestBuildMachineID(t *testing.T) {
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
fallback1 := BuildMachineID("", "", "account:1")
fallback2 := BuildMachineID("", "", "account:1")
fallback3 := BuildMachineID("", "", "account:2")
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
require.Equal(t, fallback1, fallback2)
require.NotEqual(t, fallback1, fallback3)
require.Len(t, fallback1, 64)
}
func TestNormalizeMachineID(t *testing.T) {
hex64 := strings.Repeat("A", 64)
normalized, ok := NormalizeMachineID(hex64)
require.True(t, ok)
require.Equal(t, strings.ToLower(hex64), normalized)
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
require.True(t, ok)
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
_, ok = NormalizeMachineID("not-a-machine-id")
require.False(t, ok)
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
require.False(t, ok)
}
func expectedKiroMachineID(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:])
}
+21
View File
@@ -0,0 +1,21 @@
package kiro
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
var DefaultModels = []Model{
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
}
+43
View File
@@ -0,0 +1,43 @@
package kiro
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
ids := make([]string, 0, len(DefaultModels))
for _, model := range DefaultModels {
ids = append(ids, model.ID)
}
require.Equal(t, []string{
"claude-opus-4-6",
"claude-opus-4-6-thinking",
"claude-sonnet-4-6",
"claude-sonnet-4-6-thinking",
"claude-opus-4-5-20251101",
"claude-opus-4-5-20251101-thinking",
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-5-20250929-thinking",
"claude-haiku-4-5-20251001",
"claude-haiku-4-5-20251001-thinking",
}, ids)
require.Contains(t, ids, "claude-sonnet-4-6")
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
require.NotContains(t, ids, "auto")
require.NotContains(t, ids, "claude-sonnet-4")
require.NotContains(t, ids, "gpt-4o")
require.NotContains(t, ids, "deepseek-3-2")
require.NotContains(t, ids, "minimax-m2-1")
require.NotContains(t, ids, "qwen3-coder-next")
require.NotContains(t, ids, "claude-opus-4-7")
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
for _, id := range ids {
require.NotContains(t, id, "kiro-")
require.NotContains(t, id, "-agentic")
require.NotContains(t, id, "-chat")
}
}
+511
View File
@@ -0,0 +1,511 @@
package kiro
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/google/uuid"
)
const (
socialAuthPortalURL = "https://app.kiro.dev"
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
defaultIDCRegion = "us-east-1"
BuilderIDStartURL = "https://view.awsapps.com/start"
sessionTTL = 10 * time.Minute
sessionCleanupEvery = 32
sessionCleanupMin = 32
)
var (
socialAuthEndpointURL = socialAuthEndpoint
oidcEndpointOverride = ""
)
type SocialProvider string
const (
SocialProviderGoogle SocialProvider = "Google"
SocialProviderGitHub SocialProvider = "Github"
)
type AuthSession struct {
State string
CodeVerifier string
ProxyURL string
CreatedAt time.Time
AuthType string
Provider string
RedirectURI string
ClientID string
ClientSecret string
Region string
StartURL string
}
type SessionStore struct {
mu sync.RWMutex
data map[string]*AuthSession
setCount uint64
}
func NewSessionStore() *SessionStore {
return &SessionStore{data: make(map[string]*AuthSession)}
}
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.data[id]
if ok && sessionExpired(session, now) {
delete(s.data, id)
return nil, false
}
return session, ok
}
func (s *SessionStore) Set(id string, session *AuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.setCount++
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
s.pruneExpiredLocked(time.Now())
}
s.data[id] = session
}
func (s *SessionStore) Delete(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, id)
}
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
for id, session := range s.data {
if sessionExpired(session, now) {
delete(s.data, id)
}
}
}
func sessionExpired(session *AuthSession, now time.Time) bool {
if session == nil {
return true
}
if session.CreatedAt.IsZero() {
return true
}
return now.After(session.CreatedAt.Add(sessionTTL))
}
type TokenData struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn,omitempty"`
ExpiresAt string `json:"expiresAt,omitempty"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
ClientIDHash string `json:"clientIdHash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"startUrl,omitempty"`
Region string `json:"region,omitempty"`
}
type socialTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type registerClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type createTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type userInfoResponse struct {
Email string `json:"email"`
}
type deviceRegistration struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type RefreshTokenInvalidError struct {
StatusCode int
Body string
}
func (e *RefreshTokenInvalidError) Error() string {
if e == nil {
return ""
}
body := strings.TrimSpace(e.Body)
if body == "" {
return "kiro refresh token invalid (invalid_grant)"
}
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
}
func GenerateSessionID() string {
return uuid.NewString()
}
func GenerateState() (string, error) {
return randomURLSafe(16)
}
func GenerateCodeVerifier() (string, error) {
return randomURLSafe(32)
}
func randomURLSafe(n int) (string, error) {
buf := make([]byte, n)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func GenerateCodeChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
params := url.Values{}
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("redirect_uri", redirectURI)
params.Set("redirect_from", "KiroIDE")
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
}
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
if redirectURI == "" {
return ""
}
path := strings.TrimSpace(callbackPath)
if path == "" {
path = "/oauth/callback"
} else if !strings.HasPrefix(path, "/") {
path = "/" + path
}
fullRedirectURI := redirectURI + path
if option := strings.TrimSpace(loginOption); option != "" {
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
}
return fullRedirectURI
}
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
payload := map[string]string{
"code": code,
"code_verifier": codeVerifier,
"redirect_uri": redirectURI,
}
var resp socialTokenResponse
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Region: defaultIDCRegion,
}, nil
}
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
var resp socialTokenResponse
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Provider: provider,
Region: defaultIDCRegion,
}, nil
}
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]any{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"authorization_code", "refresh_token"},
"redirectUris": []string{redirectURI},
"issuerUrl": issuerURL,
}
var resp registerClientResponse
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
return nil, err
}
return &resp, nil
}
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
if region == "" {
region = defaultIDCRegion
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scopes", strings.Join([]string{
"codewhisperer:completions",
"codewhisperer:analysis",
"codewhisperer:conversations",
"codewhisperer:transformations",
"codewhisperer:taskassist",
}, " "))
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
}
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"code": code,
"codeVerifier": codeVerifier,
"redirectUri": redirectURI,
"grantType": "authorization_code",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", "", "", 0)
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
if strings.TrimSpace(accessToken) == "" {
return ""
}
var resp userInfoResponse
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
}
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
return ""
}
return strings.TrimSpace(resp.Email)
}
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
var token TokenData
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
}
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
if strings.TrimSpace(token.AccessToken) == "" {
return nil, fmt.Errorf("access token is empty")
}
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
var reg deviceRegistration
if err := json.Unmarshal([]byte(deviceRegistrationJSON), &reg); err != nil {
return nil, fmt.Errorf("failed to parse device registration: %w", err)
}
if reg.ClientID != "" {
token.ClientID = reg.ClientID
}
if reg.ClientSecret != "" {
token.ClientSecret = reg.ClientSecret
}
}
return &token, nil
}
func getOIDCEndpoint(region string) string {
if strings.TrimSpace(oidcEndpointOverride) != "" {
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
}
if region == "" {
region = defaultIDCRegion
}
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
}
func oidcHeaders(accountKey, machineID string) map[string]string {
headers := BuildOIDCHeaders(accountKey, machineID)
if headers["amz-sdk-invocation-id"] == "" {
headers["amz-sdk-invocation-id"] = uuid.NewString()
}
if headers["amz-sdk-request"] == "" {
headers["amz-sdk-request"] = "attempt=1; max=4"
}
return headers
}
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
client, err := newHTTPClient(proxyURL)
if err != nil {
return err
}
var body io.Reader
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(encoded)
}
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
if err != nil {
return err
}
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
for key, value := range extraHeaders {
req.Header.Set(key, value)
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyText := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
}
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
}
if out == nil || len(respBody) == 0 {
return nil
}
return json.Unmarshal(respBody, out)
}
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
_, parsed, err := proxyurl.Parse(rawProxyURL)
if err != nil {
return nil, err
}
transport := &http.Transport{}
if parsed != nil {
transport.Proxy = http.ProxyURL(parsed)
}
return &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}, nil
}
@@ -0,0 +1,105 @@
package kiro
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/refreshToken", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
}))
defer server.Close()
previous := socialAuthEndpointURL
socialAuthEndpointURL = server.URL
t.Cleanup(func() { socialAuthEndpointURL = previous })
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
require.Error(t, err)
var invalid *RefreshTokenInvalidError
require.True(t, errors.As(err, &invalid))
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
require.Contains(t, invalid.Body, "invalid_grant")
}
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/token", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
require.Error(t, err)
var invalid *RefreshTokenInvalidError
require.True(t, errors.As(err, &invalid))
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
require.Contains(t, invalid.Body, "invalid_grant")
}
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
require.NoError(t, err)
require.Equal(t, profileArn, token.ProfileArn)
require.Equal(t, "kiro@example.com", token.Email)
}
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
require.NoError(t, err)
require.Equal(t, profileArn, token.ProfileArn)
require.Equal(t, "kiro@example.com", token.Email)
}
+56
View File
@@ -0,0 +1,56 @@
//go:build unit
package kiro
import (
"fmt"
"testing"
"time"
)
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
if got != want {
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
}
}
func TestBuildSocialTokenRedirectURI(t *testing.T) {
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
want := "http://localhost:49153/oauth/callback?login_option=github"
if got != want {
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
}
}
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
store := NewSessionStore()
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
session, ok := store.Get("expired")
if ok || session != nil {
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
}
if _, exists := store.data["expired"]; exists {
t.Fatalf("expired session should be deleted from the store")
}
}
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
store := NewSessionStore()
now := time.Now()
for i := 0; i < sessionCleanupMin; i++ {
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
}
store.setCount = sessionCleanupEvery - 1
store.Set("fresh", &AuthSession{CreatedAt: now})
if len(store.data) != 1 {
t.Fatalf("store size = %d, want 1", len(store.data))
}
if _, ok := store.data["fresh"]; !ok {
t.Fatalf("fresh session should remain after pruning")
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+368
View File
@@ -0,0 +1,368 @@
package kiro
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
var cachedWebSearchDescription atomic.Value // stores string
type MCPRequest struct {
ID string `json:"id"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}
type MCPResponse struct {
Result *struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
} `json:"result,omitempty"`
Error *struct {
Code *int `json:"code,omitempty"`
Message *string `json:"message,omitempty"`
} `json:"error,omitempty"`
}
type WebSearchResults struct {
Results []WebSearchResult `json:"results"`
}
type WebSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Snippet *string `json:"snippet,omitempty"`
PublishedDate *int64 `json:"publishedDate,omitempty"`
ID *string `json:"id,omitempty"`
Domain *string `json:"domain,omitempty"`
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
PublicDomain *bool `json:"publicDomain,omitempty"`
}
type SearchIndicator struct {
ToolUseID string
Query string
Results *WebSearchResults
}
func GetCachedWebSearchDescription() string {
if v := cachedWebSearchDescription.Load(); v != nil {
return strings.TrimSpace(v.(string))
}
return ""
}
func SetCachedWebSearchDescription(desc string) {
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
}
func BuildMcpEndpoint(region string) string {
if strings.TrimSpace(region) == "" {
region = "us-east-1"
}
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
}
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
return nil
}
for _, item := range resp.Result.Content {
if item.Type != "" && item.Type != "text" {
continue
}
var results WebSearchResults
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
return &results
}
}
return nil
}
func ExtractSearchQuery(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
for i := len(arr) - 1; i >= 0; i-- {
msg := arr[i]
if msg.Get("role").String() != "user" {
continue
}
text := extractSearchText(msg.Get("content"))
const prefix = "Perform a web search for the query: "
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
if text != "" {
return text
}
}
return ""
}
func extractSearchText(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if !content.IsArray() {
return ""
}
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
return text
}
}
}
return ""
}
func GenerateToolUseID() string {
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
}
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err != nil {
return body, err
}
rawTools, ok := payload["tools"].([]interface{})
if !ok {
return body, nil
}
replaced := make([]interface{}, 0, len(rawTools))
for _, rawTool := range rawTools {
tool, ok := rawTool.(map[string]interface{})
if !ok {
replaced = append(replaced, rawTool)
continue
}
name := getInterfaceString(tool["name"])
toolType := getInterfaceString(tool["type"])
if !isWebSearchToolName(name, toolType) {
replaced = append(replaced, rawTool)
continue
}
replaced = append(replaced, map[string]interface{}{
"name": "web_search",
"description": minimalWebSearchDescription,
"input_schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "The search query to execute",
},
},
"required": []string{"query"},
"additionalProperties": false,
},
})
}
payload["tools"] = replaced
updated, err := json.Marshal(payload)
if err != nil {
return body, err
}
return updated, nil
}
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(claudePayload, &payload); err != nil {
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
}
rawMessages, ok := payload["messages"].([]interface{})
if !ok {
return claudePayload, fmt.Errorf("claude payload missing messages array")
}
assistantMsg := map[string]interface{}{
"role": "assistant",
"content": []interface{}{
map[string]interface{}{
"type": "tool_use",
"id": toolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": query},
},
},
}
userContent := []interface{}{
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolUseID,
"content": formatToolResultText(results),
},
}
if guidance := searchGuidanceText(); guidance != "" {
userContent = append(userContent, map[string]interface{}{
"type": "text",
"text": guidance,
})
}
userMsg := map[string]interface{}{
"role": "user",
"content": userContent,
}
rawMessages = append(rawMessages, assistantMsg, userMsg)
payload["messages"] = rawMessages
updated, err := json.Marshal(payload)
if err != nil {
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
}
return updated, nil
}
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
if len(searches) == 0 {
return responsePayload, nil
}
var response map[string]interface{}
if err := json.Unmarshal(responsePayload, &response); err != nil {
return responsePayload, err
}
content, _ := response["content"].([]interface{})
updated := make([]interface{}, 0, len(searches)*2+len(content))
for _, search := range searches {
updated = append(updated, map[string]interface{}{
"type": "server_tool_use",
"id": search.ToolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": search.Query},
})
updated = append(updated, map[string]interface{}{
"type": "web_search_tool_result",
"content": buildSearchResultContent(search.Results),
})
}
updated = append(updated, content...)
response["content"] = updated
encoded, err := json.Marshal(response)
if err != nil {
return responsePayload, err
}
return encoded, nil
}
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
content := make([]map[string]interface{}, 0)
if results == nil {
return content
}
for _, result := range results.Results {
snippet := ""
if result.Snippet != nil {
snippet = strings.TrimSpace(*result.Snippet)
}
content = append(content, map[string]interface{}{
"type": "web_search_result",
"title": result.Title,
"url": result.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
return content
}
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
content := gjson.GetBytes(responsePayload, "content")
if !content.IsArray() {
return "", "", false
}
for _, block := range content.Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if !isWebSearchToolName(name, "") {
continue
}
query = strings.TrimSpace(block.Get("input.query").String())
if query == "" {
continue
}
return block.Get("id").String(), query, true
}
return "", "", false
}
func isWebSearchToolName(name, toolType string) bool {
name = strings.ToLower(strings.TrimSpace(name))
toolType = strings.ToLower(strings.TrimSpace(toolType))
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
return true
}
switch name {
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
return true
default:
return false
}
}
func getInterfaceString(v interface{}) string {
if v == nil {
return ""
}
switch val := v.(type) {
case string:
return strings.TrimSpace(val)
default:
return strings.TrimSpace(fmt.Sprint(val))
}
}
func formatToolResultText(results *WebSearchResults) string {
if results == nil || len(results.Results) == 0 {
return "No search results found."
}
payload, err := json.MarshalIndent(results.Results, "", " ")
if err != nil {
return "Found search results, but failed to format them."
}
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
}
func searchGuidanceText() string {
now := time.Now()
return fmt.Sprintf(`<search_guidance>
Current date: %s (%s)
IMPORTANT: Evaluate the search results above carefully. If the results are:
- Mostly spam, SEO junk, or unrelated websites
- Missing actual information about the query topic
- Outdated or not matching the requested time frame
Then you MUST use the web_search tool again with a refined query. Try:
- Rephrasing in English for better coverage
- Using more specific keywords
- Adding date context
Do NOT apologize for bad results without first attempting a re-search.
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
}
@@ -0,0 +1,297 @@
package kiro
import (
"encoding/json"
"strings"
)
type BufferedStreamResult struct {
StopReason string
WebSearchQuery string
WebSearchToolUseID string
HasWebSearchToolUse bool
WebSearchToolUseIndex int
}
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
searchContent := make([]map[string]interface{}, 0)
if results != nil {
for _, result := range results.Results {
snippet := ""
if result.Snippet != nil {
snippet = strings.TrimSpace(*result.Snippet)
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": result.Title,
"url": result.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
inputJSON, _ := json.Marshal(map[string]string{"query": query})
events := []map[string]interface{}{
{
"type": "content_block_start",
"index": startIndex,
"content_block": map[string]interface{}{
"type": "server_tool_use",
"id": toolUseID,
"name": "web_search",
"input": map[string]interface{}{},
},
},
{
"type": "content_block_delta",
"index": startIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
},
{
"type": "content_block_stop",
"index": startIndex,
},
{
"type": "content_block_start",
"index": startIndex + 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"content": searchContent,
},
},
{
"type": "content_block_stop",
"index": startIndex + 1,
},
}
result := make([][]byte, 0, len(events))
for _, event := range events {
eventType, _ := event["type"].(string)
payload, _ := json.Marshal(event)
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
}
return result
}
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
var currentToolName string
currentToolIndex := -1
var toolInputBuilder strings.Builder
for _, chunk := range chunks {
lines := strings.Split(string(chunk), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
switch eventType, _ := event["type"].(string); eventType {
case "message_delta":
if delta, ok := event["delta"].(map[string]interface{}); ok {
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
result.StopReason = stopReason
}
}
case "content_block_start":
contentBlock, ok := event["content_block"].(map[string]interface{})
if !ok {
continue
}
blockType, _ := contentBlock["type"].(string)
if blockType != "tool_use" {
continue
}
currentToolName, _ = contentBlock["name"].(string)
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
if idx, ok := event["index"].(float64); ok {
currentToolIndex = int(idx)
}
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
}
toolInputBuilder.Reset()
case "content_block_delta":
if currentToolName == "" {
continue
}
delta, ok := event["delta"].(map[string]interface{})
if !ok {
continue
}
deltaType, _ := delta["type"].(string)
if deltaType != "input_json_delta" {
continue
}
if partialJSON, ok := delta["partial_json"].(string); ok {
toolInputBuilder.WriteString(partialJSON)
}
case "content_block_stop":
if !isWebSearchToolName(currentToolName, "") {
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
continue
}
result.HasWebSearchToolUse = true
result.WebSearchToolUseIndex = currentToolIndex
var input map[string]string
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
result.WebSearchQuery = strings.TrimSpace(input["query"])
}
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
}
}
}
return result
}
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
filtered := make([][]byte, 0, len(chunks))
for _, chunk := range chunks {
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
if shouldForward {
filtered = append(filtered, adjusted)
}
}
return filtered
}
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
return filterSSEChunk(chunk, -1, offset)
}
func MaxContentBlockIndex(chunks [][]byte) int {
maxIndex := -1
for _, chunk := range chunks {
lines := strings.Split(string(chunk), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
switch eventType, _ := event["type"].(string); eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
maxIndex = int(idx)
}
}
}
}
return maxIndex
}
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
lines := strings.Split(string(chunk), "\n")
var builder strings.Builder
hasContent := false
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "event: ") {
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
i++
continue
}
}
builder.WriteString(line + "\n")
hasContent = true
continue
}
if strings.HasPrefix(line, "data: ") {
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "[DONE]" {
continue
}
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
continue
}
adjusted := adjustEventPayload(payload, indexOffset)
if adjusted == "" {
continue
}
builder.WriteString("data: " + adjusted + "\n")
hasContent = true
continue
}
builder.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
if !hasContent {
return nil, false
}
return []byte(builder.String()), true
}
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
if payload == "" {
return false
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
return false
}
eventType, _ := event["type"].(string)
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
return true
}
if webSearchToolUseIndex < 0 {
return false
}
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
return true
}
return false
}
func adjustEventPayload(payload string, indexOffset int) string {
if payload == "" || indexOffset == 0 {
return payload
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
return payload
}
switch eventType, _ := event["type"].(string); eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + indexOffset
if adjusted, err := json.Marshal(event); err == nil {
return string(adjusted)
}
}
}
return payload
}
@@ -0,0 +1,73 @@
package kiro
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
snippet := "result snippet"
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
},
}, 0)
require.Len(t, events, 5)
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
require.Contains(t, string(events[0]), `"input":{}`)
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
require.NotContains(t, string(events[3]), `"tool_use_id"`)
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
}
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
chunks := [][]byte{
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
}
result := AnalyzeBufferedStream(chunks)
require.True(t, result.HasWebSearchToolUse)
require.Equal(t, "golang concurrency", result.WebSearchQuery)
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
require.Equal(t, 1, result.WebSearchToolUseIndex)
require.Equal(t, "tool_use", result.StopReason)
}
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
chunks := [][]byte{
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
}
filtered := FilterChunksForClient(chunks, 1, 2)
require.NotEmpty(t, filtered)
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
require.NotContains(t, joined, `"type":"message_start"`)
require.NotContains(t, joined, `"type":"message_delta"`)
require.NotContains(t, joined, `"name":"web_search"`)
require.Contains(t, joined, `"index":2`)
require.Equal(t, 2, MaxContentBlockIndex(filtered))
}
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
require.False(t, shouldForward)
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
require.True(t, shouldForward)
require.Contains(t, string(adjusted), `"index":3`)
}
+138
View File
@@ -0,0 +1,138 @@
package kiro
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
body := []byte(`{
"tools":[{"type":"web_search_20250305","description":"old"}],
"messages":[{"role":"user","content":"golang"}]
}`)
updated, err := ReplaceWebSearchToolDescription(body)
require.NoError(t, err)
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
}
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
body := []byte(`{
"messages":[{"role":"user","content":"what is golang"}]
}`)
results := &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev"},
},
}
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
require.NoError(t, err)
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
}
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
response := []byte(`{
"content":[
{"type":"text","text":"let me search"},
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
]
}`)
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
require.True(t, ok)
require.Equal(t, "srvtoolu_next", toolUseID)
require.Equal(t, "golang concurrency", query)
}
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
response := []byte(`{
"id":"msg_1",
"type":"message",
"role":"assistant",
"model":"kiro",
"content":[{"type":"text","text":"final"}],
"stop_reason":"end_turn",
"usage":{"input_tokens":1,"output_tokens":1}
}`)
snippet := "result snippet"
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
{
ToolUseID: "srvtoolu_test",
Query: "golang",
Results: &WebSearchResults{
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
},
},
})
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(updated, &decoded))
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
}
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
resp := &MCPResponse{
Result: &struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
}{
Content: []struct {
Type string `json:"type"`
Text string `json:"text"`
}{
{
Type: "text",
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
},
},
},
}
results := ParseSearchResults(resp)
require.NotNil(t, results)
require.Len(t, results.Results, 1)
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
require.Equal(t, "doc-1", *results.Results[0].ID)
require.Equal(t, "go.dev", *results.Results[0].Domain)
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
require.True(t, *results.Results[0].PublicDomain)
}
func TestSearchGuidanceText_IsStructured(t *testing.T) {
guidance := searchGuidanceText()
require.Contains(t, guidance, "<search_guidance>")
require.Contains(t, guidance, "Current date:")
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
require.Contains(t, guidance, "Rephrasing in English for better coverage")
}
+479
View File
@@ -0,0 +1,479 @@
package kirocooldown
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
const (
MinRequestInterval = time.Second
MaxRequestInterval = 2 * time.Second
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
ShortCooldown = time.Minute
MaxCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
redisTimeout = 3 * time.Second
activeTTL = 10 * time.Second
stateTTL = 25 * time.Hour
keyPrefix = "kiro:cooldown:"
)
var (
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
reserveRequestScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
local interval_ms = tonumber(ARGV[1])
local active_ttl_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
if cooldown_until_ms > now_ms then
return {1, cooldown_until_ms - now_ms, cooldown_reason}
end
if cooldown_until_ms > 0 then
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
end
local next_slot_ms = now_ms
if last_request_ms > 0 then
local candidate_ms = last_request_ms + interval_ms
if candidate_ms > now_ms then
next_slot_ms = candidate_ms
end
end
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
if fail_count > 0 or cooldown_until_ms > now_ms then
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
else
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
end
return {0, next_slot_ms - now_ms, ''}
`)
mark429Script = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
local short_cooldown_ms = tonumber(ARGV[1])
local max_cooldown_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
if cooldown_ms > max_cooldown_ms then
cooldown_ms = max_cooldown_ms
end
redis.call('HSET', KEYS[1],
'fail_count', fail_count,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[4]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
markSuccessScript = redis.NewScript(`
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', 0,
'cooldown_reason', ''
)
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
return 1
`)
markSuspendedScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local cooldown_ms = tonumber(ARGV[1])
local state_ttl_ms = tonumber(ARGV[2])
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[3]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
)
type Error struct {
remaining time.Duration
reason string
}
type State struct {
Active bool
Reason string
CooldownUntil time.Time
Remaining time.Duration
FailCount int
}
func NewError(remaining time.Duration, reason string) error {
return &Error{remaining: remaining, reason: reason}
}
func (e *Error) Error() string {
if e == nil {
return ""
}
if e.reason == "" {
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
}
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
}
func Calculate429Cooldown(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
cooldown := ShortCooldown * time.Duration(1<<retryCount)
if cooldown > MaxCooldown {
return MaxCooldown
}
return cooldown
}
type Store struct {
client *redis.Client
rngMu sync.Mutex
rng *rand.Rand
}
func NewStore(client *redis.Client) *Store {
return &Store{
client: client,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := reserveRequestScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
s.nextInterval().Milliseconds(),
activeTTL.Milliseconds(),
stateTTL.Milliseconds(),
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
}
parts, ok := values.([]interface{})
if !ok || len(parts) != 3 {
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
}
state, err := luaInt64(parts[0])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
}
waitMS, err := luaInt64(parts[1])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
}
reason, err := luaString(parts[2])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
}
if state == 1 {
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
}
if waitMS <= 0 {
return 0, nil
}
return time.Duration(waitMS) * time.Millisecond, nil
}
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
if err := s.validate(); err != nil {
return err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
if err := markSuccessScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
activeTTL.Milliseconds(),
).Err(); err != nil {
return fmt.Errorf("kiro cooldown mark success: %w", err)
}
return nil
}
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := mark429Script.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
ShortCooldown.Milliseconds(),
MaxCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReason429,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := markSuspendedScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
LongCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReasonSuspended,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
if err := s.validate(); err != nil {
return nil, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := s.client.HMGet(
cacheCtx,
RedisKey(tokenKey),
"cooldown_until_ms",
"cooldown_reason",
"fail_count",
).Result()
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
}
if len(values) != 3 {
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
}
if cooldownUntilMS <= 0 {
return nil, nil
}
cooldownUntil := time.UnixMilli(cooldownUntilMS)
remaining := time.Until(cooldownUntil)
if remaining <= 0 {
return nil, nil
}
return &State{
Active: true,
Reason: reason,
CooldownUntil: cooldownUntil,
Remaining: remaining,
FailCount: int(failCount),
}, nil
}
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
if err := s.validate(); err != nil {
return false, err
}
uniqueKeys := make([]string, 0, len(tokenKeys))
seen := make(map[string]struct{}, len(tokenKeys))
for _, tokenKey := range tokenKeys {
tokenKey = strings.TrimSpace(tokenKey)
if tokenKey == "" {
continue
}
redisKey := RedisKey(tokenKey)
if _, ok := seen[redisKey]; ok {
continue
}
seen[redisKey] = struct{}{}
uniqueKeys = append(uniqueKeys, redisKey)
}
if len(uniqueKeys) == 0 {
return false, nil
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
type candidate struct {
redisKey string
cooldownUntilMS int64
failCount int64
}
now := time.Now().UnixMilli()
var best *candidate
pipe := s.client.Pipeline()
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
for _, redisKey := range uniqueKeys {
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
}
if _, err := pipe.Exec(cacheCtx); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
}
for i, cmd := range cmds {
values, err := cmd.Result()
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
}
if len(values) != 3 {
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
}
if cooldownUntilMS <= now || reason != CooldownReason429 {
continue
}
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
if best == nil ||
current.cooldownUntilMS < best.cooldownUntilMS ||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
best = current
}
}
if best == nil {
return false, nil
}
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
}
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
}
return true, nil
}
func RedisKey(tokenKey string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
digest := hex.EncodeToString(sum[:])
return keyPrefix + "{" + digest + "}"
}
func ActiveTTL() time.Duration {
return activeTTL
}
func StateTTL() time.Duration {
return stateTTL
}
func (s *Store) validate() error {
if s == nil || s.client == nil {
return ErrStoreUnavailable
}
return nil
}
func (s *Store) nextInterval() time.Duration {
s.rngMu.Lock()
defer s.rngMu.Unlock()
if MaxRequestInterval <= MinRequestInterval {
return MinRequestInterval
}
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
}
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, redisTimeout)
}
func luaInt64(v any) (int64, error) {
switch n := v.(type) {
case int64:
return n, nil
case int:
return int64(n), nil
case string:
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
case []byte:
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
default:
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
}
}
func luaString(v any) (string, error) {
switch s := v.(type) {
case string:
return s, nil
case []byte:
return string(s), nil
case nil:
return "", nil
default:
return "", fmt.Errorf("unsupported lua string type %T", v)
}
}
@@ -0,0 +1,32 @@
package kirocooldown
import (
"context"
"testing"
"github.com/redis/go-redis/v9"
)
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
if err != nil {
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
}
if cleared {
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
}
}
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
store := NewStore(nil)
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
if err == nil {
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
}
if cleared {
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
}
}
+15
View File
@@ -41,6 +41,9 @@ func RegisterAdminRoutes(
// Antigravity OAuth // Antigravity OAuth
registerAntigravityOAuthRoutes(admin, h) registerAntigravityOAuthRoutes(admin, h)
// Kiro OAuth / IDC
registerKiroOAuthRoutes(admin, h)
// 代理管理 // 代理管理
registerProxyRoutes(admin, h) registerProxyRoutes(admin, h)
@@ -295,6 +298,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Antigravity 默认模型映射 // Antigravity 默认模型映射
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping) accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
accounts.GET("/kiro/default-model-mapping", h.Admin.Account.GetKiroDefaultModelMapping)
// Claude OAuth routes // Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
@@ -347,6 +351,17 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
} }
} }
func registerKiroOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
kiro := admin.Group("/kiro")
{
kiro.POST("/oauth/auth-url", h.Admin.KiroOAuth.GenerateAuthURL)
kiro.POST("/oauth/idc-auth-url", h.Admin.KiroOAuth.GenerateIDCAuthURL)
kiro.POST("/oauth/exchange-code", h.Admin.KiroOAuth.ExchangeCode)
kiro.POST("/oauth/refresh-token", h.Admin.KiroOAuth.RefreshToken)
kiro.POST("/oauth/import-token", h.Admin.KiroOAuth.ImportToken)
}
}
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies := admin.Group("/proxies") proxies := admin.Group("/proxies")
{ {
+37 -13
View File
@@ -48,6 +48,13 @@ type Account struct {
TempUnschedulableUntil *time.Time TempUnschedulableUntil *time.Time
TempUnschedulableReason string TempUnschedulableReason string
KiroQuotaState string
KiroQuotaReason string
KiroQuotaResetAt *time.Time
KiroRuntimeState string
KiroRuntimeReason string
KiroRuntimeResetAt *time.Time
SessionWindowStart *time.Time SessionWindowStart *time.Time
SessionWindowEnd *time.Time SessionWindowEnd *time.Time
SessionWindowStatus string SessionWindowStatus string
@@ -164,6 +171,10 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini return a.Platform == PlatformGemini
} }
func (a *Account) IsKiro() bool {
return a.Platform == PlatformKiro
}
func (a *Account) GeminiOAuthType() string { func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth { if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return "" return ""
@@ -478,17 +489,17 @@ func (a *Account) GetModelMapping() map[string]string {
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string { func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
// Antigravity 平台使用默认映射 // 部分平台在未显式配置 model_mapping 时仍应使用默认映射
if a.Platform == domain.PlatformAntigravity { // 以限制可调度/可转发的模型集合。
return domain.DefaultAntigravityModelMapping if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
return defaults
} }
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整) // Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil return nil
} }
if len(rawMapping) == 0 { if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射 if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
if a.Platform == domain.PlatformAntigravity { return defaults
return domain.DefaultAntigravityModelMapping
} }
return nil return nil
} }
@@ -510,13 +521,23 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
return result return result
} }
// Antigravity 平台使用默认映射 if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
if a.Platform == domain.PlatformAntigravity { return defaults
return domain.DefaultAntigravityModelMapping
} }
return nil return nil
} }
func defaultModelMappingForPlatform(platform string) map[string]string {
switch platform {
case domain.PlatformAntigravity:
return domain.DefaultAntigravityModelMapping
case domain.PlatformKiro:
return domain.DefaultKiroModelMapping
default:
return nil
}
}
func mapPtr(m map[string]any) uintptr { func mapPtr(m map[string]any) uintptr {
if m == nil { if m == nil {
return 0 return 0
@@ -608,8 +629,8 @@ func resolveRequestedModelInMapping(mapping map[string]string, requestedModel st
return matchWildcardMappingResult(mapping, requestedModel) return matchWildcardMappingResult(mapping, requestedModel)
} }
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping,返回 true(允许所有模型) // 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时也会先回退到默认映射。
func (a *Account) IsModelSupported(requestedModel string) bool { func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
@@ -622,8 +643,8 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized) return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
} }
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名 // 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时返回默认映射结果。
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel) mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel return mappedModel
@@ -725,6 +746,9 @@ func (a *Account) GetBaseURL() string {
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL == "" { if baseURL == "" {
if a.Platform == PlatformKiro {
return ""
}
return "https://api.anthropic.com" return "https://api.anthropic.com"
} }
if a.Platform == PlatformAntigravity { if a.Platform == PlatformAntigravity {
+2 -2
View File
@@ -180,7 +180,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
} }
} }
@@ -296,7 +296,7 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if err != nil { if err != nil {
return nil, err return nil, err
} }
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
} }
} }
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -65,6 +66,7 @@ type AccountTestService struct {
accountRepo AccountRepository accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider geminiTokenProvider *GeminiTokenProvider
claudeTokenProvider *ClaudeTokenProvider claudeTokenProvider *ClaudeTokenProvider
kiroTokenProvider *KiroTokenProvider
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config cfg *config.Config
@@ -76,6 +78,7 @@ func NewAccountTestService(
accountRepo AccountRepository, accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider, geminiTokenProvider *GeminiTokenProvider,
claudeTokenProvider *ClaudeTokenProvider, claudeTokenProvider *ClaudeTokenProvider,
kiroTokenProvider *KiroTokenProvider,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
cfg *config.Config, cfg *config.Config,
@@ -85,6 +88,7 @@ func NewAccountTestService(
accountRepo: accountRepo, accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider, geminiTokenProvider: geminiTokenProvider,
claudeTokenProvider: claudeTokenProvider, claudeTokenProvider: claudeTokenProvider,
kiroTokenProvider: kiroTokenProvider,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg, cfg: cfg,
@@ -191,6 +195,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt) return s.routeAntigravityTest(c, account, modelID, prompt)
} }
if account.IsKiro() && account.Type == AccountTypeOAuth {
return s.testKiroAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID) return s.testClaudeAccountConnection(c, account, modelID)
} }
@@ -239,6 +247,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
} }
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
if baseURL == "" && account.Platform == PlatformKiro {
return s.sendErrorAndEnd(c, "Kiro API Key accounts require a Base URL")
}
if baseURL == "" { if baseURL == "" {
baseURL = "https://api.anthropic.com" baseURL = "https://api.anthropic.com"
} }
@@ -387,6 +398,149 @@ func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Con
return s.processClaudeStream(c, resp.Body) return s.processClaudeStream(c, resp.Body)
} }
func (s *AccountTestService) testKiroAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
testModelID := strings.TrimSpace(modelID)
if testModelID == "" {
testModelID = "claude-sonnet-4-6"
}
if mappedModel := account.GetMappedModel(testModelID); strings.TrimSpace(mappedModel) != "" {
testModelID = mappedModel
}
if account.Type != AccountTypeOAuth {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Kiro account type: %s", account.Type))
}
if s.kiroTokenProvider == nil {
return s.sendErrorAndEnd(c, "Kiro token provider not configured")
}
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get Kiro access token: %s", err.Error()))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
resp, err := s.executeKiroTestUpstream(ctx, account, payloadBytes, testModelID, accessToken)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, formatKiroTestError(resp.StatusCode, body, testModelID, account))
}
pr, pw := io.Pipe()
go func() {
defer func() { _ = resp.Body.Close() }()
_, streamErr := kiropkg.StreamEventStreamAsAnthropic(ctx, resp.Body, pw, testModelID, estimateKiroInputTokens(payloadBytes))
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
}
_ = pw.Close()
}()
return s.processClaudeStream(c, pr)
}
func formatKiroTestError(statusCode int, body []byte, requestedModel string, account *Account) string {
return fmt.Sprintf("API returned %d: %s", statusCode, string(body))
}
func (s *AccountTestService) executeKiroTestUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string) (*http.Response, error) {
modelID := kiropkg.MapModel(mappedModel)
currentToken := token
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
if err != nil {
return nil, err
}
payload := buildResult.Payload
endpoints := buildKiroEndpoints(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
accountKey := buildKiroAccountKey(account)
maxRetries := 2
for idx, endpoint := range endpoints {
for attempt := 0; attempt <= maxRetries; attempt++ {
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
if err != nil {
return nil, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
if idx+1 < len(endpoints) {
_ = resp.Body.Close()
break
}
return resp, nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, readErr
}
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
if err != nil {
return nil, err
}
payload = buildResult.Payload
continue
}
}
resetHTTPResponseBody(resp, respBody)
return resp, nil
}
if resp.StatusCode == http.StatusBadRequest {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, readErr
}
resetHTTPResponseBody(resp, respBody)
return resp, nil
}
return resp, nil
}
}
return nil, fmt.Errorf("kiro upstream endpoints exhausted")
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke // testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account) region := bedrockRuntimeRegion(account)
@@ -0,0 +1,84 @@
//go:build unit
package service
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
)
func TestAccountTestService_KiroAPIKeyUsesGenericAnthropicCompatiblePath(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 19,
Name: "kiro-apikey-test",
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"base_url": "https://kiro-upstream.example.com",
"api_key": "kiro-api-key",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4-6",
},
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"invalid api key"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
req := upstream.requests[0]
require.Equal(t, "kiro-upstream.example.com", req.URL.Host)
require.Equal(t, "/v1/messages", req.URL.Path)
require.Equal(t, "kiro-api-key", req.Header.Get("x-api-key"))
require.Empty(t, req.Header.Get("Authorization"))
require.Equal(t, claude.APIKeyBetaHeader, req.Header.Get("anthropic-beta"))
}
func TestAccountTestService_KiroAPIKeyWithoutBaseURLErrors(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 20,
Name: "kiro-apikey-missing-base-url",
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "kiro-api-key",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: &queuedHTTPUpstream{},
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Contains(t, err.Error(), "Base URL")
}
@@ -0,0 +1,317 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountTestService_KiroUsesKiroUpstreamInsteadOfAnthropic(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 1,
Name: "kiro-test",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/TESTSOCIAL",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{1: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"Invalid bearer token"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "gpt-4o", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
req := upstream.requests[0]
require.Equal(t, "q.us-east-1.amazonaws.com", req.URL.Host)
require.Equal(t, "/generateAssistantResponse", req.URL.Path)
require.Equal(t, "Bearer kiro-access-token", req.Header.Get("Authorization"))
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
require.Empty(t, req.Header.Get("anthropic-version"))
require.NotContains(t, req.URL.Host, "api.anthropic.com")
}
func TestAccountTestService_Kiro429DoesNotFallbackToCodeWhispererEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 2,
Name: "kiro-fallback",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"api_region": "us-west-2",
"region": "us-west-2",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTFALLBACK",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{2: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusTooManyRequests, `{"message":"slow down"}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
require.Contains(t, err.Error(), "API returned 429")
}
func TestAccountTestService_KiroIDCWithoutProfileArnOmitsProfileArnAndUsesDefaultRuntimeRegion(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 5,
Name: "kiro-idc-default-region",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"auth_method": "idc",
"provider": "AWS",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{5: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
require.Equal(t, "q.us-east-1.amazonaws.com", upstream.requests[0].URL.Host)
body, readErr := io.ReadAll(upstream.requests[0].Body)
require.NoError(t, readErr)
require.NotContains(t, string(body), `"profileArn":`)
}
func TestAccountTestService_KiroInvalidModelErrorPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 6,
Name: "kiro-invalid-model",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTINVALIDMODEL",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Equal(t, `API returned 400: {"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`, err.Error())
}
func TestAccountTestService_KiroInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 7,
Name: "kiro-invalid-model-refresh",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{7: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Contains(t, err.Error(), "API returned 400")
require.Len(t, upstream.requests, 1)
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
require.NoError(t, readErr)
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
}
func TestAccountTestService_KiroPreferredEndpointIsIgnored(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
account := &Account{
ID: 6,
Name: "kiro-preferred-endpoint",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"api_region": "us-west-2",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/PREFERRED",
"preferred_endpoint": "codewhisperer",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
},
}
svc := &AccountTestService{
accountRepo: repo,
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
httpUpstream: upstream,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
require.Error(t, err)
require.Len(t, upstream.requests, 1)
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
}
func TestBuildKiroPayloadForAccount_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
account := &Account{
ID: 3,
Name: "kiro-builder-id",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"auth_method": "idc",
"provider": "BuilderId",
"region": "us-east-1",
"client_id": "builder-client-id",
},
}
testPayload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(testPayload)
require.NoError(t, err)
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotContains(t, string(kiroPayload), `"profileArn":`)
}
func TestBuildKiroPayloadForAccount_KiroBuilderIDUsesCredentialProfileArn(t *testing.T) {
account := &Account{
ID: 33,
Name: "kiro-builder-id-cached",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"auth_method": "builder-id",
"provider": "BuilderId",
"region": "us-east-1",
"client_id": "builder-client-id",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED",
},
}
testPayload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(testPayload)
require.NoError(t, err)
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.Contains(t, string(kiroPayload), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED"`)
}
func TestBuildKiroPayloadForAccount_KiroEnterpriseIDCOmitsMissingProfileArn(t *testing.T) {
account := &Account{
ID: 4,
Name: "kiro-enterprise-idc",
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"auth_method": "idc",
"provider": "AWS",
"region": "us-east-1",
"client_id": "enterprise-client-id",
"start_url": "https://d-example.awsapps.com/start",
},
}
testPayload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(testPayload)
require.NoError(t, err)
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotContains(t, string(kiroPayload), `"profileArn":`)
}
@@ -103,10 +103,17 @@ type antigravityUsageCache struct {
timestamp time.Time timestamp time.Time
} }
// kiroUsageCache 缓存 Kiro 额度快照
type kiroUsageCache struct {
usageInfo *UsageInfo
timestamp time.Time
}
const ( const (
apiCacheTTL = 3 * time.Minute apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟 apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误) antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
kiroUsageErrorTTL = 1 * time.Minute // Kiro 错误缓存 TTL(可恢复错误)
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute openAIProbeCacheTTL = 10 * time.Minute
@@ -118,8 +125,10 @@ type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache antigravityCache sync.Map // accountID -> *antigravityUsageCache
kiroUsageCache sync.Map // accountID -> *kiroUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic) apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存 antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
kiroUsageFlight singleflight.Group // 防止同一 Kiro 账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time openAIProbeCache sync.Map // accountID -> time.Time
} }
@@ -176,6 +185,23 @@ type AICredit struct {
MinimumBalance float64 `json:"minimum_balance,omitempty"` MinimumBalance float64 `json:"minimum_balance,omitempty"`
} }
// KiroCreditProgress 表示 Kiro 主额度或 Bonus 的用量进度。
type KiroCreditProgress struct {
CurrentUsage float64 `json:"current_usage"`
UsageLimit float64 `json:"usage_limit"`
PercentageUsed float64 `json:"percentage_used"`
DaysRemaining int `json:"days_remaining,omitempty"`
ExpiryDate *time.Time `json:"expiry_date,omitempty"`
}
// KiroOverageInfo 表示 Kiro 账号的 overage 状态。
type KiroOverageInfo struct {
CurrentOverages float64 `json:"current_overages"`
OverageCharges float64 `json:"overage_charges"`
CurrencyCode string `json:"currency_code,omitempty"`
CurrencySymbol string `json:"currency_symbol,omitempty"`
}
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
Source string `json:"source,omitempty"` // "passive" or "active" Source string `json:"source,omitempty"` // "passive" or "active"
@@ -203,6 +229,21 @@ type UsageInfo struct {
// Antigravity AI Credits 余额 // Antigravity AI Credits 余额
AICredits []AICredit `json:"ai_credits,omitempty"` AICredits []AICredit `json:"ai_credits,omitempty"`
// Kiro Credits 额度与 overage 信息
KiroSubscriptionName string `json:"kiro_subscription_name,omitempty"`
KiroSubscriptionType string `json:"kiro_subscription_type,omitempty"`
KiroResetAt *time.Time `json:"kiro_reset_at,omitempty"`
KiroOveragesEnabled bool `json:"kiro_overages_enabled,omitempty"`
KiroCredit *KiroCreditProgress `json:"kiro_credit,omitempty"`
KiroBonus *KiroCreditProgress `json:"kiro_bonus,omitempty"`
KiroOverage *KiroOverageInfo `json:"kiro_overage,omitempty"`
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id) // Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"` ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
@@ -266,6 +307,7 @@ type AccountUsageService struct {
cache *UsageCache cache *UsageCache
identityCache IdentityCache identityCache IdentityCache
tlsFPProfileService *TLSFingerprintProfileService tlsFPProfileService *TLSFingerprintProfileService
kiroCooldownStore KiroCooldownStore
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
@@ -291,6 +333,13 @@ func NewAccountUsageService(
} }
} }
func (s *AccountUsageService) SetKiroCooldownStore(store KiroCooldownStore) *AccountUsageService {
if s != nil {
s.kiroCooldownStore = store
}
return s
}
// GetUsage 获取账号使用量 // GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟 // OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope // Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope
@@ -317,6 +366,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return usage, err return usage, err
} }
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
return s.getKiroUsage(ctx, account, "active", false)
}
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
usage, err := s.getAntigravityUsage(ctx, account) usage, err := s.getAntigravityUsage(ctx, account)
@@ -425,6 +478,13 @@ func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int
return nil, fmt.Errorf("get account failed: %w", err) return nil, fmt.Errorf("get account failed: %w", err)
} }
if account.Platform == PlatformKiro {
if account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("passive usage only supported for Kiro OAuth accounts")
}
return s.getKiroUsage(ctx, account, "passive", false)
}
if !account.IsAnthropicOAuthOrSetupToken() { if !account.IsAnthropicOAuthOrSetupToken() {
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts") return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
} }
@@ -0,0 +1,40 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccountUsageService_GetUsage_KiroAPIKeyUnsupported(t *testing.T) {
account := &Account{
ID: 9101,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
usage, err := svc.GetUsage(context.Background(), account.ID)
require.Nil(t, usage)
require.Error(t, err)
require.Contains(t, err.Error(), "does not support usage query")
}
func TestAccountUsageService_GetPassiveUsage_KiroAPIKeyUnsupported(t *testing.T) {
account := &Account{
ID: 9102,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
usage, err := svc.GetPassiveUsage(context.Background(), account.ID)
require.Nil(t, usage)
require.Error(t, err)
require.Contains(t, err.Error(), "Kiro OAuth")
}
+2 -2
View File
@@ -1448,7 +1448,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
} }
// require_oauth_only: 过滤掉 apikey 类型账号 // require_oauth_only: 过滤掉 apikey 类型账号
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
@@ -1728,7 +1728,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
// require_oauth_only: 过滤掉 apikey 类型账号 // require_oauth_only: 过滤掉 apikey 类型账号
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
@@ -37,6 +37,7 @@ const (
PlatformOpenAI = domain.PlatformOpenAI PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity PlatformAntigravity = domain.PlatformAntigravity
PlatformKiro = domain.PlatformKiro
) )
// Account type constants // Account type constants
+179 -12
View File
@@ -27,6 +27,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
@@ -56,6 +57,7 @@ const (
defaultModelsListCacheTTL = 15 * time.Second defaultModelsListCacheTTL = 15 * time.Second
postUsageBillingTimeout = 15 * time.Second postUsageBillingTimeout = 15 * time.Second
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY" debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
defaultKiroStreamKeepalive = 25 * time.Second
) )
const ( const (
@@ -70,6 +72,7 @@ const (
// ForceCacheBillingContextKey 强制缓存计费上下文键 // ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{} type forceCacheBillingKeyType struct{}
type kiroCooldownRecoveryAttemptedKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度 // accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct { type accountWithLoad struct {
@@ -78,6 +81,7 @@ type accountWithLoad struct {
} }
var ForceCacheBillingContextKey = forceCacheBillingKeyType{} var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
var kiroCooldownRecoveryAttemptedKey = kiroCooldownRecoveryAttemptedKeyType{}
var ( var (
windowCostPrefetchCacheHitTotal atomic.Int64 windowCostPrefetchCacheHitTotal atomic.Int64
@@ -554,6 +558,8 @@ type GatewayService struct {
deferredService *DeferredService deferredService *DeferredService
concurrencyService *ConcurrencyService concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider claudeTokenProvider *ClaudeTokenProvider
kiroTokenProvider *KiroTokenProvider
kiroCooldownStore KiroCooldownStore
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver userGroupRateResolver *userGroupRateResolver
@@ -592,6 +598,8 @@ func NewGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider, claudeTokenProvider *ClaudeTokenProvider,
kiroTokenProvider *KiroTokenProvider,
kiroCooldownStore KiroCooldownStore,
sessionLimitCache SessionLimitCache, sessionLimitCache SessionLimitCache,
rpmCache RPMCache, rpmCache RPMCache,
digestStore *DigestSessionStore, digestStore *DigestSessionStore,
@@ -624,6 +632,8 @@ func NewGatewayService(
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
deferredService: deferredService, deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider, claudeTokenProvider: claudeTokenProvider,
kiroTokenProvider: kiroTokenProvider,
kiroCooldownStore: kiroCooldownStore,
sessionLimitCache: sessionLimitCache, sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache, rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
@@ -1969,6 +1979,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
if len(candidates) == 0 { if len(candidates) == 0 {
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, useMixed) {
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
return s.SelectAccountWithLoadAwareness(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID)
}
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
} }
@@ -2348,14 +2362,91 @@ func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool
if account == nil { if account == nil {
return false return false
} }
return account.IsSchedulable() if !account.IsSchedulable() {
return false
}
return s.isKiroRuntimeSchedulable(context.Background(), account)
} }
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
if account == nil { if account == nil {
return false return false
} }
return account.IsSchedulableForModelWithContext(ctx, requestedModel) if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false
}
return s.isKiroRuntimeSchedulable(ctx, account)
}
func (s *GatewayService) isKiroRuntimeSchedulable(ctx context.Context, account *Account) bool {
if account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth || s == nil || s.kiroCooldownStore == nil {
return true
}
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(account))
if err != nil {
return true
}
return state == nil || !state.Active
}
func (s *GatewayService) tryRecoverKiroCooldownPool(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) bool {
if s == nil || s.kiroCooldownStore == nil || ctx.Value(kiroCooldownRecoveryAttemptedKey) == true {
return false
}
tokenKeys := s.kiroTransientCooldownRecoveryKeys(ctx, accounts, requestedModel, excludedIDs, allowMixedScheduling)
if len(tokenKeys) == 0 {
return false
}
cleared, err := s.kiroCooldownStore.ClearEarliestTransientCooldown(ctx, tokenKeys)
if err != nil {
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery failed: %v", err)
return false
}
if cleared {
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery cleared one transient cooldown")
}
return cleared
}
func (s *GatewayService) kiroTransientCooldownRecoveryKeys(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) []string {
tokenKeys := make([]string, 0, len(accounts))
eligible := 0
for i := range accounts {
acc := &accounts[i]
if acc == nil || acc.Platform != PlatformKiro || acc.Type != AccountTypeOAuth {
if allowMixedScheduling {
continue
}
return nil
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) ||
!s.isAccountSchedulableForWindowCost(ctx, acc, false) ||
!s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
eligible++
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc))
if err != nil || state == nil || !state.Active {
return nil
}
if state.Reason != kirocooldown.CooldownReason429 {
return nil
}
tokenKeys = append(tokenKeys, buildKiroAccountKey(acc))
}
if eligible == 0 || len(tokenKeys) != eligible {
return nil
}
return tokenKeys
} }
// isAccountInGroup checks if the account belongs to the specified group. // isAccountInGroup checks if the account belongs to the specified group.
@@ -3234,6 +3325,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if selected == nil { if selected == nil {
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, false) {
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
return s.selectAccountForModelWithPlatform(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
if requestedModel != "" { if requestedModel != "" {
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
} }
@@ -3613,6 +3708,17 @@ func (s *GatewayService) diagnoseSelectionFailure(
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
return selectionFailureDiagnosis{Category: "excluded"} return selectionFailureDiagnosis{Category: "excluded"}
} }
if !acc.IsSchedulable() {
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
if acc.Platform == PlatformKiro && acc.Type == AccountTypeOAuth {
if state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc)); err == nil && state != nil && state.Active {
return selectionFailureDiagnosis{
Category: "unschedulable",
Detail: fmt.Sprintf("kiro_runtime_%s remaining=%s", state.Reason, state.Remaining.Truncate(time.Second)),
}
}
}
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
} }
@@ -3776,6 +3882,13 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
} }
return accessToken, "oauth", nil return accessToken, "oauth", nil
} }
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth && s.kiroTokenProvider != nil {
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 其他情况(Gemini 有自己的 TokenProvidersetup-token 类型等)直接从账号读取 // 其他情况(Gemini 有自己的 TokenProvidersetup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
@@ -4319,11 +4432,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request") return nil, fmt.Errorf("parse request: empty request")
} }
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body passthroughBody := parsed.Body
passthroughModel := parsed.Model passthroughModel := parsed.Model
@@ -4347,6 +4455,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return s.forwardBedrock(ctx, c, account, parsed, startTime) return s.forwardBedrock(ctx, c, account, parsed, startTime)
} }
if account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
return s.forwardKiroMessages(ctx, c, account, parsed, startTime)
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account. // Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil { if account.Platform == PlatformAnthropic && c != nil {
@@ -4439,7 +4556,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel := reqModel mappedModel := reqModel
mappingSource := "" mappingSource := ""
if account.Type == AccountTypeAPIKey { if account.Platform == PlatformKiro {
if next := account.GetMappedModel(reqModel); next != "" && next != reqModel {
mappedModel = next
mappingSource = "account"
}
} else if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel) mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
mappingSource = "account" mappingSource = "account"
@@ -5938,6 +6060,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
if baseURL == "" && account.Platform == PlatformKiro {
return nil, fmt.Errorf("kiro api key account requires base_url")
}
if baseURL != "" { if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL) validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil { if err != nil {
@@ -7199,10 +7324,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0) keepaliveInterval := s.streamKeepaliveIntervalForAccount(account)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 { if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval) keepaliveTicker = time.NewTicker(keepaliveInterval)
@@ -8241,6 +8363,9 @@ type recordUsageOpts struct {
// 长上下文计费(仅 Gemini 路径需要) // 长上下文计费(仅 Gemini 路径需要)
LongContextThreshold int LongContextThreshold int
LongContextMultiplier float64 LongContextMultiplier float64
// Kiro 账号在上游返回 auto 等无法定价模型时使用保守计费兜底。
IsKiroAccount bool
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -8377,6 +8502,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
} }
// 计算费用 // 计算费用
opts.IsKiroAccount = account != nil && account.Platform == PlatformKiro
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式 // 判断计费方式:订阅模式 vs 余额模式
@@ -8454,6 +8580,28 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
} }
const kiroConservativeFallbackBillingModel = "claude-opus-4-6"
func shouldUseKiroConservativeBillingFallback(result *ForwardResult, billingModel string, opts *recordUsageOpts) bool {
if result == nil {
return false
}
return opts != nil && opts.IsKiroAccount
}
func (s *GatewayService) calculateKiroConservativeTokenCost(tokens UsageTokens, multiplier float64) *CostBreakdown {
if s == nil || s.billingService == nil {
return nil
}
cost, err := s.billingService.CalculateCost(kiroConservativeFallbackBillingModel, tokens, multiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate conservative Kiro fallback cost failed: %v", err)
return nil
}
return cost
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -8557,6 +8705,12 @@ func (s *GatewayService) calculateTokenCost(
} }
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
if shouldUseKiroConservativeBillingFallback(result, billingModel, opts) {
if fallback := s.calculateKiroConservativeTokenCost(tokens, multiplier); fallback != nil {
logger.LegacyPrintf("service.gateway", "Using conservative Kiro fallback pricing for model=%s", billingModel)
return fallback
}
}
return &CostBreakdown{ActualCost: 0} return &CostBreakdown{ActualCost: 0}
} }
return cost return cost
@@ -9444,6 +9598,19 @@ func reconcileCachedTokens(usage map[string]any) bool {
return true return true
} }
func (s *GatewayService) streamKeepaliveIntervalForAccount(account *Account) time.Duration {
if account != nil && account.Platform == PlatformKiro {
if s != nil && s.cfg != nil && s.cfg.Gateway.KiroStreamKeepaliveInterval > 0 {
return time.Duration(s.cfg.Gateway.KiroStreamKeepaliveInterval) * time.Second
}
return defaultKiroStreamKeepalive
}
if s != nil && s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
return time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
return 0
}
const debugGatewayBodyDefaultFilename = "gateway_debug.log" const debugGatewayBodyDefaultFilename = "gateway_debug.log"
// initDebugGatewayBodyFile 初始化网关调试日志文件。 // initDebugGatewayBodyFile 初始化网关调试日志文件。
@@ -0,0 +1,222 @@
package service
import (
"errors"
"net"
"net/http"
"strings"
"github.com/tidwall/gjson"
)
const (
kiroErrorAuthError = "auth_error"
kiroErrorMonthlyRequest = "monthly_request_count"
kiroErrorProfileError = "profile_error"
kiroErrorQuotaExhausted = "quota_exhausted"
kiroErrorOverageExhausted = "overage_exhausted"
kiroErrorRateLimited = "rate_limited"
kiroErrorSuspended = "suspended"
kiroErrorUsageForbidden = "usage_forbidden"
kiroErrorUpstreamTransient = "upstream_transient"
kiroErrorBadRequestSchema = "bad_request_schema"
kiroErrorBadRequestToolPairing = "bad_request_tool_pairing"
kiroErrorBadRequestInvalidModel = "bad_request_invalid_model"
kiroErrorBadRequestAuth = "bad_request_auth"
kiroErrorBadRequestQuota = "bad_request_quota"
kiroErrorBadRequestUnknown = "bad_request_unknown"
kiroErrorRefreshTokenInvalid = "refresh_token_invalid"
kiroQuotaStateNormal = "normal"
kiroQuotaStateOverageActive = "overage_active"
kiroQuotaStateCreditsExhausted = "credits_exhausted"
kiroQuotaStateOverageExhausted = "overage_exhausted"
)
type kiroErrorClassification struct {
Category string
StatusCode int
Message string
}
func classifyKiroHTTPError(statusCode int, body string) kiroErrorClassification {
trimmed := strings.TrimSpace(body)
lower := strings.ToLower(trimmed)
switch {
case statusCode == http.StatusUnauthorized:
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusPaymentRequired && looksLikeKiroMonthlyRequestCountError(trimmed):
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusForbidden && isKiroSuspendedBody([]byte(trimmed)):
return kiroErrorClassification{Category: kiroErrorSuspended, StatusCode: statusCode, Message: trimmed}
case looksLikeKiroProfileError(lower):
return kiroErrorClassification{Category: kiroErrorProfileError, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusBadRequest:
return classifyKiroBadRequest(trimmed, lower)
case statusCode == http.StatusForbidden && isKiroTokenErrorBody([]byte(trimmed)):
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
case looksLikeKiroOverageExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorOverageExhausted, StatusCode: statusCode, Message: trimmed}
case looksLikeKiroQuotaExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusTooManyRequests:
return kiroErrorClassification{Category: kiroErrorRateLimited, StatusCode: statusCode, Message: trimmed}
case statusCode == http.StatusForbidden:
return kiroErrorClassification{Category: kiroErrorUsageForbidden, StatusCode: statusCode, Message: trimmed}
case statusCode >= 500:
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
default:
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
}
}
func classifyKiroError(err error) kiroErrorClassification {
if err == nil {
return kiroErrorClassification{}
}
var httpErr *kiroUsageHTTPError
if errors.As(err, &httpErr) && httpErr != nil {
return classifyKiroHTTPError(httpErr.StatusCode, httpErr.Body)
}
errStr := strings.TrimSpace(err.Error())
lower := strings.ToLower(errStr)
switch {
case looksLikeKiroInvalidGrantError(lower):
return kiroErrorClassification{Category: kiroErrorRefreshTokenInvalid, Message: errStr}
case looksLikeKiroMonthlyRequestCountError(errStr):
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, Message: errStr}
case looksLikeKiroProfileError(lower):
return kiroErrorClassification{Category: kiroErrorProfileError, Message: errStr}
case looksLikeKiroOverageExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorOverageExhausted, Message: errStr}
case looksLikeKiroQuotaExhaustedError(lower):
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, Message: errStr}
case strings.Contains(lower, "context deadline exceeded"),
strings.Contains(lower, "timeout"),
isNetErr(err):
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
default:
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
}
}
func classifyKiroBadRequest(trimmed, lower string) kiroErrorClassification {
switch {
case looksLikeKiroBadRequestSchemaError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestSchema, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroBadRequestToolPairingError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestToolPairing, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroBadRequestInvalidModelError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestInvalidModel, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroInvalidGrantError(lower) || looksLikeKiroBadRequestAuthError(lower):
return kiroErrorClassification{Category: kiroErrorBadRequestAuth, StatusCode: http.StatusBadRequest, Message: trimmed}
case looksLikeKiroQuotaExhaustedError(lower) || looksLikeKiroMonthlyRequestCountError(trimmed):
return kiroErrorClassification{Category: kiroErrorBadRequestQuota, StatusCode: http.StatusBadRequest, Message: trimmed}
default:
return kiroErrorClassification{Category: kiroErrorBadRequestUnknown, StatusCode: http.StatusBadRequest, Message: trimmed}
}
}
func looksLikeKiroBadRequestSchemaError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "schema") ||
strings.Contains(lower, "inputschema") ||
strings.Contains(lower, "improperly formed request") ||
strings.Contains(lower, "additionalproperties") ||
(strings.Contains(lower, "properties") && strings.Contains(lower, "required"))
}
func looksLikeKiroBadRequestToolPairingError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "tool_use") ||
strings.Contains(lower, "tool_result") ||
strings.Contains(lower, "tooluseid") ||
strings.Contains(lower, "toolresults") ||
strings.Contains(lower, "must be paired") ||
strings.Contains(lower, "missing tool result")
}
func looksLikeKiroBadRequestInvalidModelError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "invalid model") ||
strings.Contains(lower, "invalid_model_id") ||
strings.Contains(lower, "model not supported") ||
strings.Contains(lower, "unsupportedmodel") ||
strings.Contains(lower, "modelid")
}
func looksLikeKiroBadRequestAuthError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "invalid token") ||
strings.Contains(lower, "expired token") ||
strings.Contains(lower, "access token") ||
strings.Contains(lower, "refresh token")
}
func looksLikeKiroInvalidGrantError(lower string) bool {
return strings.Contains(lower, "invalid_grant")
}
func looksLikeKiroMonthlyRequestCountError(body string) bool {
trimmed := strings.TrimSpace(body)
if trimmed == "" {
return false
}
if strings.Contains(trimmed, "MONTHLY_REQUEST_COUNT") {
return true
}
if !gjson.Valid(trimmed) {
return false
}
return gjson.Get(trimmed, "reason").String() == "MONTHLY_REQUEST_COUNT" ||
gjson.Get(trimmed, "error.reason").String() == "MONTHLY_REQUEST_COUNT"
}
func looksLikeKiroProfileError(lower string) bool {
if lower == "" {
return false
}
return (strings.Contains(lower, "profilearn") && strings.Contains(lower, "required")) ||
(strings.Contains(lower, "profile arn") && strings.Contains(lower, "required")) ||
(strings.Contains(lower, "profile") && strings.Contains(lower, "not found")) ||
(strings.Contains(lower, "invalid profile")) ||
(strings.Contains(lower, "listavailableprofiles"))
}
func looksLikeKiroQuotaExhaustedError(lower string) bool {
if lower == "" {
return false
}
return (strings.Contains(lower, "credit") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "depleted"))) ||
(strings.Contains(lower, "quota") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "exceeded") || strings.Contains(lower, "depleted"))) ||
(strings.Contains(lower, "usage limit") && (strings.Contains(lower, "reached") || strings.Contains(lower, "exceeded"))) ||
(strings.Contains(lower, "resource has been exhausted"))
}
func looksLikeKiroOverageExhaustedError(lower string) bool {
if lower == "" {
return false
}
return strings.Contains(lower, "overage") &&
(strings.Contains(lower, "exhaust") ||
strings.Contains(lower, "disabled") ||
strings.Contains(lower, "not enabled") ||
strings.Contains(lower, "not allowed") ||
strings.Contains(lower, "limit"))
}
func isNetErr(err error) bool {
var netErr net.Error
return errors.As(err, &netErr)
}
@@ -0,0 +1,66 @@
package service
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestClassifyKiroHTTPErrorBadRequestCategories(t *testing.T) {
tests := []struct {
name string
body string
want string
}{
{
name: "schema",
body: `{"message":"Improperly formed request: inputSchema.properties must be an object"}`,
want: kiroErrorBadRequestSchema,
},
{
name: "tool pairing",
body: `{"message":"tool_use must be paired with a matching tool_result"}`,
want: kiroErrorBadRequestToolPairing,
},
{
name: "invalid model id",
body: `{"message":"invalid modelId: model not supported"}`,
want: kiroErrorBadRequestInvalidModel,
},
{
name: "invalid model upstream",
body: `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
want: kiroErrorBadRequestInvalidModel,
},
{
name: "invalid model reason",
body: `{"message":"model route unavailable","reason":"INVALID_MODEL_ID"}`,
want: kiroErrorBadRequestInvalidModel,
},
{
name: "auth",
body: `{"error":"invalid_grant","message":"Invalid refresh token provided"}`,
want: kiroErrorBadRequestAuth,
},
{
name: "quota",
body: `{"message":"resource has been exhausted"}`,
want: kiroErrorBadRequestQuota,
},
{
name: "unknown",
body: `{"message":"bad request"}`,
want: kiroErrorBadRequestUnknown,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
classification := classifyKiroHTTPError(http.StatusBadRequest, tt.body)
require.Equal(t, tt.want, classification.Category)
require.Equal(t, http.StatusBadRequest, classification.StatusCode)
require.Equal(t, tt.body, classification.Message)
})
}
}
@@ -0,0 +1,180 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/google/uuid"
)
func buildKiroAccountKey(account *Account) string {
if account == nil {
return ""
}
return kiropkg.BuildAccountKey(
account.GetCredential("client_id"),
account.GetCredential("client_id_hash"),
account.GetCredential("refresh_token"),
account.GetCredential("profile_arn"),
account.ID,
)
}
func buildKiroMachineID(account *Account) string {
if account == nil {
return kiropkg.BuildMachineID("", "", "account:nil")
}
for _, key := range []string{"machine_id", "machineId"} {
if machineID, ok := kiropkg.NormalizeMachineID(account.GetCredential(key)); ok {
return machineID
}
}
fallbackKey := buildKiroMachineIDFallbackKey(account)
if account.Type == AccountTypeAPIKey {
return kiropkg.BuildMachineID("", firstKiroCredential(account, "kiro_api_key", "kiroApiKey", "api_key"), fallbackKey)
}
return kiropkg.BuildMachineID(account.GetCredential("refresh_token"), "", fallbackKey)
}
func firstKiroCredential(account *Account, keys ...string) string {
if account == nil {
return ""
}
for _, key := range keys {
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
return value
}
}
return ""
}
func buildKiroMachineIDFallbackKey(account *Account) string {
if account == nil {
return "account:nil"
}
if account.ID > 0 {
return fmt.Sprintf("account:%d", account.ID)
}
for _, key := range []string{"client_id", "profile_arn"} {
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
return key + ":" + value
}
}
if name := strings.TrimSpace(account.Name); name != "" {
return "name:" + name
}
return "account:unknown"
}
func buildKiroRequestID(resp *http.Response) string {
if resp == nil {
return ""
}
if requestID := strings.TrimSpace(resp.Header.Get("x-request-id")); requestID != "" {
return requestID
}
if requestID := strings.TrimSpace(resp.Header.Get("x-amzn-requestid")); requestID != "" {
return requestID
}
return strings.TrimSpace(resp.Header.Get("x-amz-request-id"))
}
func isKiroInvalidModelIDBody(respBody []byte) bool {
var payload struct {
Reason string `json:"reason"`
Message string `json:"message"`
Error struct {
Reason string `json:"reason"`
Message string `json:"message"`
} `json:"error"`
}
if json.Unmarshal(respBody, &payload) != nil {
return looksLikeKiroBadRequestInvalidModelError(strings.ToLower(string(respBody)))
}
return strings.EqualFold(strings.TrimSpace(payload.Reason), "INVALID_MODEL_ID") ||
strings.EqualFold(strings.TrimSpace(payload.Error.Reason), "INVALID_MODEL_ID") ||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Message)) ||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Error.Message))
}
func isKiroSuspendedBody(respBody []byte) bool {
body := string(respBody)
return strings.Contains(body, "SUSPENDED") || strings.Contains(body, "TEMPORARILY_SUSPENDED")
}
func isKiroTokenErrorBody(respBody []byte) bool {
lower := strings.ToLower(string(respBody))
return strings.Contains(lower, "token") ||
strings.Contains(lower, "expired") ||
strings.Contains(lower, "invalid") ||
strings.Contains(lower, "unauthorized")
}
func kiroProxyURL(account *Account) string {
if account != nil && account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
func kiroAPIRegion(account *Account) string {
if account == nil {
return kiroDefaultRegion
}
region := strings.TrimSpace(account.GetCredential("api_region"))
if region == "" {
region = kiroDefaultRegion
}
return region
}
func applyKiroConditionalHeaders(req *http.Request, account *Account) {
if req == nil || account == nil {
return
}
if strings.EqualFold(strings.TrimSpace(account.GetCredential("auth_method")), "external_idp") {
req.Header.Set("TokenType", "EXTERNAL_IDP")
}
if strings.EqualFold(strings.TrimSpace(account.GetCredential("provider")), "Internal") {
req.Header.Set("redirect-for-internal", "true")
}
}
func resolveKiroPayloadProfileArn(account *Account) string {
if account == nil {
return ""
}
return strings.TrimSpace(account.GetCredential("profile_arn"))
}
func newKiroJSONRequest(ctx context.Context, endpointURL string, payload []byte, token, accountKey, machineID, amzTarget string, account *Account) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
if amzTarget != "" {
req.Header.Set("X-Amz-Target", amzTarget)
}
if account != nil {
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
if profileArn != "" {
req.Header.Set("x-amzn-kiro-profile-arn", profileArn)
}
}
applyKiroConditionalHeaders(req, account)
return req, nil
}
@@ -0,0 +1,208 @@
//go:build unit
package service
import (
"context"
"net/http"
"strings"
"testing"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/stretchr/testify/require"
)
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
accountA := &Account{
ID: 99,
Credentials: map[string]any{
"access_token": "token-a",
},
}
accountB := &Account{
ID: 99,
Credentials: map[string]any{
"access_token": "token-b",
},
}
require.Equal(t, buildKiroAccountKey(accountA), buildKiroAccountKey(accountB))
}
func TestBuildKiroMachineIDPrefersExplicitCredential(t *testing.T) {
account := &Account{
ID: 101,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"machineId": "2582956e-cc88-4669-b546-07adbffcb894",
"refresh_token": "refresh-token",
},
}
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", buildKiroMachineID(account))
}
func TestBuildKiroMachineIDDerivesFromRefreshToken(t *testing.T) {
account := &Account{
ID: 102,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token",
},
}
require.Equal(t, kiropkg.BuildMachineID("refresh-token", "", "account:102"), buildKiroMachineID(account))
}
func TestBuildKiroMachineIDDerivesFromAPIKeyAccount(t *testing.T) {
account := &Account{
ID: 103,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"kiroApiKey": "kiro-api-key",
},
}
require.Equal(t, kiropkg.BuildMachineID("", "kiro-api-key", "account:103"), buildKiroMachineID(account))
}
func TestNewKiroJSONRequestAddsConditionalHeaders(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"auth_method": "external_idp",
"provider": "Internal",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER",
},
}
req, err := newKiroJSONRequest(
context.Background(),
"https://q.us-east-1.amazonaws.com/generateAssistantResponse",
[]byte(`{"ok":true}`),
"access-token",
"account-key",
buildKiroMachineID(account),
"",
account,
)
require.NoError(t, err)
require.Equal(t, "EXTERNAL_IDP", req.Header.Get("TokenType"))
require.Equal(t, "true", req.Header.Get("redirect-for-internal"))
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER", req.Header.Get("x-amzn-kiro-profile-arn"))
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
require.Equal(t, "true", req.Header.Get("x-amzn-codewhisperer-optout"))
require.Contains(t, req.Header.Get("User-Agent"), "aws-sdk-js/1.0.34")
require.Contains(t, req.Header.Get("User-Agent"), "md/nodejs#22.22.0")
require.Contains(t, req.Header.Get("User-Agent"), buildKiroMachineID(account))
require.Contains(t, req.Header.Get("X-Amz-User-Agent"), buildKiroMachineID(account))
require.True(t, strings.Contains(req.Header.Get("User-Agent"), "api/codewhispererstreaming#1.0.34"))
require.Empty(t, req.Header.Get("Anthropic-Beta"))
}
func TestIsKiroInvalidModelIDBodyRecognizesKnownForms(t *testing.T) {
tests := []string{
`{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`,
`{"message":"Invalid model. Please select a different model to continue."}`,
`API Error: 400 {"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
}
for _, body := range tests {
require.True(t, isKiroInvalidModelIDBody([]byte(body)), body)
}
}
func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
account := &Account{
ID: 7,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/test",
},
}
body := []byte(`{
"model":"claude-sonnet-4-6",
"messages":[{"role":"user","content":"hello"}]
}`)
headers := http.Header{}
headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-sonnet-4.6",
"kiro-access-token",
"claude-sonnet-4-6",
headers,
)
require.NoError(t, err)
require.NotContains(t, string(payload), "CHUNKED WRITE PROTOCOL")
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
}
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"api_region": "eu-west-1",
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
"region": "ap-northeast-1",
},
}
require.Equal(t, "eu-west-1", kiroAPIRegion(account))
}
func TestKiroAPIRegionIgnoresProfileARNRegionFallback(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
},
}
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
}
func TestKiroAPIRegionIgnoresOIDCRegionFallback(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"region": "ap-northeast-2",
},
}
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
}
func TestBuildKiroEndpointsUsesOnlyAmazonQEndpoint(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"api_region": "us-west-2",
"preferred_endpoint": "cw",
},
}
endpoints := buildKiroEndpoints(account)
require.Len(t, endpoints, 1)
require.Equal(t, "AmazonQ", endpoints[0].Name)
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
require.Empty(t, endpoints[0].AmzTarget)
}
func TestBuildKiroEndpointsIgnoresPreferredEndpoint(t *testing.T) {
for _, preferred := range []string{"codewhisperer", "cw", "unknown"} {
account := &Account{
Credentials: map[string]any{
"api_region": "us-west-2",
"preferred_endpoint": preferred,
},
}
endpoints := buildKiroEndpoints(account)
require.Len(t, endpoints, 1)
require.Equal(t, "AmazonQ", endpoints[0].Name)
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
}
}
@@ -0,0 +1,74 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestAccountKiroDefaultMappingRestrictsUnsupportedModels(t *testing.T) {
account := &Account{Platform: PlatformKiro}
require.False(t, account.IsModelSupported("gpt-4o"))
require.False(t, account.IsModelSupported("kiro-gpt-4o"))
require.False(t, account.IsModelSupported("auto"))
require.Equal(t, "claude-sonnet-4.6", account.GetMappedModel("claude-sonnet-4-6"))
}
func TestGatewayServiceCalculateTokenCost_KiroAutoUsesConservativeFallback(t *testing.T) {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
svc := NewGatewayService(
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
result := &ForwardResult{
Model: "auto",
UpstreamModel: "auto",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
}
expected, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
cost := svc.calculateTokenCost(context.Background(), result, &APIKey{}, "auto", 1.1, &recordUsageOpts{IsKiroAccount: true})
require.NotNil(t, cost)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-12)
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-12)
}
@@ -0,0 +1,369 @@
package service
import (
"context"
"fmt"
"strings"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
const (
// Kiro desktop social auth uses localhost loopback callbacks from a fixed
// allowlist. Use one of the bundled ports from the official client.
kiroSocialRedirectURI = "http://localhost:49153"
// AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
)
type KiroOAuthService struct {
sessionStore *kiropkg.SessionStore
proxyRepo ProxyRepository
}
func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
return &KiroOAuthService{
sessionStore: kiropkg.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
func (s *KiroOAuthService) Stop() {}
type KiroAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
type KiroIDCAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
ClientID string `json:"client_id"`
Region string `json:"region"`
StartURL string `json:"start_url"`
}
type KiroTokenInfo struct {
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ProfileArn string `json:"profile_arn,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
AuthMethod string `json:"auth_method,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ClientIDHash string `json:"client_id_hash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"start_url,omitempty"`
Region string `json:"region,omitempty"`
}
type KiroGenerateAuthURLInput struct {
ProxyID *int64
Provider string
}
type KiroExchangeCodeInput struct {
SessionID string
State string
Code string
CallbackPath string
LoginOption string
ProxyID *int64
}
type KiroGenerateIDCAuthURLInput struct {
ProxyID *int64
StartURL string
Region string
}
type KiroRefreshTokenInput struct {
RefreshToken string
AuthMethod string
Provider string
ClientID string
ClientSecret string
StartURL string
Region string
ProfileArn string
ProxyID *int64
}
type KiroImportTokenInput struct {
TokenJSON string
DeviceRegistrationJSON string
}
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
provider := strings.TrimSpace(input.Provider)
if provider == "" {
provider = string(kiropkg.SocialProviderGoogle)
}
if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
}
state, err := kiropkg.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
codeVerifier, err := kiropkg.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
sessionID := kiropkg.GenerateSessionID()
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
AuthType: "social",
Provider: provider,
RedirectURI: kiroSocialRedirectURI,
})
return &KiroAuthURLResult{
AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
SessionID: sessionID,
State: state,
}, nil
}
func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state invalid")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
}
switch session.AuthType {
case "social":
token, err := kiropkg.CreateSocialToken(
ctx,
proxyURL,
input.Code,
session.CodeVerifier,
buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
)
if err != nil {
return nil, err
}
token.Provider = session.Provider
s.sessionStore.Delete(input.SessionID)
return toKiroTokenInfo(token), nil
case "idc":
token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
if err != nil {
return nil, err
}
s.sessionStore.Delete(input.SessionID)
return toKiroTokenInfo(token), nil
default:
return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
}
}
func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
option := strings.ToLower(strings.TrimSpace(loginOption))
if option == "" {
switch provider {
case string(kiropkg.SocialProviderGitHub):
option = "github"
case string(kiropkg.SocialProviderGoogle):
option = "google"
}
}
return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
}
func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
startURL := strings.TrimSpace(input.StartURL)
if startURL == "" {
startURL = kiropkg.BuilderIDStartURL
}
region := strings.TrimSpace(input.Region)
if region == "" {
region = "us-east-1"
}
state, err := kiropkg.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
codeVerifier, err := kiropkg.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
if err != nil {
return nil, err
}
sessionID := kiropkg.GenerateSessionID()
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
AuthType: "idc",
RedirectURI: kiroIDCRedirectURI,
ClientID: reg.ClientID,
ClientSecret: reg.ClientSecret,
Region: region,
StartURL: startURL,
})
return &KiroIDCAuthURLResult{
AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
SessionID: sessionID,
State: state,
ClientID: reg.ClientID,
Region: region,
StartURL: startURL,
}, nil
}
func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
if authMethod == "" {
authMethod = "social"
}
var token *kiropkg.TokenData
var err error
switch authMethod {
case "idc":
token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
default:
token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
}
if err != nil {
return nil, err
}
if token.ProfileArn == "" {
token.ProfileArn = input.ProfileArn
}
if token.ClientID == "" {
token.ClientID = input.ClientID
}
if token.ClientSecret == "" {
token.ClientSecret = input.ClientSecret
}
if token.StartURL == "" {
token.StartURL = input.StartURL
}
if token.Region == "" {
token.Region = input.Region
}
return toKiroTokenInfo(token), nil
}
func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("not a kiro oauth account")
}
return s.RefreshToken(ctx, &KiroRefreshTokenInput{
RefreshToken: account.GetCredential("refresh_token"),
AuthMethod: account.GetCredential("auth_method"),
Provider: account.GetCredential("provider"),
ClientID: account.GetCredential("client_id"),
ClientSecret: account.GetCredential("client_secret"),
StartURL: account.GetCredential("start_url"),
Region: account.GetCredential("region"),
ProfileArn: account.GetCredential("profile_arn"),
ProxyID: account.ProxyID,
})
}
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
if err != nil {
return nil, err
}
return toKiroTokenInfo(token), nil
}
func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
if tokenInfo == nil {
return map[string]any{}
}
creds := map[string]any{}
if tokenInfo.AccessToken != "" {
creds["access_token"] = tokenInfo.AccessToken
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.ProfileArn != "" {
creds["profile_arn"] = tokenInfo.ProfileArn
}
if tokenInfo.ExpiresAt != "" {
creds["expires_at"] = tokenInfo.ExpiresAt
}
if tokenInfo.AuthMethod != "" {
creds["auth_method"] = tokenInfo.AuthMethod
}
if tokenInfo.Provider != "" {
creds["provider"] = tokenInfo.Provider
}
if tokenInfo.ClientID != "" {
creds["client_id"] = tokenInfo.ClientID
}
if tokenInfo.ClientSecret != "" {
creds["client_secret"] = tokenInfo.ClientSecret
}
if tokenInfo.ClientIDHash != "" {
creds["client_id_hash"] = tokenInfo.ClientIDHash
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.StartURL != "" {
creds["start_url"] = tokenInfo.StartURL
}
if tokenInfo.Region != "" {
creds["region"] = tokenInfo.Region
}
return creds
}
func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
if token == nil {
return nil
}
return &KiroTokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ProfileArn: token.ProfileArn,
ExpiresAt: token.ExpiresAt,
AuthMethod: token.AuthMethod,
Provider: token.Provider,
ClientID: token.ClientID,
ClientSecret: token.ClientSecret,
ClientIDHash: token.ClientIDHash,
Email: token.Email,
StartURL: token.StartURL,
Region: token.Region,
}
}
func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil || s.proxyRepo == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil || proxy == nil {
return "", err
}
return proxy.URL(), nil
}
@@ -0,0 +1,51 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/stretchr/testify/require"
)
func TestKiroIDCAuthRedirectURIUsesLoopbackIP(t *testing.T) {
require.Equal(t, "http://127.0.0.1:9876/oauth/callback", kiroIDCRedirectURI)
}
func TestKiroSocialAuthRedirectURIUsesLoopbackIP(t *testing.T) {
require.Equal(t, "http://localhost:49153", kiroSocialRedirectURI)
}
func TestBuildKiroSocialExchangeRedirectURIUsesProviderDefault(t *testing.T) {
require.Equal(
t,
"http://localhost:49153/oauth/callback?login_option=github",
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "", ""),
)
}
func TestBuildKiroSocialExchangeRedirectURIPreservesParsedCallbackData(t *testing.T) {
require.Equal(
t,
"http://localhost:49153/signin/callback?login_option=google",
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "/signin/callback", "google"),
)
}
func TestKiroOAuthService_ExchangeCodeRejectsExpiredSession(t *testing.T) {
svc := NewKiroOAuthService(nil)
svc.sessionStore.Set("expired-session", &kiropkg.AuthSession{
State: "expected-state",
CreatedAt: time.Now().Add(-11 * time.Minute),
})
_, err := svc.ExchangeCode(context.Background(), &KiroExchangeCodeInput{
SessionID: "expired-session",
State: "expected-state",
Code: "auth-code",
})
require.EqualError(t, err, "session not found or expired")
}
+724
View File
@@ -0,0 +1,724 @@
package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
mathrand "math/rand"
"net/http"
"strings"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
type kiroEndpointConfig struct {
URL string
AmzTarget string
Name string
}
const kiroInvalidModelTempUnschedDuration = time.Minute
const (
kiroRetryBaseDelay = 200 * time.Millisecond
kiroRetryMaxDelay = 2 * time.Second
)
var kiroRetrySleep = sleepWithContext
func kiroRetryBackoffDelay(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := kiroRetryBaseDelay * time.Duration(1<<attempt)
if delay > kiroRetryMaxDelay {
delay = kiroRetryMaxDelay
}
jitterMax := delay / 4
if jitterMax <= 0 {
return delay
}
return delay + time.Duration(mathrand.Int63n(int64(jitterMax)+1))
}
func sleepKiroRetry(ctx context.Context, attempt int) error {
return kiroRetrySleep(ctx, kiroRetryBackoffDelay(attempt))
}
func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) {
if account == nil || parsed == nil {
return nil, fmt.Errorf("kiro forward: missing account or request")
}
originalModel := parsed.Model
mappedModel := originalModel
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
body := parsed.Body
if mappedModel != originalModel {
body = s.replaceModelInBody(body, mappedModel)
}
logger.L().Debug("gateway forward_kiro_messages: request prepared",
zap.Int64("account_id", account.ID),
zap.String("auth_method", strings.TrimSpace(account.GetCredential("auth_method"))),
zap.String("requested_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("has_profile_arn", strings.TrimSpace(account.GetCredential("profile_arn")) != ""),
)
if s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, body) {
parsedForEmulation := *parsed
parsedForEmulation.Body = body
return s.handleWebSearchEmulation(ctx, c, account, &parsedForEmulation)
}
if parsed.Stream {
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: failoverErr.StatusCode,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(err.Error()),
})
return nil, failoverErr
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
}
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel, false)
if err != nil {
return nil, err
}
if streamResult.usage == nil {
streamResult.usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *streamResult.usage,
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: streamResult.firstTokenMs,
ClientDisconnect: streamResult.clientDisconnect,
}, nil
}
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
if tokenType != "oauth" {
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
}
if isOnlyWebSearchToolInBody(body) {
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
switch {
case errors.Is(webSearchErr, errKiroWebSearchFallback):
case webSearchErr == nil:
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
c.Header("Content-Type", "application/json")
if webSearchResult.RequestID != "" {
c.Header("x-request-id", webSearchResult.RequestID)
}
c.Data(http.StatusOK, "application/json", webSearchResult.ResponseBody)
return &ForwardResult{
RequestID: webSearchResult.RequestID,
Usage: webSearchResult.Usage,
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
default:
var httpErr *kiroWebSearchHTTPError
if errors.As(webSearchErr, &httpErr) && httpErr.Response != nil {
return nil, s.handleKiroHTTPError(ctx, httpErr.Response, c, account, mappedModel, body)
}
var failoverErr *UpstreamFailoverError
if errors.As(webSearchErr, &failoverErr) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: failoverErr.StatusCode,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(webSearchErr.Error()),
})
return nil, failoverErr
}
safeErr := sanitizeUpstreamErrorMessage(webSearchErr.Error())
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
}
}
inputTokens := estimateKiroInputTokens(body)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: failoverErr.StatusCode,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(err.Error()),
})
return nil, failoverErr
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
}
parseResult, err := kiropkg.ParseNonStreamingEventStreamWithContext(resp.Body, mappedModel, requestCtx)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Failed to parse Kiro upstream response",
},
})
return nil, err
}
c.Header("Content-Type", "application/json")
if requestID := resp.Header.Get("x-request-id"); requestID != "" {
c.Header("x-request-id", requestID)
}
c.Data(http.StatusOK, "application/json", parseResult.ResponseBody)
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
Model: originalModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, 0, err
}
if tokenType != "oauth" {
return nil, 0, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
}
inputTokens := estimateKiroInputTokens(anthropicBody)
if isOnlyWebSearchToolInBody(anthropicBody) {
pr, pw := io.Pipe()
headers := make(http.Header)
headers.Set("Content-Type", "text/event-stream")
go func() {
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
}
_ = pw.Close()
}()
return &http.Response{
StatusCode: http.StatusOK,
Header: headers,
Body: pr,
}, inputTokens, nil
}
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
return nil, inputTokens, err
}
return nil, inputTokens, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return resp, inputTokens, nil
}
pr, pw := io.Pipe()
wrappedHeaders := resp.Header.Clone()
wrappedHeaders.Set("Content-Type", "text/event-stream")
if requestID := buildKiroRequestID(resp); requestID != "" {
wrappedHeaders.Set("x-request-id", requestID)
}
go func() {
defer func() { _ = resp.Body.Close() }()
_, streamErr := kiropkg.StreamEventStreamAsAnthropicWithContext(ctx, resp.Body, pw, mappedModel, inputTokens, requestCtx)
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
}
_ = pw.Close()
}()
return &http.Response{
StatusCode: resp.StatusCode,
Header: wrappedHeaders,
Body: pr,
}, inputTokens, nil
}
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
var requestCtx kiropkg.KiroRequestContext
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
return nil, requestCtx, failoverErr
}
return nil, requestCtx, err
}
modelID := kiropkg.MapModel(mappedModel)
currentToken := token
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
if err != nil {
return nil, requestCtx, err
}
payload := buildResult.Payload
requestCtx = buildResult.Context
endpoints := buildKiroEndpoints(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
accountKey := buildKiroAccountKey(account)
maxRetries := 2
for idx, endpoint := range endpoints {
for attempt := 0; attempt <= maxRetries; attempt++ {
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
if err != nil {
return nil, requestCtx, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
if attempt < maxRetries {
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
continue
}
return nil, requestCtx, err
}
if resp.StatusCode == http.StatusTooManyRequests {
cooldown, err := s.markKiro429(ctx, accountKey)
if err != nil {
_ = resp.Body.Close()
return nil, requestCtx, err
}
if idx+1 < len(endpoints) {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
break
}
resp.Header.Set("x-kiro-cooldown", cooldown.String())
return resp, requestCtx, nil
}
if resp.StatusCode == http.StatusRequestTimeout || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
if attempt < maxRetries {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
continue
}
if idx+1 < len(endpoints) {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
break
}
return resp, requestCtx, nil
}
if resp.StatusCode == http.StatusPaymentRequired {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, requestCtx, readErr
}
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
if classification.Category == kiroErrorMonthlyRequest {
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
}
return nil, requestCtx, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, requestCtx, readErr
}
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
return nil, requestCtx, err
}
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
if err != nil {
return nil, requestCtx, err
}
payload = buildResult.Payload
requestCtx = buildResult.Context
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, requestCtx, sleepErr
}
continue
}
if refreshErr != nil && isNonRetryableRefreshError(refreshErr) {
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
}
if classifyKiroHTTPError(resp.StatusCode, string(respBody)).Category == kiroErrorAuthError {
s.markKiroAuthTemporarilyUnavailable(ctx, account, resp.StatusCode, string(respBody))
}
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
if resp.StatusCode == http.StatusBadRequest {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, requestCtx, readErr
}
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
logKiroBadRequestClassification(classification, account, mappedModel, resp.Header, respBody)
resetHTTPResponseBody(resp, respBody)
return resp, requestCtx, nil
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, requestCtx, err
}
}
return resp, requestCtx, nil
}
}
return nil, requestCtx, fmt.Errorf("kiro upstream endpoints exhausted")
}
func buildKiroEndpoints(account *Account) []kiroEndpointConfig {
region := kiroAPIRegion(account)
return []kiroEndpointConfig{
{
URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
Name: "AmazonQ",
},
}
}
func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) ([]byte, error) {
result, err := buildKiroPayloadForAccountWithRepo(ctx, nil, account, anthropicBody, modelID, token, requestModel, headers)
if err != nil {
return nil, err
}
return result.Payload, nil
}
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
profileArn := resolveKiroPayloadProfileArn(account)
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
}
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
until := time.Now().Add(10 * time.Minute)
reason := fmt.Sprintf("kiro auth failure (%d): %s", statusCode, strings.TrimSpace(body))
_ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason)
}
func (s *GatewayService) markKiroMonthlyRequestCountRateLimited(ctx context.Context, account *Account, body string) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
resetAt := nextKiroMonthlyResetUTC(time.Now())
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
logger.L().Warn("kiro monthly request count rate-limit failed",
zap.Int64("account_id", account.ID),
zap.Time("reset_at", resetAt),
zap.Error(err),
)
return
}
reason := "kiro monthly request count exhausted (402): MONTHLY_REQUEST_COUNT"
if trimmed := strings.TrimSpace(body); trimmed != "" {
reason = fmt.Sprintf("%s body=%s", reason, truncateForLog([]byte(trimmed), 512))
}
logger.L().Warn("kiro monthly request count rate-limited",
zap.Int64("account_id", account.ID),
zap.Time("reset_at", resetAt),
zap.String("reason", reason),
)
}
func nextKiroMonthlyResetUTC(now time.Time) time.Time {
utc := now.UTC()
year, month, _ := utc.Date()
return time.Date(year, month+1, 1, 0, 0, 0, 0, time.UTC)
}
func resetHTTPResponseBody(resp *http.Response, body []byte) {
if resp == nil {
return
}
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))
}
func estimateKiroInputTokens(body []byte) int {
if len(body) == 0 {
return 0
}
if tokens := gjson.GetBytes(body, "metadata.input_tokens").Int(); tokens > 0 {
return int(tokens)
}
tokens := len(body) / 4
if tokens == 0 {
return 1
}
return tokens
}
func kiroUsageToClaude(usage kiropkg.Usage, fallbackInput int) ClaudeUsage {
inputTokens := usage.InputTokens
if inputTokens == 0 {
inputTokens = fallbackInput
}
return ClaudeUsage{
InputTokens: inputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
func (s *GatewayService) markKiroInvalidModelRateLimited(ctx context.Context, account *Account, mappedModel string) {
if s == nil || s.accountRepo == nil || account == nil || account.Type != AccountTypeOAuth {
return
}
resetAt := time.Now().Add(kiroInvalidModelTempUnschedDuration)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
logger.L().Warn("kiro invalid model rate-limit failed",
zap.Int64("account_id", account.ID),
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
zap.Time("reset_at", resetAt),
zap.Error(err),
)
return
}
logger.L().Warn("kiro invalid model rate-limited",
zap.Int64("account_id", account.ID),
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
zap.Time("reset_at", resetAt),
)
}
func (s *GatewayService) handleKiroHTTPError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, mappedModel string, requestBody []byte) error {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg == "" {
upstreamMsg = strings.TrimSpace(string(respBody))
}
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
if resp.StatusCode == http.StatusBadRequest {
logKiroBadRequestClassification(classification, account, "", resp.Header, respBody)
}
if classification.Category == kiroErrorMonthlyRequest {
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
}
if classification.Category == kiroErrorBadRequestInvalidModel && account != nil && account.Type == AccountTypeOAuth {
s.markKiroInvalidModelRateLimited(ctx, account, mappedModel)
event := s.buildKiroInvalidModelUpstreamEvent(account, resp, upstreamMsg, mappedModel, requestBody, c)
appendOpsUpstreamError(c, event)
return &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
if resp.StatusCode == http.StatusPaymentRequired || s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
})
c.JSON(mapUpstreamStatusCode(resp.StatusCode), gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": coalesceKiroErrorMessage(resp.StatusCode, upstreamMsg),
},
})
return fmt.Errorf("kiro upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
func (s *GatewayService) buildKiroInvalidModelUpstreamEvent(account *Account, resp *http.Response, upstreamMsg, mappedModel string, requestBody []byte, c *gin.Context) OpsUpstreamErrorEvent {
_ = s
requestedModel := strings.TrimSpace(gjson.GetBytes(requestBody, "model").String())
hasTools := gjson.GetBytes(requestBody, "tools").Exists()
hasAdaptiveThinking := strings.EqualFold(strings.TrimSpace(gjson.GetBytes(requestBody, "thinking.type").String()), "adaptive")
hasContext1MBeta := false
if c != nil {
hasContext1MBeta = strings.Contains(c.GetHeader("Anthropic-Beta"), "context-1m")
}
return OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
RequestedModel: requestedModel,
MappedModel: strings.TrimSpace(mappedModel),
KiroModelID: kiropkg.MapModel(mappedModel),
HasTools: hasTools,
HasAdaptiveThinking: hasAdaptiveThinking,
HasContext1MBeta: hasContext1MBeta,
}
}
func logKiroBadRequestClassification(classification kiroErrorClassification, account *Account, model string, headers http.Header, body []byte) {
if classification.StatusCode != http.StatusBadRequest {
return
}
var accountID int64
if account != nil {
accountID = account.ID
}
logger.L().Warn("kiro upstream bad request classified",
zap.String("category", classification.Category),
zap.Int("status", classification.StatusCode),
zap.Int64("account_id", accountID),
zap.String("model", strings.TrimSpace(model)),
zap.String("request_id", headers.Get("x-request-id")),
zap.String("body_excerpt", truncateForLog(body, 512)),
)
}
func coalesceKiroErrorMessage(statusCode int, upstreamMsg string) string {
if upstreamMsg != "" {
return upstreamMsg
}
switch statusCode {
case http.StatusTooManyRequests:
return "Rate limit exceeded"
case http.StatusForbidden:
return "Access denied"
case http.StatusUnauthorized:
return "Authentication failed"
default:
return "Upstream request failed"
}
}
@@ -0,0 +1,99 @@
package service
import (
"context"
"errors"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
)
var errKiroCooldownStoreUnavailable = errors.New("kiro cooldown store unavailable")
type KiroCooldownStore interface {
ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error)
MarkSuccess(ctx context.Context, tokenKey string) error
Mark429(ctx context.Context, tokenKey string) (time.Duration, error)
MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error)
GetState(ctx context.Context, tokenKey string) (*kirocooldown.State, error)
ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error)
}
func asKiroCooldownFailoverError(err error) *UpstreamFailoverError {
if err == nil {
return nil
}
var cooldownErr *kirocooldown.Error
if !errors.As(err, &cooldownErr) {
return nil
}
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: []byte(cooldownErr.Error()),
}
}
func (s *GatewayService) checkAndWaitKiroCooldown(ctx context.Context, tokenKey string) error {
if s == nil || s.kiroCooldownStore == nil {
return errKiroCooldownStoreUnavailable
}
waitFor, err := s.kiroCooldownStore.ReserveRequest(ctx, tokenKey)
if err != nil {
return err
}
if waitFor <= 0 {
return nil
}
timer := time.NewTimer(waitFor)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ctx.Err()
case <-timer.C:
return nil
}
}
func (s *GatewayService) markKiroSuccess(ctx context.Context, tokenKey string) error {
if s == nil || s.kiroCooldownStore == nil {
return errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.MarkSuccess(ctx, tokenKey)
}
func (s *GatewayService) markKiro429(ctx context.Context, tokenKey string) (time.Duration, error) {
if s == nil || s.kiroCooldownStore == nil {
return 0, errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.Mark429(ctx, tokenKey)
}
func (s *GatewayService) markKiroSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
if s == nil || s.kiroCooldownStore == nil {
return 0, errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.MarkSuspended(ctx, tokenKey)
}
func (s *GatewayService) getKiroCooldownState(ctx context.Context, tokenKey string) (*kirocooldown.State, error) {
if s == nil || s.kiroCooldownStore == nil {
return nil, errKiroCooldownStoreUnavailable
}
return s.kiroCooldownStore.GetState(ctx, tokenKey)
}
func kiroRuntimeStateSnapshot(state *kirocooldown.State) (string, string, *time.Time) {
if state == nil || !state.Active {
return "", "", nil
}
resetAt := state.CooldownUntil
switch state.Reason {
case kirocooldown.CooldownReasonSuspended:
return "suspended", state.Reason, &resetAt
default:
return "cooldown", state.Reason, &resetAt
}
}
@@ -0,0 +1,192 @@
//go:build integration
package service
import (
"context"
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const kiroCooldownRedisImageTag = "redis:8.4-alpine"
func TestRedisKiroCooldownStoreSharesCooldownAcrossInstances(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
storeA := kirocooldown.NewStore(rdb)
storeB := kirocooldown.NewStore(rdb)
cooldown, err := storeA.Mark429(ctx, "token-shared")
require.NoError(t, err)
require.Equal(t, time.Minute, cooldown)
wait, err := storeB.ReserveRequest(ctx, "token-shared")
require.Zero(t, wait)
require.Error(t, err)
require.Contains(t, err.Error(), kirocooldown.CooldownReason429)
require.NoError(t, storeB.MarkSuccess(ctx, "token-shared"))
wait, err = storeA.ReserveRequest(ctx, "token-shared")
require.NoError(t, err)
require.GreaterOrEqual(t, wait, 0*time.Second)
}
func TestRedisKiroCooldownStoreSharesReservationAcrossInstances(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
storeA := kirocooldown.NewStore(rdb)
storeB := kirocooldown.NewStore(rdb)
wait, err := storeA.ReserveRequest(ctx, "token-rate")
require.NoError(t, err)
require.Zero(t, wait)
wait, err = storeB.ReserveRequest(ctx, "token-rate")
require.NoError(t, err)
require.Greater(t, wait, 0*time.Millisecond)
require.LessOrEqual(t, wait, kirocooldown.MaxRequestInterval)
}
func TestRedisKiroCooldownStoreSharesSuspendedStateAcrossInstances(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
storeA := kirocooldown.NewStore(rdb)
storeB := kirocooldown.NewStore(rdb)
cooldown, err := storeA.MarkSuspended(ctx, "token-suspended")
require.NoError(t, err)
require.Equal(t, kirocooldown.LongCooldown, cooldown)
wait, err := storeB.ReserveRequest(ctx, "token-suspended")
require.Zero(t, wait)
require.Error(t, err)
require.Contains(t, err.Error(), kirocooldown.CooldownReasonSuspended)
}
func TestRedisKiroCooldownStoreSuspendedResetsFailCount(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
store := kirocooldown.NewStore(rdb)
_, err := store.Mark429(ctx, "token-reset")
require.NoError(t, err)
_, err = store.Mark429(ctx, "token-reset")
require.NoError(t, err)
cooldown, err := store.MarkSuspended(ctx, "token-reset")
require.NoError(t, err)
require.Equal(t, kirocooldown.LongCooldown, cooldown)
cooldown, err = store.Mark429(ctx, "token-reset")
require.NoError(t, err)
require.Equal(t, time.Minute, cooldown)
}
func TestRedisKiroCooldownStoreReserveDifferentTokenIgnoresOldCooldown(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
store := kirocooldown.NewStore(rdb)
_, err := store.Mark429(ctx, "token-old")
require.NoError(t, err)
wait, err := store.ReserveRequest(ctx, "token-new")
require.NoError(t, err)
require.Zero(t, wait)
}
func TestRedisKiroCooldownStoreUsesExpectedTTLs(t *testing.T) {
ctx := context.Background()
rdb := startKiroCooldownRedis(t, ctx)
store := kirocooldown.NewStore(rdb)
_, err := store.ReserveRequest(ctx, "token-ttl-active")
require.NoError(t, err)
activeTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-active")).Result()
require.NoError(t, err)
require.Greater(t, activeTTL, 0*time.Second)
require.LessOrEqual(t, activeTTL, kirocooldown.ActiveTTL())
_, err = store.MarkSuspended(ctx, "token-ttl-state")
require.NoError(t, err)
stateTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-state")).Result()
require.NoError(t, err)
require.Greater(t, stateTTL, 24*time.Hour)
require.LessOrEqual(t, stateTTL, kirocooldown.StateTTL())
}
func startKiroCooldownRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
ensureKiroCooldownDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, kiroCooldownRedisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
host, err := redisContainer.Host(ctx)
require.NoError(t, err)
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", host, port.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}
func ensureKiroCooldownDockerAvailable(t *testing.T) {
t.Helper()
if kiroCooldownDockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过依赖 testcontainers 的 Kiro cooldown 集成测试")
}
func kiroCooldownDockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func kiroCooldownUserHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}
@@ -0,0 +1,583 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type stubKiroCooldownStore struct {
reserveWait time.Duration
reserveErr error
successErr error
mark429TTL time.Duration
mark429Err error
suspendedTTL time.Duration
suspendedErr error
state *kirocooldown.State
stateErr error
clearCalled bool
clearKeys []string
clearResult bool
clearErr error
}
type recordingKiroTempUnschedRepo struct {
mockAccountRepoForGemini
called bool
id int64
until time.Time
reason string
rateCalled bool
rateID int64
rateLimitReset time.Time
rateLimitedCall int
}
func (r *recordingKiroTempUnschedRepo) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
r.called = true
r.id = id
r.until = until
r.reason = reason
return nil
}
func (r *recordingKiroTempUnschedRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
r.rateCalled = true
r.rateID = id
r.rateLimitReset = resetAt
r.rateLimitedCall++
return nil
}
type recordingKiroErrorRepo struct {
recordingKiroTempUnschedRepo
setErrorCalls int
errorID int64
errorMsg string
}
func (r *recordingKiroErrorRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.errorID = id
r.errorMsg = errorMsg
return nil
}
func (s *stubKiroCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
return s.reserveWait, s.reserveErr
}
func (s *stubKiroCooldownStore) MarkSuccess(context.Context, string) error {
return s.successErr
}
func (s *stubKiroCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
return s.mark429TTL, s.mark429Err
}
func (s *stubKiroCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
return s.suspendedTTL, s.suspendedErr
}
func (s *stubKiroCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
if s.clearCalled && s.clearResult {
return nil, nil
}
return s.state, s.stateErr
}
func (s *stubKiroCooldownStore) ClearEarliestTransientCooldown(_ context.Context, tokenKeys []string) (bool, error) {
s.clearCalled = true
s.clearKeys = append([]string(nil), tokenKeys...)
return s.clearResult, s.clearErr
}
func TestCalculateKiro429Cooldown(t *testing.T) {
require.Equal(t, time.Minute, kirocooldown.Calculate429Cooldown(0))
require.Equal(t, 2*time.Minute, kirocooldown.Calculate429Cooldown(1))
require.Equal(t, 4*time.Minute, kirocooldown.Calculate429Cooldown(2))
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(3))
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(10))
}
func TestGatewayServiceCheckAndWaitKiroCooldownReturnsNilWithoutWait(t *testing.T) {
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{},
}
require.NoError(t, svc.checkAndWaitKiroCooldown(context.Background(), "token1"))
}
func TestGatewayServiceCheckAndWaitKiroCooldownPropagatesReserveError(t *testing.T) {
expected := errors.New("redis unavailable")
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{reserveErr: expected},
}
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
require.ErrorIs(t, err, expected)
}
func TestGatewayServiceCheckAndWaitKiroCooldownRequiresStore(t *testing.T) {
svc := &GatewayService{}
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
require.ErrorIs(t, err, errKiroCooldownStoreUnavailable)
}
func TestGatewayServiceCheckAndWaitKiroCooldownWaitsAndHonorsContext(t *testing.T) {
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{reserveWait: 200 * time.Millisecond},
}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
err := svc.checkAndWaitKiroCooldown(ctx, "token1")
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestAsKiroCooldownFailoverError(t *testing.T) {
err := kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429)
var cooldownErr *kirocooldown.Error
require.ErrorAs(t, err, &cooldownErr)
failoverErr := asKiroCooldownFailoverError(err)
require.NotNil(t, failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
require.False(t, failoverErr.RetryableOnSameAccount)
}
func TestAsKiroCooldownFailoverErrorIgnoresNonCooldownErrors(t *testing.T) {
require.Nil(t, asKiroCooldownFailoverError(errors.New("redis unavailable")))
}
func TestGatewayServiceTryRecoverKiroCooldownPoolClearsOnlyTransientCooldown(t *testing.T) {
store := &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(time.Minute),
Remaining: time.Minute,
},
clearResult: true,
}
svc := &GatewayService{kiroCooldownStore: store}
accounts := []Account{
{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
},
}
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
require.True(t, recovered)
require.True(t, store.clearCalled)
require.Len(t, store.clearKeys, 1)
require.Equal(t, buildKiroAccountKey(&accounts[0]), store.clearKeys[0])
}
func TestGatewayServiceTryRecoverKiroCooldownPoolSkipsSuspended(t *testing.T) {
store := &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReasonSuspended,
CooldownUntil: time.Now().Add(time.Hour),
Remaining: time.Hour,
},
clearResult: true,
}
svc := &GatewayService{kiroCooldownStore: store}
accounts := []Account{
{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
},
}
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
require.False(t, recovered)
require.False(t, store.clearCalled)
}
func TestSelectAccountWithLoadAwarenessRecoversKiroCooldownPool(t *testing.T) {
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
account := Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
store := &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(time.Minute),
Remaining: time.Minute,
},
clearResult: true,
}
svc := &GatewayService{
accountRepo: &mockAccountRepoForGemini{accounts: []Account{account}},
concurrencyService: NewConcurrencyService(&mockConcurrencyCache{}),
cfg: cfg,
kiroCooldownStore: store,
tlsFPProfileService: &TLSFingerprintProfileService{},
}
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformKiro)
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "", nil, "", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, account.ID, result.Account.ID)
require.True(t, store.clearCalled)
require.Equal(t, []string{buildKiroAccountKey(&account)}, store.clearKeys)
}
func TestClassifyKiroHTTPErrorMonthlyRequestCount(t *testing.T) {
tests := []string{
`{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
`{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}`,
`API returned 402: {"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
}
for _, body := range tests {
classification := classifyKiroHTTPError(http.StatusPaymentRequired, body)
require.Equal(t, kiroErrorMonthlyRequest, classification.Category)
}
}
func TestClassifyKiroHTTPErrorPlain402IsTransient(t *testing.T) {
classification := classifyKiroHTTPError(http.StatusPaymentRequired, `{"message":"payment required"}`)
require.Equal(t, kiroErrorUpstreamTransient, classification.Category)
}
func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{
reserveErr: kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429),
},
}
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
require.False(t, failoverErr.RetryableOnSameAccount)
}
func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
},
}
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
},
}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
payload, err := createTestPayload("claude-opus-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
require.Len(t, upstream.requests, 1)
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
require.NoError(t, readErr)
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
}
func TestHandleKiroHTTPErrorOAuthInvalidModelRateLimitsAndFailovers(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Anthropic-Beta", "context-1m-2025-08-07")
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Name: "kiro-oauth",
}
repo := &recordingKiroTempUnschedRepo{}
svc := &GatewayService{accountRepo: repo}
requestBody := []byte(`{"model":"claude-opus-4-7","tools":[{"name":"search"}],"thinking":{"type":"adaptive"}}`)
resp := newJSONResponse(http.StatusBadRequest, `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`)
resp.Header.Set("x-request-id", "req-invalid-model")
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", requestBody)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "Invalid model")
require.False(t, failoverErr.RetryableOnSameAccount)
require.False(t, repo.called)
require.True(t, repo.rateCalled)
require.Equal(t, account.ID, repo.rateID)
require.WithinDuration(t, time.Now().Add(kiroInvalidModelTempUnschedDuration), repo.rateLimitReset, 5*time.Second)
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, PlatformKiro, events[0].Platform)
require.Equal(t, account.ID, events[0].AccountID)
require.Equal(t, account.Name, events[0].AccountName)
require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
require.Equal(t, "req-invalid-model", events[0].UpstreamRequestID)
require.Equal(t, "failover", events[0].Kind)
require.Equal(t, "claude-opus-4-7", events[0].RequestedModel)
require.Equal(t, "claude-opus-4.6", events[0].MappedModel)
require.Equal(t, "claude-opus-4.6", events[0].KiroModelID)
require.True(t, events[0].HasTools)
require.True(t, events[0].HasAdaptiveThinking)
require.True(t, events[0].HasContext1MBeta)
}
func TestHandleKiroHTTPErrorAPIKeyInvalidModelDoesNotFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
account := &Account{
ID: 43,
Platform: PlatformKiro,
Type: AccountTypeAPIKey,
}
repo := &recordingKiroTempUnschedRepo{}
svc := &GatewayService{accountRepo: repo}
resp := newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model. Please select a different model to continue."}`)
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", []byte(`{"model":"claude-opus-4-7"}`))
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.NotErrorAs(t, err, &failoverErr)
require.False(t, repo.called)
require.False(t, repo.rateCalled)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestNextKiroMonthlyResetUTC(t *testing.T) {
tests := []struct {
name string
now time.Time
want time.Time
}{
{
name: "middle of month",
now: time.Date(2026, time.April, 27, 10, 30, 45, 123, time.FixedZone("CST", 8*3600)),
want: time.Date(2026, time.May, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "december rolls year",
now: time.Date(2026, time.December, 31, 23, 59, 59, 0, time.UTC),
want: time.Date(2027, time.January, 1, 0, 0, 0, 0, time.UTC),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, nextKiroMonthlyResetUTC(tt.now))
})
}
}
func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailovers(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
repo := &recordingKiroTempUnschedRepo{}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusPaymentRequired, `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`),
},
}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
payload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "MONTHLY_REQUEST_COUNT")
require.False(t, repo.called)
require.True(t, repo.rateCalled)
require.Equal(t, account.ID, repo.rateID)
require.Equal(t, nextKiroMonthlyResetUTC(time.Now()), repo.rateLimitReset)
}
func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
repo := &recordingKiroTempUnschedRepo{}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusPaymentRequired, `{"message":"payment required"}`),
},
}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
}
payload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
require.False(t, repo.called)
require.False(t, repo.rateCalled)
}
func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedule(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"refresh_token": "old-refresh",
},
}
repo := &recordingKiroErrorRepo{
recordingKiroTempUnschedRepo: recordingKiroTempUnschedRepo{
mockAccountRepoForGemini: mockAccountRepoForGemini{
accountsByID: map[int64]*Account{account.ID: account},
},
},
}
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"message":"token expired"}`),
},
}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
svc := &GatewayService{
accountRepo: repo,
httpUpstream: upstream,
kiroCooldownStore: &stubKiroCooldownStore{},
tlsFPProfileService: &TLSFingerprintProfileService{},
kiroTokenProvider: provider,
}
payload, err := createTestPayload("claude-sonnet-4-6")
require.NoError(t, err)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, account.ID, repo.errorID)
require.Contains(t, repo.errorMsg, "invalid_grant")
require.False(t, repo.called, "non-retryable refresh errors should not mark temporary unschedulable")
}
func TestGatewayServiceIsAccountSchedulableForSelectionSkipsActiveKiroCooldown(t *testing.T) {
now := time.Now().Add(2 * time.Minute)
svc := &GatewayService{
kiroCooldownStore: &stubKiroCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: now,
Remaining: 2 * time.Minute,
},
},
}
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
}
require.False(t, svc.isAccountSchedulableForSelection(account))
}
@@ -0,0 +1,221 @@
package service
import (
"context"
"errors"
"strconv"
"strings"
"time"
)
const (
kiroTokenRefreshSkew = 3 * time.Minute
kiroTokenCacheSkew = 5 * time.Minute
)
type KiroTokenCache = GeminiTokenCache
type kiroAccountTokenRefresher interface {
RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error)
BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any
}
type KiroTokenProvider struct {
accountRepo AccountRepository
tokenCache KiroTokenCache
kiroOAuthService kiroAccountTokenRefresher
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewKiroTokenProvider(
accountRepo AccountRepository,
tokenCache KiroTokenCache,
kiroOAuthService *KiroOAuthService,
) *KiroTokenProvider {
return &KiroTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
kiroOAuthService: kiroOAuthService,
refreshPolicy: GeminiProviderRefreshPolicy(),
}
}
func (p *KiroTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
func (p *KiroTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *KiroTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return "", errors.New("not a kiro oauth account")
}
cacheKey := KiroTokenCacheKey(account)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= kiroTokenRefreshSkew
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, kiroTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > kiroTokenCacheSkew:
ttl = until - kiroTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
func KiroTokenCacheKey(account *Account) string {
if account == nil {
return "kiro:account:0"
}
if clientIDHash := strings.TrimSpace(account.GetCredential("client_id_hash")); clientIDHash != "" {
return "kiro:" + clientIDHash
}
if clientID := strings.TrimSpace(account.GetCredential("client_id")); clientID != "" {
return "kiro:client:" + clientID
}
return "kiro:account:" + strconv.FormatInt(account.ID, 10)
}
func (p *KiroTokenProvider) ForceRefreshAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return "", errors.New("not a kiro oauth account")
}
if p.kiroOAuthService == nil {
return "", errors.New("kiro oauth service is nil")
}
cacheKey := KiroTokenCacheKey(account)
lockHeld := false
if p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
lockHeld = true
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
if p.accountRepo != nil {
if latestAccount, err := p.accountRepo.GetByID(ctx, account.ID); err == nil && latestAccount != nil {
account = latestAccount
}
}
tokenInfo, err := p.kiroOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
if !lockHeld {
if latestAccount, stale := CheckTokenVersion(ctx, account, p.accountRepo); stale && latestAccount != nil {
account = latestAccount
if accessToken := strings.TrimSpace(account.GetCredential("access_token")); accessToken != "" {
_ = p.cacheAccessToken(ctx, account, accessToken)
return accessToken, nil
}
}
}
if isNonRetryableRefreshError(err) && p.accountRepo != nil {
errorMsg := "Token refresh failed (non-retryable): " + err.Error()
_ = p.accountRepo.SetError(ctx, account.ID, errorMsg)
}
return "", err
}
newCredentials := MergeCredentials(account.Credentials, p.kiroOAuthService.BuildAccountCredentials(tokenInfo))
newCredentials["_token_version"] = time.Now().UnixMilli()
if err := persistAccountCredentials(ctx, p.accountRepo, account, newCredentials); err != nil {
return "", err
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken == "" {
accessToken = strings.TrimSpace(tokenInfo.AccessToken)
}
if accessToken == "" {
return "", errors.New("access_token not found after kiro refresh")
}
if err := p.cacheAccessToken(ctx, account, accessToken); err != nil {
return "", err
}
return accessToken, nil
}
func (p *KiroTokenProvider) cacheAccessToken(ctx context.Context, account *Account, accessToken string) error {
if p.tokenCache == nil || account == nil || strings.TrimSpace(accessToken) == "" {
return nil
}
ttl := 30 * time.Minute
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > kiroTokenCacheSkew:
ttl = until - kiroTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
return p.tokenCache.SetAccessToken(ctx, KiroTokenCacheKey(account), accessToken, ttl)
}
@@ -0,0 +1,112 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type kiroTokenProviderRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setErrorID int64
setErrorMsg string
}
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
type kiroTokenProviderSequenceRepo struct {
kiroTokenProviderRepo
accounts []*Account
reads int
}
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
if len(r.accounts) == 0 {
return nil, errors.New("account not found")
}
idx := r.reads
if idx >= len(r.accounts) {
idx = len(r.accounts) - 1
}
r.reads++
return r.accounts[idx], nil
}
type stubKiroAccountTokenRefresher struct {
tokenInfo *KiroTokenInfo
err error
}
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
if s.err != nil {
return nil, s.err
}
return s.tokenInfo, nil
}
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
if tokenInfo == nil {
return nil
}
return map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": tokenInfo.ExpiresAt,
}
}
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
account := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "old-refresh"},
}
repo := &kiroTokenProviderRepo{
mockAccountRepoForGemini: mockAccountRepoForGemini{
accountsByID: map[int64]*Account{account.ID: account},
},
}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
require.Error(t, err)
require.Empty(t, token)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
require.Contains(t, repo.setErrorMsg, "invalid_grant")
}
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
usedAccount := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "old-refresh"},
}
latestAccount := &Account{
ID: 42,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
}
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
provider := NewKiroTokenProvider(repo, nil, nil)
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
require.NoError(t, err)
require.Equal(t, "fresh-access", token)
require.Equal(t, 0, repo.setErrorCalls)
}
@@ -0,0 +1,47 @@
package service
import (
"context"
"time"
)
const kiroRefreshWindow = 15 * time.Minute
type KiroTokenRefresher struct {
kiroOAuthService *KiroOAuthService
}
func NewKiroTokenRefresher(kiroOAuthService *KiroOAuthService) *KiroTokenRefresher {
return &KiroTokenRefresher{
kiroOAuthService: kiroOAuthService,
}
}
func (r *KiroTokenRefresher) CacheKey(account *Account) string {
return KiroTokenCacheKey(account)
}
func (r *KiroTokenRefresher) CanRefresh(account *Account) bool {
return account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth
}
func (r *KiroTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
return time.Until(*expiresAt) <= kiroRefreshWindow
}
func (r *KiroTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.kiroOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
newCredentials := r.kiroOAuthService.BuildAccountCredentials(tokenInfo)
return MergeCredentials(account.Credentials, newCredentials), nil
}
@@ -0,0 +1,608 @@
package service
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/google/uuid"
)
const (
kiroUsageOrigin = "AI_EDITOR"
kiroUsageResourceType = "AGENTIC_REQUEST"
kiroDefaultRegion = "us-east-1"
)
var resolveKiroRuntimeEndpoint = kiroRuntimeEndpoint
type kiroUsageLimitsResponse struct {
NextDateReset any `json:"nextDateReset"`
OverageConfiguration kiroOverageConfiguration `json:"overageConfiguration"`
SubscriptionInfo kiroSubscriptionInfo `json:"subscriptionInfo"`
UsageBreakdownList []kiroUsageBreakdown `json:"usageBreakdownList"`
}
type kiroOverageConfiguration struct {
OverageStatus string `json:"overageStatus"`
}
type kiroSubscriptionInfo struct {
SubscriptionTitle string `json:"subscriptionTitle"`
Type string `json:"type"`
}
type kiroUsageBreakdown struct {
Currency string `json:"currency"`
CurrentOverages *float64 `json:"currentOverages"`
CurrentOveragesWithPrecision *float64 `json:"currentOveragesWithPrecision"`
CurrentUsage *float64 `json:"currentUsage"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
DisplayName string `json:"displayName"`
DisplayNamePlural string `json:"displayNamePlural"`
FreeTrialInfo *kiroFreeTrialInfo `json:"freeTrialInfo"`
NextDateReset any `json:"nextDateReset"`
OverageCharges *float64 `json:"overageCharges"`
ResourceType string `json:"resourceType"`
UsageLimit *float64 `json:"usageLimit"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
}
type kiroFreeTrialInfo struct {
CurrentUsage *float64 `json:"currentUsage"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
FreeTrialExpiry any `json:"freeTrialExpiry"`
FreeTrialStatus string `json:"freeTrialStatus"`
UsageLimit *float64 `json:"usageLimit"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
}
type kiroUsageHTTPError struct {
StatusCode int
Body string
}
func (e *kiroUsageHTTPError) Error() string {
if e == nil {
return "kiro usage request failed"
}
if strings.TrimSpace(e.Body) == "" {
return fmt.Sprintf("kiro usage request failed (status %d)", e.StatusCode)
}
return fmt.Sprintf("kiro usage request failed (status %d): %s", e.StatusCode, e.Body)
}
func (s *AccountUsageService) getKiroUsage(ctx context.Context, account *Account, source string, forceRefresh bool) (*UsageInfo, error) {
now := time.Now()
if account == nil {
return &UsageInfo{
Source: source,
UpdatedAt: &now,
Error: "account is nil",
ErrorCode: errorCodeNetworkError,
}, nil
}
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return &UsageInfo{
Source: source,
UpdatedAt: &now,
}, nil
}
cached, hasCached := s.getCachedKiroUsage(account.ID)
if hasCached && (cached.ErrorCode != "" || cached.Error != "") {
cached.Source = source
s.attachKiroRuntimeState(ctx, account, cached)
return cached, nil
}
if !forceRefresh && hasCached {
cached.Source = source
s.attachKiroRuntimeState(ctx, account, cached)
return cached, nil
}
flightKey := fmt.Sprintf("kiro-usage:%d", account.ID)
result, fetchErr, _ := s.cache.kiroUsageFlight.Do(flightKey, func() (any, error) {
if !forceRefresh {
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
return usage, nil
}
}
usage, err := s.fetchAndCacheKiroUsage(ctx, account, source)
if err != nil {
return nil, err
}
return usage, nil
})
if fetchErr == nil {
if usage, ok := result.(*UsageInfo); ok && usage != nil {
usage.Source = source
s.attachKiroRuntimeState(ctx, account, usage)
if source == "active" {
s.tryClearRecoverableAccountError(ctx, account)
}
return usage, nil
}
}
degraded := buildKiroDegradedUsage(fetchErr)
degraded.Source = source
if hasCached {
cached.Error = degraded.Error
cached.ErrorCode = degraded.ErrorCode
cached.NeedsReauth = degraded.NeedsReauth
cached.KiroQuotaState = degraded.KiroQuotaState
cached.KiroQuotaReason = degraded.KiroQuotaReason
cached.KiroQuotaResetAt = degraded.KiroQuotaResetAt
cached.Source = source
s.attachKiroRuntimeState(ctx, account, cached)
return cached, nil
}
s.storeKiroUsageSnapshot(account.ID, degraded)
s.attachKiroRuntimeState(ctx, account, degraded)
return degraded, nil
}
func (s *AccountUsageService) fetchAndCacheKiroUsage(ctx context.Context, account *Account, source string) (*UsageInfo, error) {
token := strings.TrimSpace(account.GetCredential("access_token"))
if token == "" {
return nil, fmt.Errorf("no access token available")
}
region := kiroAPIRegion(account)
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
resp, err := s.requestKiroUsageLimits(ctx, account, region, profileArn, token)
if err != nil {
return nil, err
}
usage := mapKiroUsageToInfo(resp)
usage.Source = source
s.storeKiroUsageSnapshot(account.ID, usage)
return usage, nil
}
func (s *AccountUsageService) storeKiroUsageSnapshot(accountID int64, usage *UsageInfo) {
if s == nil || s.cache == nil || accountID <= 0 || usage == nil {
return
}
now := time.Now()
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
s.cache.kiroUsageCache.Store(accountID, &kiroUsageCache{
usageInfo: cloneUsageInfo(usage),
timestamp: now,
})
}
func (s *AccountUsageService) getCachedKiroUsage(accountID int64) (*UsageInfo, bool) {
if s == nil || s.cache == nil || accountID <= 0 {
return nil, false
}
cached, ok := s.cache.kiroUsageCache.Load(accountID)
if !ok {
return nil, false
}
cache, ok := cached.(*kiroUsageCache)
if !ok || cache == nil || cache.usageInfo == nil {
return nil, false
}
if time.Since(cache.timestamp) >= kiroCacheTTL(cache.usageInfo) {
return nil, false
}
return cloneUsageInfo(cache.usageInfo), true
}
func kiroCacheTTL(info *UsageInfo) time.Duration {
if info == nil {
return kiroUsageErrorTTL
}
if info.ErrorCode != "" || info.Error != "" {
return kiroUsageErrorTTL
}
return apiCacheTTL
}
func cloneUsageInfo(info *UsageInfo) *UsageInfo {
if info == nil {
return nil
}
cloned := *info
return &cloned
}
func (s *AccountUsageService) requestKiroUsageLimits(ctx context.Context, account *Account, region, profileArn, token string) (*kiroUsageLimitsResponse, error) {
endpoint := resolveKiroRuntimeEndpoint(region)
reqURL, err := url.Parse(endpoint + "/getUsageLimits")
if err != nil {
return nil, fmt.Errorf("build kiro usage url failed: %w", err)
}
q := reqURL.Query()
q.Set("origin", kiroUsageOrigin)
if profileArn = strings.TrimSpace(profileArn); profileArn != "" {
q.Set("profileArn", profileArn)
}
q.Set("resourceType", kiroUsageResourceType)
reqURL.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("create kiro usage request failed: %w", err)
}
s.applyKiroRuntimeHeaders(req, account, token)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: accountProxyURL(account),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
})
if err != nil {
return nil, fmt.Errorf("create kiro usage client failed: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("kiro usage request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("read kiro usage response failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
}
var parsed kiroUsageLimitsResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("decode kiro usage response failed: %w", err)
}
return &parsed, nil
}
func (s *AccountUsageService) applyKiroRuntimeHeaders(req *http.Request, account *Account, token string) {
if req == nil {
return
}
accountKey := buildKiroAccountKey(account)
machineID := buildKiroMachineID(account)
req.Header.Set("Accept", "*/*")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
if account == nil {
return
}
applyKiroConditionalHeaders(req, account)
}
func accountProxyURL(account *Account) string {
if account == nil || account.ProxyID == nil || account.Proxy == nil {
return ""
}
return account.Proxy.URL()
}
func kiroRuntimeEndpoint(region string) string {
region = strings.TrimSpace(region)
if region == "" {
region = kiroDefaultRegion
}
switch region {
case "us-east-1":
return "https://q.us-east-1.amazonaws.com"
case "eu-central-1":
return "https://q.eu-central-1.amazonaws.com"
case "us-gov-east-1":
return "https://q-fips.us-gov-east-1.amazonaws.com"
case "us-gov-west-1":
return "https://q-fips.us-gov-west-1.amazonaws.com"
case "us-iso-east-1":
return "https://q.us-iso-east-1.c2s.ic.gov"
case "us-isob-east-1":
return "https://q.us-isob-east-1.sc2s.sgov.gov"
case "us-isof-south-1":
return "https://q.us-isof-south-1.csp.hci.ic.gov"
case "us-isof-east-1":
return "https://q.us-isof-east-1.csp.hci.ic.gov"
default:
if strings.HasPrefix(region, "us-gov-") {
return "https://q-fips." + region + ".amazonaws.com"
}
return "https://q." + region + ".amazonaws.com"
}
}
func isLoopbackEndpoint(raw string) bool {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return false
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func mapKiroUsageToInfo(resp *kiroUsageLimitsResponse) *UsageInfo {
now := time.Now()
if resp == nil {
return &UsageInfo{UpdatedAt: &now}
}
info := &UsageInfo{
UpdatedAt: &now,
KiroSubscriptionName: strings.TrimSpace(resp.SubscriptionInfo.SubscriptionTitle),
KiroSubscriptionType: strings.TrimSpace(resp.SubscriptionInfo.Type),
KiroOveragesEnabled: strings.EqualFold(strings.TrimSpace(resp.OverageConfiguration.OverageStatus), "ENABLED"),
}
resetAt := parseKiroTimestamp(resp.NextDateReset)
if credit := selectKiroCreditBreakdown(resp.UsageBreakdownList); credit != nil {
if breakdownReset := parseKiroTimestamp(credit.NextDateReset); breakdownReset != nil {
resetAt = breakdownReset
}
info.KiroCredit = &KiroCreditProgress{
CurrentUsage: selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage),
UsageLimit: selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit),
PercentageUsed: percentageOrZero(selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage), selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit)),
}
info.KiroOverage = &KiroOverageInfo{
CurrentOverages: selectKiroFloat(credit.CurrentOveragesWithPrecision, credit.CurrentOverages),
OverageCharges: selectKiroFloat(credit.OverageCharges, nil),
CurrencyCode: strings.TrimSpace(credit.Currency),
CurrencySymbol: kiroCurrencySymbol(strings.TrimSpace(credit.Currency)),
}
if ft := credit.FreeTrialInfo; ft != nil && strings.EqualFold(strings.TrimSpace(ft.FreeTrialStatus), "ACTIVE") {
expiry := parseKiroTimestamp(ft.FreeTrialExpiry)
daysRemaining := 0
if expiry != nil {
daysRemaining = int(time.Until(*expiry).Hours() / 24)
if time.Until(*expiry)%(24*time.Hour) != 0 {
daysRemaining++
}
if daysRemaining < 0 {
daysRemaining = 0
}
}
current := selectKiroFloat(ft.CurrentUsageWithPrecision, ft.CurrentUsage)
limit := selectKiroFloat(ft.UsageLimitWithPrecision, ft.UsageLimit)
info.KiroBonus = &KiroCreditProgress{
CurrentUsage: current,
UsageLimit: limit,
PercentageUsed: percentageOrZero(current, limit),
DaysRemaining: daysRemaining,
ExpiryDate: expiry,
}
}
}
info.KiroResetAt = resetAt
setKiroQuotaStateFromUsage(info)
return info
}
func selectKiroCreditBreakdown(items []kiroUsageBreakdown) *kiroUsageBreakdown {
for i := range items {
if strings.EqualFold(strings.TrimSpace(items[i].ResourceType), "CREDIT") {
return &items[i]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func selectKiroFloat(preferred *float64, fallback *float64) float64 {
switch {
case preferred != nil:
return *preferred
case fallback != nil:
return *fallback
default:
return 0
}
}
func percentageOrZero(current, limit float64) float64 {
if limit <= 0 {
return 0
}
return current / limit * 100
}
func parseKiroTimestamp(raw any) *time.Time {
if raw == nil {
return nil
}
switch v := raw.(type) {
case string:
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, trimmed); err == nil {
return &parsed
}
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
return unixishToTime(i)
}
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
return unixishFloatToTime(f)
}
case float64:
return unixishFloatToTime(v)
case int64:
return unixishToTime(v)
case int:
return unixishToTime(int64(v))
case json.Number:
if i, err := v.Int64(); err == nil {
return unixishToTime(i)
}
if f, err := v.Float64(); err == nil {
return unixishFloatToTime(f)
}
}
return nil
}
func unixishFloatToTime(v float64) *time.Time {
if v <= 0 {
return nil
}
if v >= 1e12 {
t := time.UnixMilli(int64(v))
return &t
}
t := time.Unix(int64(v), 0)
return &t
}
func unixishToTime(v int64) *time.Time {
if v <= 0 {
return nil
}
if v >= 1e12 {
t := time.UnixMilli(v)
return &t
}
t := time.Unix(v, 0)
return &t
}
func kiroCurrencySymbol(code string) string {
switch strings.ToUpper(strings.TrimSpace(code)) {
case "USD":
return "$"
default:
return ""
}
}
func buildKiroDegradedUsage(err error) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
Error: "usage API error",
ErrorCode: errorCodeNetworkError,
}
if err == nil {
return info
}
info.Error = fmt.Sprintf("usage API error: %v", err)
classification := classifyKiroError(err)
switch classification.Category {
case kiroErrorAuthError:
info.ErrorCode = errorCodeUnauthenticated
info.NeedsReauth = true
case kiroErrorRateLimited:
info.ErrorCode = errorCodeRateLimited
case kiroErrorQuotaExhausted:
info.ErrorCode = errorCodeNetworkError
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
info.KiroQuotaReason = classification.Message
case kiroErrorOverageExhausted:
info.ErrorCode = errorCodeNetworkError
info.KiroQuotaState = kiroQuotaStateOverageExhausted
info.KiroQuotaReason = classification.Message
case kiroErrorSuspended, kiroErrorUsageForbidden, kiroErrorProfileError:
info.ErrorCode = errorCodeForbidden
default:
info.ErrorCode = errorCodeNetworkError
}
return info
}
func (s *AccountUsageService) attachKiroRuntimeState(ctx context.Context, account *Account, usage *UsageInfo) {
if s == nil || usage == nil || account == nil || account.Platform != PlatformKiro || s.kiroCooldownStore == nil {
return
}
usage.KiroRuntimeState = ""
usage.KiroRuntimeReason = ""
usage.KiroRuntimeResetAt = nil
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
if err != nil || state == nil {
return
}
usage.KiroRuntimeState, usage.KiroRuntimeReason, usage.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
}
func (s *AccountUsageService) EnrichAccountWithKiroRuntimeState(ctx context.Context, account *Account) {
if s == nil || account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return
}
account.KiroQuotaState = ""
account.KiroQuotaReason = ""
account.KiroQuotaResetAt = nil
account.KiroRuntimeState = ""
account.KiroRuntimeReason = ""
account.KiroRuntimeResetAt = nil
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
account.KiroQuotaState = usage.KiroQuotaState
account.KiroQuotaReason = usage.KiroQuotaReason
account.KiroQuotaResetAt = usage.KiroQuotaResetAt
}
if s.kiroCooldownStore == nil {
return
}
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
if err != nil || state == nil {
return
}
account.KiroRuntimeState, account.KiroRuntimeReason, account.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
}
func setKiroQuotaStateFromUsage(info *UsageInfo) {
if info == nil {
return
}
info.KiroQuotaState = ""
info.KiroQuotaReason = ""
info.KiroQuotaResetAt = nil
creditExhausted := false
if info.KiroCredit != nil && info.KiroCredit.UsageLimit > 0 {
creditExhausted = info.KiroCredit.CurrentUsage >= info.KiroCredit.UsageLimit
}
overageActive := info.KiroOverage != nil &&
(info.KiroOverage.CurrentOverages > 0 || info.KiroOverage.OverageCharges > 0)
switch {
case info.KiroOveragesEnabled && (overageActive || creditExhausted):
info.KiroQuotaState = kiroQuotaStateOverageActive
info.KiroQuotaReason = "overages_enabled"
info.KiroQuotaResetAt = info.KiroResetAt
case creditExhausted:
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
info.KiroQuotaReason = "credits_exhausted"
info.KiroQuotaResetAt = info.KiroResetAt
default:
info.KiroQuotaState = kiroQuotaStateNormal
}
}
+458
View File
@@ -0,0 +1,458 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
const kiroMaxWebSearchIterations = 5
var (
errKiroWebSearchFallback = errors.New("kiro web search fallback")
kiroWebSearchDescCache sync.Map
)
type kiroWebSearchExecution struct {
ResponseBody []byte
Usage ClaudeUsage
RequestID string
}
type kiroWebSearchHTTPError struct {
Response *http.Response
}
type kiroStreamChunkCollector struct {
chunks [][]byte
}
func (e *kiroWebSearchHTTPError) Error() string {
if e == nil || e.Response == nil {
return "kiro web search http error"
}
return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
}
func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
if len(p) > 0 {
w.chunks = append(w.chunks, append([]byte(nil), p...))
}
return len(p), nil
}
func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
collector := &kiroStreamChunkCollector{}
result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
if err != nil {
return nil, nil, err
}
return collector.chunks, result, nil
}
func writeSSEChunks(w io.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if len(chunk) == 0 {
continue
}
if _, err := w.Write(chunk); err != nil {
return err
}
}
return nil
}
func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
if strings.TrimSpace(msgID) == "" {
msgID = "msg_" + kiropkg.GenerateToolUseID()
}
if strings.TrimSpace(model) == "" {
model = "kiro"
}
payload, err := json.Marshal(map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": inputTokens,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
})
if err != nil {
return err
}
_, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
return err
}
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
) error {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return errKiroWebSearchFallback
}
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
if err != nil {
currentBody = anthropicBody
}
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
nextContentBlockIndex := 0
if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
return err
}
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
s.prefetchKiroWebSearchDescription(ctx, account, token)
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
if strings.TrimSpace(nextToken) != "" {
token = nextToken
}
if mcpErr != nil {
results = nil
}
if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
return err
}
nextContentBlockIndex += 2
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
if err != nil {
return errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return &kiroWebSearchHTTPError{Response: resp}
}
chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
defer func() { _ = resp.Body.Close() }()
return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
}()
if streamErr != nil {
return streamErr
}
analysis := kiropkg.AnalyzeBufferedStream(chunks)
if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
if err := writeSSEChunks(w, filtered); err != nil {
return err
}
if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
nextContentBlockIndex = maxIndex + 1
}
query = analysis.WebSearchQuery
if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
} else {
currentToolUseID = analysis.WebSearchToolUseID
}
continue
}
for _, chunk := range chunks {
adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
if !shouldForward {
continue
}
if _, err := w.Write(adjusted); err != nil {
return err
}
}
return nil
}
return fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return nil, errKiroWebSearchFallback
}
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
if err != nil {
currentBody = anthropicBody
}
inputTokens := estimateKiroInputTokens(anthropicBody)
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
searches := make([]kiropkg.SearchIndicator, 0, 2)
requestID := ""
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
s.prefetchKiroWebSearchDescription(ctx, account, token)
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
if strings.TrimSpace(nextToken) != "" {
token = nextToken
}
if mcpErr != nil {
results = nil
}
searches = append(searches, kiropkg.SearchIndicator{
ToolUseID: currentToolUseID,
Query: query,
Results: results,
})
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
if err != nil {
return nil, errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &kiroWebSearchHTTPError{Response: resp}
}
parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
defer func() { _ = resp.Body.Close() }()
return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
}()
if parseErr != nil {
return nil, parseErr
}
if requestID == "" {
requestID = buildKiroRequestID(resp)
}
nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
if injectErr == nil {
parseResult.ResponseBody = finalBody
}
return &kiroWebSearchExecution{
ResponseBody: parseResult.ResponseBody,
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
RequestID: requestID,
}, nil
}
query = nextQuery
if strings.TrimSpace(nextToolUseID) == "" {
nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
}
currentToolUseID = nextToolUseID
}
return nil, fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
kiropkg.SetCachedWebSearchDescription(desc)
}
return
}
reqBody, _ := json.Marshal(kiropkg.MCPRequest{
ID: "tools_list",
JSONRPC: "2.0",
Method: "tools/list",
})
resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
if err != nil || resp == nil {
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return
}
var result kiropkg.MCPResponse
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
return
}
for _, tool := range result.Result.Tools {
if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
kiroWebSearchDescCache.Store(endpoint, tool.Description)
kiropkg.SetCachedWebSearchDescription(tool.Description)
return
}
}
}
func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
if err != nil {
return nil, token, err
}
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
if err != nil {
return nil, nextToken, err
}
if resp == nil {
return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nextToken, err
}
if resp.StatusCode != http.StatusOK {
return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed kiropkg.MCPResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, nextToken, err
}
if parsed.Error != nil {
msg := "unknown error"
if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
msg = strings.TrimSpace(*parsed.Error.Message)
}
code := 0
if parsed.Error.Code != nil {
code = *parsed.Error.Code
}
return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
}
return kiropkg.ParseSearchResults(&parsed), nextToken, nil
}
func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
return kiropkg.MCPRequest{
ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
JSONRPC: "2.0",
Method: "tools/call",
Params: map[string]interface{}{
"name": "web_search",
"arguments": map[string]interface{}{
"query": query,
"_meta": map[string]interface{}{
"_isValid": true,
"_activePath": []string{"query"},
"_completedPaths": [][]string{{"query"}},
},
},
},
}
}
func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
currentToken := token
accountKey := buildKiroAccountKey(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
for attempt := 0; attempt < 3; attempt++ {
if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
return nil, currentToken, failoverErr
}
return nil, currentToken, err
}
req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
if err != nil {
return nil, currentToken, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return nil, currentToken, err
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, currentToken, readErr
}
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
return nil, currentToken, err
}
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
if s.kiroTokenProvider == nil {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr != nil {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, currentToken, sleepErr
}
continue
}
if resp.StatusCode == http.StatusTooManyRequests {
if _, err := s.markKiro429(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, currentToken, err
}
}
if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
if attempt < 2 {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, currentToken, sleepErr
}
continue
}
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, currentToken, err
}
}
return resp, currentToken, nil
}
return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
}
@@ -0,0 +1,28 @@
//go:build unit
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildKiroWebSearchMCPRequest_UsesUnderscoredMetaKeys(t *testing.T) {
req := buildKiroWebSearchMCPRequest("golang concurrency")
body, err := json.Marshal(req)
require.NoError(t, err)
require.Equal(t, "tools/call", gjson.GetBytes(body, "method").String())
require.Equal(t, "web_search", gjson.GetBytes(body, "params.name").String())
require.Equal(t, "golang concurrency", gjson.GetBytes(body, "params.arguments.query").String())
require.True(t, gjson.GetBytes(body, "params.arguments._meta._isValid").Bool())
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._activePath.0").String())
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._completedPaths.0.0").String())
require.False(t, gjson.GetBytes(body, "params.arguments._meta.isValid").Exists())
require.False(t, gjson.GetBytes(body, "params.arguments._meta.activePath").Exists())
require.False(t, gjson.GetBytes(body, "params.arguments._meta.completedPaths").Exists())
}
@@ -89,6 +89,14 @@ type OpsUpstreamErrorEvent struct {
AccountID int64 `json:"account_id,omitempty"` AccountID int64 `json:"account_id,omitempty"`
AccountName string `json:"account_name,omitempty"` AccountName string `json:"account_name,omitempty"`
// Model diagnostics.
RequestedModel string `json:"requested_model,omitempty"`
MappedModel string `json:"mapped_model,omitempty"`
KiroModelID string `json:"kiro_model_id,omitempty"`
HasTools bool `json:"has_tools,omitempty"`
HasAdaptiveThinking bool `json:"has_adaptive_thinking,omitempty"`
HasContext1MBeta bool `json:"has_context_1m_beta,omitempty"`
// Outcome // Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"`
@@ -42,6 +42,9 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键 // Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account)) keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey) keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformKiro:
keysToDelete = append(keysToDelete, KiroTokenCacheKey(account))
keysToDelete = append(keysToDelete, "kiro:"+accountIDKey)
case PlatformOpenAI: case PlatformOpenAI:
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account)) keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic: case PlatformAnthropic:
@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
) )
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间 // tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
@@ -44,6 +45,7 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService, antigravityOAuthService *AntigravityOAuthService,
kiroOAuthService *KiroOAuthService,
cacheInvalidator TokenCacheInvalidator, cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache, schedulerCache SchedulerCache,
cfg *config.Config, cfg *config.Config,
@@ -64,6 +66,7 @@ func NewTokenRefreshService(
claudeRefresher := NewClaudeTokenRefresher(oauthService) claudeRefresher := NewClaudeTokenRefresher(oauthService)
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService) geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService) agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
kiroRefresher := NewKiroTokenRefresher(kiroOAuthService)
// 注册平台特定的刷新器(TokenRefresher 接口) // 注册平台特定的刷新器(TokenRefresher 接口)
s.refreshers = []TokenRefresher{ s.refreshers = []TokenRefresher{
@@ -71,6 +74,7 @@ func NewTokenRefreshService(
openAIRefresher, openAIRefresher,
geminiRefresher, geminiRefresher,
agRefresher, agRefresher,
kiroRefresher,
} }
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法) // 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
@@ -79,6 +83,7 @@ func NewTokenRefreshService(
openAIRefresher, openAIRefresher,
geminiRefresher, geminiRefresher,
agRefresher, agRefresher,
kiroRefresher,
} }
return s return s
@@ -415,6 +420,10 @@ func isNonRetryableRefreshError(err error) bool {
if err == nil { if err == nil {
return false return false
} }
var kiroInvalidGrant *kiropkg.RefreshTokenInvalidError
if errors.As(err, &kiroInvalidGrant) {
return true
}
msg := strings.ToLower(err.Error()) msg := strings.ToLower(err.Error())
nonRetryable := []string{ nonRetryable := []string{
"invalid_grant", // refresh_token 已失效 "invalid_grant", // refresh_token 已失效
@@ -22,9 +22,9 @@ func optionalNonEqualStringPtr(value, compare string) *string {
func forwardResultBillingModel(requestedModel, upstreamModel string) string { func forwardResultBillingModel(requestedModel, upstreamModel string) string {
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" { if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
return trimmed return normalizeModelNameForPricing(trimmed)
} }
return strings.TrimSpace(upstreamModel) return normalizeModelNameForPricing(upstreamModel)
} }
func optionalInt64Ptr(v int64) *int64 { func optionalInt64Ptr(v int64) *int64 {
+23 -1
View File
@@ -8,6 +8,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -51,6 +52,7 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService, antigravityOAuthService *AntigravityOAuthService,
kiroOAuthService *KiroOAuthService,
cacheInvalidator TokenCacheInvalidator, cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache, schedulerCache SchedulerCache,
cfg *config.Config, cfg *config.Config,
@@ -59,7 +61,7 @@ func ProvideTokenRefreshService(
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
refreshAPI *OAuthRefreshAPI, refreshAPI *OAuthRefreshAPI,
) *TokenRefreshService { ) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 OpenAI privacy opt-out 依赖 // 注入 OpenAI privacy opt-out 依赖
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo) svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件) // 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
@@ -128,6 +130,23 @@ func ProvideAntigravityTokenProvider(
return p return p
} }
func ProvideKiroTokenProvider(
accountRepo AccountRepository,
tokenCache GeminiTokenCache,
kiroOAuthService *KiroOAuthService,
refreshAPI *OAuthRefreshAPI,
) *KiroTokenProvider {
p := NewKiroTokenProvider(accountRepo, tokenCache, kiroOAuthService)
executor := NewKiroTokenRefresher(kiroOAuthService)
p.SetRefreshAPI(refreshAPI, executor)
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
return p
}
func ProvideKiroCooldownStore(redisClient *redis.Client) KiroCooldownStore {
return kirocooldown.NewStore(redisClient)
}
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 // ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
svc := NewDashboardAggregationService(repo, timingWheel, cfg) svc := NewDashboardAggregationService(repo, timingWheel, cfg)
@@ -448,8 +467,11 @@ var ProviderSet = wire.NewSet(
NewCompositeTokenCacheInvalidator, NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService, NewAntigravityOAuthService,
NewKiroOAuthService,
ProvideOAuthRefreshAPI, ProvideOAuthRefreshAPI,
ProvideGeminiTokenProvider, ProvideGeminiTokenProvider,
ProvideKiroTokenProvider,
ProvideKiroCooldownStore,
NewGeminiMessagesCompatService, NewGeminiMessagesCompatService,
ProvideAntigravityTokenProvider, ProvideAntigravityTokenProvider,
ProvideOpenAITokenProvider, ProvideOpenAITokenProvider,