feat(backend): add kiro account support
This commit is contained in:
@@ -93,6 +93,7 @@ func provideCleanup(
|
|||||||
openaiOAuth *service.OpenAIOAuthService,
|
openaiOAuth *service.OpenAIOAuthService,
|
||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
|
kiroOAuth *service.KiroOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
backupSvc *service.BackupService,
|
backupSvc *service.BackupService,
|
||||||
@@ -216,6 +217,10 @@ func provideCleanup(
|
|||||||
antigravityOAuth.Stop()
|
antigravityOAuth.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"KiroOAuthService", func() error {
|
||||||
|
kiroOAuth.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"OpenAIWSPool", func() error {
|
{"OpenAIWSPool", func() error {
|
||||||
if openAIGateway != nil {
|
if openAIGateway != nil {
|
||||||
openAIGateway.CloseOpenAIWSPool()
|
openAIGateway.CloseOpenAIWSPool()
|
||||||
|
|||||||
@@ -146,13 +146,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||||
|
kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
|
||||||
|
kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
@@ -166,6 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
|
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
|
||||||
|
kiroOAuthHandler := admin.NewKiroOAuthHandler(kiroOAuthService)
|
||||||
proxyHandler := admin.NewProxyHandler(adminService)
|
proxyHandler := admin.NewProxyHandler(adminService)
|
||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
||||||
promoHandler := admin.NewPromoHandler(promoService)
|
promoHandler := admin.NewPromoHandler(promoService)
|
||||||
@@ -179,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
|
kiroCooldownStore := service.ProvideKiroCooldownStore(redisClient)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
channelRepository := repository.NewChannelRepository(db)
|
channelRepository := repository.NewChannelRepository(db)
|
||||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
||||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, kiroOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
@@ -256,13 +260,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -312,6 +316,7 @@ func provideCleanup(
|
|||||||
openaiOAuth *service.OpenAIOAuthService,
|
openaiOAuth *service.OpenAIOAuthService,
|
||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
|
kiroOAuth *service.KiroOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
backupSvc *service.BackupService,
|
backupSvc *service.BackupService,
|
||||||
@@ -434,6 +439,10 @@ func provideCleanup(
|
|||||||
antigravityOAuth.Stop()
|
antigravityOAuth.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"KiroOAuthService", func() error {
|
||||||
|
kiroOAuth.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"OpenAIWSPool", func() error {
|
{"OpenAIWSPool", func() error {
|
||||||
if openAIGateway != nil {
|
if openAIGateway != nil {
|
||||||
openAIGateway.CloseOpenAIWSPool()
|
openAIGateway.CloseOpenAIWSPool()
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
antigravityOAuthSvc,
|
antigravityOAuthSvc,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
@@ -72,6 +73,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
openAIOAuthSvc,
|
openAIOAuthSvc,
|
||||||
geminiOAuthSvc,
|
geminiOAuthSvc,
|
||||||
antigravityOAuthSvc,
|
antigravityOAuthSvc,
|
||||||
|
nil, // kiroOAuth
|
||||||
nil, // openAIGateway
|
nil, // openAIGateway
|
||||||
nil, // scheduledTestRunner
|
nil, // scheduledTestRunner
|
||||||
nil, // backupSvc
|
nil, // backupSvc
|
||||||
|
|||||||
@@ -635,6 +635,8 @@ type GatewayConfig struct {
|
|||||||
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
||||||
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
|
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
|
||||||
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
|
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
|
||||||
|
// KiroStreamKeepaliveInterval: Kiro 流式 keepalive 间隔(秒),0使用默认 25 秒
|
||||||
|
KiroStreamKeepaliveInterval int `mapstructure:"kiro_stream_keepalive_interval"`
|
||||||
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
|
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
|
||||||
MaxLineSize int `mapstructure:"max_line_size"`
|
MaxLineSize int `mapstructure:"max_line_size"`
|
||||||
|
|
||||||
@@ -1689,6 +1691,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||||
|
viper.SetDefault("gateway.kiro_stream_keepalive_interval", 25)
|
||||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||||
@@ -2277,6 +2280,13 @@ func (c *Config) Validate() error {
|
|||||||
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
||||||
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.KiroStreamKeepaliveInterval < 0 {
|
||||||
|
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.KiroStreamKeepaliveInterval != 0 &&
|
||||||
|
(c.Gateway.KiroStreamKeepaliveInterval < 5 || c.Gateway.KiroStreamKeepaliveInterval > 30) {
|
||||||
|
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||||
|
}
|
||||||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ const (
|
|||||||
PlatformOpenAI = "openai"
|
PlatformOpenAI = "openai"
|
||||||
PlatformGemini = "gemini"
|
PlatformGemini = "gemini"
|
||||||
PlatformAntigravity = "antigravity"
|
PlatformAntigravity = "antigravity"
|
||||||
|
PlatformKiro = "kiro"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
@@ -116,6 +117,21 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultKiroModelMapping 是 Kiro 平台的默认模型映射。
|
||||||
|
// 键为对外暴露/允许请求的模型名,值为实际发送到 Kiro 上游的模型名。
|
||||||
|
var DefaultKiroModelMapping = map[string]string{
|
||||||
|
"claude-opus-4-6": "claude-opus-4.6",
|
||||||
|
"claude-opus-4-6-thinking": "claude-opus-4.6",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||||
|
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4.5",
|
||||||
|
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||||
|
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
|
||||||
|
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package domain
|
package domain
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
@@ -24,3 +27,54 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
expected := map[string]string{
|
||||||
|
"claude-opus-4-6": "claude-opus-4.6",
|
||||||
|
"claude-opus-4-6-thinking": "claude-opus-4.6",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||||
|
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4.5",
|
||||||
|
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||||
|
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
|
||||||
|
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(DefaultKiroModelMapping) != len(expected) {
|
||||||
|
t.Fatalf("expected %d Kiro mappings, got %d", len(expected), len(DefaultKiroModelMapping))
|
||||||
|
}
|
||||||
|
for model, want := range expected {
|
||||||
|
if got := DefaultKiroModelMapping[model]; got != want {
|
||||||
|
t.Fatalf("unexpected Kiro mapping for %q: got %q want %q", model, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range []string{
|
||||||
|
"claude-opus-4-5",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4",
|
||||||
|
"deepseek-3-2",
|
||||||
|
"minimax-m2-1",
|
||||||
|
"qwen3-coder-next",
|
||||||
|
"claude-opus-4-7",
|
||||||
|
"claude-sonnet-4-6-chat",
|
||||||
|
} {
|
||||||
|
if _, ok := DefaultKiroModelMapping[model]; ok {
|
||||||
|
t.Fatalf("did not expect %q to remain in DefaultKiroModelMapping", model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for model := range DefaultKiroModelMapping {
|
||||||
|
if strings.HasSuffix(model, "-agentic") {
|
||||||
|
t.Fatalf("did not expect agentic Kiro mapping %q", model)
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(model, "-chat") {
|
||||||
|
t.Fatalf("did not expect chat-only Kiro mapping %q", model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
@@ -179,6 +180,9 @@ type AccountWithConcurrency struct {
|
|||||||
const accountListGroupUngroupedQueryValue = "ungrouped"
|
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||||
|
|
||||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||||
|
if h.accountUsageService != nil {
|
||||||
|
h.accountUsageService.EnrichAccountWithKiroRuntimeState(ctx, account)
|
||||||
|
}
|
||||||
item := AccountWithConcurrency{
|
item := AccountWithConcurrency{
|
||||||
Account: dto.AccountFromService(account),
|
Account: dto.AccountFromService(account),
|
||||||
CurrentConcurrency: 0,
|
CurrentConcurrency: 0,
|
||||||
@@ -351,6 +355,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
result := make([]AccountWithConcurrency, len(accounts))
|
result := make([]AccountWithConcurrency, len(accounts))
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
acc := &accounts[i]
|
acc := &accounts[i]
|
||||||
|
if h.accountUsageService != nil {
|
||||||
|
h.accountUsageService.EnrichAccountWithKiroRuntimeState(c.Request.Context(), acc)
|
||||||
|
}
|
||||||
item := AccountWithConcurrency{
|
item := AccountWithConcurrency{
|
||||||
Account: dto.AccountFromService(acc),
|
Account: dto.AccountFromService(acc),
|
||||||
CurrentConcurrency: concurrencyCounts[acc.ID],
|
CurrentConcurrency: concurrencyCounts[acc.ID],
|
||||||
@@ -1913,6 +1920,18 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Kiro accounts
|
||||||
|
if account.Platform == service.PlatformKiro {
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if len(mapping) == 0 {
|
||||||
|
response.Success(c, kiropkg.DefaultModels)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, buildMappedKiroModels(mapping))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Claude/Anthropic accounts
|
// Handle Claude/Anthropic accounts
|
||||||
// For OAuth and Setup-Token accounts: return default models
|
// For OAuth and Setup-Token accounts: return default models
|
||||||
if account.IsOAuth() {
|
if account.IsOAuth() {
|
||||||
@@ -1954,6 +1973,28 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
response.Success(c, models)
|
response.Success(c, models)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildMappedKiroModels(mapping map[string]string) []kiropkg.Model {
|
||||||
|
models := make([]kiropkg.Model, 0, len(mapping))
|
||||||
|
for requestedModel := range mapping {
|
||||||
|
var found bool
|
||||||
|
for _, dm := range kiropkg.DefaultModels {
|
||||||
|
if dm.ID == requestedModel {
|
||||||
|
models = append(models, dm)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
models = append(models, kiropkg.Model{
|
||||||
|
ID: requestedModel,
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: requestedModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||||
@@ -2166,6 +2207,12 @@ func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
|
|||||||
response.Success(c, domain.DefaultAntigravityModelMapping)
|
response.Success(c, domain.DefaultAntigravityModelMapping)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetKiroDefaultModelMapping 获取 Kiro 平台的默认模型映射
|
||||||
|
// GET /api/v1/admin/accounts/kiro/default-model-mapping
|
||||||
|
func (h *AccountHandler) GetKiroDefaultModelMapping(c *gin.Context) {
|
||||||
|
response.Success(c, domain.DefaultKiroModelMapping)
|
||||||
|
}
|
||||||
|
|
||||||
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
|
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
|
||||||
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
|
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
|
||||||
func sanitizeExtraBaseRPM(extra map[string]any) {
|
func sanitizeExtraBaseRPM(extra map[string]any) {
|
||||||
|
|||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KiroOAuthHandler struct {
|
||||||
|
kiroOAuthService *service.KiroOAuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKiroOAuthHandler(kiroOAuthService *service.KiroOAuthService) *KiroOAuthHandler {
|
||||||
|
return &KiroOAuthHandler{kiroOAuthService: kiroOAuthService}
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroGenerateAuthURLRequest struct {
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *KiroOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||||
|
var req KiroGenerateAuthURLRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "请求无效: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.kiroOAuthService.GenerateAuthURL(c.Request.Context(), &service.KiroGenerateAuthURLInput{
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
Provider: req.Provider,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "生成授权链接失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroGenerateIDCAuthURLRequest struct {
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
StartURL string `json:"start_url"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *KiroOAuthHandler) GenerateIDCAuthURL(c *gin.Context) {
|
||||||
|
var req KiroGenerateIDCAuthURLRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "请求无效: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.kiroOAuthService.GenerateIDCAuthURL(c.Request.Context(), &service.KiroGenerateIDCAuthURLInput{
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
StartURL: req.StartURL,
|
||||||
|
Region: req.Region,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "生成 IDC 授权链接失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroExchangeCodeRequest struct {
|
||||||
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
|
State string `json:"state" binding:"required"`
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
CallbackPath string `json:"callback_path"`
|
||||||
|
LoginOption string `json:"login_option"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *KiroOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||||
|
var req KiroExchangeCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "请求无效: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenInfo, err := h.kiroOAuthService.ExchangeCode(c.Request.Context(), &service.KiroExchangeCodeInput{
|
||||||
|
SessionID: req.SessionID,
|
||||||
|
State: req.State,
|
||||||
|
Code: req.Code,
|
||||||
|
CallbackPath: req.CallbackPath,
|
||||||
|
LoginOption: req.LoginOption,
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Token 交换失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroRefreshTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||||
|
AuthMethod string `json:"auth_method"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
StartURL string `json:"start_url"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
ProfileArn string `json:"profile_arn"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
|
||||||
|
var req KiroRefreshTokenRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "请求无效: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenInfo, err := h.kiroOAuthService.RefreshToken(c.Request.Context(), &service.KiroRefreshTokenInput{
|
||||||
|
RefreshToken: req.RefreshToken,
|
||||||
|
AuthMethod: req.AuthMethod,
|
||||||
|
Provider: req.Provider,
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
ClientSecret: req.ClientSecret,
|
||||||
|
StartURL: req.StartURL,
|
||||||
|
Region: req.Region,
|
||||||
|
ProfileArn: req.ProfileArn,
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "刷新 Kiro Token 失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroImportTokenRequest struct {
|
||||||
|
TokenJSON string `json:"token_json" binding:"required"`
|
||||||
|
DeviceRegistrationJSON string `json:"device_registration_json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
|
||||||
|
var req KiroImportTokenRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "请求无效: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
|
||||||
|
TokenJSON: req.TokenJSON,
|
||||||
|
DeviceRegistrationJSON: req.DeviceRegistrationJSON,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
@@ -221,6 +221,12 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
OverloadUntil: a.OverloadUntil,
|
OverloadUntil: a.OverloadUntil,
|
||||||
TempUnschedulableUntil: a.TempUnschedulableUntil,
|
TempUnschedulableUntil: a.TempUnschedulableUntil,
|
||||||
TempUnschedulableReason: a.TempUnschedulableReason,
|
TempUnschedulableReason: a.TempUnschedulableReason,
|
||||||
|
KiroQuotaState: a.KiroQuotaState,
|
||||||
|
KiroQuotaReason: a.KiroQuotaReason,
|
||||||
|
KiroQuotaResetAt: a.KiroQuotaResetAt,
|
||||||
|
KiroRuntimeState: a.KiroRuntimeState,
|
||||||
|
KiroRuntimeReason: a.KiroRuntimeReason,
|
||||||
|
KiroRuntimeResetAt: a.KiroRuntimeResetAt,
|
||||||
SessionWindowStart: a.SessionWindowStart,
|
SessionWindowStart: a.SessionWindowStart,
|
||||||
SessionWindowEnd: a.SessionWindowEnd,
|
SessionWindowEnd: a.SessionWindowEnd,
|
||||||
SessionWindowStatus: a.SessionWindowStatus,
|
SessionWindowStatus: a.SessionWindowStatus,
|
||||||
|
|||||||
@@ -174,6 +174,12 @@ type Account struct {
|
|||||||
|
|
||||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||||
|
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
|
||||||
|
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
|
||||||
|
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
|
||||||
|
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
|
||||||
|
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
|
||||||
|
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
|
||||||
|
|
||||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type AdminHandlers struct {
|
|||||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||||
GeminiOAuth *admin.GeminiOAuthHandler
|
GeminiOAuth *admin.GeminiOAuthHandler
|
||||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||||
|
KiroOAuth *admin.KiroOAuthHandler
|
||||||
Proxy *admin.ProxyHandler
|
Proxy *admin.ProxyHandler
|
||||||
Redeem *admin.RedeemHandler
|
Redeem *admin.RedeemHandler
|
||||||
Promo *admin.PromoHandler
|
Promo *admin.PromoHandler
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
|
|||||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||||
|
kiroOAuthHandler *admin.KiroOAuthHandler,
|
||||||
proxyHandler *admin.ProxyHandler,
|
proxyHandler *admin.ProxyHandler,
|
||||||
redeemHandler *admin.RedeemHandler,
|
redeemHandler *admin.RedeemHandler,
|
||||||
promoHandler *admin.PromoHandler,
|
promoHandler *admin.PromoHandler,
|
||||||
@@ -51,6 +52,7 @@ func ProvideAdminHandlers(
|
|||||||
OpenAIOAuth: openaiOAuthHandler,
|
OpenAIOAuth: openaiOAuthHandler,
|
||||||
GeminiOAuth: geminiOAuthHandler,
|
GeminiOAuth: geminiOAuthHandler,
|
||||||
AntigravityOAuth: antigravityOAuthHandler,
|
AntigravityOAuth: antigravityOAuthHandler,
|
||||||
|
KiroOAuth: kiroOAuthHandler,
|
||||||
Proxy: proxyHandler,
|
Proxy: proxyHandler,
|
||||||
Redeem: redeemHandler,
|
Redeem: redeemHandler,
|
||||||
Promo: promoHandler,
|
Promo: promoHandler,
|
||||||
@@ -154,6 +156,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewOpenAIOAuthHandler,
|
admin.NewOpenAIOAuthHandler,
|
||||||
admin.NewGeminiOAuthHandler,
|
admin.NewGeminiOAuthHandler,
|
||||||
admin.NewAntigravityOAuthHandler,
|
admin.NewAntigravityOAuthHandler,
|
||||||
|
admin.NewKiroOAuthHandler,
|
||||||
admin.NewProxyHandler,
|
admin.NewProxyHandler,
|
||||||
admin.NewRedeemHandler,
|
admin.NewRedeemHandler,
|
||||||
admin.NewPromoHandler,
|
admin.NewPromoHandler,
|
||||||
|
|||||||
@@ -0,0 +1,258 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RuntimeFingerprint struct {
|
||||||
|
OIDCSDKVersion string
|
||||||
|
RuntimeSDKVersion string
|
||||||
|
StreamingSDKVersion string
|
||||||
|
OSType string
|
||||||
|
OSVersion string
|
||||||
|
NodeVersion string
|
||||||
|
KiroVersion string
|
||||||
|
KiroHash string
|
||||||
|
}
|
||||||
|
|
||||||
|
type runtimeFingerprintManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
fingerprints map[string]*RuntimeFingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
globalRuntimeFingerprintManager *runtimeFingerprintManager
|
||||||
|
globalRuntimeFingerprintManagerOnce sync.Once
|
||||||
|
|
||||||
|
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
|
||||||
|
runtimeSDKVersions = []string{"1.0.0"}
|
||||||
|
streamingSDKVersions = []string{"1.0.34"}
|
||||||
|
osTypes = []string{"darwin", "win32"}
|
||||||
|
osVersions = map[string][]string{
|
||||||
|
"darwin": {"24.6.0"},
|
||||||
|
"win32": {"10.0.22631"},
|
||||||
|
}
|
||||||
|
nodeVersions = []string{"22.22.0"}
|
||||||
|
kiroVersions = []string{
|
||||||
|
"0.11.132", "0.11.131", "0.11.130",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func globalRuntimeFingerprints() *runtimeFingerprintManager {
|
||||||
|
globalRuntimeFingerprintManagerOnce.Do(func() {
|
||||||
|
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
|
||||||
|
fingerprints: make(map[string]*RuntimeFingerprint),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return globalRuntimeFingerprintManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
|
||||||
|
lookupKey := fingerprintLookupKey(accountKey, "runtime")
|
||||||
|
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
fp := generateRuntimeFingerprint(lookupKey, machineID)
|
||||||
|
m.fingerprints[lookupKey] = fp
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
|
||||||
|
hash := sha256.Sum256([]byte(accountKey))
|
||||||
|
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||||
|
rng := rand.New(rand.NewSource(seed))
|
||||||
|
|
||||||
|
osType := goOSToNodePlatform(runtime.GOOS)
|
||||||
|
if !containsString(osTypes, osType) {
|
||||||
|
osType = osTypes[rng.Intn(len(osTypes))]
|
||||||
|
}
|
||||||
|
osVersionPool := osVersions[osType]
|
||||||
|
if len(osVersionPool) == 0 {
|
||||||
|
osVersionPool = osVersions["darwin"]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RuntimeFingerprint{
|
||||||
|
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||||
|
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||||
|
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||||
|
OSType: osType,
|
||||||
|
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
|
||||||
|
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||||
|
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||||
|
KiroHash: machineID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func goOSToNodePlatform(goos string) string {
|
||||||
|
switch strings.TrimSpace(goos) {
|
||||||
|
case "windows":
|
||||||
|
return "win32"
|
||||||
|
default:
|
||||||
|
return strings.TrimSpace(goos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsString(items []string, target string) bool {
|
||||||
|
for _, item := range items {
|
||||||
|
if item == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
|
||||||
|
switch {
|
||||||
|
case strings.TrimSpace(clientIDHash) != "":
|
||||||
|
return clientIDHash
|
||||||
|
case strings.TrimSpace(clientID) != "":
|
||||||
|
return shortSHA(clientID)
|
||||||
|
case strings.TrimSpace(refreshToken) != "":
|
||||||
|
return shortSHA(refreshToken)
|
||||||
|
case strings.TrimSpace(profileArn) != "":
|
||||||
|
return shortSHA(profileArn)
|
||||||
|
case accountID > 0:
|
||||||
|
return shortSHA(fmt.Sprintf("account:%d", accountID))
|
||||||
|
default:
|
||||||
|
return shortSHA(uuid.NewString())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeMachineID(machineID string) (string, bool) {
|
||||||
|
trimmed := strings.TrimSpace(machineID)
|
||||||
|
if len(trimmed) == 64 && isHexString(trimmed) {
|
||||||
|
return strings.ToLower(trimmed), true
|
||||||
|
}
|
||||||
|
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
|
||||||
|
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
|
||||||
|
normalized := strings.ToLower(withoutDashes)
|
||||||
|
return normalized + normalized, true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
|
||||||
|
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
|
||||||
|
return sha256Hex("KotlinNativeAPI/" + refreshToken)
|
||||||
|
}
|
||||||
|
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
|
||||||
|
return sha256Hex("KiroAPIKey/" + apiKey)
|
||||||
|
}
|
||||||
|
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
|
||||||
|
return sha256Hex("KiroFallback/" + fallbackKey)
|
||||||
|
}
|
||||||
|
return sha256Hex("KiroFallback/default")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shortSHA(seed string) string {
|
||||||
|
sum := sha256.Sum256([]byte(seed))
|
||||||
|
return hex.EncodeToString(sum[:8])
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Hex(seed string) string {
|
||||||
|
sum := sha256.Sum256([]byte(seed))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHexString(value string) bool {
|
||||||
|
for _, c := range value {
|
||||||
|
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
|
||||||
|
if normalized, ok := NormalizeMachineID(machineID); ok {
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
return BuildMachineID("", "", fallbackKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fingerprintLookupKey(accountKey, fallback string) string {
|
||||||
|
key := strings.TrimSpace(accountKey)
|
||||||
|
if key != "" {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildRuntimeUserAgent(accountKey, machineID string) string {
|
||||||
|
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||||
|
fp.StreamingSDKVersion,
|
||||||
|
fp.OSType,
|
||||||
|
fp.OSVersion,
|
||||||
|
fp.NodeVersion,
|
||||||
|
fp.StreamingSDKVersion,
|
||||||
|
fp.KiroVersion,
|
||||||
|
fp.KiroHash,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
|
||||||
|
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||||
|
fp.StreamingSDKVersion,
|
||||||
|
fp.KiroVersion,
|
||||||
|
fp.KiroHash,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
|
||||||
|
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
|
||||||
|
return map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
|
||||||
|
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
|
||||||
|
"amz-sdk-invocation-id": uuid.NewString(),
|
||||||
|
"amz-sdk-request": "attempt=1; max=4",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
|
||||||
|
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
|
||||||
|
return map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
|
||||||
|
"Accept": "application/json, text/plain, */*",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
delay := baseDelay << attempt
|
||||||
|
if delay > maxDelay {
|
||||||
|
delay = maxDelay
|
||||||
|
}
|
||||||
|
const jitterFactor = 0.3
|
||||||
|
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
|
||||||
|
return time.Duration(float64(delay) * jitter)
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildLoginHeadersStable(t *testing.T) {
|
||||||
|
headers1 := BuildLoginHeaders("", "")
|
||||||
|
headers2 := BuildLoginHeaders("", "")
|
||||||
|
|
||||||
|
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||||
|
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
|
||||||
|
require.Equal(t, "application/json", headers1["Content-Type"])
|
||||||
|
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
|
||||||
|
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
|
||||||
|
machineIDA := BuildMachineID("refresh-a", "", "")
|
||||||
|
machineIDB := BuildMachineID("refresh-b", "", "")
|
||||||
|
headers1 := BuildLoginHeaders("account-a", machineIDA)
|
||||||
|
headers2 := BuildLoginHeaders("account-a", machineIDA)
|
||||||
|
headers3 := BuildLoginHeaders("account-a", machineIDB)
|
||||||
|
|
||||||
|
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||||
|
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||||
|
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
|
||||||
|
require.Contains(t, headers1["User-Agent"], machineIDA)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
|
||||||
|
machineID := BuildMachineID("", "", "oidc-machine")
|
||||||
|
headers1 := BuildOIDCHeaders("account-a", machineID)
|
||||||
|
headers2 := BuildOIDCHeaders("account-a", machineID)
|
||||||
|
headers3 := BuildOIDCHeaders("account-b", machineID)
|
||||||
|
|
||||||
|
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||||
|
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||||
|
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
|
||||||
|
key1 := BuildAccountKey("", "", "", "", 42)
|
||||||
|
key2 := BuildAccountKey("", "", "", "", 42)
|
||||||
|
key3 := BuildAccountKey("", "", "", "", 43)
|
||||||
|
|
||||||
|
require.Equal(t, key1, key2)
|
||||||
|
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
|
||||||
|
require.NotEqual(t, key1, key3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildMachineID(t *testing.T) {
|
||||||
|
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
|
||||||
|
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
|
||||||
|
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
|
||||||
|
|
||||||
|
fallback1 := BuildMachineID("", "", "account:1")
|
||||||
|
fallback2 := BuildMachineID("", "", "account:1")
|
||||||
|
fallback3 := BuildMachineID("", "", "account:2")
|
||||||
|
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
|
||||||
|
require.Equal(t, fallback1, fallback2)
|
||||||
|
require.NotEqual(t, fallback1, fallback3)
|
||||||
|
require.Len(t, fallback1, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMachineID(t *testing.T) {
|
||||||
|
hex64 := strings.Repeat("A", 64)
|
||||||
|
normalized, ok := NormalizeMachineID(hex64)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, strings.ToLower(hex64), normalized)
|
||||||
|
|
||||||
|
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
|
||||||
|
|
||||||
|
_, ok = NormalizeMachineID("not-a-machine-id")
|
||||||
|
require.False(t, ok)
|
||||||
|
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectedKiroMachineID(seed string) string {
|
||||||
|
sum := sha256.Sum256([]byte(seed))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultModels = []Model{
|
||||||
|
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
|
||||||
|
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
|
||||||
|
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
|
||||||
|
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
|
||||||
|
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
|
||||||
|
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
|
||||||
|
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
|
||||||
|
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
|
||||||
|
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
|
||||||
|
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
|
||||||
|
ids := make([]string, 0, len(DefaultModels))
|
||||||
|
for _, model := range DefaultModels {
|
||||||
|
ids = append(ids, model.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, []string{
|
||||||
|
"claude-opus-4-6",
|
||||||
|
"claude-opus-4-6-thinking",
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-6-thinking",
|
||||||
|
"claude-opus-4-5-20251101",
|
||||||
|
"claude-opus-4-5-20251101-thinking",
|
||||||
|
"claude-sonnet-4-5-20250929",
|
||||||
|
"claude-sonnet-4-5-20250929-thinking",
|
||||||
|
"claude-haiku-4-5-20251001",
|
||||||
|
"claude-haiku-4-5-20251001-thinking",
|
||||||
|
}, ids)
|
||||||
|
|
||||||
|
require.Contains(t, ids, "claude-sonnet-4-6")
|
||||||
|
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
|
||||||
|
require.NotContains(t, ids, "auto")
|
||||||
|
require.NotContains(t, ids, "claude-sonnet-4")
|
||||||
|
require.NotContains(t, ids, "gpt-4o")
|
||||||
|
require.NotContains(t, ids, "deepseek-3-2")
|
||||||
|
require.NotContains(t, ids, "minimax-m2-1")
|
||||||
|
require.NotContains(t, ids, "qwen3-coder-next")
|
||||||
|
require.NotContains(t, ids, "claude-opus-4-7")
|
||||||
|
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
|
||||||
|
for _, id := range ids {
|
||||||
|
require.NotContains(t, id, "kiro-")
|
||||||
|
require.NotContains(t, id, "-agentic")
|
||||||
|
require.NotContains(t, id, "-chat")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,511 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
socialAuthPortalURL = "https://app.kiro.dev"
|
||||||
|
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||||
|
defaultIDCRegion = "us-east-1"
|
||||||
|
BuilderIDStartURL = "https://view.awsapps.com/start"
|
||||||
|
sessionTTL = 10 * time.Minute
|
||||||
|
sessionCleanupEvery = 32
|
||||||
|
sessionCleanupMin = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
socialAuthEndpointURL = socialAuthEndpoint
|
||||||
|
oidcEndpointOverride = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
type SocialProvider string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SocialProviderGoogle SocialProvider = "Google"
|
||||||
|
SocialProviderGitHub SocialProvider = "Github"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthSession struct {
|
||||||
|
State string
|
||||||
|
CodeVerifier string
|
||||||
|
ProxyURL string
|
||||||
|
CreatedAt time.Time
|
||||||
|
AuthType string
|
||||||
|
Provider string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
Region string
|
||||||
|
StartURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
data map[string]*AuthSession
|
||||||
|
setCount uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionStore() *SessionStore {
|
||||||
|
return &SessionStore{data: make(map[string]*AuthSession)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
session, ok := s.data[id]
|
||||||
|
if ok && sessionExpired(session, now) {
|
||||||
|
delete(s.data, id)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return session, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Set(id string, session *AuthSession) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.setCount++
|
||||||
|
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
|
||||||
|
s.pruneExpiredLocked(time.Now())
|
||||||
|
}
|
||||||
|
s.data[id] = session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Delete(id string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.data, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
|
||||||
|
for id, session := range s.data {
|
||||||
|
if sessionExpired(session, now) {
|
||||||
|
delete(s.data, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionExpired(session *AuthSession, now time.Time) bool {
|
||||||
|
if session == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if session.CreatedAt.IsZero() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return now.After(session.CreatedAt.Add(sessionTTL))
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenData struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn,omitempty"`
|
||||||
|
ExpiresAt string `json:"expiresAt,omitempty"`
|
||||||
|
AuthMethod string `json:"authMethod,omitempty"`
|
||||||
|
Provider string `json:"provider,omitempty"`
|
||||||
|
ClientID string `json:"clientId,omitempty"`
|
||||||
|
ClientSecret string `json:"clientSecret,omitempty"`
|
||||||
|
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
StartURL string `json:"startUrl,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type socialTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type registerClientResponse struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type createTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type userInfoResponse struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type deviceRegistration struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshTokenInvalidError struct {
|
||||||
|
StatusCode int
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RefreshTokenInvalidError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(e.Body)
|
||||||
|
if body == "" {
|
||||||
|
return "kiro refresh token invalid (invalid_grant)"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateSessionID() string {
|
||||||
|
return uuid.NewString()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateState() (string, error) {
|
||||||
|
return randomURLSafe(16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateCodeVerifier() (string, error) {
|
||||||
|
return randomURLSafe(32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomURLSafe(n int) (string, error) {
|
||||||
|
buf := make([]byte, n)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateCodeChallenge(verifier string) string {
|
||||||
|
sum := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("state", state)
|
||||||
|
params.Set("code_challenge", codeChallenge)
|
||||||
|
params.Set("code_challenge_method", "S256")
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("redirect_from", "KiroIDE")
|
||||||
|
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
|
||||||
|
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
|
||||||
|
if redirectURI == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
path := strings.TrimSpace(callbackPath)
|
||||||
|
if path == "" {
|
||||||
|
path = "/oauth/callback"
|
||||||
|
} else if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
fullRedirectURI := redirectURI + path
|
||||||
|
if option := strings.TrimSpace(loginOption); option != "" {
|
||||||
|
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
|
||||||
|
}
|
||||||
|
return fullRedirectURI
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"code": code,
|
||||||
|
"code_verifier": codeVerifier,
|
||||||
|
"redirect_uri": redirectURI,
|
||||||
|
}
|
||||||
|
var resp socialTokenResponse
|
||||||
|
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
expiresIn := resp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
return &TokenData{
|
||||||
|
AccessToken: resp.AccessToken,
|
||||||
|
RefreshToken: resp.RefreshToken,
|
||||||
|
ProfileArn: resp.ProfileArn,
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Region: defaultIDCRegion,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
}
|
||||||
|
var resp socialTokenResponse
|
||||||
|
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
|
||||||
|
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
expiresIn := resp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
return &TokenData{
|
||||||
|
AccessToken: resp.AccessToken,
|
||||||
|
RefreshToken: resp.RefreshToken,
|
||||||
|
ProfileArn: resp.ProfileArn,
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: provider,
|
||||||
|
Region: defaultIDCRegion,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
payload := map[string]any{
|
||||||
|
"clientName": "Kiro IDE",
|
||||||
|
"clientType": "public",
|
||||||
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||||
|
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||||
|
"redirectUris": []string{redirectURI},
|
||||||
|
"issuerUrl": issuerURL,
|
||||||
|
}
|
||||||
|
var resp registerClientResponse
|
||||||
|
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
|
||||||
|
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("client_id", clientID)
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("scopes", strings.Join([]string{
|
||||||
|
"codewhisperer:completions",
|
||||||
|
"codewhisperer:analysis",
|
||||||
|
"codewhisperer:conversations",
|
||||||
|
"codewhisperer:transformations",
|
||||||
|
"codewhisperer:taskassist",
|
||||||
|
}, " "))
|
||||||
|
params.Set("state", state)
|
||||||
|
params.Set("code_challenge", codeChallenge)
|
||||||
|
params.Set("code_challenge_method", "S256")
|
||||||
|
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"code": code,
|
||||||
|
"codeVerifier": codeVerifier,
|
||||||
|
"redirectUri": redirectURI,
|
||||||
|
"grantType": "authorization_code",
|
||||||
|
}
|
||||||
|
var resp createTokenResponse
|
||||||
|
accountKey := BuildAccountKey(clientID, "", "", "", 0)
|
||||||
|
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
|
||||||
|
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
expiresIn := resp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
token := &TokenData{
|
||||||
|
AccessToken: resp.AccessToken,
|
||||||
|
RefreshToken: resp.RefreshToken,
|
||||||
|
ProfileArn: resp.ProfileArn,
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
AuthMethod: "idc",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
|
}
|
||||||
|
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
"grantType": "refresh_token",
|
||||||
|
}
|
||||||
|
var resp createTokenResponse
|
||||||
|
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
|
||||||
|
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
|
||||||
|
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
expiresIn := resp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
token := &TokenData{
|
||||||
|
AccessToken: resp.AccessToken,
|
||||||
|
RefreshToken: resp.RefreshToken,
|
||||||
|
ProfileArn: resp.ProfileArn,
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
AuthMethod: "idc",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
|
}
|
||||||
|
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
|
||||||
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var resp userInfoResponse
|
||||||
|
headers := map[string]string{
|
||||||
|
"Authorization": "Bearer " + accessToken,
|
||||||
|
}
|
||||||
|
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(resp.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
|
||||||
|
var token TokenData
|
||||||
|
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
|
||||||
|
}
|
||||||
|
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
||||||
|
if strings.TrimSpace(token.AccessToken) == "" {
|
||||||
|
return nil, fmt.Errorf("access token is empty")
|
||||||
|
}
|
||||||
|
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
|
||||||
|
var reg deviceRegistration
|
||||||
|
if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse device registration: %w", err)
|
||||||
|
}
|
||||||
|
if reg.ClientID != "" {
|
||||||
|
token.ClientID = reg.ClientID
|
||||||
|
}
|
||||||
|
if reg.ClientSecret != "" {
|
||||||
|
token.ClientSecret = reg.ClientSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOIDCEndpoint(region string) string {
|
||||||
|
if strings.TrimSpace(oidcEndpointOverride) != "" {
|
||||||
|
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
|
||||||
|
}
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||||
|
}
|
||||||
|
|
||||||
|
func oidcHeaders(accountKey, machineID string) map[string]string {
|
||||||
|
headers := BuildOIDCHeaders(accountKey, machineID)
|
||||||
|
if headers["amz-sdk-invocation-id"] == "" {
|
||||||
|
headers["amz-sdk-invocation-id"] = uuid.NewString()
|
||||||
|
}
|
||||||
|
if headers["amz-sdk-request"] == "" {
|
||||||
|
headers["amz-sdk-request"] = "attempt=1; max=4"
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
|
||||||
|
client, err := newHTTPClient(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var body io.Reader
|
||||||
|
if payload != nil {
|
||||||
|
encoded, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
body = bytes.NewReader(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload != nil {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
for key, value := range extraHeaders {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
bodyText := strings.TrimSpace(string(respBody))
|
||||||
|
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
|
||||||
|
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
|
||||||
|
}
|
||||||
|
if out == nil || len(respBody) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return json.Unmarshal(respBody, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
|
||||||
|
_, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
transport := &http.Transport{}
|
||||||
|
if parsed != nil {
|
||||||
|
transport.Proxy = http.ProxyURL(parsed)
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Transport: transport,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "/refreshToken", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
previous := socialAuthEndpointURL
|
||||||
|
socialAuthEndpointURL = server.URL
|
||||||
|
t.Cleanup(func() { socialAuthEndpointURL = previous })
|
||||||
|
|
||||||
|
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var invalid *RefreshTokenInvalidError
|
||||||
|
require.True(t, errors.As(err, &invalid))
|
||||||
|
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||||
|
require.Contains(t, invalid.Body, "invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "/token", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
previous := oidcEndpointOverride
|
||||||
|
oidcEndpointOverride = server.URL
|
||||||
|
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||||
|
|
||||||
|
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var invalid *RefreshTokenInvalidError
|
||||||
|
require.True(t, errors.As(err, &invalid))
|
||||||
|
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||||
|
require.Contains(t, invalid.Body, "invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
|
||||||
|
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/token":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||||
|
case "/userinfo":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
previous := oidcEndpointOverride
|
||||||
|
oidcEndpointOverride = server.URL
|
||||||
|
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||||
|
|
||||||
|
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, profileArn, token.ProfileArn)
|
||||||
|
require.Equal(t, "kiro@example.com", token.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
|
||||||
|
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/token":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||||
|
case "/userinfo":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
previous := oidcEndpointOverride
|
||||||
|
oidcEndpointOverride = server.URL
|
||||||
|
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||||
|
|
||||||
|
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, profileArn, token.ProfileArn)
|
||||||
|
require.Equal(t, "kiro@example.com", token.Email)
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
|
||||||
|
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
|
||||||
|
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSocialTokenRedirectURI(t *testing.T) {
|
||||||
|
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
|
||||||
|
want := "http://localhost:49153/oauth/callback?login_option=github"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
|
||||||
|
store := NewSessionStore()
|
||||||
|
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
|
||||||
|
|
||||||
|
session, ok := store.Get("expired")
|
||||||
|
if ok || session != nil {
|
||||||
|
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
|
||||||
|
}
|
||||||
|
if _, exists := store.data["expired"]; exists {
|
||||||
|
t.Fatalf("expired session should be deleted from the store")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
|
||||||
|
store := NewSessionStore()
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < sessionCleanupMin; i++ {
|
||||||
|
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
|
||||||
|
}
|
||||||
|
store.setCount = sessionCleanupEvery - 1
|
||||||
|
|
||||||
|
store.Set("fresh", &AuthSession{CreatedAt: now})
|
||||||
|
|
||||||
|
if len(store.data) != 1 {
|
||||||
|
t.Fatalf("store size = %d, want 1", len(store.data))
|
||||||
|
}
|
||||||
|
if _, ok := store.data["fresh"]; !ok {
|
||||||
|
t.Fatalf("fresh session should remain after pruning")
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,368 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
|
||||||
|
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
|
||||||
|
|
||||||
|
var cachedWebSearchDescription atomic.Value // stores string
|
||||||
|
|
||||||
|
type MCPRequest struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params interface{} `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MCPResponse struct {
|
||||||
|
Result *struct {
|
||||||
|
Content []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
Tools []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
} `json:"tools"`
|
||||||
|
} `json:"result,omitempty"`
|
||||||
|
Error *struct {
|
||||||
|
Code *int `json:"code,omitempty"`
|
||||||
|
Message *string `json:"message,omitempty"`
|
||||||
|
} `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebSearchResults struct {
|
||||||
|
Results []WebSearchResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebSearchResult struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Snippet *string `json:"snippet,omitempty"`
|
||||||
|
PublishedDate *int64 `json:"publishedDate,omitempty"`
|
||||||
|
ID *string `json:"id,omitempty"`
|
||||||
|
Domain *string `json:"domain,omitempty"`
|
||||||
|
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
|
||||||
|
PublicDomain *bool `json:"publicDomain,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SearchIndicator struct {
|
||||||
|
ToolUseID string
|
||||||
|
Query string
|
||||||
|
Results *WebSearchResults
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCachedWebSearchDescription() string {
|
||||||
|
if v := cachedWebSearchDescription.Load(); v != nil {
|
||||||
|
return strings.TrimSpace(v.(string))
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetCachedWebSearchDescription(desc string) {
|
||||||
|
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildMcpEndpoint(region string) string {
|
||||||
|
if strings.TrimSpace(region) == "" {
|
||||||
|
region = "us-east-1"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
|
||||||
|
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, item := range resp.Result.Content {
|
||||||
|
if item.Type != "" && item.Type != "text" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var results WebSearchResults
|
||||||
|
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
|
||||||
|
return &results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractSearchQuery(body []byte) string {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
arr := messages.Array()
|
||||||
|
for i := len(arr) - 1; i >= 0; i-- {
|
||||||
|
msg := arr[i]
|
||||||
|
if msg.Get("role").String() != "user" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text := extractSearchText(msg.Get("content"))
|
||||||
|
const prefix = "Perform a web search for the query: "
|
||||||
|
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||||
|
if text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSearchText(content gjson.Result) string {
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
if !content.IsArray() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, block := range content.Array() {
|
||||||
|
if block.Get("type").String() == "text" {
|
||||||
|
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateToolUseID() string {
|
||||||
|
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||||
|
var payload map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
rawTools, ok := payload["tools"].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
replaced := make([]interface{}, 0, len(rawTools))
|
||||||
|
for _, rawTool := range rawTools {
|
||||||
|
tool, ok := rawTool.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
replaced = append(replaced, rawTool)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := getInterfaceString(tool["name"])
|
||||||
|
toolType := getInterfaceString(tool["type"])
|
||||||
|
if !isWebSearchToolName(name, toolType) {
|
||||||
|
replaced = append(replaced, rawTool)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
replaced = append(replaced, map[string]interface{}{
|
||||||
|
"name": "web_search",
|
||||||
|
"description": minimalWebSearchDescription,
|
||||||
|
"input_schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"query": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query to execute",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"query"},
|
||||||
|
"additionalProperties": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
payload["tools"] = replaced
|
||||||
|
updated, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
|
||||||
|
var payload map[string]interface{}
|
||||||
|
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||||
|
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawMessages, ok := payload["messages"].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return claudePayload, fmt.Errorf("claude payload missing messages array")
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantMsg := map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolUseID,
|
||||||
|
"name": "web_search",
|
||||||
|
"input": map[string]interface{}{"query": query},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
userContent := []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": toolUseID,
|
||||||
|
"content": formatToolResultText(results),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if guidance := searchGuidanceText(); guidance != "" {
|
||||||
|
userContent = append(userContent, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": guidance,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
userMsg := map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": userContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
rawMessages = append(rawMessages, assistantMsg, userMsg)
|
||||||
|
payload["messages"] = rawMessages
|
||||||
|
updated, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
|
||||||
|
}
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
|
||||||
|
if len(searches) == 0 {
|
||||||
|
return responsePayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var response map[string]interface{}
|
||||||
|
if err := json.Unmarshal(responsePayload, &response); err != nil {
|
||||||
|
return responsePayload, err
|
||||||
|
}
|
||||||
|
content, _ := response["content"].([]interface{})
|
||||||
|
updated := make([]interface{}, 0, len(searches)*2+len(content))
|
||||||
|
for _, search := range searches {
|
||||||
|
updated = append(updated, map[string]interface{}{
|
||||||
|
"type": "server_tool_use",
|
||||||
|
"id": search.ToolUseID,
|
||||||
|
"name": "web_search",
|
||||||
|
"input": map[string]interface{}{"query": search.Query},
|
||||||
|
})
|
||||||
|
updated = append(updated, map[string]interface{}{
|
||||||
|
"type": "web_search_tool_result",
|
||||||
|
"content": buildSearchResultContent(search.Results),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
updated = append(updated, content...)
|
||||||
|
response["content"] = updated
|
||||||
|
|
||||||
|
encoded, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return responsePayload, err
|
||||||
|
}
|
||||||
|
return encoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
|
||||||
|
content := make([]map[string]interface{}, 0)
|
||||||
|
if results == nil {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
for _, result := range results.Results {
|
||||||
|
snippet := ""
|
||||||
|
if result.Snippet != nil {
|
||||||
|
snippet = strings.TrimSpace(*result.Snippet)
|
||||||
|
}
|
||||||
|
content = append(content, map[string]interface{}{
|
||||||
|
"type": "web_search_result",
|
||||||
|
"title": result.Title,
|
||||||
|
"url": result.URL,
|
||||||
|
"encrypted_content": snippet,
|
||||||
|
"page_age": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
|
||||||
|
content := gjson.GetBytes(responsePayload, "content")
|
||||||
|
if !content.IsArray() {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
for _, block := range content.Array() {
|
||||||
|
if block.Get("type").String() != "tool_use" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := block.Get("name").String()
|
||||||
|
if !isWebSearchToolName(name, "") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
query = strings.TrimSpace(block.Get("input.query").String())
|
||||||
|
if query == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return block.Get("id").String(), query, true
|
||||||
|
}
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWebSearchToolName(name, toolType string) bool {
|
||||||
|
name = strings.ToLower(strings.TrimSpace(name))
|
||||||
|
toolType = strings.ToLower(strings.TrimSpace(toolType))
|
||||||
|
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch name {
|
||||||
|
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInterfaceString(v interface{}) string {
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(val)
|
||||||
|
default:
|
||||||
|
return strings.TrimSpace(fmt.Sprint(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatToolResultText(results *WebSearchResults) string {
|
||||||
|
if results == nil || len(results.Results) == 0 {
|
||||||
|
return "No search results found."
|
||||||
|
}
|
||||||
|
payload, err := json.MarshalIndent(results.Results, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return "Found search results, but failed to format them."
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
func searchGuidanceText() string {
|
||||||
|
now := time.Now()
|
||||||
|
return fmt.Sprintf(`<search_guidance>
|
||||||
|
Current date: %s (%s)
|
||||||
|
|
||||||
|
IMPORTANT: Evaluate the search results above carefully. If the results are:
|
||||||
|
- Mostly spam, SEO junk, or unrelated websites
|
||||||
|
- Missing actual information about the query topic
|
||||||
|
- Outdated or not matching the requested time frame
|
||||||
|
|
||||||
|
Then you MUST use the web_search tool again with a refined query. Try:
|
||||||
|
- Rephrasing in English for better coverage
|
||||||
|
- Using more specific keywords
|
||||||
|
- Adding date context
|
||||||
|
|
||||||
|
Do NOT apologize for bad results without first attempting a re-search.
|
||||||
|
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
|
||||||
|
}
|
||||||
@@ -0,0 +1,297 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BufferedStreamResult struct {
|
||||||
|
StopReason string
|
||||||
|
WebSearchQuery string
|
||||||
|
WebSearchToolUseID string
|
||||||
|
HasWebSearchToolUse bool
|
||||||
|
WebSearchToolUseIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
|
||||||
|
searchContent := make([]map[string]interface{}, 0)
|
||||||
|
if results != nil {
|
||||||
|
for _, result := range results.Results {
|
||||||
|
snippet := ""
|
||||||
|
if result.Snippet != nil {
|
||||||
|
snippet = strings.TrimSpace(*result.Snippet)
|
||||||
|
}
|
||||||
|
searchContent = append(searchContent, map[string]interface{}{
|
||||||
|
"type": "web_search_result",
|
||||||
|
"title": result.Title,
|
||||||
|
"url": result.URL,
|
||||||
|
"encrypted_content": snippet,
|
||||||
|
"page_age": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||||
|
|
||||||
|
events := []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": startIndex,
|
||||||
|
"content_block": map[string]interface{}{
|
||||||
|
"type": "server_tool_use",
|
||||||
|
"id": toolUseID,
|
||||||
|
"name": "web_search",
|
||||||
|
"input": map[string]interface{}{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": startIndex,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"type": "input_json_delta",
|
||||||
|
"partial_json": string(inputJSON),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": startIndex,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": startIndex + 1,
|
||||||
|
"content_block": map[string]interface{}{
|
||||||
|
"type": "web_search_tool_result",
|
||||||
|
"content": searchContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": startIndex + 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([][]byte, 0, len(events))
|
||||||
|
for _, event := range events {
|
||||||
|
eventType, _ := event["type"].(string)
|
||||||
|
payload, _ := json.Marshal(event)
|
||||||
|
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||||
|
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||||
|
var currentToolName string
|
||||||
|
currentToolIndex := -1
|
||||||
|
var toolInputBuilder strings.Builder
|
||||||
|
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
lines := strings.Split(string(chunk), "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
if payload == "" || payload == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var event map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch eventType, _ := event["type"].(string); eventType {
|
||||||
|
case "message_delta":
|
||||||
|
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||||
|
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
|
||||||
|
result.StopReason = stopReason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_start":
|
||||||
|
contentBlock, ok := event["content_block"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
blockType, _ := contentBlock["type"].(string)
|
||||||
|
if blockType != "tool_use" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentToolName, _ = contentBlock["name"].(string)
|
||||||
|
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
|
||||||
|
if idx, ok := event["index"].(float64); ok {
|
||||||
|
currentToolIndex = int(idx)
|
||||||
|
}
|
||||||
|
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
|
||||||
|
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
|
||||||
|
}
|
||||||
|
toolInputBuilder.Reset()
|
||||||
|
case "content_block_delta":
|
||||||
|
if currentToolName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deltaType, _ := delta["type"].(string)
|
||||||
|
if deltaType != "input_json_delta" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||||
|
toolInputBuilder.WriteString(partialJSON)
|
||||||
|
}
|
||||||
|
case "content_block_stop":
|
||||||
|
if !isWebSearchToolName(currentToolName, "") {
|
||||||
|
currentToolName = ""
|
||||||
|
currentToolIndex = -1
|
||||||
|
toolInputBuilder.Reset()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result.HasWebSearchToolUse = true
|
||||||
|
result.WebSearchToolUseIndex = currentToolIndex
|
||||||
|
var input map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
|
||||||
|
result.WebSearchQuery = strings.TrimSpace(input["query"])
|
||||||
|
}
|
||||||
|
currentToolName = ""
|
||||||
|
currentToolIndex = -1
|
||||||
|
toolInputBuilder.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
|
||||||
|
filtered := make([][]byte, 0, len(chunks))
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
|
||||||
|
if shouldForward {
|
||||||
|
filtered = append(filtered, adjusted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||||
|
return filterSSEChunk(chunk, -1, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MaxContentBlockIndex(chunks [][]byte) int {
|
||||||
|
maxIndex := -1
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
lines := strings.Split(string(chunk), "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
if payload == "" || payload == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var event map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch eventType, _ := event["type"].(string); eventType {
|
||||||
|
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||||
|
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
|
||||||
|
maxIndex = int(idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
|
||||||
|
lines := strings.Split(string(chunk), "\n")
|
||||||
|
var builder strings.Builder
|
||||||
|
hasContent := false
|
||||||
|
|
||||||
|
for i := 0; i < len(lines); i++ {
|
||||||
|
line := lines[i]
|
||||||
|
if strings.HasPrefix(line, "event: ") {
|
||||||
|
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||||
|
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
|
||||||
|
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.WriteString(line + "\n")
|
||||||
|
hasContent = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "data: ") {
|
||||||
|
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
if payload == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
adjusted := adjustEventPayload(payload, indexOffset)
|
||||||
|
if adjusted == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
builder.WriteString("data: " + adjusted + "\n")
|
||||||
|
hasContent = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteString(line + "\n")
|
||||||
|
if strings.TrimSpace(line) != "" {
|
||||||
|
hasContent = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasContent {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return []byte(builder.String()), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
|
||||||
|
if payload == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var event map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
eventType, _ := event["type"].(string)
|
||||||
|
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if webSearchToolUseIndex < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func adjustEventPayload(payload string, indexOffset int) string {
|
||||||
|
if payload == "" || indexOffset == 0 {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
var event map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
switch eventType, _ := event["type"].(string); eventType {
|
||||||
|
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||||
|
if idx, ok := event["index"].(float64); ok {
|
||||||
|
event["index"] = int(idx) + indexOffset
|
||||||
|
if adjusted, err := json.Marshal(event); err == nil {
|
||||||
|
return string(adjusted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
|
||||||
|
snippet := "result snippet"
|
||||||
|
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
|
||||||
|
Results: []WebSearchResult{
|
||||||
|
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
|
||||||
|
},
|
||||||
|
}, 0)
|
||||||
|
|
||||||
|
require.Len(t, events, 5)
|
||||||
|
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
|
||||||
|
require.Contains(t, string(events[0]), `"input":{}`)
|
||||||
|
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
|
||||||
|
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
|
||||||
|
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
|
||||||
|
require.NotContains(t, string(events[3]), `"tool_use_id"`)
|
||||||
|
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||||
|
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||||
|
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||||
|
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||||
|
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result := AnalyzeBufferedStream(chunks)
|
||||||
|
require.True(t, result.HasWebSearchToolUse)
|
||||||
|
require.Equal(t, "golang concurrency", result.WebSearchQuery)
|
||||||
|
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
|
||||||
|
require.Equal(t, 1, result.WebSearchToolUseIndex)
|
||||||
|
require.Equal(t, "tool_use", result.StopReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||||
|
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
|
||||||
|
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
|
||||||
|
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
|
||||||
|
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||||
|
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||||
|
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||||
|
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := FilterChunksForClient(chunks, 1, 2)
|
||||||
|
require.NotEmpty(t, filtered)
|
||||||
|
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
|
||||||
|
require.NotContains(t, joined, `"type":"message_start"`)
|
||||||
|
require.NotContains(t, joined, `"type":"message_delta"`)
|
||||||
|
require.NotContains(t, joined, `"name":"web_search"`)
|
||||||
|
require.Contains(t, joined, `"index":2`)
|
||||||
|
require.Equal(t, 2, MaxContentBlockIndex(filtered))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
|
||||||
|
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
|
||||||
|
require.False(t, shouldForward)
|
||||||
|
|
||||||
|
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
|
||||||
|
require.True(t, shouldForward)
|
||||||
|
require.Contains(t, string(adjusted), `"index":3`)
|
||||||
|
}
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"tools":[{"type":"web_search_20250305","description":"old"}],
|
||||||
|
"messages":[{"role":"user","content":"golang"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
updated, err := ReplaceWebSearchToolDescription(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
|
||||||
|
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
|
||||||
|
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
|
||||||
|
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
|
||||||
|
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
|
||||||
|
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[{"role":"user","content":"what is golang"}]
|
||||||
|
}`)
|
||||||
|
results := &WebSearchResults{
|
||||||
|
Results: []WebSearchResult{
|
||||||
|
{Title: "Go", URL: "https://go.dev"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
|
||||||
|
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
|
||||||
|
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
|
||||||
|
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
|
||||||
|
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
|
||||||
|
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
|
||||||
|
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
|
||||||
|
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
|
||||||
|
response := []byte(`{
|
||||||
|
"content":[
|
||||||
|
{"type":"text","text":"let me search"},
|
||||||
|
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "srvtoolu_next", toolUseID)
|
||||||
|
require.Equal(t, "golang concurrency", query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
|
||||||
|
response := []byte(`{
|
||||||
|
"id":"msg_1",
|
||||||
|
"type":"message",
|
||||||
|
"role":"assistant",
|
||||||
|
"model":"kiro",
|
||||||
|
"content":[{"type":"text","text":"final"}],
|
||||||
|
"stop_reason":"end_turn",
|
||||||
|
"usage":{"input_tokens":1,"output_tokens":1}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
snippet := "result snippet"
|
||||||
|
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
|
||||||
|
{
|
||||||
|
ToolUseID: "srvtoolu_test",
|
||||||
|
Query: "golang",
|
||||||
|
Results: &WebSearchResults{
|
||||||
|
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(updated, &decoded))
|
||||||
|
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
|
||||||
|
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
|
||||||
|
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
|
||||||
|
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
|
||||||
|
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
|
||||||
|
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
|
||||||
|
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
|
||||||
|
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
|
||||||
|
resp := &MCPResponse{
|
||||||
|
Result: &struct {
|
||||||
|
Content []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
Tools []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
} `json:"tools"`
|
||||||
|
}{
|
||||||
|
Content: []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Type: "text",
|
||||||
|
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results := ParseSearchResults(resp)
|
||||||
|
require.NotNil(t, results)
|
||||||
|
require.Len(t, results.Results, 1)
|
||||||
|
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
|
||||||
|
require.Equal(t, "doc-1", *results.Results[0].ID)
|
||||||
|
require.Equal(t, "go.dev", *results.Results[0].Domain)
|
||||||
|
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
|
||||||
|
require.True(t, *results.Results[0].PublicDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchGuidanceText_IsStructured(t *testing.T) {
|
||||||
|
guidance := searchGuidanceText()
|
||||||
|
require.Contains(t, guidance, "<search_guidance>")
|
||||||
|
require.Contains(t, guidance, "Current date:")
|
||||||
|
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
|
||||||
|
require.Contains(t, guidance, "Rephrasing in English for better coverage")
|
||||||
|
}
|
||||||
@@ -0,0 +1,479 @@
|
|||||||
|
package kirocooldown
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MinRequestInterval = time.Second
|
||||||
|
MaxRequestInterval = 2 * time.Second
|
||||||
|
|
||||||
|
CooldownReason429 = "rate_limit_exceeded"
|
||||||
|
CooldownReasonSuspended = "account_suspended"
|
||||||
|
|
||||||
|
ShortCooldown = time.Minute
|
||||||
|
MaxCooldown = 5 * time.Minute
|
||||||
|
LongCooldown = 24 * time.Hour
|
||||||
|
|
||||||
|
redisTimeout = 3 * time.Second
|
||||||
|
activeTTL = 10 * time.Second
|
||||||
|
stateTTL = 25 * time.Hour
|
||||||
|
keyPrefix = "kiro:cooldown:"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||||
|
|
||||||
|
reserveRequestScript = redis.NewScript(`
|
||||||
|
local t = redis.call('TIME')
|
||||||
|
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||||
|
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
|
||||||
|
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
|
||||||
|
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
|
||||||
|
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
|
||||||
|
local interval_ms = tonumber(ARGV[1])
|
||||||
|
local active_ttl_ms = tonumber(ARGV[2])
|
||||||
|
local state_ttl_ms = tonumber(ARGV[3])
|
||||||
|
|
||||||
|
if cooldown_until_ms > now_ms then
|
||||||
|
return {1, cooldown_until_ms - now_ms, cooldown_reason}
|
||||||
|
end
|
||||||
|
|
||||||
|
if cooldown_until_ms > 0 then
|
||||||
|
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
|
||||||
|
end
|
||||||
|
|
||||||
|
local next_slot_ms = now_ms
|
||||||
|
if last_request_ms > 0 then
|
||||||
|
local candidate_ms = last_request_ms + interval_ms
|
||||||
|
if candidate_ms > now_ms then
|
||||||
|
next_slot_ms = candidate_ms
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
|
||||||
|
if fail_count > 0 or cooldown_until_ms > now_ms then
|
||||||
|
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||||
|
else
|
||||||
|
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
|
||||||
|
end
|
||||||
|
return {0, next_slot_ms - now_ms, ''}
|
||||||
|
`)
|
||||||
|
|
||||||
|
mark429Script = redis.NewScript(`
|
||||||
|
local t = redis.call('TIME')
|
||||||
|
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||||
|
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
|
||||||
|
local short_cooldown_ms = tonumber(ARGV[1])
|
||||||
|
local max_cooldown_ms = tonumber(ARGV[2])
|
||||||
|
local state_ttl_ms = tonumber(ARGV[3])
|
||||||
|
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
|
||||||
|
if cooldown_ms > max_cooldown_ms then
|
||||||
|
cooldown_ms = max_cooldown_ms
|
||||||
|
end
|
||||||
|
redis.call('HSET', KEYS[1],
|
||||||
|
'fail_count', fail_count,
|
||||||
|
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||||
|
'cooldown_reason', ARGV[4]
|
||||||
|
)
|
||||||
|
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||||
|
return cooldown_ms
|
||||||
|
`)
|
||||||
|
|
||||||
|
markSuccessScript = redis.NewScript(`
|
||||||
|
redis.call('HSET', KEYS[1],
|
||||||
|
'fail_count', 0,
|
||||||
|
'cooldown_until_ms', 0,
|
||||||
|
'cooldown_reason', ''
|
||||||
|
)
|
||||||
|
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
markSuspendedScript = redis.NewScript(`
|
||||||
|
local t = redis.call('TIME')
|
||||||
|
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||||
|
local cooldown_ms = tonumber(ARGV[1])
|
||||||
|
local state_ttl_ms = tonumber(ARGV[2])
|
||||||
|
redis.call('HSET', KEYS[1],
|
||||||
|
'fail_count', 0,
|
||||||
|
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||||
|
'cooldown_reason', ARGV[3]
|
||||||
|
)
|
||||||
|
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||||
|
return cooldown_ms
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
remaining time.Duration
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
Active bool
|
||||||
|
Reason string
|
||||||
|
CooldownUntil time.Time
|
||||||
|
Remaining time.Duration
|
||||||
|
FailCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewError(remaining time.Duration, reason string) error {
|
||||||
|
return &Error{remaining: remaining, reason: reason}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if e.reason == "" {
|
||||||
|
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Calculate429Cooldown(retryCount int) time.Duration {
|
||||||
|
if retryCount < 0 {
|
||||||
|
retryCount = 0
|
||||||
|
}
|
||||||
|
cooldown := ShortCooldown * time.Duration(1<<retryCount)
|
||||||
|
if cooldown > MaxCooldown {
|
||||||
|
return MaxCooldown
|
||||||
|
}
|
||||||
|
return cooldown
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
client *redis.Client
|
||||||
|
rngMu sync.Mutex
|
||||||
|
rng *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore(client *redis.Client) *Store {
|
||||||
|
return &Store{
|
||||||
|
client: client,
|
||||||
|
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||||
|
if err := s.validate(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
values, err := reserveRequestScript.Run(
|
||||||
|
cacheCtx,
|
||||||
|
s.client,
|
||||||
|
[]string{RedisKey(tokenKey)},
|
||||||
|
s.nextInterval().Milliseconds(),
|
||||||
|
activeTTL.Milliseconds(),
|
||||||
|
stateTTL.Milliseconds(),
|
||||||
|
).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
|
||||||
|
}
|
||||||
|
parts, ok := values.([]interface{})
|
||||||
|
if !ok || len(parts) != 3 {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
|
||||||
|
}
|
||||||
|
state, err := luaInt64(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
|
||||||
|
}
|
||||||
|
waitMS, err := luaInt64(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
|
||||||
|
}
|
||||||
|
reason, err := luaString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
|
||||||
|
}
|
||||||
|
if state == 1 {
|
||||||
|
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
|
||||||
|
}
|
||||||
|
if waitMS <= 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return time.Duration(waitMS) * time.Millisecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
|
||||||
|
if err := s.validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||||
|
defer cancel()
|
||||||
|
if err := markSuccessScript.Run(
|
||||||
|
cacheCtx,
|
||||||
|
s.client,
|
||||||
|
[]string{RedisKey(tokenKey)},
|
||||||
|
activeTTL.Milliseconds(),
|
||||||
|
).Err(); err != nil {
|
||||||
|
return fmt.Errorf("kiro cooldown mark success: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||||
|
if err := s.validate(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||||
|
defer cancel()
|
||||||
|
result, err := mark429Script.Run(
|
||||||
|
cacheCtx,
|
||||||
|
s.client,
|
||||||
|
[]string{RedisKey(tokenKey)},
|
||||||
|
ShortCooldown.Milliseconds(),
|
||||||
|
MaxCooldown.Milliseconds(),
|
||||||
|
stateTTL.Milliseconds(),
|
||||||
|
CooldownReason429,
|
||||||
|
).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||||
|
}
|
||||||
|
cooldownMS, err := luaInt64(result)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||||
|
}
|
||||||
|
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||||
|
if err := s.validate(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||||
|
defer cancel()
|
||||||
|
result, err := markSuspendedScript.Run(
|
||||||
|
cacheCtx,
|
||||||
|
s.client,
|
||||||
|
[]string{RedisKey(tokenKey)},
|
||||||
|
LongCooldown.Milliseconds(),
|
||||||
|
stateTTL.Milliseconds(),
|
||||||
|
CooldownReasonSuspended,
|
||||||
|
).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||||
|
}
|
||||||
|
cooldownMS, err := luaInt64(result)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||||
|
}
|
||||||
|
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
|
||||||
|
if err := s.validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
values, err := s.client.HMGet(
|
||||||
|
cacheCtx,
|
||||||
|
RedisKey(tokenKey),
|
||||||
|
"cooldown_until_ms",
|
||||||
|
"cooldown_reason",
|
||||||
|
"fail_count",
|
||||||
|
).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
|
||||||
|
}
|
||||||
|
if len(values) != 3 {
|
||||||
|
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
|
||||||
|
}
|
||||||
|
|
||||||
|
cooldownUntilMS, err := luaInt64(values[0])
|
||||||
|
if err != nil && values[0] != nil {
|
||||||
|
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
|
||||||
|
}
|
||||||
|
reason, err := luaString(values[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
|
||||||
|
}
|
||||||
|
failCount, err := luaInt64(values[2])
|
||||||
|
if err != nil && values[2] != nil {
|
||||||
|
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
|
||||||
|
}
|
||||||
|
if cooldownUntilMS <= 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cooldownUntil := time.UnixMilli(cooldownUntilMS)
|
||||||
|
remaining := time.Until(cooldownUntil)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &State{
|
||||||
|
Active: true,
|
||||||
|
Reason: reason,
|
||||||
|
CooldownUntil: cooldownUntil,
|
||||||
|
Remaining: remaining,
|
||||||
|
FailCount: int(failCount),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
|
||||||
|
if err := s.validate(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
uniqueKeys := make([]string, 0, len(tokenKeys))
|
||||||
|
seen := make(map[string]struct{}, len(tokenKeys))
|
||||||
|
for _, tokenKey := range tokenKeys {
|
||||||
|
tokenKey = strings.TrimSpace(tokenKey)
|
||||||
|
if tokenKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
redisKey := RedisKey(tokenKey)
|
||||||
|
if _, ok := seen[redisKey]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[redisKey] = struct{}{}
|
||||||
|
uniqueKeys = append(uniqueKeys, redisKey)
|
||||||
|
}
|
||||||
|
if len(uniqueKeys) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
type candidate struct {
|
||||||
|
redisKey string
|
||||||
|
cooldownUntilMS int64
|
||||||
|
failCount int64
|
||||||
|
}
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
|
var best *candidate
|
||||||
|
|
||||||
|
pipe := s.client.Pipeline()
|
||||||
|
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
|
||||||
|
for _, redisKey := range uniqueKeys {
|
||||||
|
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
|
||||||
|
}
|
||||||
|
if _, err := pipe.Exec(cacheCtx); err != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, cmd := range cmds {
|
||||||
|
values, err := cmd.Result()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
|
||||||
|
}
|
||||||
|
if len(values) != 3 {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
|
||||||
|
}
|
||||||
|
cooldownUntilMS, err := luaInt64(values[0])
|
||||||
|
if err != nil && values[0] != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
|
||||||
|
}
|
||||||
|
reason, err := luaString(values[1])
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
|
||||||
|
}
|
||||||
|
failCount, err := luaInt64(values[2])
|
||||||
|
if err != nil && values[2] != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
|
||||||
|
}
|
||||||
|
if cooldownUntilMS <= now || reason != CooldownReason429 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
|
||||||
|
if best == nil ||
|
||||||
|
current.cooldownUntilMS < best.cooldownUntilMS ||
|
||||||
|
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
|
||||||
|
best = current
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if best == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
|
||||||
|
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedisKey(tokenKey string) string {
|
||||||
|
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
|
||||||
|
digest := hex.EncodeToString(sum[:])
|
||||||
|
return keyPrefix + "{" + digest + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
func ActiveTTL() time.Duration {
|
||||||
|
return activeTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
func StateTTL() time.Duration {
|
||||||
|
return stateTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) validate() error {
|
||||||
|
if s == nil || s.client == nil {
|
||||||
|
return ErrStoreUnavailable
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) nextInterval() time.Duration {
|
||||||
|
s.rngMu.Lock()
|
||||||
|
defer s.rngMu.Unlock()
|
||||||
|
if MaxRequestInterval <= MinRequestInterval {
|
||||||
|
return MinRequestInterval
|
||||||
|
}
|
||||||
|
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithTimeout(ctx, redisTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func luaInt64(v any) (int64, error) {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case int64:
|
||||||
|
return n, nil
|
||||||
|
case int:
|
||||||
|
return int64(n), nil
|
||||||
|
case string:
|
||||||
|
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
|
||||||
|
case []byte:
|
||||||
|
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func luaString(v any) (string, error) {
|
||||||
|
switch s := v.(type) {
|
||||||
|
case string:
|
||||||
|
return s, nil
|
||||||
|
case []byte:
|
||||||
|
return string(s), nil
|
||||||
|
case nil:
|
||||||
|
return "", nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported lua string type %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package kirocooldown
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
|
||||||
|
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
|
||||||
|
|
||||||
|
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
|
||||||
|
}
|
||||||
|
if cleared {
|
||||||
|
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
|
||||||
|
store := NewStore(nil)
|
||||||
|
|
||||||
|
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
|
||||||
|
}
|
||||||
|
if cleared {
|
||||||
|
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -41,6 +41,9 @@ func RegisterAdminRoutes(
|
|||||||
// Antigravity OAuth
|
// Antigravity OAuth
|
||||||
registerAntigravityOAuthRoutes(admin, h)
|
registerAntigravityOAuthRoutes(admin, h)
|
||||||
|
|
||||||
|
// Kiro OAuth / IDC
|
||||||
|
registerKiroOAuthRoutes(admin, h)
|
||||||
|
|
||||||
// 代理管理
|
// 代理管理
|
||||||
registerProxyRoutes(admin, h)
|
registerProxyRoutes(admin, h)
|
||||||
|
|
||||||
@@ -295,6 +298,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
|
|
||||||
// Antigravity 默认模型映射
|
// Antigravity 默认模型映射
|
||||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||||
|
accounts.GET("/kiro/default-model-mapping", h.Admin.Account.GetKiroDefaultModelMapping)
|
||||||
|
|
||||||
// Claude OAuth routes
|
// Claude OAuth routes
|
||||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||||
@@ -347,6 +351,17 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerKiroOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
kiro := admin.Group("/kiro")
|
||||||
|
{
|
||||||
|
kiro.POST("/oauth/auth-url", h.Admin.KiroOAuth.GenerateAuthURL)
|
||||||
|
kiro.POST("/oauth/idc-auth-url", h.Admin.KiroOAuth.GenerateIDCAuthURL)
|
||||||
|
kiro.POST("/oauth/exchange-code", h.Admin.KiroOAuth.ExchangeCode)
|
||||||
|
kiro.POST("/oauth/refresh-token", h.Admin.KiroOAuth.RefreshToken)
|
||||||
|
kiro.POST("/oauth/import-token", h.Admin.KiroOAuth.ImportToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
proxies := admin.Group("/proxies")
|
proxies := admin.Group("/proxies")
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -48,6 +48,13 @@ type Account struct {
|
|||||||
TempUnschedulableUntil *time.Time
|
TempUnschedulableUntil *time.Time
|
||||||
TempUnschedulableReason string
|
TempUnschedulableReason string
|
||||||
|
|
||||||
|
KiroQuotaState string
|
||||||
|
KiroQuotaReason string
|
||||||
|
KiroQuotaResetAt *time.Time
|
||||||
|
KiroRuntimeState string
|
||||||
|
KiroRuntimeReason string
|
||||||
|
KiroRuntimeResetAt *time.Time
|
||||||
|
|
||||||
SessionWindowStart *time.Time
|
SessionWindowStart *time.Time
|
||||||
SessionWindowEnd *time.Time
|
SessionWindowEnd *time.Time
|
||||||
SessionWindowStatus string
|
SessionWindowStatus string
|
||||||
@@ -164,6 +171,10 @@ func (a *Account) IsGemini() bool {
|
|||||||
return a.Platform == PlatformGemini
|
return a.Platform == PlatformGemini
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsKiro() bool {
|
||||||
|
return a.Platform == PlatformKiro
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) GeminiOAuthType() string {
|
func (a *Account) GeminiOAuthType() string {
|
||||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||||
return ""
|
return ""
|
||||||
@@ -478,17 +489,17 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
|
|
||||||
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
||||||
if a.Credentials == nil {
|
if a.Credentials == nil {
|
||||||
// Antigravity 平台使用默认映射
|
// 部分平台在未显式配置 model_mapping 时仍应使用默认映射,
|
||||||
if a.Platform == domain.PlatformAntigravity {
|
// 以限制可调度/可转发的模型集合。
|
||||||
return domain.DefaultAntigravityModelMapping
|
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||||
|
return defaults
|
||||||
}
|
}
|
||||||
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(rawMapping) == 0 {
|
if len(rawMapping) == 0 {
|
||||||
// Antigravity 平台使用默认映射
|
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||||
if a.Platform == domain.PlatformAntigravity {
|
return defaults
|
||||||
return domain.DefaultAntigravityModelMapping
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -510,13 +521,23 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Antigravity 平台使用默认映射
|
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||||
if a.Platform == domain.PlatformAntigravity {
|
return defaults
|
||||||
return domain.DefaultAntigravityModelMapping
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func defaultModelMappingForPlatform(platform string) map[string]string {
|
||||||
|
switch platform {
|
||||||
|
case domain.PlatformAntigravity:
|
||||||
|
return domain.DefaultAntigravityModelMapping
|
||||||
|
case domain.PlatformKiro:
|
||||||
|
return domain.DefaultKiroModelMapping
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func mapPtr(m map[string]any) uintptr {
|
func mapPtr(m map[string]any) uintptr {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return 0
|
return 0
|
||||||
@@ -608,8 +629,8 @@ func resolveRequestedModelInMapping(mapping map[string]string, requestedModel st
|
|||||||
return matchWildcardMappingResult(mapping, requestedModel)
|
return matchWildcardMappingResult(mapping, requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)。
|
||||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时也会先回退到默认映射。
|
||||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
@@ -622,8 +643,8 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
|
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)。
|
||||||
// 如果未配置 mapping,返回原始模型名
|
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时返回默认映射结果。
|
||||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||||
return mappedModel
|
return mappedModel
|
||||||
@@ -725,6 +746,9 @@ func (a *Account) GetBaseURL() string {
|
|||||||
}
|
}
|
||||||
baseURL := a.GetCredential("base_url")
|
baseURL := a.GetCredential("base_url")
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
|
if a.Platform == PlatformKiro {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return "https://api.anthropic.com"
|
return "https://api.anthropic.com"
|
||||||
}
|
}
|
||||||
if a.Platform == PlatformAntigravity {
|
if a.Platform == PlatformAntigravity {
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
|
||||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -296,7 +296,7 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
|
||||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -65,6 +66,7 @@ type AccountTestService struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
geminiTokenProvider *GeminiTokenProvider
|
geminiTokenProvider *GeminiTokenProvider
|
||||||
claudeTokenProvider *ClaudeTokenProvider
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
|
kiroTokenProvider *KiroTokenProvider
|
||||||
antigravityGatewayService *AntigravityGatewayService
|
antigravityGatewayService *AntigravityGatewayService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -76,6 +78,7 @@ func NewAccountTestService(
|
|||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
geminiTokenProvider *GeminiTokenProvider,
|
geminiTokenProvider *GeminiTokenProvider,
|
||||||
claudeTokenProvider *ClaudeTokenProvider,
|
claudeTokenProvider *ClaudeTokenProvider,
|
||||||
|
kiroTokenProvider *KiroTokenProvider,
|
||||||
antigravityGatewayService *AntigravityGatewayService,
|
antigravityGatewayService *AntigravityGatewayService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
@@ -85,6 +88,7 @@ func NewAccountTestService(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
geminiTokenProvider: geminiTokenProvider,
|
geminiTokenProvider: geminiTokenProvider,
|
||||||
claudeTokenProvider: claudeTokenProvider,
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
|
kiroTokenProvider: kiroTokenProvider,
|
||||||
antigravityGatewayService: antigravityGatewayService,
|
antigravityGatewayService: antigravityGatewayService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -191,6 +195,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.IsKiro() && account.Type == AccountTypeOAuth {
|
||||||
|
return s.testKiroAccountConnection(c, account, modelID)
|
||||||
|
}
|
||||||
|
|
||||||
return s.testClaudeAccountConnection(c, account, modelID)
|
return s.testClaudeAccountConnection(c, account, modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +247,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseURL := account.GetBaseURL()
|
baseURL := account.GetBaseURL()
|
||||||
|
if baseURL == "" && account.Platform == PlatformKiro {
|
||||||
|
return s.sendErrorAndEnd(c, "Kiro API Key accounts require a Base URL")
|
||||||
|
}
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://api.anthropic.com"
|
baseURL = "https://api.anthropic.com"
|
||||||
}
|
}
|
||||||
@@ -387,6 +398,149 @@ func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Con
|
|||||||
return s.processClaudeStream(c, resp.Body)
|
return s.processClaudeStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) testKiroAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
testModelID := strings.TrimSpace(modelID)
|
||||||
|
if testModelID == "" {
|
||||||
|
testModelID = "claude-sonnet-4-6"
|
||||||
|
}
|
||||||
|
if mappedModel := account.GetMappedModel(testModelID); strings.TrimSpace(mappedModel) != "" {
|
||||||
|
testModelID = mappedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Type != AccountTypeOAuth {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Kiro account type: %s", account.Type))
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.kiroTokenProvider == nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Kiro token provider not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get Kiro access token: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
payload, err := createTestPayload(testModelID)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||||
|
}
|
||||||
|
payloadBytes, _ := json.Marshal(payload)
|
||||||
|
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||||
|
|
||||||
|
resp, err := s.executeKiroTestUpstream(ctx, account, payloadBytes, testModelID, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return s.sendErrorAndEnd(c, formatKiroTestError(resp.StatusCode, body, testModelID, account))
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
_, streamErr := kiropkg.StreamEventStreamAsAnthropic(ctx, resp.Body, pw, testModelID, estimateKiroInputTokens(payloadBytes))
|
||||||
|
if streamErr != nil {
|
||||||
|
_ = pw.CloseWithError(streamErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = pw.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s.processClaudeStream(c, pr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatKiroTestError(statusCode int, body []byte, requestedModel string, account *Account) string {
|
||||||
|
return fmt.Sprintf("API returned %d: %s", statusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) executeKiroTestUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string) (*http.Response, error) {
|
||||||
|
modelID := kiropkg.MapModel(mappedModel)
|
||||||
|
currentToken := token
|
||||||
|
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload := buildResult.Payload
|
||||||
|
|
||||||
|
endpoints := buildKiroEndpoints(account)
|
||||||
|
proxyURL := kiroProxyURL(account)
|
||||||
|
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||||
|
accountKey := buildKiroAccountKey(account)
|
||||||
|
maxRetries := 2
|
||||||
|
for idx, endpoint := range endpoints {
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
|
||||||
|
if idx+1 < len(endpoints) {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
|
||||||
|
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||||
|
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||||
|
currentToken = refreshedToken
|
||||||
|
accountKey = buildKiroAccountKey(account)
|
||||||
|
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload = buildResult.Payload
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resetHTTPResponseBody(resp, respBody)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, readErr
|
||||||
|
}
|
||||||
|
resetHTTPResponseBody(resp, respBody)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("kiro upstream endpoints exhausted")
|
||||||
|
}
|
||||||
|
|
||||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||||
region := bedrockRuntimeRegion(account)
|
region := bedrockRuntimeRegion(account)
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroAPIKeyUsesGenericAnthropicCompatiblePath(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 19,
|
||||||
|
Name: "kiro-apikey-test",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"base_url": "https://kiro-upstream.example.com",
|
||||||
|
"api_key": "kiro-api-key",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"invalid api key"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
httpUpstream: upstream,
|
||||||
|
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
|
||||||
|
req := upstream.requests[0]
|
||||||
|
require.Equal(t, "kiro-upstream.example.com", req.URL.Host)
|
||||||
|
require.Equal(t, "/v1/messages", req.URL.Path)
|
||||||
|
require.Equal(t, "kiro-api-key", req.Header.Get("x-api-key"))
|
||||||
|
require.Empty(t, req.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, claude.APIKeyBetaHeader, req.Header.Get("anthropic-beta"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroAPIKeyWithoutBaseURLErrors(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 20,
|
||||||
|
Name: "kiro-apikey-missing-base-url",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "kiro-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
httpUpstream: &queuedHTTPUpstream{},
|
||||||
|
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "Base URL")
|
||||||
|
}
|
||||||
@@ -0,0 +1,317 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroUsesKiroUpstreamInsteadOfAnthropic(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "kiro-test",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "kiro-access-token",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/TESTSOCIAL",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{1: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"Invalid bearer token"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "gpt-4o", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
|
||||||
|
req := upstream.requests[0]
|
||||||
|
require.Equal(t, "q.us-east-1.amazonaws.com", req.URL.Host)
|
||||||
|
require.Equal(t, "/generateAssistantResponse", req.URL.Path)
|
||||||
|
require.Equal(t, "Bearer kiro-access-token", req.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
|
||||||
|
require.Empty(t, req.Header.Get("anthropic-version"))
|
||||||
|
require.NotContains(t, req.URL.Host, "api.anthropic.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_Kiro429DoesNotFallbackToCodeWhispererEndpoint(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 2,
|
||||||
|
Name: "kiro-fallback",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "kiro-access-token",
|
||||||
|
"api_region": "us-west-2",
|
||||||
|
"region": "us-west-2",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTFALLBACK",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{2: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusTooManyRequests, `{"message":"slow down"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
|
||||||
|
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
|
||||||
|
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
|
||||||
|
require.Contains(t, err.Error(), "API returned 429")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroIDCWithoutProfileArnOmitsProfileArnAndUsesDefaultRuntimeRegion(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 5,
|
||||||
|
Name: "kiro-idc-default-region",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "kiro-access-token",
|
||||||
|
"auth_method": "idc",
|
||||||
|
"provider": "AWS",
|
||||||
|
"region": "ap-northeast-2",
|
||||||
|
"start_url": "https://d-example.awsapps.com/start",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{5: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
require.Equal(t, "q.us-east-1.amazonaws.com", upstream.requests[0].URL.Host)
|
||||||
|
body, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.NotContains(t, string(body), `"profileArn":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroInvalidModelErrorPassthrough(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 6,
|
||||||
|
Name: "kiro-invalid-model",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "kiro-access-token",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTINVALIDMODEL",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, `API returned 400: {"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Name: "kiro-invalid-model-refresh",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "kiro-access-token",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{7: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "API returned 400")
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
|
||||||
|
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
|
||||||
|
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_KiroPreferredEndpointIsIgnored(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 6,
|
||||||
|
Name: "kiro-preferred-endpoint",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "kiro-access-token",
|
||||||
|
"api_region": "us-west-2",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/PREFERRED",
|
||||||
|
"preferred_endpoint": "codewhisperer",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
accountRepo: repo,
|
||||||
|
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
|
||||||
|
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroPayloadForAccount_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 3,
|
||||||
|
Name: "kiro-builder-id",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"auth_method": "idc",
|
||||||
|
"provider": "BuilderId",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"client_id": "builder-client-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(testPayload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(kiroPayload), `"profileArn":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroPayloadForAccount_KiroBuilderIDUsesCredentialProfileArn(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 33,
|
||||||
|
Name: "kiro-builder-id-cached",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"auth_method": "builder-id",
|
||||||
|
"provider": "BuilderId",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"client_id": "builder-client-id",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(testPayload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, string(kiroPayload), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroPayloadForAccount_KiroEnterpriseIDCOmitsMissingProfileArn(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 4,
|
||||||
|
Name: "kiro-enterprise-idc",
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"auth_method": "idc",
|
||||||
|
"provider": "AWS",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"client_id": "enterprise-client-id",
|
||||||
|
"start_url": "https://d-example.awsapps.com/start",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(testPayload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(kiroPayload), `"profileArn":`)
|
||||||
|
}
|
||||||
@@ -103,10 +103,17 @@ type antigravityUsageCache struct {
|
|||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kiroUsageCache 缓存 Kiro 额度快照
|
||||||
|
type kiroUsageCache struct {
|
||||||
|
usageInfo *UsageInfo
|
||||||
|
timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
apiCacheTTL = 3 * time.Minute
|
apiCacheTTL = 3 * time.Minute
|
||||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||||
|
kiroUsageErrorTTL = 1 * time.Minute // Kiro 错误缓存 TTL(可恢复错误)
|
||||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||||
windowStatsCacheTTL = 1 * time.Minute
|
windowStatsCacheTTL = 1 * time.Minute
|
||||||
openAIProbeCacheTTL = 10 * time.Minute
|
openAIProbeCacheTTL = 10 * time.Minute
|
||||||
@@ -118,8 +125,10 @@ type UsageCache struct {
|
|||||||
apiCache sync.Map // accountID -> *apiUsageCache
|
apiCache sync.Map // accountID -> *apiUsageCache
|
||||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||||
|
kiroUsageCache sync.Map // accountID -> *kiroUsageCache
|
||||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||||
|
kiroUsageFlight singleflight.Group // 防止同一 Kiro 账号的并发请求击穿缓存
|
||||||
openAIProbeCache sync.Map // accountID -> time.Time
|
openAIProbeCache sync.Map // accountID -> time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,6 +185,23 @@ type AICredit struct {
|
|||||||
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KiroCreditProgress 表示 Kiro 主额度或 Bonus 的用量进度。
|
||||||
|
type KiroCreditProgress struct {
|
||||||
|
CurrentUsage float64 `json:"current_usage"`
|
||||||
|
UsageLimit float64 `json:"usage_limit"`
|
||||||
|
PercentageUsed float64 `json:"percentage_used"`
|
||||||
|
DaysRemaining int `json:"days_remaining,omitempty"`
|
||||||
|
ExpiryDate *time.Time `json:"expiry_date,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroOverageInfo 表示 Kiro 账号的 overage 状态。
|
||||||
|
type KiroOverageInfo struct {
|
||||||
|
CurrentOverages float64 `json:"current_overages"`
|
||||||
|
OverageCharges float64 `json:"overage_charges"`
|
||||||
|
CurrencyCode string `json:"currency_code,omitempty"`
|
||||||
|
CurrencySymbol string `json:"currency_symbol,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
Source string `json:"source,omitempty"` // "passive" or "active"
|
Source string `json:"source,omitempty"` // "passive" or "active"
|
||||||
@@ -203,6 +229,21 @@ type UsageInfo struct {
|
|||||||
// Antigravity AI Credits 余额
|
// Antigravity AI Credits 余额
|
||||||
AICredits []AICredit `json:"ai_credits,omitempty"`
|
AICredits []AICredit `json:"ai_credits,omitempty"`
|
||||||
|
|
||||||
|
// Kiro Credits 额度与 overage 信息
|
||||||
|
KiroSubscriptionName string `json:"kiro_subscription_name,omitempty"`
|
||||||
|
KiroSubscriptionType string `json:"kiro_subscription_type,omitempty"`
|
||||||
|
KiroResetAt *time.Time `json:"kiro_reset_at,omitempty"`
|
||||||
|
KiroOveragesEnabled bool `json:"kiro_overages_enabled,omitempty"`
|
||||||
|
KiroCredit *KiroCreditProgress `json:"kiro_credit,omitempty"`
|
||||||
|
KiroBonus *KiroCreditProgress `json:"kiro_bonus,omitempty"`
|
||||||
|
KiroOverage *KiroOverageInfo `json:"kiro_overage,omitempty"`
|
||||||
|
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
|
||||||
|
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
|
||||||
|
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
|
||||||
|
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
|
||||||
|
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
|
||||||
|
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
|
||||||
|
|
||||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||||
|
|
||||||
@@ -266,6 +307,7 @@ type AccountUsageService struct {
|
|||||||
cache *UsageCache
|
cache *UsageCache
|
||||||
identityCache IdentityCache
|
identityCache IdentityCache
|
||||||
tlsFPProfileService *TLSFingerprintProfileService
|
tlsFPProfileService *TLSFingerprintProfileService
|
||||||
|
kiroCooldownStore KiroCooldownStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
@@ -291,6 +333,13 @@ func NewAccountUsageService(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) SetKiroCooldownStore(store KiroCooldownStore) *AccountUsageService {
|
||||||
|
if s != nil {
|
||||||
|
s.kiroCooldownStore = store
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// GetUsage 获取账号使用量
|
// GetUsage 获取账号使用量
|
||||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
|
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
|
||||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||||
@@ -317,6 +366,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return usage, err
|
return usage, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||||
|
return s.getKiroUsage(ctx, account, "active", false)
|
||||||
|
}
|
||||||
|
|
||||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
usage, err := s.getAntigravityUsage(ctx, account)
|
usage, err := s.getAntigravityUsage(ctx, account)
|
||||||
@@ -425,6 +478,13 @@ func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int
|
|||||||
return nil, fmt.Errorf("get account failed: %w", err)
|
return nil, fmt.Errorf("get account failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Platform == PlatformKiro {
|
||||||
|
if account.Type != AccountTypeOAuth {
|
||||||
|
return nil, fmt.Errorf("passive usage only supported for Kiro OAuth accounts")
|
||||||
|
}
|
||||||
|
return s.getKiroUsage(ctx, account, "passive", false)
|
||||||
|
}
|
||||||
|
|
||||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||||
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
|
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountUsageService_GetUsage_KiroAPIKeyUnsupported(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 9101,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||||
|
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||||
|
|
||||||
|
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||||
|
require.Nil(t, usage)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "does not support usage query")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountUsageService_GetPassiveUsage_KiroAPIKeyUnsupported(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 9102,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||||
|
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||||
|
|
||||||
|
usage, err := svc.GetPassiveUsage(context.Background(), account.ID)
|
||||||
|
require.Nil(t, usage)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "Kiro OAuth")
|
||||||
|
}
|
||||||
@@ -1448,7 +1448,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
|
||||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||||
@@ -1728,7 +1728,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
}
|
}
|
||||||
|
|
||||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
|
||||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ const (
|
|||||||
PlatformOpenAI = domain.PlatformOpenAI
|
PlatformOpenAI = domain.PlatformOpenAI
|
||||||
PlatformGemini = domain.PlatformGemini
|
PlatformGemini = domain.PlatformGemini
|
||||||
PlatformAntigravity = domain.PlatformAntigravity
|
PlatformAntigravity = domain.PlatformAntigravity
|
||||||
|
PlatformKiro = domain.PlatformKiro
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
@@ -56,6 +57,7 @@ const (
|
|||||||
defaultModelsListCacheTTL = 15 * time.Second
|
defaultModelsListCacheTTL = 15 * time.Second
|
||||||
postUsageBillingTimeout = 15 * time.Second
|
postUsageBillingTimeout = 15 * time.Second
|
||||||
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
|
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
|
||||||
|
defaultKiroStreamKeepalive = 25 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -70,6 +72,7 @@ const (
|
|||||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||||
type forceCacheBillingKeyType struct{}
|
type forceCacheBillingKeyType struct{}
|
||||||
|
type kiroCooldownRecoveryAttemptedKeyType struct{}
|
||||||
|
|
||||||
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
|
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
|
||||||
type accountWithLoad struct {
|
type accountWithLoad struct {
|
||||||
@@ -78,6 +81,7 @@ type accountWithLoad struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
||||||
|
var kiroCooldownRecoveryAttemptedKey = kiroCooldownRecoveryAttemptedKeyType{}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
windowCostPrefetchCacheHitTotal atomic.Int64
|
windowCostPrefetchCacheHitTotal atomic.Int64
|
||||||
@@ -554,6 +558,8 @@ type GatewayService struct {
|
|||||||
deferredService *DeferredService
|
deferredService *DeferredService
|
||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
claudeTokenProvider *ClaudeTokenProvider
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
|
kiroTokenProvider *KiroTokenProvider
|
||||||
|
kiroCooldownStore KiroCooldownStore
|
||||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
userGroupRateResolver *userGroupRateResolver
|
userGroupRateResolver *userGroupRateResolver
|
||||||
@@ -592,6 +598,8 @@ func NewGatewayService(
|
|||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
deferredService *DeferredService,
|
deferredService *DeferredService,
|
||||||
claudeTokenProvider *ClaudeTokenProvider,
|
claudeTokenProvider *ClaudeTokenProvider,
|
||||||
|
kiroTokenProvider *KiroTokenProvider,
|
||||||
|
kiroCooldownStore KiroCooldownStore,
|
||||||
sessionLimitCache SessionLimitCache,
|
sessionLimitCache SessionLimitCache,
|
||||||
rpmCache RPMCache,
|
rpmCache RPMCache,
|
||||||
digestStore *DigestSessionStore,
|
digestStore *DigestSessionStore,
|
||||||
@@ -624,6 +632,8 @@ func NewGatewayService(
|
|||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
claudeTokenProvider: claudeTokenProvider,
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
|
kiroTokenProvider: kiroTokenProvider,
|
||||||
|
kiroCooldownStore: kiroCooldownStore,
|
||||||
sessionLimitCache: sessionLimitCache,
|
sessionLimitCache: sessionLimitCache,
|
||||||
rpmCache: rpmCache,
|
rpmCache: rpmCache,
|
||||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||||
@@ -1969,6 +1979,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(candidates) == 0 {
|
if len(candidates) == 0 {
|
||||||
|
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, useMixed) {
|
||||||
|
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
|
||||||
|
return s.SelectAccountWithLoadAwareness(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID)
|
||||||
|
}
|
||||||
return nil, ErrNoAvailableAccounts
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2348,14 +2362,91 @@ func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return account.IsSchedulable()
|
if !account.IsSchedulable() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.isKiroRuntimeSchedulable(context.Background(), account)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
|
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
|
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.isKiroRuntimeSchedulable(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) isKiroRuntimeSchedulable(ctx context.Context, account *Account) bool {
|
||||||
|
if account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth || s == nil || s.kiroCooldownStore == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(account))
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return state == nil || !state.Active
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) tryRecoverKiroCooldownPool(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) bool {
|
||||||
|
if s == nil || s.kiroCooldownStore == nil || ctx.Value(kiroCooldownRecoveryAttemptedKey) == true {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tokenKeys := s.kiroTransientCooldownRecoveryKeys(ctx, accounts, requestedModel, excludedIDs, allowMixedScheduling)
|
||||||
|
if len(tokenKeys) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cleared, err := s.kiroCooldownStore.ClearEarliestTransientCooldown(ctx, tokenKeys)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery failed: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if cleared {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery cleared one transient cooldown")
|
||||||
|
}
|
||||||
|
return cleared
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) kiroTransientCooldownRecoveryKeys(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) []string {
|
||||||
|
tokenKeys := make([]string, 0, len(accounts))
|
||||||
|
eligible := 0
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if acc == nil || acc.Platform != PlatformKiro || acc.Type != AccountTypeOAuth {
|
||||||
|
if allowMixedScheduling {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountSchedulableForQuota(acc) ||
|
||||||
|
!s.isAccountSchedulableForWindowCost(ctx, acc, false) ||
|
||||||
|
!s.isAccountSchedulableForRPM(ctx, acc, false) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eligible++
|
||||||
|
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc))
|
||||||
|
if err != nil || state == nil || !state.Active {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if state.Reason != kirocooldown.CooldownReason429 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tokenKeys = append(tokenKeys, buildKiroAccountKey(acc))
|
||||||
|
}
|
||||||
|
if eligible == 0 || len(tokenKeys) != eligible {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tokenKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
// isAccountInGroup checks if the account belongs to the specified group.
|
// isAccountInGroup checks if the account belongs to the specified group.
|
||||||
@@ -3234,6 +3325,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
|
|
||||||
if selected == nil {
|
if selected == nil {
|
||||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
||||||
|
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, false) {
|
||||||
|
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
|
||||||
|
return s.selectAccountForModelWithPlatform(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
|
}
|
||||||
if requestedModel != "" {
|
if requestedModel != "" {
|
||||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||||
}
|
}
|
||||||
@@ -3613,6 +3708,17 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
return selectionFailureDiagnosis{Category: "excluded"}
|
return selectionFailureDiagnosis{Category: "excluded"}
|
||||||
}
|
}
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||||
|
}
|
||||||
|
if acc.Platform == PlatformKiro && acc.Type == AccountTypeOAuth {
|
||||||
|
if state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc)); err == nil && state != nil && state.Active {
|
||||||
|
return selectionFailureDiagnosis{
|
||||||
|
Category: "unschedulable",
|
||||||
|
Detail: fmt.Sprintf("kiro_runtime_%s remaining=%s", state.Reason, state.Remaining.Truncate(time.Second)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||||
}
|
}
|
||||||
@@ -3776,6 +3882,13 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
|
|||||||
}
|
}
|
||||||
return accessToken, "oauth", nil
|
return accessToken, "oauth", nil
|
||||||
}
|
}
|
||||||
|
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth && s.kiroTokenProvider != nil {
|
||||||
|
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return accessToken, "oauth", nil
|
||||||
|
}
|
||||||
|
|
||||||
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
@@ -4319,11 +4432,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
return nil, fmt.Errorf("parse request: empty request")
|
return nil, fmt.Errorf("parse request: empty request")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
|
||||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
|
||||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
|
||||||
}
|
|
||||||
|
|
||||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||||
passthroughBody := parsed.Body
|
passthroughBody := parsed.Body
|
||||||
passthroughModel := parsed.Model
|
passthroughModel := parsed.Model
|
||||||
@@ -4347,6 +4455,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
return s.forwardBedrock(ctx, c, account, parsed, startTime)
|
return s.forwardBedrock(ctx, c, account, parsed, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||||
|
return s.forwardKiroMessages(ctx, c, account, parsed, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||||
|
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
||||||
|
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||||
|
}
|
||||||
|
|
||||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||||
if account.Platform == PlatformAnthropic && c != nil {
|
if account.Platform == PlatformAnthropic && c != nil {
|
||||||
@@ -4439,7 +4556,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
|
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
|
||||||
mappedModel := reqModel
|
mappedModel := reqModel
|
||||||
mappingSource := ""
|
mappingSource := ""
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Platform == PlatformKiro {
|
||||||
|
if next := account.GetMappedModel(reqModel); next != "" && next != reqModel {
|
||||||
|
mappedModel = next
|
||||||
|
mappingSource = "account"
|
||||||
|
}
|
||||||
|
} else if account.Type == AccountTypeAPIKey {
|
||||||
mappedModel = account.GetMappedModel(reqModel)
|
mappedModel = account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
mappingSource = "account"
|
mappingSource = "account"
|
||||||
@@ -5938,6 +6060,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
targetURL := claudeAPIURL
|
targetURL := claudeAPIURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
baseURL := account.GetBaseURL()
|
baseURL := account.GetBaseURL()
|
||||||
|
if baseURL == "" && account.Platform == PlatformKiro {
|
||||||
|
return nil, fmt.Errorf("kiro api key account requires base_url")
|
||||||
|
}
|
||||||
if baseURL != "" {
|
if baseURL != "" {
|
||||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -7199,10 +7324,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||||
keepaliveInterval := time.Duration(0)
|
keepaliveInterval := s.streamKeepaliveIntervalForAccount(account)
|
||||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
|
||||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
|
||||||
}
|
|
||||||
var keepaliveTicker *time.Ticker
|
var keepaliveTicker *time.Ticker
|
||||||
if keepaliveInterval > 0 {
|
if keepaliveInterval > 0 {
|
||||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
@@ -8241,6 +8363,9 @@ type recordUsageOpts struct {
|
|||||||
// 长上下文计费(仅 Gemini 路径需要)
|
// 长上下文计费(仅 Gemini 路径需要)
|
||||||
LongContextThreshold int
|
LongContextThreshold int
|
||||||
LongContextMultiplier float64
|
LongContextMultiplier float64
|
||||||
|
|
||||||
|
// Kiro 账号在上游返回 auto 等无法定价模型时使用保守计费兜底。
|
||||||
|
IsKiroAccount bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||||
@@ -8377,6 +8502,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 计算费用
|
// 计算费用
|
||||||
|
opts.IsKiroAccount = account != nil && account.Platform == PlatformKiro
|
||||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||||
|
|
||||||
// 判断计费方式:订阅模式 vs 余额模式
|
// 判断计费方式:订阅模式 vs 余额模式
|
||||||
@@ -8454,6 +8580,28 @@ func (s *GatewayService) calculateRecordUsageCost(
|
|||||||
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const kiroConservativeFallbackBillingModel = "claude-opus-4-6"
|
||||||
|
|
||||||
|
func shouldUseKiroConservativeBillingFallback(result *ForwardResult, billingModel string, opts *recordUsageOpts) bool {
|
||||||
|
if result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts != nil && opts.IsKiroAccount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) calculateKiroConservativeTokenCost(tokens UsageTokens, multiplier float64) *CostBreakdown {
|
||||||
|
if s == nil || s.billingService == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cost, err := s.billingService.CalculateCost(kiroConservativeFallbackBillingModel, tokens, multiplier)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Calculate conservative Kiro fallback cost failed: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
|
||||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||||
@@ -8557,6 +8705,12 @@ func (s *GatewayService) calculateTokenCost(
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||||
|
if shouldUseKiroConservativeBillingFallback(result, billingModel, opts) {
|
||||||
|
if fallback := s.calculateKiroConservativeTokenCost(tokens, multiplier); fallback != nil {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Using conservative Kiro fallback pricing for model=%s", billingModel)
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
return &CostBreakdown{ActualCost: 0}
|
return &CostBreakdown{ActualCost: 0}
|
||||||
}
|
}
|
||||||
return cost
|
return cost
|
||||||
@@ -9444,6 +9598,19 @@ func reconcileCachedTokens(usage map[string]any) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) streamKeepaliveIntervalForAccount(account *Account) time.Duration {
|
||||||
|
if account != nil && account.Platform == PlatformKiro {
|
||||||
|
if s != nil && s.cfg != nil && s.cfg.Gateway.KiroStreamKeepaliveInterval > 0 {
|
||||||
|
return time.Duration(s.cfg.Gateway.KiroStreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
return defaultKiroStreamKeepalive
|
||||||
|
}
|
||||||
|
if s != nil && s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
return time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
|
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
|
||||||
|
|
||||||
// initDebugGatewayBodyFile 初始化网关调试日志文件。
|
// initDebugGatewayBodyFile 初始化网关调试日志文件。
|
||||||
|
|||||||
@@ -0,0 +1,222 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
kiroErrorAuthError = "auth_error"
|
||||||
|
kiroErrorMonthlyRequest = "monthly_request_count"
|
||||||
|
kiroErrorProfileError = "profile_error"
|
||||||
|
kiroErrorQuotaExhausted = "quota_exhausted"
|
||||||
|
kiroErrorOverageExhausted = "overage_exhausted"
|
||||||
|
kiroErrorRateLimited = "rate_limited"
|
||||||
|
kiroErrorSuspended = "suspended"
|
||||||
|
kiroErrorUsageForbidden = "usage_forbidden"
|
||||||
|
kiroErrorUpstreamTransient = "upstream_transient"
|
||||||
|
kiroErrorBadRequestSchema = "bad_request_schema"
|
||||||
|
kiroErrorBadRequestToolPairing = "bad_request_tool_pairing"
|
||||||
|
kiroErrorBadRequestInvalidModel = "bad_request_invalid_model"
|
||||||
|
kiroErrorBadRequestAuth = "bad_request_auth"
|
||||||
|
kiroErrorBadRequestQuota = "bad_request_quota"
|
||||||
|
kiroErrorBadRequestUnknown = "bad_request_unknown"
|
||||||
|
kiroErrorRefreshTokenInvalid = "refresh_token_invalid"
|
||||||
|
|
||||||
|
kiroQuotaStateNormal = "normal"
|
||||||
|
kiroQuotaStateOverageActive = "overage_active"
|
||||||
|
kiroQuotaStateCreditsExhausted = "credits_exhausted"
|
||||||
|
kiroQuotaStateOverageExhausted = "overage_exhausted"
|
||||||
|
)
|
||||||
|
|
||||||
|
type kiroErrorClassification struct {
|
||||||
|
Category string
|
||||||
|
StatusCode int
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyKiroHTTPError(statusCode int, body string) kiroErrorClassification {
|
||||||
|
trimmed := strings.TrimSpace(body)
|
||||||
|
lower := strings.ToLower(trimmed)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case statusCode == http.StatusUnauthorized:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case statusCode == http.StatusPaymentRequired && looksLikeKiroMonthlyRequestCountError(trimmed):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case statusCode == http.StatusForbidden && isKiroSuspendedBody([]byte(trimmed)):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorSuspended, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case looksLikeKiroProfileError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorProfileError, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case statusCode == http.StatusBadRequest:
|
||||||
|
return classifyKiroBadRequest(trimmed, lower)
|
||||||
|
case statusCode == http.StatusForbidden && isKiroTokenErrorBody([]byte(trimmed)):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case looksLikeKiroOverageExhaustedError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorOverageExhausted, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case looksLikeKiroQuotaExhaustedError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case statusCode == http.StatusTooManyRequests:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorRateLimited, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case statusCode == http.StatusForbidden:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorUsageForbidden, StatusCode: statusCode, Message: trimmed}
|
||||||
|
case statusCode >= 500:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
|
||||||
|
default:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyKiroError(err error) kiroErrorClassification {
|
||||||
|
if err == nil {
|
||||||
|
return kiroErrorClassification{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var httpErr *kiroUsageHTTPError
|
||||||
|
if errors.As(err, &httpErr) && httpErr != nil {
|
||||||
|
return classifyKiroHTTPError(httpErr.StatusCode, httpErr.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
errStr := strings.TrimSpace(err.Error())
|
||||||
|
lower := strings.ToLower(errStr)
|
||||||
|
switch {
|
||||||
|
case looksLikeKiroInvalidGrantError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorRefreshTokenInvalid, Message: errStr}
|
||||||
|
case looksLikeKiroMonthlyRequestCountError(errStr):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, Message: errStr}
|
||||||
|
case looksLikeKiroProfileError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorProfileError, Message: errStr}
|
||||||
|
case looksLikeKiroOverageExhaustedError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorOverageExhausted, Message: errStr}
|
||||||
|
case looksLikeKiroQuotaExhaustedError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, Message: errStr}
|
||||||
|
case strings.Contains(lower, "context deadline exceeded"),
|
||||||
|
strings.Contains(lower, "timeout"),
|
||||||
|
isNetErr(err):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
|
||||||
|
default:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyKiroBadRequest(trimmed, lower string) kiroErrorClassification {
|
||||||
|
switch {
|
||||||
|
case looksLikeKiroBadRequestSchemaError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorBadRequestSchema, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||||
|
case looksLikeKiroBadRequestToolPairingError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorBadRequestToolPairing, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||||
|
case looksLikeKiroBadRequestInvalidModelError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorBadRequestInvalidModel, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||||
|
case looksLikeKiroInvalidGrantError(lower) || looksLikeKiroBadRequestAuthError(lower):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorBadRequestAuth, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||||
|
case looksLikeKiroQuotaExhaustedError(lower) || looksLikeKiroMonthlyRequestCountError(trimmed):
|
||||||
|
return kiroErrorClassification{Category: kiroErrorBadRequestQuota, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||||
|
default:
|
||||||
|
return kiroErrorClassification{Category: kiroErrorBadRequestUnknown, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroBadRequestSchemaError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "schema") ||
|
||||||
|
strings.Contains(lower, "inputschema") ||
|
||||||
|
strings.Contains(lower, "improperly formed request") ||
|
||||||
|
strings.Contains(lower, "additionalproperties") ||
|
||||||
|
(strings.Contains(lower, "properties") && strings.Contains(lower, "required"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroBadRequestToolPairingError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "tool_use") ||
|
||||||
|
strings.Contains(lower, "tool_result") ||
|
||||||
|
strings.Contains(lower, "tooluseid") ||
|
||||||
|
strings.Contains(lower, "toolresults") ||
|
||||||
|
strings.Contains(lower, "must be paired") ||
|
||||||
|
strings.Contains(lower, "missing tool result")
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroBadRequestInvalidModelError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "invalid model") ||
|
||||||
|
strings.Contains(lower, "invalid_model_id") ||
|
||||||
|
strings.Contains(lower, "model not supported") ||
|
||||||
|
strings.Contains(lower, "unsupportedmodel") ||
|
||||||
|
strings.Contains(lower, "modelid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroBadRequestAuthError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "invalid token") ||
|
||||||
|
strings.Contains(lower, "expired token") ||
|
||||||
|
strings.Contains(lower, "access token") ||
|
||||||
|
strings.Contains(lower, "refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroInvalidGrantError(lower string) bool {
|
||||||
|
return strings.Contains(lower, "invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroMonthlyRequestCountError(body string) bool {
|
||||||
|
trimmed := strings.TrimSpace(body)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(trimmed, "MONTHLY_REQUEST_COUNT") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !gjson.Valid(trimmed) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return gjson.Get(trimmed, "reason").String() == "MONTHLY_REQUEST_COUNT" ||
|
||||||
|
gjson.Get(trimmed, "error.reason").String() == "MONTHLY_REQUEST_COUNT"
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroProfileError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return (strings.Contains(lower, "profilearn") && strings.Contains(lower, "required")) ||
|
||||||
|
(strings.Contains(lower, "profile arn") && strings.Contains(lower, "required")) ||
|
||||||
|
(strings.Contains(lower, "profile") && strings.Contains(lower, "not found")) ||
|
||||||
|
(strings.Contains(lower, "invalid profile")) ||
|
||||||
|
(strings.Contains(lower, "listavailableprofiles"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroQuotaExhaustedError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return (strings.Contains(lower, "credit") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "depleted"))) ||
|
||||||
|
(strings.Contains(lower, "quota") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "exceeded") || strings.Contains(lower, "depleted"))) ||
|
||||||
|
(strings.Contains(lower, "usage limit") && (strings.Contains(lower, "reached") || strings.Contains(lower, "exceeded"))) ||
|
||||||
|
(strings.Contains(lower, "resource has been exhausted"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeKiroOverageExhaustedError(lower string) bool {
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "overage") &&
|
||||||
|
(strings.Contains(lower, "exhaust") ||
|
||||||
|
strings.Contains(lower, "disabled") ||
|
||||||
|
strings.Contains(lower, "not enabled") ||
|
||||||
|
strings.Contains(lower, "not allowed") ||
|
||||||
|
strings.Contains(lower, "limit"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNetErr(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
return errors.As(err, &netErr)
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClassifyKiroHTTPErrorBadRequestCategories(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "schema",
|
||||||
|
body: `{"message":"Improperly formed request: inputSchema.properties must be an object"}`,
|
||||||
|
want: kiroErrorBadRequestSchema,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool pairing",
|
||||||
|
body: `{"message":"tool_use must be paired with a matching tool_result"}`,
|
||||||
|
want: kiroErrorBadRequestToolPairing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid model id",
|
||||||
|
body: `{"message":"invalid modelId: model not supported"}`,
|
||||||
|
want: kiroErrorBadRequestInvalidModel,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid model upstream",
|
||||||
|
body: `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
|
||||||
|
want: kiroErrorBadRequestInvalidModel,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid model reason",
|
||||||
|
body: `{"message":"model route unavailable","reason":"INVALID_MODEL_ID"}`,
|
||||||
|
want: kiroErrorBadRequestInvalidModel,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth",
|
||||||
|
body: `{"error":"invalid_grant","message":"Invalid refresh token provided"}`,
|
||||||
|
want: kiroErrorBadRequestAuth,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "quota",
|
||||||
|
body: `{"message":"resource has been exhausted"}`,
|
||||||
|
want: kiroErrorBadRequestQuota,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown",
|
||||||
|
body: `{"message":"bad request"}`,
|
||||||
|
want: kiroErrorBadRequestUnknown,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
classification := classifyKiroHTTPError(http.StatusBadRequest, tt.body)
|
||||||
|
require.Equal(t, tt.want, classification.Category)
|
||||||
|
require.Equal(t, http.StatusBadRequest, classification.StatusCode)
|
||||||
|
require.Equal(t, tt.body, classification.Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildKiroAccountKey(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return kiropkg.BuildAccountKey(
|
||||||
|
account.GetCredential("client_id"),
|
||||||
|
account.GetCredential("client_id_hash"),
|
||||||
|
account.GetCredential("refresh_token"),
|
||||||
|
account.GetCredential("profile_arn"),
|
||||||
|
account.ID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroMachineID(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return kiropkg.BuildMachineID("", "", "account:nil")
|
||||||
|
}
|
||||||
|
for _, key := range []string{"machine_id", "machineId"} {
|
||||||
|
if machineID, ok := kiropkg.NormalizeMachineID(account.GetCredential(key)); ok {
|
||||||
|
return machineID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fallbackKey := buildKiroMachineIDFallbackKey(account)
|
||||||
|
if account.Type == AccountTypeAPIKey {
|
||||||
|
return kiropkg.BuildMachineID("", firstKiroCredential(account, "kiro_api_key", "kiroApiKey", "api_key"), fallbackKey)
|
||||||
|
}
|
||||||
|
return kiropkg.BuildMachineID(account.GetCredential("refresh_token"), "", fallbackKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstKiroCredential(account *Account, keys ...string) string {
|
||||||
|
if account == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroMachineIDFallbackKey(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return "account:nil"
|
||||||
|
}
|
||||||
|
if account.ID > 0 {
|
||||||
|
return fmt.Sprintf("account:%d", account.ID)
|
||||||
|
}
|
||||||
|
for _, key := range []string{"client_id", "profile_arn"} {
|
||||||
|
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
|
||||||
|
return key + ":" + value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if name := strings.TrimSpace(account.Name); name != "" {
|
||||||
|
return "name:" + name
|
||||||
|
}
|
||||||
|
return "account:unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroRequestID(resp *http.Response) string {
|
||||||
|
if resp == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if requestID := strings.TrimSpace(resp.Header.Get("x-request-id")); requestID != "" {
|
||||||
|
return requestID
|
||||||
|
}
|
||||||
|
if requestID := strings.TrimSpace(resp.Header.Get("x-amzn-requestid")); requestID != "" {
|
||||||
|
return requestID
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(resp.Header.Get("x-amz-request-id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isKiroInvalidModelIDBody(respBody []byte) bool {
|
||||||
|
var payload struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(respBody, &payload) != nil {
|
||||||
|
return looksLikeKiroBadRequestInvalidModelError(strings.ToLower(string(respBody)))
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(payload.Reason), "INVALID_MODEL_ID") ||
|
||||||
|
strings.EqualFold(strings.TrimSpace(payload.Error.Reason), "INVALID_MODEL_ID") ||
|
||||||
|
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Message)) ||
|
||||||
|
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Error.Message))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isKiroSuspendedBody(respBody []byte) bool {
|
||||||
|
body := string(respBody)
|
||||||
|
return strings.Contains(body, "SUSPENDED") || strings.Contains(body, "TEMPORARILY_SUSPENDED")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isKiroTokenErrorBody(respBody []byte) bool {
|
||||||
|
lower := strings.ToLower(string(respBody))
|
||||||
|
return strings.Contains(lower, "token") ||
|
||||||
|
strings.Contains(lower, "expired") ||
|
||||||
|
strings.Contains(lower, "invalid") ||
|
||||||
|
strings.Contains(lower, "unauthorized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroProxyURL(account *Account) string {
|
||||||
|
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
return account.Proxy.URL()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroAPIRegion(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return kiroDefaultRegion
|
||||||
|
}
|
||||||
|
region := strings.TrimSpace(account.GetCredential("api_region"))
|
||||||
|
if region == "" {
|
||||||
|
region = kiroDefaultRegion
|
||||||
|
}
|
||||||
|
return region
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyKiroConditionalHeaders(req *http.Request, account *Account) {
|
||||||
|
if req == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(account.GetCredential("auth_method")), "external_idp") {
|
||||||
|
req.Header.Set("TokenType", "EXTERNAL_IDP")
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(account.GetCredential("provider")), "Internal") {
|
||||||
|
req.Header.Set("redirect-for-internal", "true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveKiroPayloadProfileArn(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKiroJSONRequest(ctx context.Context, endpointURL string, payload []byte, token, accountKey, machineID, amzTarget string, account *Account) (*http.Request, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "*/*")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||||
|
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||||
|
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||||
|
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||||
|
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||||
|
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||||
|
if amzTarget != "" {
|
||||||
|
req.Header.Set("X-Amz-Target", amzTarget)
|
||||||
|
}
|
||||||
|
if account != nil {
|
||||||
|
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||||
|
if profileArn != "" {
|
||||||
|
req.Header.Set("x-amzn-kiro-profile-arn", profileArn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
applyKiroConditionalHeaders(req, account)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
|
||||||
|
accountA := &Account{
|
||||||
|
ID: 99,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-a",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountB := &Account{
|
||||||
|
ID: 99,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token-b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, buildKiroAccountKey(accountA), buildKiroAccountKey(accountB))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroMachineIDPrefersExplicitCredential(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 101,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"machineId": "2582956e-cc88-4669-b546-07adbffcb894",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", buildKiroMachineID(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroMachineIDDerivesFromRefreshToken(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 102,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, kiropkg.BuildMachineID("refresh-token", "", "account:102"), buildKiroMachineID(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroMachineIDDerivesFromAPIKeyAccount(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"kiroApiKey": "kiro-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, kiropkg.BuildMachineID("", "kiro-api-key", "account:103"), buildKiroMachineID(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewKiroJSONRequestAddsConditionalHeaders(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"auth_method": "external_idp",
|
||||||
|
"provider": "Internal",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := newKiroJSONRequest(
|
||||||
|
context.Background(),
|
||||||
|
"https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||||
|
[]byte(`{"ok":true}`),
|
||||||
|
"access-token",
|
||||||
|
"account-key",
|
||||||
|
buildKiroMachineID(account),
|
||||||
|
"",
|
||||||
|
account,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "EXTERNAL_IDP", req.Header.Get("TokenType"))
|
||||||
|
require.Equal(t, "true", req.Header.Get("redirect-for-internal"))
|
||||||
|
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER", req.Header.Get("x-amzn-kiro-profile-arn"))
|
||||||
|
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
|
||||||
|
require.Equal(t, "true", req.Header.Get("x-amzn-codewhisperer-optout"))
|
||||||
|
require.Contains(t, req.Header.Get("User-Agent"), "aws-sdk-js/1.0.34")
|
||||||
|
require.Contains(t, req.Header.Get("User-Agent"), "md/nodejs#22.22.0")
|
||||||
|
require.Contains(t, req.Header.Get("User-Agent"), buildKiroMachineID(account))
|
||||||
|
require.Contains(t, req.Header.Get("X-Amz-User-Agent"), buildKiroMachineID(account))
|
||||||
|
require.True(t, strings.Contains(req.Header.Get("User-Agent"), "api/codewhispererstreaming#1.0.34"))
|
||||||
|
require.Empty(t, req.Header.Get("Anthropic-Beta"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsKiroInvalidModelIDBodyRecognizesKnownForms(t *testing.T) {
|
||||||
|
tests := []string{
|
||||||
|
`{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`,
|
||||||
|
`{"message":"Invalid model. Please select a different model to continue."}`,
|
||||||
|
`API Error: 400 {"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, body := range tests {
|
||||||
|
require.True(t, isKiroInvalidModelIDBody([]byte(body)), body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body := []byte(`{
|
||||||
|
"model":"claude-sonnet-4-6",
|
||||||
|
"messages":[{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||||
|
|
||||||
|
payload, err := buildKiroPayloadForAccount(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
body,
|
||||||
|
"claude-sonnet-4.6",
|
||||||
|
"kiro-access-token",
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
headers,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(payload), "CHUNKED WRITE PROTOCOL")
|
||||||
|
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_region": "eu-west-1",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
|
||||||
|
"region": "ap-northeast-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, "eu-west-1", kiroAPIRegion(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroAPIRegionIgnoresProfileARNRegionFallback(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroAPIRegionIgnoresOIDCRegionFallback(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"region": "ap-northeast-2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroEndpointsUsesOnlyAmazonQEndpoint(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_region": "us-west-2",
|
||||||
|
"preferred_endpoint": "cw",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoints := buildKiroEndpoints(account)
|
||||||
|
require.Len(t, endpoints, 1)
|
||||||
|
require.Equal(t, "AmazonQ", endpoints[0].Name)
|
||||||
|
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
|
||||||
|
require.Empty(t, endpoints[0].AmzTarget)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroEndpointsIgnoresPreferredEndpoint(t *testing.T) {
|
||||||
|
for _, preferred := range []string{"codewhisperer", "cw", "unknown"} {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_region": "us-west-2",
|
||||||
|
"preferred_endpoint": preferred,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoints := buildKiroEndpoints(account)
|
||||||
|
require.Len(t, endpoints, 1)
|
||||||
|
require.Equal(t, "AmazonQ", endpoints[0].Name)
|
||||||
|
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountKiroDefaultMappingRestrictsUnsupportedModels(t *testing.T) {
|
||||||
|
account := &Account{Platform: PlatformKiro}
|
||||||
|
|
||||||
|
require.False(t, account.IsModelSupported("gpt-4o"))
|
||||||
|
require.False(t, account.IsModelSupported("kiro-gpt-4o"))
|
||||||
|
require.False(t, account.IsModelSupported("auto"))
|
||||||
|
require.Equal(t, "claude-sonnet-4.6", account.GetMappedModel("claude-sonnet-4-6"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceCalculateTokenCost_KiroAutoUsesConservativeFallback(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Default.RateMultiplier = 1.1
|
||||||
|
|
||||||
|
svc := NewGatewayService(
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
NewBillingService(cfg, nil),
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result := &ForwardResult{
|
||||||
|
Model: "auto",
|
||||||
|
UpstreamModel: "auto",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 20,
|
||||||
|
OutputTokens: 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expected, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||||
|
InputTokens: 20,
|
||||||
|
OutputTokens: 10,
|
||||||
|
}, 1.1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cost := svc.calculateTokenCost(context.Background(), result, &APIKey{}, "auto", 1.1, &recordUsageOpts{IsKiroAccount: true})
|
||||||
|
require.NotNil(t, cost)
|
||||||
|
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-12)
|
||||||
|
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-12)
|
||||||
|
}
|
||||||
@@ -0,0 +1,369 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Kiro desktop social auth uses localhost loopback callbacks from a fixed
|
||||||
|
// allowlist. Use one of the bundled ports from the official client.
|
||||||
|
kiroSocialRedirectURI = "http://localhost:49153"
|
||||||
|
// AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
|
||||||
|
kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KiroOAuthService struct {
|
||||||
|
sessionStore *kiropkg.SessionStore
|
||||||
|
proxyRepo ProxyRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
|
||||||
|
return &KiroOAuthService{
|
||||||
|
sessionStore: kiropkg.NewSessionStore(),
|
||||||
|
proxyRepo: proxyRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) Stop() {}
|
||||||
|
|
||||||
|
type KiroAuthURLResult struct {
|
||||||
|
AuthURL string `json:"auth_url"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
State string `json:"state"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroIDCAuthURLResult struct {
|
||||||
|
AuthURL string `json:"auth_url"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
State string `json:"state"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
StartURL string `json:"start_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroTokenInfo struct {
|
||||||
|
AccessToken string `json:"access_token,omitempty"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
ProfileArn string `json:"profile_arn,omitempty"`
|
||||||
|
ExpiresAt string `json:"expires_at,omitempty"`
|
||||||
|
AuthMethod string `json:"auth_method,omitempty"`
|
||||||
|
Provider string `json:"provider,omitempty"`
|
||||||
|
ClientID string `json:"client_id,omitempty"`
|
||||||
|
ClientSecret string `json:"client_secret,omitempty"`
|
||||||
|
ClientIDHash string `json:"client_id_hash,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
StartURL string `json:"start_url,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroGenerateAuthURLInput struct {
|
||||||
|
ProxyID *int64
|
||||||
|
Provider string
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroExchangeCodeInput struct {
|
||||||
|
SessionID string
|
||||||
|
State string
|
||||||
|
Code string
|
||||||
|
CallbackPath string
|
||||||
|
LoginOption string
|
||||||
|
ProxyID *int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroGenerateIDCAuthURLInput struct {
|
||||||
|
ProxyID *int64
|
||||||
|
StartURL string
|
||||||
|
Region string
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroRefreshTokenInput struct {
|
||||||
|
RefreshToken string
|
||||||
|
AuthMethod string
|
||||||
|
Provider string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
StartURL string
|
||||||
|
Region string
|
||||||
|
ProfileArn string
|
||||||
|
ProxyID *int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroImportTokenInput struct {
|
||||||
|
TokenJSON string
|
||||||
|
DeviceRegistrationJSON string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
|
||||||
|
provider := strings.TrimSpace(input.Provider)
|
||||||
|
if provider == "" {
|
||||||
|
provider = string(kiropkg.SocialProviderGoogle)
|
||||||
|
}
|
||||||
|
if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
|
||||||
|
return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
|
||||||
|
}
|
||||||
|
state, err := kiropkg.GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate state failed: %w", err)
|
||||||
|
}
|
||||||
|
codeVerifier, err := kiropkg.GenerateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate code verifier failed: %w", err)
|
||||||
|
}
|
||||||
|
sessionID := kiropkg.GenerateSessionID()
|
||||||
|
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||||
|
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
|
||||||
|
State: state,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
AuthType: "social",
|
||||||
|
Provider: provider,
|
||||||
|
RedirectURI: kiroSocialRedirectURI,
|
||||||
|
})
|
||||||
|
return &KiroAuthURLResult{
|
||||||
|
AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
|
||||||
|
SessionID: sessionID,
|
||||||
|
State: state,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
|
||||||
|
session, ok := s.sessionStore.Get(input.SessionID)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("session not found or expired")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||||
|
return nil, fmt.Errorf("state invalid")
|
||||||
|
}
|
||||||
|
proxyURL := session.ProxyURL
|
||||||
|
if input.ProxyID != nil {
|
||||||
|
proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch session.AuthType {
|
||||||
|
case "social":
|
||||||
|
token, err := kiropkg.CreateSocialToken(
|
||||||
|
ctx,
|
||||||
|
proxyURL,
|
||||||
|
input.Code,
|
||||||
|
session.CodeVerifier,
|
||||||
|
buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token.Provider = session.Provider
|
||||||
|
s.sessionStore.Delete(input.SessionID)
|
||||||
|
return toKiroTokenInfo(token), nil
|
||||||
|
case "idc":
|
||||||
|
token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.sessionStore.Delete(input.SessionID)
|
||||||
|
return toKiroTokenInfo(token), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
|
||||||
|
option := strings.ToLower(strings.TrimSpace(loginOption))
|
||||||
|
if option == "" {
|
||||||
|
switch provider {
|
||||||
|
case string(kiropkg.SocialProviderGitHub):
|
||||||
|
option = "github"
|
||||||
|
case string(kiropkg.SocialProviderGoogle):
|
||||||
|
option = "google"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
|
||||||
|
startURL := strings.TrimSpace(input.StartURL)
|
||||||
|
if startURL == "" {
|
||||||
|
startURL = kiropkg.BuilderIDStartURL
|
||||||
|
}
|
||||||
|
region := strings.TrimSpace(input.Region)
|
||||||
|
if region == "" {
|
||||||
|
region = "us-east-1"
|
||||||
|
}
|
||||||
|
state, err := kiropkg.GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate state failed: %w", err)
|
||||||
|
}
|
||||||
|
codeVerifier, err := kiropkg.GenerateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate code verifier failed: %w", err)
|
||||||
|
}
|
||||||
|
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||||
|
reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sessionID := kiropkg.GenerateSessionID()
|
||||||
|
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
|
||||||
|
State: state,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
AuthType: "idc",
|
||||||
|
RedirectURI: kiroIDCRedirectURI,
|
||||||
|
ClientID: reg.ClientID,
|
||||||
|
ClientSecret: reg.ClientSecret,
|
||||||
|
Region: region,
|
||||||
|
StartURL: startURL,
|
||||||
|
})
|
||||||
|
return &KiroIDCAuthURLResult{
|
||||||
|
AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
|
||||||
|
SessionID: sessionID,
|
||||||
|
State: state,
|
||||||
|
ClientID: reg.ClientID,
|
||||||
|
Region: region,
|
||||||
|
StartURL: startURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
|
||||||
|
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||||
|
authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
|
||||||
|
if authMethod == "" {
|
||||||
|
authMethod = "social"
|
||||||
|
}
|
||||||
|
|
||||||
|
var token *kiropkg.TokenData
|
||||||
|
var err error
|
||||||
|
switch authMethod {
|
||||||
|
case "idc":
|
||||||
|
token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
|
||||||
|
default:
|
||||||
|
token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if token.ProfileArn == "" {
|
||||||
|
token.ProfileArn = input.ProfileArn
|
||||||
|
}
|
||||||
|
if token.ClientID == "" {
|
||||||
|
token.ClientID = input.ClientID
|
||||||
|
}
|
||||||
|
if token.ClientSecret == "" {
|
||||||
|
token.ClientSecret = input.ClientSecret
|
||||||
|
}
|
||||||
|
if token.StartURL == "" {
|
||||||
|
token.StartURL = input.StartURL
|
||||||
|
}
|
||||||
|
if token.Region == "" {
|
||||||
|
token.Region = input.Region
|
||||||
|
}
|
||||||
|
return toKiroTokenInfo(token), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
|
||||||
|
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||||
|
return nil, fmt.Errorf("not a kiro oauth account")
|
||||||
|
}
|
||||||
|
return s.RefreshToken(ctx, &KiroRefreshTokenInput{
|
||||||
|
RefreshToken: account.GetCredential("refresh_token"),
|
||||||
|
AuthMethod: account.GetCredential("auth_method"),
|
||||||
|
Provider: account.GetCredential("provider"),
|
||||||
|
ClientID: account.GetCredential("client_id"),
|
||||||
|
ClientSecret: account.GetCredential("client_secret"),
|
||||||
|
StartURL: account.GetCredential("start_url"),
|
||||||
|
Region: account.GetCredential("region"),
|
||||||
|
ProfileArn: account.GetCredential("profile_arn"),
|
||||||
|
ProxyID: account.ProxyID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
|
||||||
|
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return toKiroTokenInfo(token), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
||||||
|
if tokenInfo == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := map[string]any{}
|
||||||
|
if tokenInfo.AccessToken != "" {
|
||||||
|
creds["access_token"] = tokenInfo.AccessToken
|
||||||
|
}
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
if tokenInfo.ProfileArn != "" {
|
||||||
|
creds["profile_arn"] = tokenInfo.ProfileArn
|
||||||
|
}
|
||||||
|
if tokenInfo.ExpiresAt != "" {
|
||||||
|
creds["expires_at"] = tokenInfo.ExpiresAt
|
||||||
|
}
|
||||||
|
if tokenInfo.AuthMethod != "" {
|
||||||
|
creds["auth_method"] = tokenInfo.AuthMethod
|
||||||
|
}
|
||||||
|
if tokenInfo.Provider != "" {
|
||||||
|
creds["provider"] = tokenInfo.Provider
|
||||||
|
}
|
||||||
|
if tokenInfo.ClientID != "" {
|
||||||
|
creds["client_id"] = tokenInfo.ClientID
|
||||||
|
}
|
||||||
|
if tokenInfo.ClientSecret != "" {
|
||||||
|
creds["client_secret"] = tokenInfo.ClientSecret
|
||||||
|
}
|
||||||
|
if tokenInfo.ClientIDHash != "" {
|
||||||
|
creds["client_id_hash"] = tokenInfo.ClientIDHash
|
||||||
|
}
|
||||||
|
if tokenInfo.Email != "" {
|
||||||
|
creds["email"] = tokenInfo.Email
|
||||||
|
}
|
||||||
|
if tokenInfo.StartURL != "" {
|
||||||
|
creds["start_url"] = tokenInfo.StartURL
|
||||||
|
}
|
||||||
|
if tokenInfo.Region != "" {
|
||||||
|
creds["region"] = tokenInfo.Region
|
||||||
|
}
|
||||||
|
|
||||||
|
return creds
|
||||||
|
}
|
||||||
|
|
||||||
|
func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
|
||||||
|
if token == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &KiroTokenInfo{
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
ProfileArn: token.ProfileArn,
|
||||||
|
ExpiresAt: token.ExpiresAt,
|
||||||
|
AuthMethod: token.AuthMethod,
|
||||||
|
Provider: token.Provider,
|
||||||
|
ClientID: token.ClientID,
|
||||||
|
ClientSecret: token.ClientSecret,
|
||||||
|
ClientIDHash: token.ClientIDHash,
|
||||||
|
Email: token.Email,
|
||||||
|
StartURL: token.StartURL,
|
||||||
|
Region: token.Region,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||||
|
if proxyID == nil || s.proxyRepo == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||||
|
if err != nil || proxy == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return proxy.URL(), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKiroIDCAuthRedirectURIUsesLoopbackIP(t *testing.T) {
|
||||||
|
require.Equal(t, "http://127.0.0.1:9876/oauth/callback", kiroIDCRedirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroSocialAuthRedirectURIUsesLoopbackIP(t *testing.T) {
|
||||||
|
require.Equal(t, "http://localhost:49153", kiroSocialRedirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroSocialExchangeRedirectURIUsesProviderDefault(t *testing.T) {
|
||||||
|
require.Equal(
|
||||||
|
t,
|
||||||
|
"http://localhost:49153/oauth/callback?login_option=github",
|
||||||
|
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "", ""),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKiroSocialExchangeRedirectURIPreservesParsedCallbackData(t *testing.T) {
|
||||||
|
require.Equal(
|
||||||
|
t,
|
||||||
|
"http://localhost:49153/signin/callback?login_option=google",
|
||||||
|
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "/signin/callback", "google"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroOAuthService_ExchangeCodeRejectsExpiredSession(t *testing.T) {
|
||||||
|
svc := NewKiroOAuthService(nil)
|
||||||
|
svc.sessionStore.Set("expired-session", &kiropkg.AuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CreatedAt: time.Now().Add(-11 * time.Minute),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := svc.ExchangeCode(context.Background(), &KiroExchangeCodeInput{
|
||||||
|
SessionID: "expired-session",
|
||||||
|
State: "expected-state",
|
||||||
|
Code: "auth-code",
|
||||||
|
})
|
||||||
|
require.EqualError(t, err, "session not found or expired")
|
||||||
|
}
|
||||||
@@ -0,0 +1,724 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
mathrand "math/rand"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type kiroEndpointConfig struct {
|
||||||
|
URL string
|
||||||
|
AmzTarget string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
const kiroInvalidModelTempUnschedDuration = time.Minute
|
||||||
|
|
||||||
|
const (
|
||||||
|
kiroRetryBaseDelay = 200 * time.Millisecond
|
||||||
|
kiroRetryMaxDelay = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var kiroRetrySleep = sleepWithContext
|
||||||
|
|
||||||
|
func kiroRetryBackoffDelay(attempt int) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
delay := kiroRetryBaseDelay * time.Duration(1<<attempt)
|
||||||
|
if delay > kiroRetryMaxDelay {
|
||||||
|
delay = kiroRetryMaxDelay
|
||||||
|
}
|
||||||
|
jitterMax := delay / 4
|
||||||
|
if jitterMax <= 0 {
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
return delay + time.Duration(mathrand.Int63n(int64(jitterMax)+1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepKiroRetry(ctx context.Context, attempt int) error {
|
||||||
|
return kiroRetrySleep(ctx, kiroRetryBackoffDelay(attempt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) {
|
||||||
|
if account == nil || parsed == nil {
|
||||||
|
return nil, fmt.Errorf("kiro forward: missing account or request")
|
||||||
|
}
|
||||||
|
|
||||||
|
originalModel := parsed.Model
|
||||||
|
mappedModel := originalModel
|
||||||
|
if next := account.GetMappedModel(originalModel); next != "" {
|
||||||
|
mappedModel = next
|
||||||
|
}
|
||||||
|
body := parsed.Body
|
||||||
|
if mappedModel != originalModel {
|
||||||
|
body = s.replaceModelInBody(body, mappedModel)
|
||||||
|
}
|
||||||
|
logger.L().Debug("gateway forward_kiro_messages: request prepared",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("auth_method", strings.TrimSpace(account.GetCredential("auth_method"))),
|
||||||
|
zap.String("requested_model", originalModel),
|
||||||
|
zap.String("mapped_model", mappedModel),
|
||||||
|
zap.Bool("has_profile_arn", strings.TrimSpace(account.GetCredential("profile_arn")) != ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, body) {
|
||||||
|
parsedForEmulation := *parsed
|
||||||
|
parsedForEmulation.Body = body
|
||||||
|
return s.handleWebSearchEmulation(ctx, c, account, &parsedForEmulation)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.Stream {
|
||||||
|
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: failoverErr.StatusCode,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||||
|
})
|
||||||
|
return nil, failoverErr
|
||||||
|
}
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
|
||||||
|
}
|
||||||
|
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||||
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if streamResult.usage == nil {
|
||||||
|
streamResult.usage = &ClaudeUsage{}
|
||||||
|
}
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Usage: *streamResult.usage,
|
||||||
|
Model: originalModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
Stream: true,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: streamResult.firstTokenMs,
|
||||||
|
ClientDisconnect: streamResult.clientDisconnect,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tokenType != "oauth" {
|
||||||
|
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||||
|
}
|
||||||
|
if isOnlyWebSearchToolInBody(body) {
|
||||||
|
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
|
||||||
|
switch {
|
||||||
|
case errors.Is(webSearchErr, errKiroWebSearchFallback):
|
||||||
|
case webSearchErr == nil:
|
||||||
|
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
if webSearchResult.RequestID != "" {
|
||||||
|
c.Header("x-request-id", webSearchResult.RequestID)
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", webSearchResult.ResponseBody)
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: webSearchResult.RequestID,
|
||||||
|
Usage: webSearchResult.Usage,
|
||||||
|
Model: originalModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
Stream: false,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
var httpErr *kiroWebSearchHTTPError
|
||||||
|
if errors.As(webSearchErr, &httpErr) && httpErr.Response != nil {
|
||||||
|
return nil, s.handleKiroHTTPError(ctx, httpErr.Response, c, account, mappedModel, body)
|
||||||
|
}
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
if errors.As(webSearchErr, &failoverErr) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: failoverErr.StatusCode,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(webSearchErr.Error()),
|
||||||
|
})
|
||||||
|
return nil, failoverErr
|
||||||
|
}
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(webSearchErr.Error())
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := estimateKiroInputTokens(body)
|
||||||
|
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: failoverErr.StatusCode,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||||
|
})
|
||||||
|
return nil, failoverErr
|
||||||
|
}
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
parseResult, err := kiropkg.ParseNonStreamingEventStreamWithContext(resp.Body, mappedModel, requestCtx)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Failed to parse Kiro upstream response",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
if requestID := resp.Header.Get("x-request-id"); requestID != "" {
|
||||||
|
c.Header("x-request-id", requestID)
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", parseResult.ResponseBody)
|
||||||
|
|
||||||
|
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||||
|
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
|
||||||
|
Model: originalModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
Stream: false,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
|
||||||
|
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if tokenType != "oauth" {
|
||||||
|
return nil, 0, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := estimateKiroInputTokens(anthropicBody)
|
||||||
|
if isOnlyWebSearchToolInBody(anthropicBody) {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("Content-Type", "text/event-stream")
|
||||||
|
go func() {
|
||||||
|
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
|
||||||
|
if streamErr != nil {
|
||||||
|
_ = pw.CloseWithError(streamErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = pw.Close()
|
||||||
|
}()
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: headers,
|
||||||
|
Body: pr,
|
||||||
|
}, inputTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
return nil, inputTokens, err
|
||||||
|
}
|
||||||
|
return nil, inputTokens, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return resp, inputTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
wrappedHeaders := resp.Header.Clone()
|
||||||
|
wrappedHeaders.Set("Content-Type", "text/event-stream")
|
||||||
|
if requestID := buildKiroRequestID(resp); requestID != "" {
|
||||||
|
wrappedHeaders.Set("x-request-id", requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
_, streamErr := kiropkg.StreamEventStreamAsAnthropicWithContext(ctx, resp.Body, pw, mappedModel, inputTokens, requestCtx)
|
||||||
|
if streamErr != nil {
|
||||||
|
_ = pw.CloseWithError(streamErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = pw.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: wrappedHeaders,
|
||||||
|
Body: pr,
|
||||||
|
}, inputTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||||
|
var requestCtx kiropkg.KiroRequestContext
|
||||||
|
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
|
||||||
|
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||||
|
return nil, requestCtx, failoverErr
|
||||||
|
}
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID := kiropkg.MapModel(mappedModel)
|
||||||
|
currentToken := token
|
||||||
|
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
payload := buildResult.Payload
|
||||||
|
requestCtx = buildResult.Context
|
||||||
|
|
||||||
|
endpoints := buildKiroEndpoints(account)
|
||||||
|
proxyURL := kiroProxyURL(account)
|
||||||
|
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||||
|
accountKey := buildKiroAccountKey(account)
|
||||||
|
maxRetries := 2
|
||||||
|
|
||||||
|
for idx, endpoint := range endpoints {
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||||
|
if err != nil {
|
||||||
|
if attempt < maxRetries {
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, requestCtx, sleepErr
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
cooldown, err := s.markKiro429(ctx, accountKey)
|
||||||
|
if err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
if idx+1 < len(endpoints) {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, requestCtx, sleepErr
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
resp.Header.Set("x-kiro-cooldown", cooldown.String())
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusRequestTimeout || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
|
||||||
|
if attempt < maxRetries {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, requestCtx, sleepErr
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if idx+1 < len(endpoints) {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, requestCtx, sleepErr
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusPaymentRequired {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, requestCtx, readErr
|
||||||
|
}
|
||||||
|
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||||
|
if classification.Category == kiroErrorMonthlyRequest {
|
||||||
|
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
|
||||||
|
}
|
||||||
|
return nil, requestCtx, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
ResponseHeaders: resp.Header.Clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, requestCtx, readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
|
||||||
|
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
resetHTTPResponseBody(resp, respBody)
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
|
||||||
|
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||||
|
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||||
|
currentToken = refreshedToken
|
||||||
|
accountKey = buildKiroAccountKey(account)
|
||||||
|
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
payload = buildResult.Payload
|
||||||
|
requestCtx = buildResult.Context
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, requestCtx, sleepErr
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if refreshErr != nil && isNonRetryableRefreshError(refreshErr) {
|
||||||
|
resetHTTPResponseBody(resp, respBody)
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if classifyKiroHTTPError(resp.StatusCode, string(respBody)).Category == kiroErrorAuthError {
|
||||||
|
s.markKiroAuthTemporarilyUnavailable(ctx, account, resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
resetHTTPResponseBody(resp, respBody)
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, requestCtx, readErr
|
||||||
|
}
|
||||||
|
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||||
|
logKiroBadRequestClassification(classification, account, mappedModel, resp.Header, respBody)
|
||||||
|
resetHTTPResponseBody(resp, respBody)
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, requestCtx, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp, requestCtx, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, requestCtx, fmt.Errorf("kiro upstream endpoints exhausted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroEndpoints(account *Account) []kiroEndpointConfig {
|
||||||
|
region := kiroAPIRegion(account)
|
||||||
|
return []kiroEndpointConfig{
|
||||||
|
{
|
||||||
|
URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
|
||||||
|
Name: "AmazonQ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) ([]byte, error) {
|
||||||
|
result, err := buildKiroPayloadForAccountWithRepo(ctx, nil, account, anthropicBody, modelID, token, requestModel, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result.Payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
|
||||||
|
profileArn := resolveKiroPayloadProfileArn(account)
|
||||||
|
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
|
||||||
|
if s == nil || s.accountRepo == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
until := time.Now().Add(10 * time.Minute)
|
||||||
|
reason := fmt.Sprintf("kiro auth failure (%d): %s", statusCode, strings.TrimSpace(body))
|
||||||
|
_ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) markKiroMonthlyRequestCountRateLimited(ctx context.Context, account *Account, body string) {
|
||||||
|
if s == nil || s.accountRepo == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resetAt := nextKiroMonthlyResetUTC(time.Now())
|
||||||
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
|
logger.L().Warn("kiro monthly request count rate-limit failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Time("reset_at", resetAt),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reason := "kiro monthly request count exhausted (402): MONTHLY_REQUEST_COUNT"
|
||||||
|
if trimmed := strings.TrimSpace(body); trimmed != "" {
|
||||||
|
reason = fmt.Sprintf("%s body=%s", reason, truncateForLog([]byte(trimmed), 512))
|
||||||
|
}
|
||||||
|
logger.L().Warn("kiro monthly request count rate-limited",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Time("reset_at", resetAt),
|
||||||
|
zap.String("reason", reason),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextKiroMonthlyResetUTC(now time.Time) time.Time {
|
||||||
|
utc := now.UTC()
|
||||||
|
year, month, _ := utc.Date()
|
||||||
|
return time.Date(year, month+1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetHTTPResponseBody(resp *http.Response, body []byte) {
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
resp.ContentLength = int64(len(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimateKiroInputTokens(body []byte) int {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if tokens := gjson.GetBytes(body, "metadata.input_tokens").Int(); tokens > 0 {
|
||||||
|
return int(tokens)
|
||||||
|
}
|
||||||
|
tokens := len(body) / 4
|
||||||
|
if tokens == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroUsageToClaude(usage kiropkg.Usage, fallbackInput int) ClaudeUsage {
|
||||||
|
inputTokens := usage.InputTokens
|
||||||
|
if inputTokens == 0 {
|
||||||
|
inputTokens = fallbackInput
|
||||||
|
}
|
||||||
|
return ClaudeUsage{
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: usage.OutputTokens,
|
||||||
|
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) markKiroInvalidModelRateLimited(ctx context.Context, account *Account, mappedModel string) {
|
||||||
|
if s == nil || s.accountRepo == nil || account == nil || account.Type != AccountTypeOAuth {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resetAt := time.Now().Add(kiroInvalidModelTempUnschedDuration)
|
||||||
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
|
logger.L().Warn("kiro invalid model rate-limit failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
|
||||||
|
zap.Time("reset_at", resetAt),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.L().Warn("kiro invalid model rate-limited",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
|
||||||
|
zap.Time("reset_at", resetAt),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) handleKiroHTTPError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, mappedModel string, requestBody []byte) error {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
upstreamMsg = strings.TrimSpace(string(respBody))
|
||||||
|
}
|
||||||
|
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
logKiroBadRequestClassification(classification, account, "", resp.Header, respBody)
|
||||||
|
}
|
||||||
|
if classification.Category == kiroErrorMonthlyRequest {
|
||||||
|
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
|
||||||
|
}
|
||||||
|
if classification.Category == kiroErrorBadRequestInvalidModel && account != nil && account.Type == AccountTypeOAuth {
|
||||||
|
s.markKiroInvalidModelRateLimited(ctx, account, mappedModel)
|
||||||
|
event := s.buildKiroInvalidModelUpstreamEvent(account, resp, upstreamMsg, mappedModel, requestBody, c)
|
||||||
|
appendOpsUpstreamError(c, event)
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
ResponseHeaders: resp.Header.Clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusPaymentRequired || s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
})
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
}
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
ResponseHeaders: resp.Header.Clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
})
|
||||||
|
c.JSON(mapUpstreamStatusCode(resp.StatusCode), gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": coalesceKiroErrorMessage(resp.StatusCode, upstreamMsg),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return fmt.Errorf("kiro upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) buildKiroInvalidModelUpstreamEvent(account *Account, resp *http.Response, upstreamMsg, mappedModel string, requestBody []byte, c *gin.Context) OpsUpstreamErrorEvent {
|
||||||
|
_ = s
|
||||||
|
requestedModel := strings.TrimSpace(gjson.GetBytes(requestBody, "model").String())
|
||||||
|
hasTools := gjson.GetBytes(requestBody, "tools").Exists()
|
||||||
|
hasAdaptiveThinking := strings.EqualFold(strings.TrimSpace(gjson.GetBytes(requestBody, "thinking.type").String()), "adaptive")
|
||||||
|
hasContext1MBeta := false
|
||||||
|
if c != nil {
|
||||||
|
hasContext1MBeta = strings.Contains(c.GetHeader("Anthropic-Beta"), "context-1m")
|
||||||
|
}
|
||||||
|
return OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
RequestedModel: requestedModel,
|
||||||
|
MappedModel: strings.TrimSpace(mappedModel),
|
||||||
|
KiroModelID: kiropkg.MapModel(mappedModel),
|
||||||
|
HasTools: hasTools,
|
||||||
|
HasAdaptiveThinking: hasAdaptiveThinking,
|
||||||
|
HasContext1MBeta: hasContext1MBeta,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logKiroBadRequestClassification(classification kiroErrorClassification, account *Account, model string, headers http.Header, body []byte) {
|
||||||
|
if classification.StatusCode != http.StatusBadRequest {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var accountID int64
|
||||||
|
if account != nil {
|
||||||
|
accountID = account.ID
|
||||||
|
}
|
||||||
|
logger.L().Warn("kiro upstream bad request classified",
|
||||||
|
zap.String("category", classification.Category),
|
||||||
|
zap.Int("status", classification.StatusCode),
|
||||||
|
zap.Int64("account_id", accountID),
|
||||||
|
zap.String("model", strings.TrimSpace(model)),
|
||||||
|
zap.String("request_id", headers.Get("x-request-id")),
|
||||||
|
zap.String("body_excerpt", truncateForLog(body, 512)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func coalesceKiroErrorMessage(statusCode int, upstreamMsg string) string {
|
||||||
|
if upstreamMsg != "" {
|
||||||
|
return upstreamMsg
|
||||||
|
}
|
||||||
|
switch statusCode {
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "Rate limit exceeded"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "Access denied"
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "Authentication failed"
|
||||||
|
default:
|
||||||
|
return "Upstream request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errKiroCooldownStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||||
|
|
||||||
|
type KiroCooldownStore interface {
|
||||||
|
ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||||
|
MarkSuccess(ctx context.Context, tokenKey string) error
|
||||||
|
Mark429(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||||
|
MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||||
|
GetState(ctx context.Context, tokenKey string) (*kirocooldown.State, error)
|
||||||
|
ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func asKiroCooldownFailoverError(err error) *UpstreamFailoverError {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var cooldownErr *kirocooldown.Error
|
||||||
|
if !errors.As(err, &cooldownErr) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
ResponseBody: []byte(cooldownErr.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) checkAndWaitKiroCooldown(ctx context.Context, tokenKey string) error {
|
||||||
|
if s == nil || s.kiroCooldownStore == nil {
|
||||||
|
return errKiroCooldownStoreUnavailable
|
||||||
|
}
|
||||||
|
waitFor, err := s.kiroCooldownStore.ReserveRequest(ctx, tokenKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if waitFor <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(waitFor)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
if !timer.Stop() {
|
||||||
|
<-timer.C
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) markKiroSuccess(ctx context.Context, tokenKey string) error {
|
||||||
|
if s == nil || s.kiroCooldownStore == nil {
|
||||||
|
return errKiroCooldownStoreUnavailable
|
||||||
|
}
|
||||||
|
return s.kiroCooldownStore.MarkSuccess(ctx, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) markKiro429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||||
|
if s == nil || s.kiroCooldownStore == nil {
|
||||||
|
return 0, errKiroCooldownStoreUnavailable
|
||||||
|
}
|
||||||
|
return s.kiroCooldownStore.Mark429(ctx, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) markKiroSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||||
|
if s == nil || s.kiroCooldownStore == nil {
|
||||||
|
return 0, errKiroCooldownStoreUnavailable
|
||||||
|
}
|
||||||
|
return s.kiroCooldownStore.MarkSuspended(ctx, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) getKiroCooldownState(ctx context.Context, tokenKey string) (*kirocooldown.State, error) {
|
||||||
|
if s == nil || s.kiroCooldownStore == nil {
|
||||||
|
return nil, errKiroCooldownStoreUnavailable
|
||||||
|
}
|
||||||
|
return s.kiroCooldownStore.GetState(ctx, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroRuntimeStateSnapshot(state *kirocooldown.State) (string, string, *time.Time) {
|
||||||
|
if state == nil || !state.Active {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
resetAt := state.CooldownUntil
|
||||||
|
switch state.Reason {
|
||||||
|
case kirocooldown.CooldownReasonSuspended:
|
||||||
|
return "suspended", state.Reason, &resetAt
|
||||||
|
default:
|
||||||
|
return "cooldown", state.Reason, &resetAt
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,192 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kiroCooldownRedisImageTag = "redis:8.4-alpine"
|
||||||
|
|
||||||
|
func TestRedisKiroCooldownStoreSharesCooldownAcrossInstances(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startKiroCooldownRedis(t, ctx)
|
||||||
|
storeA := kirocooldown.NewStore(rdb)
|
||||||
|
storeB := kirocooldown.NewStore(rdb)
|
||||||
|
|
||||||
|
cooldown, err := storeA.Mark429(ctx, "token-shared")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, time.Minute, cooldown)
|
||||||
|
|
||||||
|
wait, err := storeB.ReserveRequest(ctx, "token-shared")
|
||||||
|
require.Zero(t, wait)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), kirocooldown.CooldownReason429)
|
||||||
|
|
||||||
|
require.NoError(t, storeB.MarkSuccess(ctx, "token-shared"))
|
||||||
|
|
||||||
|
wait, err = storeA.ReserveRequest(ctx, "token-shared")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.GreaterOrEqual(t, wait, 0*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKiroCooldownStoreSharesReservationAcrossInstances(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startKiroCooldownRedis(t, ctx)
|
||||||
|
storeA := kirocooldown.NewStore(rdb)
|
||||||
|
storeB := kirocooldown.NewStore(rdb)
|
||||||
|
|
||||||
|
wait, err := storeA.ReserveRequest(ctx, "token-rate")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, wait)
|
||||||
|
|
||||||
|
wait, err = storeB.ReserveRequest(ctx, "token-rate")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, wait, 0*time.Millisecond)
|
||||||
|
require.LessOrEqual(t, wait, kirocooldown.MaxRequestInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKiroCooldownStoreSharesSuspendedStateAcrossInstances(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startKiroCooldownRedis(t, ctx)
|
||||||
|
storeA := kirocooldown.NewStore(rdb)
|
||||||
|
storeB := kirocooldown.NewStore(rdb)
|
||||||
|
|
||||||
|
cooldown, err := storeA.MarkSuspended(ctx, "token-suspended")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, kirocooldown.LongCooldown, cooldown)
|
||||||
|
|
||||||
|
wait, err := storeB.ReserveRequest(ctx, "token-suspended")
|
||||||
|
require.Zero(t, wait)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), kirocooldown.CooldownReasonSuspended)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKiroCooldownStoreSuspendedResetsFailCount(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startKiroCooldownRedis(t, ctx)
|
||||||
|
store := kirocooldown.NewStore(rdb)
|
||||||
|
|
||||||
|
_, err := store.Mark429(ctx, "token-reset")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = store.Mark429(ctx, "token-reset")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cooldown, err := store.MarkSuspended(ctx, "token-reset")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, kirocooldown.LongCooldown, cooldown)
|
||||||
|
|
||||||
|
cooldown, err = store.Mark429(ctx, "token-reset")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, time.Minute, cooldown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKiroCooldownStoreReserveDifferentTokenIgnoresOldCooldown(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startKiroCooldownRedis(t, ctx)
|
||||||
|
store := kirocooldown.NewStore(rdb)
|
||||||
|
|
||||||
|
_, err := store.Mark429(ctx, "token-old")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wait, err := store.ReserveRequest(ctx, "token-new")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKiroCooldownStoreUsesExpectedTTLs(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startKiroCooldownRedis(t, ctx)
|
||||||
|
store := kirocooldown.NewStore(rdb)
|
||||||
|
|
||||||
|
_, err := store.ReserveRequest(ctx, "token-ttl-active")
|
||||||
|
require.NoError(t, err)
|
||||||
|
activeTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-active")).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, activeTTL, 0*time.Second)
|
||||||
|
require.LessOrEqual(t, activeTTL, kirocooldown.ActiveTTL())
|
||||||
|
|
||||||
|
_, err = store.MarkSuspended(ctx, "token-ttl-state")
|
||||||
|
require.NoError(t, err)
|
||||||
|
stateTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-state")).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, stateTTL, 24*time.Hour)
|
||||||
|
require.LessOrEqual(t, stateTTL, kirocooldown.StateTTL())
|
||||||
|
}
|
||||||
|
|
||||||
|
func startKiroCooldownRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||||
|
t.Helper()
|
||||||
|
ensureKiroCooldownDockerAvailable(t)
|
||||||
|
|
||||||
|
redisContainer, err := tcredis.Run(ctx, kiroCooldownRedisImageTag)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = redisContainer.Terminate(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
host, err := redisContainer.Host(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", host, port.Int()),
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
require.NoError(t, rdb.Ping(ctx).Err())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
return rdb
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureKiroCooldownDockerAvailable(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if kiroCooldownDockerAvailable() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Skip("Docker 未启用,跳过依赖 testcontainers 的 Kiro cooldown 集成测试")
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroCooldownDockerAvailable() bool {
|
||||||
|
if os.Getenv("DOCKER_HOST") != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
socketCandidates := []string{
|
||||||
|
"/var/run/docker.sock",
|
||||||
|
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||||
|
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||||
|
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||||
|
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, socket := range socketCandidates {
|
||||||
|
if socket == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(socket); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroCooldownUserHomeDir() string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return home
|
||||||
|
}
|
||||||
@@ -0,0 +1,583 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubKiroCooldownStore struct {
|
||||||
|
reserveWait time.Duration
|
||||||
|
reserveErr error
|
||||||
|
successErr error
|
||||||
|
mark429TTL time.Duration
|
||||||
|
mark429Err error
|
||||||
|
suspendedTTL time.Duration
|
||||||
|
suspendedErr error
|
||||||
|
state *kirocooldown.State
|
||||||
|
stateErr error
|
||||||
|
clearCalled bool
|
||||||
|
clearKeys []string
|
||||||
|
clearResult bool
|
||||||
|
clearErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordingKiroTempUnschedRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
called bool
|
||||||
|
id int64
|
||||||
|
until time.Time
|
||||||
|
reason string
|
||||||
|
rateCalled bool
|
||||||
|
rateID int64
|
||||||
|
rateLimitReset time.Time
|
||||||
|
rateLimitedCall int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recordingKiroTempUnschedRepo) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
r.called = true
|
||||||
|
r.id = id
|
||||||
|
r.until = until
|
||||||
|
r.reason = reason
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recordingKiroTempUnschedRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||||
|
r.rateCalled = true
|
||||||
|
r.rateID = id
|
||||||
|
r.rateLimitReset = resetAt
|
||||||
|
r.rateLimitedCall++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordingKiroErrorRepo struct {
|
||||||
|
recordingKiroTempUnschedRepo
|
||||||
|
setErrorCalls int
|
||||||
|
errorID int64
|
||||||
|
errorMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recordingKiroErrorRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||||
|
r.setErrorCalls++
|
||||||
|
r.errorID = id
|
||||||
|
r.errorMsg = errorMsg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||||
|
return s.reserveWait, s.reserveErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroCooldownStore) MarkSuccess(context.Context, string) error {
|
||||||
|
return s.successErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||||
|
return s.mark429TTL, s.mark429Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||||
|
return s.suspendedTTL, s.suspendedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||||
|
if s.clearCalled && s.clearResult {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.state, s.stateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroCooldownStore) ClearEarliestTransientCooldown(_ context.Context, tokenKeys []string) (bool, error) {
|
||||||
|
s.clearCalled = true
|
||||||
|
s.clearKeys = append([]string(nil), tokenKeys...)
|
||||||
|
return s.clearResult, s.clearErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateKiro429Cooldown(t *testing.T) {
|
||||||
|
require.Equal(t, time.Minute, kirocooldown.Calculate429Cooldown(0))
|
||||||
|
require.Equal(t, 2*time.Minute, kirocooldown.Calculate429Cooldown(1))
|
||||||
|
require.Equal(t, 4*time.Minute, kirocooldown.Calculate429Cooldown(2))
|
||||||
|
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(3))
|
||||||
|
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceCheckAndWaitKiroCooldownReturnsNilWithoutWait(t *testing.T) {
|
||||||
|
svc := &GatewayService{
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, svc.checkAndWaitKiroCooldown(context.Background(), "token1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceCheckAndWaitKiroCooldownPropagatesReserveError(t *testing.T) {
|
||||||
|
expected := errors.New("redis unavailable")
|
||||||
|
svc := &GatewayService{
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{reserveErr: expected},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
|
||||||
|
require.ErrorIs(t, err, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceCheckAndWaitKiroCooldownRequiresStore(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
|
||||||
|
require.ErrorIs(t, err, errKiroCooldownStoreUnavailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceCheckAndWaitKiroCooldownWaitsAndHonorsContext(t *testing.T) {
|
||||||
|
svc := &GatewayService{
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{reserveWait: 200 * time.Millisecond},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := svc.checkAndWaitKiroCooldown(ctx, "token1")
|
||||||
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAsKiroCooldownFailoverError(t *testing.T) {
|
||||||
|
err := kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429)
|
||||||
|
|
||||||
|
var cooldownErr *kirocooldown.Error
|
||||||
|
require.ErrorAs(t, err, &cooldownErr)
|
||||||
|
|
||||||
|
failoverErr := asKiroCooldownFailoverError(err)
|
||||||
|
require.NotNil(t, failoverErr)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||||
|
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
|
||||||
|
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAsKiroCooldownFailoverErrorIgnoresNonCooldownErrors(t *testing.T) {
|
||||||
|
require.Nil(t, asKiroCooldownFailoverError(errors.New("redis unavailable")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceTryRecoverKiroCooldownPoolClearsOnlyTransientCooldown(t *testing.T) {
|
||||||
|
store := &stubKiroCooldownStore{
|
||||||
|
state: &kirocooldown.State{
|
||||||
|
Active: true,
|
||||||
|
Reason: kirocooldown.CooldownReason429,
|
||||||
|
CooldownUntil: time.Now().Add(time.Minute),
|
||||||
|
Remaining: time.Minute,
|
||||||
|
},
|
||||||
|
clearResult: true,
|
||||||
|
}
|
||||||
|
svc := &GatewayService{kiroCooldownStore: store}
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
|
||||||
|
require.True(t, recovered)
|
||||||
|
require.True(t, store.clearCalled)
|
||||||
|
require.Len(t, store.clearKeys, 1)
|
||||||
|
require.Equal(t, buildKiroAccountKey(&accounts[0]), store.clearKeys[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceTryRecoverKiroCooldownPoolSkipsSuspended(t *testing.T) {
|
||||||
|
store := &stubKiroCooldownStore{
|
||||||
|
state: &kirocooldown.State{
|
||||||
|
Active: true,
|
||||||
|
Reason: kirocooldown.CooldownReasonSuspended,
|
||||||
|
CooldownUntil: time.Now().Add(time.Hour),
|
||||||
|
Remaining: time.Hour,
|
||||||
|
},
|
||||||
|
clearResult: true,
|
||||||
|
}
|
||||||
|
svc := &GatewayService{kiroCooldownStore: store}
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
|
||||||
|
require.False(t, recovered)
|
||||||
|
require.False(t, store.clearCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectAccountWithLoadAwarenessRecoversKiroCooldownPool(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||||
|
|
||||||
|
account := Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
store := &stubKiroCooldownStore{
|
||||||
|
state: &kirocooldown.State{
|
||||||
|
Active: true,
|
||||||
|
Reason: kirocooldown.CooldownReason429,
|
||||||
|
CooldownUntil: time.Now().Add(time.Minute),
|
||||||
|
Remaining: time.Minute,
|
||||||
|
},
|
||||||
|
clearResult: true,
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: &mockAccountRepoForGemini{accounts: []Account{account}},
|
||||||
|
concurrencyService: NewConcurrencyService(&mockConcurrencyCache{}),
|
||||||
|
cfg: cfg,
|
||||||
|
kiroCooldownStore: store,
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformKiro)
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "", nil, "", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, account.ID, result.Account.ID)
|
||||||
|
require.True(t, store.clearCalled)
|
||||||
|
require.Equal(t, []string{buildKiroAccountKey(&account)}, store.clearKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyKiroHTTPErrorMonthlyRequestCount(t *testing.T) {
|
||||||
|
tests := []string{
|
||||||
|
`{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
|
||||||
|
`{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}`,
|
||||||
|
`API returned 402: {"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, body := range tests {
|
||||||
|
classification := classifyKiroHTTPError(http.StatusPaymentRequired, body)
|
||||||
|
require.Equal(t, kiroErrorMonthlyRequest, classification.Category)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyKiroHTTPErrorPlain402IsTransient(t *testing.T) {
|
||||||
|
classification := classifyKiroHTTPError(http.StatusPaymentRequired, `{"message":"payment required"}`)
|
||||||
|
require.Equal(t, kiroErrorUpstreamTransient, classification.Category)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
|
||||||
|
svc := &GatewayService{
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{
|
||||||
|
reserveErr: kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||||
|
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
|
||||||
|
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
httpUpstream: upstream,
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := createTestPayload("claude-opus-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
require.Len(t, upstream.requests, 1)
|
||||||
|
|
||||||
|
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
|
||||||
|
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleKiroHTTPErrorOAuthInvalidModelRateLimitsAndFailovers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "context-1m-2025-08-07")
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Name: "kiro-oauth",
|
||||||
|
}
|
||||||
|
repo := &recordingKiroTempUnschedRepo{}
|
||||||
|
svc := &GatewayService{accountRepo: repo}
|
||||||
|
requestBody := []byte(`{"model":"claude-opus-4-7","tools":[{"name":"search"}],"thinking":{"type":"adaptive"}}`)
|
||||||
|
resp := newJSONResponse(http.StatusBadRequest, `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`)
|
||||||
|
resp.Header.Set("x-request-id", "req-invalid-model")
|
||||||
|
|
||||||
|
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", requestBody)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), "Invalid model")
|
||||||
|
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||||
|
|
||||||
|
require.False(t, repo.called)
|
||||||
|
require.True(t, repo.rateCalled)
|
||||||
|
require.Equal(t, account.ID, repo.rateID)
|
||||||
|
require.WithinDuration(t, time.Now().Add(kiroInvalidModelTempUnschedDuration), repo.rateLimitReset, 5*time.Second)
|
||||||
|
|
||||||
|
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 1)
|
||||||
|
require.Equal(t, PlatformKiro, events[0].Platform)
|
||||||
|
require.Equal(t, account.ID, events[0].AccountID)
|
||||||
|
require.Equal(t, account.Name, events[0].AccountName)
|
||||||
|
require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
|
||||||
|
require.Equal(t, "req-invalid-model", events[0].UpstreamRequestID)
|
||||||
|
require.Equal(t, "failover", events[0].Kind)
|
||||||
|
require.Equal(t, "claude-opus-4-7", events[0].RequestedModel)
|
||||||
|
require.Equal(t, "claude-opus-4.6", events[0].MappedModel)
|
||||||
|
require.Equal(t, "claude-opus-4.6", events[0].KiroModelID)
|
||||||
|
require.True(t, events[0].HasTools)
|
||||||
|
require.True(t, events[0].HasAdaptiveThinking)
|
||||||
|
require.True(t, events[0].HasContext1MBeta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleKiroHTTPErrorAPIKeyInvalidModelDoesNotFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
account := &Account{
|
||||||
|
ID: 43,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
repo := &recordingKiroTempUnschedRepo{}
|
||||||
|
svc := &GatewayService{accountRepo: repo}
|
||||||
|
resp := newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model. Please select a different model to continue."}`)
|
||||||
|
|
||||||
|
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", []byte(`{"model":"claude-opus-4-7"}`))
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.NotErrorAs(t, err, &failoverErr)
|
||||||
|
require.False(t, repo.called)
|
||||||
|
require.False(t, repo.rateCalled)
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextKiroMonthlyResetUTC(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
now time.Time
|
||||||
|
want time.Time
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "middle of month",
|
||||||
|
now: time.Date(2026, time.April, 27, 10, 30, 45, 123, time.FixedZone("CST", 8*3600)),
|
||||||
|
want: time.Date(2026, time.May, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "december rolls year",
|
||||||
|
now: time.Date(2026, time.December, 31, 23, 59, 59, 0, time.UTC),
|
||||||
|
want: time.Date(2027, time.January, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, nextKiroMonthlyResetUTC(tt.now))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailovers(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
repo := &recordingKiroTempUnschedRepo{}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusPaymentRequired, `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
httpUpstream: upstream,
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), "MONTHLY_REQUEST_COUNT")
|
||||||
|
require.False(t, repo.called)
|
||||||
|
require.True(t, repo.rateCalled)
|
||||||
|
require.Equal(t, account.ID, repo.rateID)
|
||||||
|
require.Equal(t, nextKiroMonthlyResetUTC(time.Now()), repo.rateLimitReset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
repo := &recordingKiroTempUnschedRepo{}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusPaymentRequired, `{"message":"payment required"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
httpUpstream: upstream,
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
|
||||||
|
require.False(t, repo.called)
|
||||||
|
require.False(t, repo.rateCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedule(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"refresh_token": "old-refresh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &recordingKiroErrorRepo{
|
||||||
|
recordingKiroTempUnschedRepo: recordingKiroTempUnschedRepo{
|
||||||
|
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
||||||
|
accountsByID: map[int64]*Account{account.ID: account},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusUnauthorized, `{"message":"token expired"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||||
|
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
httpUpstream: upstream,
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||||
|
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||||
|
kiroTokenProvider: provider,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||||
|
require.NoError(t, err)
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Equal(t, account.ID, repo.errorID)
|
||||||
|
require.Contains(t, repo.errorMsg, "invalid_grant")
|
||||||
|
require.False(t, repo.called, "non-retryable refresh errors should not mark temporary unschedulable")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceIsAccountSchedulableForSelectionSkipsActiveKiroCooldown(t *testing.T) {
|
||||||
|
now := time.Now().Add(2 * time.Minute)
|
||||||
|
svc := &GatewayService{
|
||||||
|
kiroCooldownStore: &stubKiroCooldownStore{
|
||||||
|
state: &kirocooldown.State{
|
||||||
|
Active: true,
|
||||||
|
Reason: kirocooldown.CooldownReason429,
|
||||||
|
CooldownUntil: now,
|
||||||
|
Remaining: 2 * time.Minute,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
require.False(t, svc.isAccountSchedulableForSelection(account))
|
||||||
|
}
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
kiroTokenRefreshSkew = 3 * time.Minute
|
||||||
|
kiroTokenCacheSkew = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type KiroTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
|
type kiroAccountTokenRefresher interface {
|
||||||
|
RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error)
|
||||||
|
BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type KiroTokenProvider struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
tokenCache KiroTokenCache
|
||||||
|
kiroOAuthService kiroAccountTokenRefresher
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKiroTokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache KiroTokenCache,
|
||||||
|
kiroOAuthService *KiroOAuthService,
|
||||||
|
) *KiroTokenProvider {
|
||||||
|
return &KiroTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: tokenCache,
|
||||||
|
kiroOAuthService: kiroOAuthService,
|
||||||
|
refreshPolicy: GeminiProviderRefreshPolicy(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KiroTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KiroTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KiroTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||||
|
return "", errors.New("not a kiro oauth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := KiroTokenCacheKey(account)
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= kiroTokenRefreshSkew
|
||||||
|
|
||||||
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, kiroTokenRefreshSkew)
|
||||||
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if result.LockHeld {
|
||||||
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := account.GetCredential("access_token")
|
||||||
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
|
return "", errors.New("access_token not found in credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
|
if isStale && latestAccount != nil {
|
||||||
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
|
return "", errors.New("access_token not found after version check")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > kiroTokenCacheSkew:
|
||||||
|
ttl = until - kiroTokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func KiroTokenCacheKey(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return "kiro:account:0"
|
||||||
|
}
|
||||||
|
if clientIDHash := strings.TrimSpace(account.GetCredential("client_id_hash")); clientIDHash != "" {
|
||||||
|
return "kiro:" + clientIDHash
|
||||||
|
}
|
||||||
|
if clientID := strings.TrimSpace(account.GetCredential("client_id")); clientID != "" {
|
||||||
|
return "kiro:client:" + clientID
|
||||||
|
}
|
||||||
|
return "kiro:account:" + strconv.FormatInt(account.ID, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KiroTokenProvider) ForceRefreshAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||||
|
return "", errors.New("not a kiro oauth account")
|
||||||
|
}
|
||||||
|
if p.kiroOAuthService == nil {
|
||||||
|
return "", errors.New("kiro oauth service is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := KiroTokenCacheKey(account)
|
||||||
|
lockHeld := false
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
lockHeld = true
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.accountRepo != nil {
|
||||||
|
if latestAccount, err := p.accountRepo.GetByID(ctx, account.ID); err == nil && latestAccount != nil {
|
||||||
|
account = latestAccount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := p.kiroOAuthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
if !lockHeld {
|
||||||
|
if latestAccount, stale := CheckTokenVersion(ctx, account, p.accountRepo); stale && latestAccount != nil {
|
||||||
|
account = latestAccount
|
||||||
|
if accessToken := strings.TrimSpace(account.GetCredential("access_token")); accessToken != "" {
|
||||||
|
_ = p.cacheAccessToken(ctx, account, accessToken)
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isNonRetryableRefreshError(err) && p.accountRepo != nil {
|
||||||
|
errorMsg := "Token refresh failed (non-retryable): " + err.Error()
|
||||||
|
_ = p.accountRepo.SetError(ctx, account.ID, errorMsg)
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
newCredentials := MergeCredentials(account.Credentials, p.kiroOAuthService.BuildAccountCredentials(tokenInfo))
|
||||||
|
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||||
|
if err := persistAccountCredentials(ctx, p.accountRepo, account, newCredentials); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||||
|
if accessToken == "" {
|
||||||
|
accessToken = strings.TrimSpace(tokenInfo.AccessToken)
|
||||||
|
}
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", errors.New("access_token not found after kiro refresh")
|
||||||
|
}
|
||||||
|
if err := p.cacheAccessToken(ctx, account, accessToken); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KiroTokenProvider) cacheAccessToken(ctx context.Context, account *Account, accessToken string) error {
|
||||||
|
if p.tokenCache == nil || account == nil || strings.TrimSpace(accessToken) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > kiroTokenCacheSkew:
|
||||||
|
ttl = until - kiroTokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.tokenCache.SetAccessToken(ctx, KiroTokenCacheKey(account), accessToken, ttl)
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type kiroTokenProviderRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
setErrorCalls int
|
||||||
|
setErrorID int64
|
||||||
|
setErrorMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||||
|
r.setErrorCalls++
|
||||||
|
r.setErrorID = id
|
||||||
|
r.setErrorMsg = errorMsg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroTokenProviderSequenceRepo struct {
|
||||||
|
kiroTokenProviderRepo
|
||||||
|
accounts []*Account
|
||||||
|
reads int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||||
|
if len(r.accounts) == 0 {
|
||||||
|
return nil, errors.New("account not found")
|
||||||
|
}
|
||||||
|
idx := r.reads
|
||||||
|
if idx >= len(r.accounts) {
|
||||||
|
idx = len(r.accounts) - 1
|
||||||
|
}
|
||||||
|
r.reads++
|
||||||
|
return r.accounts[idx], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubKiroAccountTokenRefresher struct {
|
||||||
|
tokenInfo *KiroTokenInfo
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
||||||
|
if tokenInfo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"access_token": tokenInfo.AccessToken,
|
||||||
|
"expires_at": tokenInfo.ExpiresAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
||||||
|
}
|
||||||
|
repo := &kiroTokenProviderRepo{
|
||||||
|
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
||||||
|
accountsByID: map[int64]*Account{account.ID: account},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||||
|
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||||
|
|
||||||
|
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, token)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Equal(t, account.ID, repo.setErrorID)
|
||||||
|
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
|
||||||
|
require.Contains(t, repo.setErrorMsg, "invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
|
||||||
|
usedAccount := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
||||||
|
}
|
||||||
|
latestAccount := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformKiro,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
|
||||||
|
}
|
||||||
|
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
|
||||||
|
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||||
|
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||||
|
|
||||||
|
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "fresh-access", token)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls)
|
||||||
|
}
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kiroRefreshWindow = 15 * time.Minute
|
||||||
|
|
||||||
|
type KiroTokenRefresher struct {
|
||||||
|
kiroOAuthService *KiroOAuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKiroTokenRefresher(kiroOAuthService *KiroOAuthService) *KiroTokenRefresher {
|
||||||
|
return &KiroTokenRefresher{
|
||||||
|
kiroOAuthService: kiroOAuthService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *KiroTokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return KiroTokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *KiroTokenRefresher) CanRefresh(account *Account) bool {
|
||||||
|
return account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *KiroTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||||
|
if !r.CanRefresh(account) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
if expiresAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Until(*expiresAt) <= kiroRefreshWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *KiroTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||||
|
tokenInfo, err := r.kiroOAuthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newCredentials := r.kiroOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
return MergeCredentials(account.Credentials, newCredentials), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,608 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
kiroUsageOrigin = "AI_EDITOR"
|
||||||
|
kiroUsageResourceType = "AGENTIC_REQUEST"
|
||||||
|
|
||||||
|
kiroDefaultRegion = "us-east-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var resolveKiroRuntimeEndpoint = kiroRuntimeEndpoint
|
||||||
|
|
||||||
|
type kiroUsageLimitsResponse struct {
|
||||||
|
NextDateReset any `json:"nextDateReset"`
|
||||||
|
OverageConfiguration kiroOverageConfiguration `json:"overageConfiguration"`
|
||||||
|
SubscriptionInfo kiroSubscriptionInfo `json:"subscriptionInfo"`
|
||||||
|
UsageBreakdownList []kiroUsageBreakdown `json:"usageBreakdownList"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroOverageConfiguration struct {
|
||||||
|
OverageStatus string `json:"overageStatus"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroSubscriptionInfo struct {
|
||||||
|
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroUsageBreakdown struct {
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
CurrentOverages *float64 `json:"currentOverages"`
|
||||||
|
CurrentOveragesWithPrecision *float64 `json:"currentOveragesWithPrecision"`
|
||||||
|
CurrentUsage *float64 `json:"currentUsage"`
|
||||||
|
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
DisplayNamePlural string `json:"displayNamePlural"`
|
||||||
|
FreeTrialInfo *kiroFreeTrialInfo `json:"freeTrialInfo"`
|
||||||
|
NextDateReset any `json:"nextDateReset"`
|
||||||
|
OverageCharges *float64 `json:"overageCharges"`
|
||||||
|
ResourceType string `json:"resourceType"`
|
||||||
|
UsageLimit *float64 `json:"usageLimit"`
|
||||||
|
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroFreeTrialInfo struct {
|
||||||
|
CurrentUsage *float64 `json:"currentUsage"`
|
||||||
|
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
|
||||||
|
FreeTrialExpiry any `json:"freeTrialExpiry"`
|
||||||
|
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||||
|
UsageLimit *float64 `json:"usageLimit"`
|
||||||
|
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroUsageHTTPError struct {
|
||||||
|
StatusCode int
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *kiroUsageHTTPError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return "kiro usage request failed"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.Body) == "" {
|
||||||
|
return fmt.Sprintf("kiro usage request failed (status %d)", e.StatusCode)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("kiro usage request failed (status %d): %s", e.StatusCode, e.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) getKiroUsage(ctx context.Context, account *Account, source string, forceRefresh bool) (*UsageInfo, error) {
|
||||||
|
now := time.Now()
|
||||||
|
if account == nil {
|
||||||
|
return &UsageInfo{
|
||||||
|
Source: source,
|
||||||
|
UpdatedAt: &now,
|
||||||
|
Error: "account is nil",
|
||||||
|
ErrorCode: errorCodeNetworkError,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||||
|
return &UsageInfo{
|
||||||
|
Source: source,
|
||||||
|
UpdatedAt: &now,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cached, hasCached := s.getCachedKiroUsage(account.ID)
|
||||||
|
if hasCached && (cached.ErrorCode != "" || cached.Error != "") {
|
||||||
|
cached.Source = source
|
||||||
|
s.attachKiroRuntimeState(ctx, account, cached)
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
if !forceRefresh && hasCached {
|
||||||
|
cached.Source = source
|
||||||
|
s.attachKiroRuntimeState(ctx, account, cached)
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flightKey := fmt.Sprintf("kiro-usage:%d", account.ID)
|
||||||
|
result, fetchErr, _ := s.cache.kiroUsageFlight.Do(flightKey, func() (any, error) {
|
||||||
|
if !forceRefresh {
|
||||||
|
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
usage, err := s.fetchAndCacheKiroUsage(ctx, account, source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return usage, nil
|
||||||
|
})
|
||||||
|
if fetchErr == nil {
|
||||||
|
if usage, ok := result.(*UsageInfo); ok && usage != nil {
|
||||||
|
usage.Source = source
|
||||||
|
s.attachKiroRuntimeState(ctx, account, usage)
|
||||||
|
if source == "active" {
|
||||||
|
s.tryClearRecoverableAccountError(ctx, account)
|
||||||
|
}
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
degraded := buildKiroDegradedUsage(fetchErr)
|
||||||
|
degraded.Source = source
|
||||||
|
if hasCached {
|
||||||
|
cached.Error = degraded.Error
|
||||||
|
cached.ErrorCode = degraded.ErrorCode
|
||||||
|
cached.NeedsReauth = degraded.NeedsReauth
|
||||||
|
cached.KiroQuotaState = degraded.KiroQuotaState
|
||||||
|
cached.KiroQuotaReason = degraded.KiroQuotaReason
|
||||||
|
cached.KiroQuotaResetAt = degraded.KiroQuotaResetAt
|
||||||
|
cached.Source = source
|
||||||
|
s.attachKiroRuntimeState(ctx, account, cached)
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
s.storeKiroUsageSnapshot(account.ID, degraded)
|
||||||
|
s.attachKiroRuntimeState(ctx, account, degraded)
|
||||||
|
return degraded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) fetchAndCacheKiroUsage(ctx context.Context, account *Account, source string) (*UsageInfo, error) {
|
||||||
|
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||||
|
if token == "" {
|
||||||
|
return nil, fmt.Errorf("no access token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
region := kiroAPIRegion(account)
|
||||||
|
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||||
|
|
||||||
|
resp, err := s.requestKiroUsageLimits(ctx, account, region, profileArn, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := mapKiroUsageToInfo(resp)
|
||||||
|
usage.Source = source
|
||||||
|
s.storeKiroUsageSnapshot(account.ID, usage)
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) storeKiroUsageSnapshot(accountID int64, usage *UsageInfo) {
|
||||||
|
if s == nil || s.cache == nil || accountID <= 0 || usage == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if usage.UpdatedAt == nil {
|
||||||
|
usage.UpdatedAt = &now
|
||||||
|
}
|
||||||
|
s.cache.kiroUsageCache.Store(accountID, &kiroUsageCache{
|
||||||
|
usageInfo: cloneUsageInfo(usage),
|
||||||
|
timestamp: now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) getCachedKiroUsage(accountID int64) (*UsageInfo, bool) {
|
||||||
|
if s == nil || s.cache == nil || accountID <= 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
cached, ok := s.cache.kiroUsageCache.Load(accountID)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
cache, ok := cached.(*kiroUsageCache)
|
||||||
|
if !ok || cache == nil || cache.usageInfo == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if time.Since(cache.timestamp) >= kiroCacheTTL(cache.usageInfo) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return cloneUsageInfo(cache.usageInfo), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroCacheTTL(info *UsageInfo) time.Duration {
|
||||||
|
if info == nil {
|
||||||
|
return kiroUsageErrorTTL
|
||||||
|
}
|
||||||
|
if info.ErrorCode != "" || info.Error != "" {
|
||||||
|
return kiroUsageErrorTTL
|
||||||
|
}
|
||||||
|
return apiCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneUsageInfo(info *UsageInfo) *UsageInfo {
|
||||||
|
if info == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *info
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) requestKiroUsageLimits(ctx context.Context, account *Account, region, profileArn, token string) (*kiroUsageLimitsResponse, error) {
|
||||||
|
endpoint := resolveKiroRuntimeEndpoint(region)
|
||||||
|
reqURL, err := url.Parse(endpoint + "/getUsageLimits")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build kiro usage url failed: %w", err)
|
||||||
|
}
|
||||||
|
q := reqURL.Query()
|
||||||
|
q.Set("origin", kiroUsageOrigin)
|
||||||
|
if profileArn = strings.TrimSpace(profileArn); profileArn != "" {
|
||||||
|
q.Set("profileArn", profileArn)
|
||||||
|
}
|
||||||
|
q.Set("resourceType", kiroUsageResourceType)
|
||||||
|
reqURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create kiro usage request failed: %w", err)
|
||||||
|
}
|
||||||
|
s.applyKiroRuntimeHeaders(req, account, token)
|
||||||
|
|
||||||
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
|
ProxyURL: accountProxyURL(account),
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create kiro usage client failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kiro usage request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read kiro usage response failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed kiroUsageLimitsResponse
|
||||||
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode kiro usage response failed: %w", err)
|
||||||
|
}
|
||||||
|
return &parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) applyKiroRuntimeHeaders(req *http.Request, account *Account, token string) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accountKey := buildKiroAccountKey(account)
|
||||||
|
machineID := buildKiroMachineID(account)
|
||||||
|
req.Header.Set("Accept", "*/*")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||||
|
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||||
|
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||||
|
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||||
|
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||||
|
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||||
|
|
||||||
|
if account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
applyKiroConditionalHeaders(req, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountProxyURL(account *Account) string {
|
||||||
|
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroRuntimeEndpoint(region string) string {
|
||||||
|
region = strings.TrimSpace(region)
|
||||||
|
if region == "" {
|
||||||
|
region = kiroDefaultRegion
|
||||||
|
}
|
||||||
|
switch region {
|
||||||
|
case "us-east-1":
|
||||||
|
return "https://q.us-east-1.amazonaws.com"
|
||||||
|
case "eu-central-1":
|
||||||
|
return "https://q.eu-central-1.amazonaws.com"
|
||||||
|
case "us-gov-east-1":
|
||||||
|
return "https://q-fips.us-gov-east-1.amazonaws.com"
|
||||||
|
case "us-gov-west-1":
|
||||||
|
return "https://q-fips.us-gov-west-1.amazonaws.com"
|
||||||
|
case "us-iso-east-1":
|
||||||
|
return "https://q.us-iso-east-1.c2s.ic.gov"
|
||||||
|
case "us-isob-east-1":
|
||||||
|
return "https://q.us-isob-east-1.sc2s.sgov.gov"
|
||||||
|
case "us-isof-south-1":
|
||||||
|
return "https://q.us-isof-south-1.csp.hci.ic.gov"
|
||||||
|
case "us-isof-east-1":
|
||||||
|
return "https://q.us-isof-east-1.csp.hci.ic.gov"
|
||||||
|
default:
|
||||||
|
if strings.HasPrefix(region, "us-gov-") {
|
||||||
|
return "https://q-fips." + region + ".amazonaws.com"
|
||||||
|
}
|
||||||
|
return "https://q." + region + ".amazonaws.com"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLoopbackEndpoint(raw string) bool {
|
||||||
|
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host := strings.TrimSpace(parsed.Hostname())
|
||||||
|
if host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.EqualFold(host, "localhost") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
return ip != nil && ip.IsLoopback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapKiroUsageToInfo(resp *kiroUsageLimitsResponse) *UsageInfo {
|
||||||
|
now := time.Now()
|
||||||
|
if resp == nil {
|
||||||
|
return &UsageInfo{UpdatedAt: &now}
|
||||||
|
}
|
||||||
|
info := &UsageInfo{
|
||||||
|
UpdatedAt: &now,
|
||||||
|
KiroSubscriptionName: strings.TrimSpace(resp.SubscriptionInfo.SubscriptionTitle),
|
||||||
|
KiroSubscriptionType: strings.TrimSpace(resp.SubscriptionInfo.Type),
|
||||||
|
KiroOveragesEnabled: strings.EqualFold(strings.TrimSpace(resp.OverageConfiguration.OverageStatus), "ENABLED"),
|
||||||
|
}
|
||||||
|
|
||||||
|
resetAt := parseKiroTimestamp(resp.NextDateReset)
|
||||||
|
if credit := selectKiroCreditBreakdown(resp.UsageBreakdownList); credit != nil {
|
||||||
|
if breakdownReset := parseKiroTimestamp(credit.NextDateReset); breakdownReset != nil {
|
||||||
|
resetAt = breakdownReset
|
||||||
|
}
|
||||||
|
info.KiroCredit = &KiroCreditProgress{
|
||||||
|
CurrentUsage: selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage),
|
||||||
|
UsageLimit: selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit),
|
||||||
|
PercentageUsed: percentageOrZero(selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage), selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit)),
|
||||||
|
}
|
||||||
|
info.KiroOverage = &KiroOverageInfo{
|
||||||
|
CurrentOverages: selectKiroFloat(credit.CurrentOveragesWithPrecision, credit.CurrentOverages),
|
||||||
|
OverageCharges: selectKiroFloat(credit.OverageCharges, nil),
|
||||||
|
CurrencyCode: strings.TrimSpace(credit.Currency),
|
||||||
|
CurrencySymbol: kiroCurrencySymbol(strings.TrimSpace(credit.Currency)),
|
||||||
|
}
|
||||||
|
if ft := credit.FreeTrialInfo; ft != nil && strings.EqualFold(strings.TrimSpace(ft.FreeTrialStatus), "ACTIVE") {
|
||||||
|
expiry := parseKiroTimestamp(ft.FreeTrialExpiry)
|
||||||
|
daysRemaining := 0
|
||||||
|
if expiry != nil {
|
||||||
|
daysRemaining = int(time.Until(*expiry).Hours() / 24)
|
||||||
|
if time.Until(*expiry)%(24*time.Hour) != 0 {
|
||||||
|
daysRemaining++
|
||||||
|
}
|
||||||
|
if daysRemaining < 0 {
|
||||||
|
daysRemaining = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current := selectKiroFloat(ft.CurrentUsageWithPrecision, ft.CurrentUsage)
|
||||||
|
limit := selectKiroFloat(ft.UsageLimitWithPrecision, ft.UsageLimit)
|
||||||
|
info.KiroBonus = &KiroCreditProgress{
|
||||||
|
CurrentUsage: current,
|
||||||
|
UsageLimit: limit,
|
||||||
|
PercentageUsed: percentageOrZero(current, limit),
|
||||||
|
DaysRemaining: daysRemaining,
|
||||||
|
ExpiryDate: expiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info.KiroResetAt = resetAt
|
||||||
|
setKiroQuotaStateFromUsage(info)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectKiroCreditBreakdown(items []kiroUsageBreakdown) *kiroUsageBreakdown {
|
||||||
|
for i := range items {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(items[i].ResourceType), "CREDIT") {
|
||||||
|
return &items[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &items[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectKiroFloat(preferred *float64, fallback *float64) float64 {
|
||||||
|
switch {
|
||||||
|
case preferred != nil:
|
||||||
|
return *preferred
|
||||||
|
case fallback != nil:
|
||||||
|
return *fallback
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func percentageOrZero(current, limit float64) float64 {
|
||||||
|
if limit <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return current / limit * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseKiroTimestamp(raw any) *time.Time {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
trimmed := strings.TrimSpace(v)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if parsed, err := time.Parse(time.RFC3339, trimmed); err == nil {
|
||||||
|
return &parsed
|
||||||
|
}
|
||||||
|
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
||||||
|
return unixishToTime(i)
|
||||||
|
}
|
||||||
|
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
|
||||||
|
return unixishFloatToTime(f)
|
||||||
|
}
|
||||||
|
case float64:
|
||||||
|
return unixishFloatToTime(v)
|
||||||
|
case int64:
|
||||||
|
return unixishToTime(v)
|
||||||
|
case int:
|
||||||
|
return unixishToTime(int64(v))
|
||||||
|
case json.Number:
|
||||||
|
if i, err := v.Int64(); err == nil {
|
||||||
|
return unixishToTime(i)
|
||||||
|
}
|
||||||
|
if f, err := v.Float64(); err == nil {
|
||||||
|
return unixishFloatToTime(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unixishFloatToTime(v float64) *time.Time {
|
||||||
|
if v <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if v >= 1e12 {
|
||||||
|
t := time.UnixMilli(int64(v))
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
t := time.Unix(int64(v), 0)
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
func unixishToTime(v int64) *time.Time {
|
||||||
|
if v <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if v >= 1e12 {
|
||||||
|
t := time.UnixMilli(v)
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
t := time.Unix(v, 0)
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
func kiroCurrencySymbol(code string) string {
|
||||||
|
switch strings.ToUpper(strings.TrimSpace(code)) {
|
||||||
|
case "USD":
|
||||||
|
return "$"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroDegradedUsage(err error) *UsageInfo {
|
||||||
|
now := time.Now()
|
||||||
|
info := &UsageInfo{
|
||||||
|
UpdatedAt: &now,
|
||||||
|
Error: "usage API error",
|
||||||
|
ErrorCode: errorCodeNetworkError,
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
info.Error = fmt.Sprintf("usage API error: %v", err)
|
||||||
|
|
||||||
|
classification := classifyKiroError(err)
|
||||||
|
switch classification.Category {
|
||||||
|
case kiroErrorAuthError:
|
||||||
|
info.ErrorCode = errorCodeUnauthenticated
|
||||||
|
info.NeedsReauth = true
|
||||||
|
case kiroErrorRateLimited:
|
||||||
|
info.ErrorCode = errorCodeRateLimited
|
||||||
|
case kiroErrorQuotaExhausted:
|
||||||
|
info.ErrorCode = errorCodeNetworkError
|
||||||
|
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
|
||||||
|
info.KiroQuotaReason = classification.Message
|
||||||
|
case kiroErrorOverageExhausted:
|
||||||
|
info.ErrorCode = errorCodeNetworkError
|
||||||
|
info.KiroQuotaState = kiroQuotaStateOverageExhausted
|
||||||
|
info.KiroQuotaReason = classification.Message
|
||||||
|
case kiroErrorSuspended, kiroErrorUsageForbidden, kiroErrorProfileError:
|
||||||
|
info.ErrorCode = errorCodeForbidden
|
||||||
|
default:
|
||||||
|
info.ErrorCode = errorCodeNetworkError
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) attachKiroRuntimeState(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||||
|
if s == nil || usage == nil || account == nil || account.Platform != PlatformKiro || s.kiroCooldownStore == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usage.KiroRuntimeState = ""
|
||||||
|
usage.KiroRuntimeReason = ""
|
||||||
|
usage.KiroRuntimeResetAt = nil
|
||||||
|
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
|
||||||
|
if err != nil || state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usage.KiroRuntimeState, usage.KiroRuntimeReason, usage.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) EnrichAccountWithKiroRuntimeState(ctx context.Context, account *Account) {
|
||||||
|
if s == nil || account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.KiroQuotaState = ""
|
||||||
|
account.KiroQuotaReason = ""
|
||||||
|
account.KiroQuotaResetAt = nil
|
||||||
|
account.KiroRuntimeState = ""
|
||||||
|
account.KiroRuntimeReason = ""
|
||||||
|
account.KiroRuntimeResetAt = nil
|
||||||
|
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
|
||||||
|
account.KiroQuotaState = usage.KiroQuotaState
|
||||||
|
account.KiroQuotaReason = usage.KiroQuotaReason
|
||||||
|
account.KiroQuotaResetAt = usage.KiroQuotaResetAt
|
||||||
|
}
|
||||||
|
if s.kiroCooldownStore == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
|
||||||
|
if err != nil || state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.KiroRuntimeState, account.KiroRuntimeReason, account.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setKiroQuotaStateFromUsage(info *UsageInfo) {
|
||||||
|
if info == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info.KiroQuotaState = ""
|
||||||
|
info.KiroQuotaReason = ""
|
||||||
|
info.KiroQuotaResetAt = nil
|
||||||
|
|
||||||
|
creditExhausted := false
|
||||||
|
if info.KiroCredit != nil && info.KiroCredit.UsageLimit > 0 {
|
||||||
|
creditExhausted = info.KiroCredit.CurrentUsage >= info.KiroCredit.UsageLimit
|
||||||
|
}
|
||||||
|
overageActive := info.KiroOverage != nil &&
|
||||||
|
(info.KiroOverage.CurrentOverages > 0 || info.KiroOverage.OverageCharges > 0)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case info.KiroOveragesEnabled && (overageActive || creditExhausted):
|
||||||
|
info.KiroQuotaState = kiroQuotaStateOverageActive
|
||||||
|
info.KiroQuotaReason = "overages_enabled"
|
||||||
|
info.KiroQuotaResetAt = info.KiroResetAt
|
||||||
|
case creditExhausted:
|
||||||
|
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
|
||||||
|
info.KiroQuotaReason = "credits_exhausted"
|
||||||
|
info.KiroQuotaResetAt = info.KiroResetAt
|
||||||
|
default:
|
||||||
|
info.KiroQuotaState = kiroQuotaStateNormal
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,458 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kiroMaxWebSearchIterations = 5
|
||||||
|
|
||||||
|
var (
|
||||||
|
errKiroWebSearchFallback = errors.New("kiro web search fallback")
|
||||||
|
kiroWebSearchDescCache sync.Map
|
||||||
|
)
|
||||||
|
|
||||||
|
type kiroWebSearchExecution struct {
|
||||||
|
ResponseBody []byte
|
||||||
|
Usage ClaudeUsage
|
||||||
|
RequestID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroWebSearchHTTPError struct {
|
||||||
|
Response *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
type kiroStreamChunkCollector struct {
|
||||||
|
chunks [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *kiroWebSearchHTTPError) Error() string {
|
||||||
|
if e == nil || e.Response == nil {
|
||||||
|
return "kiro web search http error"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
|
||||||
|
if len(p) > 0 {
|
||||||
|
w.chunks = append(w.chunks, append([]byte(nil), p...))
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
|
||||||
|
collector := &kiroStreamChunkCollector{}
|
||||||
|
result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return collector.chunks, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSEChunks(w io.Writer, chunks [][]byte) error {
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := w.Write(chunk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
|
||||||
|
if strings.TrimSpace(msgID) == "" {
|
||||||
|
msgID = "msg_" + kiropkg.GenerateToolUseID()
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
model = "kiro"
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(map[string]any{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]any{
|
||||||
|
"id": msgID,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": model,
|
||||||
|
"content": []any{},
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"input_tokens": inputTokens,
|
||||||
|
"output_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 0,
|
||||||
|
"cache_read_input_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||||
|
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||||
|
) error {
|
||||||
|
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return errKiroWebSearchFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
|
||||||
|
if err != nil {
|
||||||
|
currentBody = anthropicBody
|
||||||
|
}
|
||||||
|
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||||
|
nextContentBlockIndex := 0
|
||||||
|
|
||||||
|
if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
|
||||||
|
s.prefetchKiroWebSearchDescription(ctx, account, token)
|
||||||
|
|
||||||
|
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
|
||||||
|
if strings.TrimSpace(nextToken) != "" {
|
||||||
|
token = nextToken
|
||||||
|
}
|
||||||
|
if mcpErr != nil {
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nextContentBlockIndex += 2
|
||||||
|
|
||||||
|
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
|
||||||
|
if err != nil {
|
||||||
|
return errKiroWebSearchFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return &kiroWebSearchHTTPError{Response: resp}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
|
||||||
|
}()
|
||||||
|
if streamErr != nil {
|
||||||
|
return streamErr
|
||||||
|
}
|
||||||
|
|
||||||
|
analysis := kiropkg.AnalyzeBufferedStream(chunks)
|
||||||
|
if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
|
||||||
|
filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
|
||||||
|
if err := writeSSEChunks(w, filtered); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
|
||||||
|
nextContentBlockIndex = maxIndex + 1
|
||||||
|
}
|
||||||
|
query = analysis.WebSearchQuery
|
||||||
|
if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
|
||||||
|
currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||||
|
} else {
|
||||||
|
currentToolUseID = analysis.WebSearchToolUseID
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
|
||||||
|
if !shouldForward {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := w.Write(adjusted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("kiro web search exceeded max iterations")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||||
|
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, errKiroWebSearchFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
|
||||||
|
if err != nil {
|
||||||
|
currentBody = anthropicBody
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := estimateKiroInputTokens(anthropicBody)
|
||||||
|
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||||
|
searches := make([]kiropkg.SearchIndicator, 0, 2)
|
||||||
|
requestID := ""
|
||||||
|
|
||||||
|
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
|
||||||
|
s.prefetchKiroWebSearchDescription(ctx, account, token)
|
||||||
|
|
||||||
|
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
|
||||||
|
if strings.TrimSpace(nextToken) != "" {
|
||||||
|
token = nextToken
|
||||||
|
}
|
||||||
|
if mcpErr != nil {
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
searches = append(searches, kiropkg.SearchIndicator{
|
||||||
|
ToolUseID: currentToolUseID,
|
||||||
|
Query: query,
|
||||||
|
Results: results,
|
||||||
|
})
|
||||||
|
|
||||||
|
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errKiroWebSearchFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, &kiroWebSearchHTTPError{Response: resp}
|
||||||
|
}
|
||||||
|
|
||||||
|
parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
|
||||||
|
}()
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = buildKiroRequestID(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
|
||||||
|
if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
|
||||||
|
finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
|
||||||
|
if injectErr == nil {
|
||||||
|
parseResult.ResponseBody = finalBody
|
||||||
|
}
|
||||||
|
return &kiroWebSearchExecution{
|
||||||
|
ResponseBody: parseResult.ResponseBody,
|
||||||
|
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
|
||||||
|
RequestID: requestID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query = nextQuery
|
||||||
|
if strings.TrimSpace(nextToolUseID) == "" {
|
||||||
|
nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||||
|
}
|
||||||
|
currentToolUseID = nextToolUseID
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("kiro web search exceeded max iterations")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
|
||||||
|
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
|
||||||
|
if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
|
||||||
|
if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
|
||||||
|
kiropkg.SetCachedWebSearchDescription(desc)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, _ := json.Marshal(kiropkg.MCPRequest{
|
||||||
|
ID: "tools_list",
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Method: "tools/list",
|
||||||
|
})
|
||||||
|
resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
|
||||||
|
if err != nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var result kiropkg.MCPResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, tool := range result.Result.Tools {
|
||||||
|
if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
|
||||||
|
kiroWebSearchDescCache.Store(endpoint, tool.Description)
|
||||||
|
kiropkg.SetCachedWebSearchDescription(tool.Description)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
|
||||||
|
reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
|
||||||
|
if err != nil {
|
||||||
|
return nil, token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
|
||||||
|
resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nextToken, err
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nextToken, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed kiropkg.MCPResponse
|
||||||
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||||
|
return nil, nextToken, err
|
||||||
|
}
|
||||||
|
if parsed.Error != nil {
|
||||||
|
msg := "unknown error"
|
||||||
|
if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
|
||||||
|
msg = strings.TrimSpace(*parsed.Error.Message)
|
||||||
|
}
|
||||||
|
code := 0
|
||||||
|
if parsed.Error.Code != nil {
|
||||||
|
code = *parsed.Error.Code
|
||||||
|
}
|
||||||
|
return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return kiropkg.ParseSearchResults(&parsed), nextToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
|
||||||
|
return kiropkg.MCPRequest{
|
||||||
|
ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Method: "tools/call",
|
||||||
|
Params: map[string]interface{}{
|
||||||
|
"name": "web_search",
|
||||||
|
"arguments": map[string]interface{}{
|
||||||
|
"query": query,
|
||||||
|
"_meta": map[string]interface{}{
|
||||||
|
"_isValid": true,
|
||||||
|
"_activePath": []string{"query"},
|
||||||
|
"_completedPaths": [][]string{{"query"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
|
||||||
|
currentToken := token
|
||||||
|
accountKey := buildKiroAccountKey(account)
|
||||||
|
proxyURL := kiroProxyURL(account)
|
||||||
|
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||||
|
|
||||||
|
for attempt := 0; attempt < 3; attempt++ {
|
||||||
|
if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
|
||||||
|
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||||
|
return nil, currentToken, failoverErr
|
||||||
|
}
|
||||||
|
return nil, currentToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, currentToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, currentToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, currentToken, readErr
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
|
||||||
|
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
|
||||||
|
return nil, currentToken, err
|
||||||
|
}
|
||||||
|
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||||
|
return resp, currentToken, nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
|
||||||
|
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||||
|
return resp, currentToken, nil
|
||||||
|
}
|
||||||
|
if s.kiroTokenProvider == nil {
|
||||||
|
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||||
|
return resp, currentToken, nil
|
||||||
|
}
|
||||||
|
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||||
|
if refreshErr != nil {
|
||||||
|
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||||
|
return resp, currentToken, nil
|
||||||
|
}
|
||||||
|
currentToken = refreshedToken
|
||||||
|
accountKey = buildKiroAccountKey(account)
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, currentToken, sleepErr
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if _, err := s.markKiro429(ctx, accountKey); err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, currentToken, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
|
||||||
|
if attempt < 2 {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||||
|
return nil, currentToken, sleepErr
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, currentToken, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, currentToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildKiroWebSearchMCPRequest_UsesUnderscoredMetaKeys(t *testing.T) {
|
||||||
|
req := buildKiroWebSearchMCPRequest("golang concurrency")
|
||||||
|
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "tools/call", gjson.GetBytes(body, "method").String())
|
||||||
|
require.Equal(t, "web_search", gjson.GetBytes(body, "params.name").String())
|
||||||
|
require.Equal(t, "golang concurrency", gjson.GetBytes(body, "params.arguments.query").String())
|
||||||
|
require.True(t, gjson.GetBytes(body, "params.arguments._meta._isValid").Bool())
|
||||||
|
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._activePath.0").String())
|
||||||
|
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._completedPaths.0.0").String())
|
||||||
|
require.False(t, gjson.GetBytes(body, "params.arguments._meta.isValid").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(body, "params.arguments._meta.activePath").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(body, "params.arguments._meta.completedPaths").Exists())
|
||||||
|
}
|
||||||
@@ -89,6 +89,14 @@ type OpsUpstreamErrorEvent struct {
|
|||||||
AccountID int64 `json:"account_id,omitempty"`
|
AccountID int64 `json:"account_id,omitempty"`
|
||||||
AccountName string `json:"account_name,omitempty"`
|
AccountName string `json:"account_name,omitempty"`
|
||||||
|
|
||||||
|
// Model diagnostics.
|
||||||
|
RequestedModel string `json:"requested_model,omitempty"`
|
||||||
|
MappedModel string `json:"mapped_model,omitempty"`
|
||||||
|
KiroModelID string `json:"kiro_model_id,omitempty"`
|
||||||
|
HasTools bool `json:"has_tools,omitempty"`
|
||||||
|
HasAdaptiveThinking bool `json:"has_adaptive_thinking,omitempty"`
|
||||||
|
HasContext1MBeta bool `json:"has_context_1m_beta,omitempty"`
|
||||||
|
|
||||||
// Outcome
|
// Outcome
|
||||||
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
||||||
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
|
|||||||
// Antigravity 同样可能有两种缓存键
|
// Antigravity 同样可能有两种缓存键
|
||||||
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
||||||
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
||||||
|
case PlatformKiro:
|
||||||
|
keysToDelete = append(keysToDelete, KiroTokenCacheKey(account))
|
||||||
|
keysToDelete = append(keysToDelete, "kiro:"+accountIDKey)
|
||||||
case PlatformOpenAI:
|
case PlatformOpenAI:
|
||||||
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
||||||
case PlatformAnthropic:
|
case PlatformAnthropic:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
|
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
|
||||||
@@ -44,6 +45,7 @@ func NewTokenRefreshService(
|
|||||||
openaiOAuthService *OpenAIOAuthService,
|
openaiOAuthService *OpenAIOAuthService,
|
||||||
geminiOAuthService *GeminiOAuthService,
|
geminiOAuthService *GeminiOAuthService,
|
||||||
antigravityOAuthService *AntigravityOAuthService,
|
antigravityOAuthService *AntigravityOAuthService,
|
||||||
|
kiroOAuthService *KiroOAuthService,
|
||||||
cacheInvalidator TokenCacheInvalidator,
|
cacheInvalidator TokenCacheInvalidator,
|
||||||
schedulerCache SchedulerCache,
|
schedulerCache SchedulerCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
@@ -64,6 +66,7 @@ func NewTokenRefreshService(
|
|||||||
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||||
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||||
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||||
|
kiroRefresher := NewKiroTokenRefresher(kiroOAuthService)
|
||||||
|
|
||||||
// 注册平台特定的刷新器(TokenRefresher 接口)
|
// 注册平台特定的刷新器(TokenRefresher 接口)
|
||||||
s.refreshers = []TokenRefresher{
|
s.refreshers = []TokenRefresher{
|
||||||
@@ -71,6 +74,7 @@ func NewTokenRefreshService(
|
|||||||
openAIRefresher,
|
openAIRefresher,
|
||||||
geminiRefresher,
|
geminiRefresher,
|
||||||
agRefresher,
|
agRefresher,
|
||||||
|
kiroRefresher,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
||||||
@@ -79,6 +83,7 @@ func NewTokenRefreshService(
|
|||||||
openAIRefresher,
|
openAIRefresher,
|
||||||
geminiRefresher,
|
geminiRefresher,
|
||||||
agRefresher,
|
agRefresher,
|
||||||
|
kiroRefresher,
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -415,6 +420,10 @@ func isNonRetryableRefreshError(err error) bool {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
var kiroInvalidGrant *kiropkg.RefreshTokenInvalidError
|
||||||
|
if errors.As(err, &kiroInvalidGrant) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
msg := strings.ToLower(err.Error())
|
msg := strings.ToLower(err.Error())
|
||||||
nonRetryable := []string{
|
nonRetryable := []string{
|
||||||
"invalid_grant", // refresh_token 已失效
|
"invalid_grant", // refresh_token 已失效
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ func optionalNonEqualStringPtr(value, compare string) *string {
|
|||||||
|
|
||||||
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||||
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
|
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
|
||||||
return trimmed
|
return normalizeModelNameForPricing(trimmed)
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(upstreamModel)
|
return normalizeModelNameForPricing(upstreamModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func optionalInt64Ptr(v int64) *int64 {
|
func optionalInt64Ptr(v int64) *int64 {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@@ -51,6 +52,7 @@ func ProvideTokenRefreshService(
|
|||||||
openaiOAuthService *OpenAIOAuthService,
|
openaiOAuthService *OpenAIOAuthService,
|
||||||
geminiOAuthService *GeminiOAuthService,
|
geminiOAuthService *GeminiOAuthService,
|
||||||
antigravityOAuthService *AntigravityOAuthService,
|
antigravityOAuthService *AntigravityOAuthService,
|
||||||
|
kiroOAuthService *KiroOAuthService,
|
||||||
cacheInvalidator TokenCacheInvalidator,
|
cacheInvalidator TokenCacheInvalidator,
|
||||||
schedulerCache SchedulerCache,
|
schedulerCache SchedulerCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
@@ -59,7 +61,7 @@ func ProvideTokenRefreshService(
|
|||||||
proxyRepo ProxyRepository,
|
proxyRepo ProxyRepository,
|
||||||
refreshAPI *OAuthRefreshAPI,
|
refreshAPI *OAuthRefreshAPI,
|
||||||
) *TokenRefreshService {
|
) *TokenRefreshService {
|
||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||||
// 注入 OpenAI privacy opt-out 依赖
|
// 注入 OpenAI privacy opt-out 依赖
|
||||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||||
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
||||||
@@ -128,6 +130,23 @@ func ProvideAntigravityTokenProvider(
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ProvideKiroTokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache GeminiTokenCache,
|
||||||
|
kiroOAuthService *KiroOAuthService,
|
||||||
|
refreshAPI *OAuthRefreshAPI,
|
||||||
|
) *KiroTokenProvider {
|
||||||
|
p := NewKiroTokenProvider(accountRepo, tokenCache, kiroOAuthService)
|
||||||
|
executor := NewKiroTokenRefresher(kiroOAuthService)
|
||||||
|
p.SetRefreshAPI(refreshAPI, executor)
|
||||||
|
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProvideKiroCooldownStore(redisClient *redis.Client) KiroCooldownStore {
|
||||||
|
return kirocooldown.NewStore(redisClient)
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||||
@@ -448,8 +467,11 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewCompositeTokenCacheInvalidator,
|
NewCompositeTokenCacheInvalidator,
|
||||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||||
NewAntigravityOAuthService,
|
NewAntigravityOAuthService,
|
||||||
|
NewKiroOAuthService,
|
||||||
ProvideOAuthRefreshAPI,
|
ProvideOAuthRefreshAPI,
|
||||||
ProvideGeminiTokenProvider,
|
ProvideGeminiTokenProvider,
|
||||||
|
ProvideKiroTokenProvider,
|
||||||
|
ProvideKiroCooldownStore,
|
||||||
NewGeminiMessagesCompatService,
|
NewGeminiMessagesCompatService,
|
||||||
ProvideAntigravityTokenProvider,
|
ProvideAntigravityTokenProvider,
|
||||||
ProvideOpenAITokenProvider,
|
ProvideOpenAITokenProvider,
|
||||||
|
|||||||
Reference in New Issue
Block a user