diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 9bfa2717..3c07b96f 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -93,6 +93,7 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
+ kiroOAuth *service.KiroOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
@@ -216,6 +217,10 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
+ {"KiroOAuthService", func() error {
+ kiroOAuth.Stop()
+ return nil
+ }},
{"OpenAIWSPool", func() error {
if openAIGateway != nil {
openAIGateway.CloseOpenAIWSPool()
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 40f0191c..23002178 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -146,13 +146,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
+ kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
+ kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
- accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
+ accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -166,6 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
+ kiroOAuthHandler := admin.NewKiroOAuthHandler(kiroOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
@@ -179,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
+ kiroCooldownStore := service.ProvideKiroCooldownStore(redisClient)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, kiroOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -256,13 +260,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -312,6 +316,7 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
+ kiroOAuth *service.KiroOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
@@ -434,6 +439,10 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
+ {"KiroOAuthService", func() error {
+ kiroOAuth.Stop()
+ return nil
+ }},
{"OpenAIWSPool", func() error {
if openAIGateway != nil {
openAIGateway.CloseOpenAIWSPool()
diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go
index 5ccd67fb..59acdb49 100644
--- a/backend/cmd/server/wire_gen_test.go
+++ b/backend/cmd/server/wire_gen_test.go
@@ -36,6 +36,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
antigravityOAuthSvc,
nil,
nil,
+ nil,
cfg,
nil,
)
@@ -72,6 +73,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
openAIOAuthSvc,
geminiOAuthSvc,
antigravityOAuthSvc,
+ nil, // kiroOAuth
nil, // openAIGateway
nil, // scheduledTestRunner
nil, // backupSvc
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 87263db0..26d51121 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -635,6 +635,8 @@ type GatewayConfig struct {
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
+ // KiroStreamKeepaliveInterval: Kiro 流式 keepalive 间隔(秒),0使用默认 25 秒
+ KiroStreamKeepaliveInterval int `mapstructure:"kiro_stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
MaxLineSize int `mapstructure:"max_line_size"`
@@ -1689,6 +1691,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
+ viper.SetDefault("gateway.kiro_stream_keepalive_interval", 25)
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
@@ -2277,6 +2280,13 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
+ if c.Gateway.KiroStreamKeepaliveInterval < 0 {
+ return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be non-negative")
+ }
+ if c.Gateway.KiroStreamKeepaliveInterval != 0 &&
+ (c.Gateway.KiroStreamKeepaliveInterval < 5 || c.Gateway.KiroStreamKeepaliveInterval > 30) {
+ return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be 0 or between 5-30 seconds")
+ }
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index 27c543dd..e73c1e57 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
+ PlatformKiro = "kiro"
)
// Account type constants
@@ -116,6 +117,21 @@ var DefaultAntigravityModelMapping = map[string]string{
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
+// DefaultKiroModelMapping 是 Kiro 平台的默认模型映射。
+// 键为对外暴露/允许请求的模型名,值为实际发送到 Kiro 上游的模型名。
+var DefaultKiroModelMapping = map[string]string{
+ "claude-opus-4-6": "claude-opus-4.6",
+ "claude-opus-4-6-thinking": "claude-opus-4.6",
+ "claude-sonnet-4-6": "claude-sonnet-4.6",
+ "claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
+ "claude-opus-4-5-20251101": "claude-opus-4.5",
+ "claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
+ "claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
+ "claude-haiku-4-5-20251001": "claude-haiku-4.5",
+ "claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
+}
+
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go
index de66137f..a53cc802 100644
--- a/backend/internal/domain/constants_test.go
+++ b/backend/internal/domain/constants_test.go
@@ -1,6 +1,9 @@
package domain
-import "testing"
+import (
+ "strings"
+ "testing"
+)
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
t.Parallel()
@@ -24,3 +27,54 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
}
}
}
+
+func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
+ t.Parallel()
+
+ expected := map[string]string{
+ "claude-opus-4-6": "claude-opus-4.6",
+ "claude-opus-4-6-thinking": "claude-opus-4.6",
+ "claude-sonnet-4-6": "claude-sonnet-4.6",
+ "claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
+ "claude-opus-4-5-20251101": "claude-opus-4.5",
+ "claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
+ "claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
+ "claude-haiku-4-5-20251001": "claude-haiku-4.5",
+ "claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
+ }
+
+ if len(DefaultKiroModelMapping) != len(expected) {
+ t.Fatalf("expected %d Kiro mappings, got %d", len(expected), len(DefaultKiroModelMapping))
+ }
+ for model, want := range expected {
+ if got := DefaultKiroModelMapping[model]; got != want {
+ t.Fatalf("unexpected Kiro mapping for %q: got %q want %q", model, got, want)
+ }
+ }
+
+ for _, model := range []string{
+ "claude-opus-4-5",
+ "claude-sonnet-4-5",
+ "claude-sonnet-4",
+ "gpt-4o",
+ "gpt-4",
+ "deepseek-3-2",
+ "minimax-m2-1",
+ "qwen3-coder-next",
+ "claude-opus-4-7",
+ "claude-sonnet-4-6-chat",
+ } {
+ if _, ok := DefaultKiroModelMapping[model]; ok {
+ t.Fatalf("did not expect %q to remain in DefaultKiroModelMapping", model)
+ }
+ }
+ for model := range DefaultKiroModelMapping {
+ if strings.HasSuffix(model, "-agentic") {
+ t.Fatalf("did not expect agentic Kiro mapping %q", model)
+ }
+ if strings.HasSuffix(model, "-chat") {
+ t.Fatalf("did not expect chat-only Kiro mapping %q", model)
+ }
+ }
+}
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 2d00ccc6..78aa57a0 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -179,6 +180,9 @@ type AccountWithConcurrency struct {
const accountListGroupUngroupedQueryValue = "ungrouped"
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
+ if h.accountUsageService != nil {
+ h.accountUsageService.EnrichAccountWithKiroRuntimeState(ctx, account)
+ }
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
@@ -351,6 +355,9 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
acc := &accounts[i]
+ if h.accountUsageService != nil {
+ h.accountUsageService.EnrichAccountWithKiroRuntimeState(c.Request.Context(), acc)
+ }
item := AccountWithConcurrency{
Account: dto.AccountFromService(acc),
CurrentConcurrency: concurrencyCounts[acc.ID],
@@ -1913,6 +1920,18 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
+ // Handle Kiro accounts
+ if account.Platform == service.PlatformKiro {
+ mapping := account.GetModelMapping()
+ if len(mapping) == 0 {
+ response.Success(c, kiropkg.DefaultModels)
+ return
+ }
+
+ response.Success(c, buildMappedKiroModels(mapping))
+ return
+ }
+
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
@@ -1954,6 +1973,28 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models)
}
+func buildMappedKiroModels(mapping map[string]string) []kiropkg.Model {
+ models := make([]kiropkg.Model, 0, len(mapping))
+ for requestedModel := range mapping {
+ var found bool
+ for _, dm := range kiropkg.DefaultModels {
+ if dm.ID == requestedModel {
+ models = append(models, dm)
+ found = true
+ break
+ }
+ }
+ if !found {
+ models = append(models, kiropkg.Model{
+ ID: requestedModel,
+ Type: "model",
+ DisplayName: requestedModel,
+ })
+ }
+ }
+ return models
+}
+
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
@@ -2166,6 +2207,12 @@ func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
+// GetKiroDefaultModelMapping 获取 Kiro 平台的默认模型映射
+// GET /api/v1/admin/accounts/kiro/default-model-mapping
+func (h *AccountHandler) GetKiroDefaultModelMapping(c *gin.Context) {
+ response.Success(c, domain.DefaultKiroModelMapping)
+}
+
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) {
diff --git a/backend/internal/handler/admin/kiro_oauth_handler.go b/backend/internal/handler/admin/kiro_oauth_handler.go
new file mode 100644
index 00000000..fc6727b8
--- /dev/null
+++ b/backend/internal/handler/admin/kiro_oauth_handler.go
@@ -0,0 +1,149 @@
+package admin
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+)
+
+type KiroOAuthHandler struct {
+ kiroOAuthService *service.KiroOAuthService
+}
+
+func NewKiroOAuthHandler(kiroOAuthService *service.KiroOAuthService) *KiroOAuthHandler {
+ return &KiroOAuthHandler{kiroOAuthService: kiroOAuthService}
+}
+
+type KiroGenerateAuthURLRequest struct {
+ ProxyID *int64 `json:"proxy_id"`
+ Provider string `json:"provider"`
+}
+
+func (h *KiroOAuthHandler) GenerateAuthURL(c *gin.Context) {
+ var req KiroGenerateAuthURLRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "请求无效: "+err.Error())
+ return
+ }
+ result, err := h.kiroOAuthService.GenerateAuthURL(c.Request.Context(), &service.KiroGenerateAuthURLInput{
+ ProxyID: req.ProxyID,
+ Provider: req.Provider,
+ })
+ if err != nil {
+ response.BadRequest(c, "生成授权链接失败: "+err.Error())
+ return
+ }
+ response.Success(c, result)
+}
+
+type KiroGenerateIDCAuthURLRequest struct {
+ ProxyID *int64 `json:"proxy_id"`
+ StartURL string `json:"start_url"`
+ Region string `json:"region"`
+}
+
+func (h *KiroOAuthHandler) GenerateIDCAuthURL(c *gin.Context) {
+ var req KiroGenerateIDCAuthURLRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "请求无效: "+err.Error())
+ return
+ }
+ result, err := h.kiroOAuthService.GenerateIDCAuthURL(c.Request.Context(), &service.KiroGenerateIDCAuthURLInput{
+ ProxyID: req.ProxyID,
+ StartURL: req.StartURL,
+ Region: req.Region,
+ })
+ if err != nil {
+ response.BadRequest(c, "生成 IDC 授权链接失败: "+err.Error())
+ return
+ }
+ response.Success(c, result)
+}
+
+type KiroExchangeCodeRequest struct {
+ SessionID string `json:"session_id" binding:"required"`
+ State string `json:"state" binding:"required"`
+ Code string `json:"code" binding:"required"`
+ CallbackPath string `json:"callback_path"`
+ LoginOption string `json:"login_option"`
+ ProxyID *int64 `json:"proxy_id"`
+}
+
+func (h *KiroOAuthHandler) ExchangeCode(c *gin.Context) {
+ var req KiroExchangeCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "请求无效: "+err.Error())
+ return
+ }
+ tokenInfo, err := h.kiroOAuthService.ExchangeCode(c.Request.Context(), &service.KiroExchangeCodeInput{
+ SessionID: req.SessionID,
+ State: req.State,
+ Code: req.Code,
+ CallbackPath: req.CallbackPath,
+ LoginOption: req.LoginOption,
+ ProxyID: req.ProxyID,
+ })
+ if err != nil {
+ response.BadRequest(c, "Token 交换失败: "+err.Error())
+ return
+ }
+ response.Success(c, tokenInfo)
+}
+
+type KiroRefreshTokenRequest struct {
+ RefreshToken string `json:"refresh_token" binding:"required"`
+ AuthMethod string `json:"auth_method"`
+ Provider string `json:"provider"`
+ ClientID string `json:"client_id"`
+ ClientSecret string `json:"client_secret"`
+ StartURL string `json:"start_url"`
+ Region string `json:"region"`
+ ProfileArn string `json:"profile_arn"`
+ ProxyID *int64 `json:"proxy_id"`
+}
+
+func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
+ var req KiroRefreshTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "请求无效: "+err.Error())
+ return
+ }
+ tokenInfo, err := h.kiroOAuthService.RefreshToken(c.Request.Context(), &service.KiroRefreshTokenInput{
+ RefreshToken: req.RefreshToken,
+ AuthMethod: req.AuthMethod,
+ Provider: req.Provider,
+ ClientID: req.ClientID,
+ ClientSecret: req.ClientSecret,
+ StartURL: req.StartURL,
+ Region: req.Region,
+ ProfileArn: req.ProfileArn,
+ ProxyID: req.ProxyID,
+ })
+ if err != nil {
+ response.BadRequest(c, "刷新 Kiro Token 失败: "+err.Error())
+ return
+ }
+ response.Success(c, tokenInfo)
+}
+
+type KiroImportTokenRequest struct {
+ TokenJSON string `json:"token_json" binding:"required"`
+ DeviceRegistrationJSON string `json:"device_registration_json"`
+}
+
+func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
+ var req KiroImportTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "请求无效: "+err.Error())
+ return
+ }
+ tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
+ TokenJSON: req.TokenJSON,
+ DeviceRegistrationJSON: req.DeviceRegistrationJSON,
+ })
+ if err != nil {
+ response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
+ return
+ }
+ response.Success(c, tokenInfo)
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index f7503c2e..a9cc8044 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -221,6 +221,12 @@ func AccountFromServiceShallow(a *service.Account) *Account {
OverloadUntil: a.OverloadUntil,
TempUnschedulableUntil: a.TempUnschedulableUntil,
TempUnschedulableReason: a.TempUnschedulableReason,
+ KiroQuotaState: a.KiroQuotaState,
+ KiroQuotaReason: a.KiroQuotaReason,
+ KiroQuotaResetAt: a.KiroQuotaResetAt,
+ KiroRuntimeState: a.KiroRuntimeState,
+ KiroRuntimeReason: a.KiroRuntimeReason,
+ KiroRuntimeResetAt: a.KiroRuntimeResetAt,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 5cc2f8e4..6c57f377 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -174,6 +174,12 @@ type Account struct {
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
+ KiroQuotaState string `json:"kiro_quota_state,omitempty"`
+ KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
+ KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
+ KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
+ KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
+ KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index 13e3ac88..3253b9ce 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -17,6 +17,7 @@ type AdminHandlers struct {
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
AntigravityOAuth *admin.AntigravityOAuthHandler
+ KiroOAuth *admin.KiroOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index a8725875..aef8768a 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
+ kiroOAuthHandler *admin.KiroOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
@@ -51,6 +52,7 @@ func ProvideAdminHandlers(
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
AntigravityOAuth: antigravityOAuthHandler,
+ KiroOAuth: kiroOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Promo: promoHandler,
@@ -154,6 +156,7 @@ var ProviderSet = wire.NewSet(
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
admin.NewAntigravityOAuthHandler,
+ admin.NewKiroOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewPromoHandler,
diff --git a/backend/internal/pkg/kiro/fingerprint.go b/backend/internal/pkg/kiro/fingerprint.go
new file mode 100644
index 00000000..461b3411
--- /dev/null
+++ b/backend/internal/pkg/kiro/fingerprint.go
@@ -0,0 +1,258 @@
+package kiro
+
+import (
+ "crypto/sha256"
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "math/rand"
+ "runtime"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+type RuntimeFingerprint struct {
+ OIDCSDKVersion string
+ RuntimeSDKVersion string
+ StreamingSDKVersion string
+ OSType string
+ OSVersion string
+ NodeVersion string
+ KiroVersion string
+ KiroHash string
+}
+
+type runtimeFingerprintManager struct {
+ mu sync.RWMutex
+ fingerprints map[string]*RuntimeFingerprint
+}
+
+var (
+ globalRuntimeFingerprintManager *runtimeFingerprintManager
+ globalRuntimeFingerprintManagerOnce sync.Once
+
+ oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
+ runtimeSDKVersions = []string{"1.0.0"}
+ streamingSDKVersions = []string{"1.0.34"}
+ osTypes = []string{"darwin", "win32"}
+ osVersions = map[string][]string{
+ "darwin": {"24.6.0"},
+ "win32": {"10.0.22631"},
+ }
+ nodeVersions = []string{"22.22.0"}
+ kiroVersions = []string{
+ "0.11.132", "0.11.131", "0.11.130",
+ }
+)
+
+func globalRuntimeFingerprints() *runtimeFingerprintManager {
+ globalRuntimeFingerprintManagerOnce.Do(func() {
+ globalRuntimeFingerprintManager = &runtimeFingerprintManager{
+ fingerprints: make(map[string]*RuntimeFingerprint),
+ }
+ })
+ return globalRuntimeFingerprintManager
+}
+
+func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
+ lookupKey := fingerprintLookupKey(accountKey, "runtime")
+ machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
+
+ m.mu.RLock()
+ if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
+ m.mu.RUnlock()
+ return fp
+ }
+ m.mu.RUnlock()
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
+ return fp
+ }
+ fp := generateRuntimeFingerprint(lookupKey, machineID)
+ m.fingerprints[lookupKey] = fp
+ return fp
+}
+
+func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
+ hash := sha256.Sum256([]byte(accountKey))
+ seed := int64(binary.BigEndian.Uint64(hash[:8]))
+ rng := rand.New(rand.NewSource(seed))
+
+ osType := goOSToNodePlatform(runtime.GOOS)
+ if !containsString(osTypes, osType) {
+ osType = osTypes[rng.Intn(len(osTypes))]
+ }
+ osVersionPool := osVersions[osType]
+ if len(osVersionPool) == 0 {
+ osVersionPool = osVersions["darwin"]
+ }
+
+ return &RuntimeFingerprint{
+ OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
+ RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
+ StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
+ OSType: osType,
+ OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
+ NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
+ KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
+ KiroHash: machineID,
+ }
+}
+
+func goOSToNodePlatform(goos string) string {
+ switch strings.TrimSpace(goos) {
+ case "windows":
+ return "win32"
+ default:
+ return strings.TrimSpace(goos)
+ }
+}
+
+func containsString(items []string, target string) bool {
+ for _, item := range items {
+ if item == target {
+ return true
+ }
+ }
+ return false
+}
+
+func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
+ switch {
+ case strings.TrimSpace(clientIDHash) != "":
+ return clientIDHash
+ case strings.TrimSpace(clientID) != "":
+ return shortSHA(clientID)
+ case strings.TrimSpace(refreshToken) != "":
+ return shortSHA(refreshToken)
+ case strings.TrimSpace(profileArn) != "":
+ return shortSHA(profileArn)
+ case accountID > 0:
+ return shortSHA(fmt.Sprintf("account:%d", accountID))
+ default:
+ return shortSHA(uuid.NewString())
+ }
+}
+
+func NormalizeMachineID(machineID string) (string, bool) {
+ trimmed := strings.TrimSpace(machineID)
+ if len(trimmed) == 64 && isHexString(trimmed) {
+ return strings.ToLower(trimmed), true
+ }
+ withoutDashes := strings.ReplaceAll(trimmed, "-", "")
+ if len(withoutDashes) == 32 && isHexString(withoutDashes) {
+ normalized := strings.ToLower(withoutDashes)
+ return normalized + normalized, true
+ }
+ return "", false
+}
+
+func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
+ if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
+ return sha256Hex("KotlinNativeAPI/" + refreshToken)
+ }
+ if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
+ return sha256Hex("KiroAPIKey/" + apiKey)
+ }
+ if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
+ return sha256Hex("KiroFallback/" + fallbackKey)
+ }
+ return sha256Hex("KiroFallback/default")
+}
+
+func shortSHA(seed string) string {
+ sum := sha256.Sum256([]byte(seed))
+ return hex.EncodeToString(sum[:8])
+}
+
+func sha256Hex(seed string) string {
+ sum := sha256.Sum256([]byte(seed))
+ return hex.EncodeToString(sum[:])
+}
+
+func isHexString(value string) bool {
+ for _, c := range value {
+ if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
+ return false
+ }
+ }
+ return true
+}
+
+func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
+ if normalized, ok := NormalizeMachineID(machineID); ok {
+ return normalized
+ }
+ return BuildMachineID("", "", fallbackKey)
+}
+
+func fingerprintLookupKey(accountKey, fallback string) string {
+ key := strings.TrimSpace(accountKey)
+ if key != "" {
+ return key
+ }
+ return fallback
+}
+
+func BuildRuntimeUserAgent(accountKey, machineID string) string {
+ fp := globalRuntimeFingerprints().Get(accountKey, machineID)
+ return fmt.Sprintf(
+ "aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
+ fp.StreamingSDKVersion,
+ fp.OSType,
+ fp.OSVersion,
+ fp.NodeVersion,
+ fp.StreamingSDKVersion,
+ fp.KiroVersion,
+ fp.KiroHash,
+ )
+}
+
+func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
+ fp := globalRuntimeFingerprints().Get(accountKey, machineID)
+ return fmt.Sprintf(
+ "aws-sdk-js/%s KiroIDE-%s-%s",
+ fp.StreamingSDKVersion,
+ fp.KiroVersion,
+ fp.KiroHash,
+ )
+}
+
+func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
+ fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
+ return map[string]string{
+ "Content-Type": "application/json",
+ "x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
+ "User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
+ "amz-sdk-invocation-id": uuid.NewString(),
+ "amz-sdk-request": "attempt=1; max=4",
+ }
+}
+
+func BuildLoginHeaders(accountKey, machineID string) map[string]string {
+ fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
+ return map[string]string{
+ "Content-Type": "application/json",
+ "User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
+ "Accept": "application/json, text/plain, */*",
+ }
+}
+
+func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
+ if attempt < 0 {
+ attempt = 0
+ }
+ delay := baseDelay << attempt
+ if delay > maxDelay {
+ delay = maxDelay
+ }
+ const jitterFactor = 0.3
+ seed := rand.New(rand.NewSource(time.Now().UnixNano()))
+ jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
+ return time.Duration(float64(delay) * jitter)
+}
diff --git a/backend/internal/pkg/kiro/fingerprint_test.go b/backend/internal/pkg/kiro/fingerprint_test.go
new file mode 100644
index 00000000..16eec045
--- /dev/null
+++ b/backend/internal/pkg/kiro/fingerprint_test.go
@@ -0,0 +1,91 @@
+package kiro
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildLoginHeadersStable(t *testing.T) {
+ headers1 := BuildLoginHeaders("", "")
+ headers2 := BuildLoginHeaders("", "")
+
+ require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
+ require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
+ require.Equal(t, "application/json", headers1["Content-Type"])
+ require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
+ require.Contains(t, headers1["User-Agent"], "KiroIDE-")
+}
+
+func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
+ machineIDA := BuildMachineID("refresh-a", "", "")
+ machineIDB := BuildMachineID("refresh-b", "", "")
+ headers1 := BuildLoginHeaders("account-a", machineIDA)
+ headers2 := BuildLoginHeaders("account-a", machineIDA)
+ headers3 := BuildLoginHeaders("account-a", machineIDB)
+
+ require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
+ require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
+ require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
+ require.Contains(t, headers1["User-Agent"], machineIDA)
+}
+
+func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
+ machineID := BuildMachineID("", "", "oidc-machine")
+ headers1 := BuildOIDCHeaders("account-a", machineID)
+ headers2 := BuildOIDCHeaders("account-a", machineID)
+ headers3 := BuildOIDCHeaders("account-b", machineID)
+
+ require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
+ require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
+ require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
+}
+
+func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
+ key1 := BuildAccountKey("", "", "", "", 42)
+ key2 := BuildAccountKey("", "", "", "", 42)
+ key3 := BuildAccountKey("", "", "", "", 43)
+
+ require.Equal(t, key1, key2)
+ require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
+ require.NotEqual(t, key1, key3)
+}
+
+func TestBuildMachineID(t *testing.T) {
+ require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
+ require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
+ require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
+
+ fallback1 := BuildMachineID("", "", "account:1")
+ fallback2 := BuildMachineID("", "", "account:1")
+ fallback3 := BuildMachineID("", "", "account:2")
+ require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
+ require.Equal(t, fallback1, fallback2)
+ require.NotEqual(t, fallback1, fallback3)
+ require.Len(t, fallback1, 64)
+}
+
+func TestNormalizeMachineID(t *testing.T) {
+ hex64 := strings.Repeat("A", 64)
+ normalized, ok := NormalizeMachineID(hex64)
+ require.True(t, ok)
+ require.Equal(t, strings.ToLower(hex64), normalized)
+
+ normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
+ require.True(t, ok)
+ require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
+
+ _, ok = NormalizeMachineID("not-a-machine-id")
+ require.False(t, ok)
+ _, ok = NormalizeMachineID(strings.Repeat("g", 64))
+ require.False(t, ok)
+}
+
+func expectedKiroMachineID(seed string) string {
+ sum := sha256.Sum256([]byte(seed))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/pkg/kiro/models.go b/backend/internal/pkg/kiro/models.go
new file mode 100644
index 00000000..ca2c5dc7
--- /dev/null
+++ b/backend/internal/pkg/kiro/models.go
@@ -0,0 +1,21 @@
+package kiro
+
+type Model struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ DisplayName string `json:"display_name"`
+ CreatedAt string `json:"created_at"`
+}
+
+var DefaultModels = []Model{
+ {ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
+ {ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
+ {ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
+ {ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
+ {ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
+ {ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
+ {ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
+ {ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
+ {ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
+ {ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
+}
diff --git a/backend/internal/pkg/kiro/models_test.go b/backend/internal/pkg/kiro/models_test.go
new file mode 100644
index 00000000..95451090
--- /dev/null
+++ b/backend/internal/pkg/kiro/models_test.go
@@ -0,0 +1,43 @@
+package kiro
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
+ ids := make([]string, 0, len(DefaultModels))
+ for _, model := range DefaultModels {
+ ids = append(ids, model.ID)
+ }
+
+ require.Equal(t, []string{
+ "claude-opus-4-6",
+ "claude-opus-4-6-thinking",
+ "claude-sonnet-4-6",
+ "claude-sonnet-4-6-thinking",
+ "claude-opus-4-5-20251101",
+ "claude-opus-4-5-20251101-thinking",
+ "claude-sonnet-4-5-20250929",
+ "claude-sonnet-4-5-20250929-thinking",
+ "claude-haiku-4-5-20251001",
+ "claude-haiku-4-5-20251001-thinking",
+ }, ids)
+
+ require.Contains(t, ids, "claude-sonnet-4-6")
+ require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
+ require.NotContains(t, ids, "auto")
+ require.NotContains(t, ids, "claude-sonnet-4")
+ require.NotContains(t, ids, "gpt-4o")
+ require.NotContains(t, ids, "deepseek-3-2")
+ require.NotContains(t, ids, "minimax-m2-1")
+ require.NotContains(t, ids, "qwen3-coder-next")
+ require.NotContains(t, ids, "claude-opus-4-7")
+ require.NotContains(t, ids, "claude-sonnet-4-6-chat")
+ for _, id := range ids {
+ require.NotContains(t, id, "kiro-")
+ require.NotContains(t, id, "-agentic")
+ require.NotContains(t, id, "-chat")
+ }
+}
diff --git a/backend/internal/pkg/kiro/oauth.go b/backend/internal/pkg/kiro/oauth.go
new file mode 100644
index 00000000..2a6e1338
--- /dev/null
+++ b/backend/internal/pkg/kiro/oauth.go
@@ -0,0 +1,511 @@
+package kiro
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
+ "github.com/google/uuid"
+)
+
+const (
+ socialAuthPortalURL = "https://app.kiro.dev"
+ socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
+ defaultIDCRegion = "us-east-1"
+ BuilderIDStartURL = "https://view.awsapps.com/start"
+ sessionTTL = 10 * time.Minute
+ sessionCleanupEvery = 32
+ sessionCleanupMin = 32
+)
+
+var (
+ socialAuthEndpointURL = socialAuthEndpoint
+ oidcEndpointOverride = ""
+)
+
+type SocialProvider string
+
+const (
+ SocialProviderGoogle SocialProvider = "Google"
+ SocialProviderGitHub SocialProvider = "Github"
+)
+
+type AuthSession struct {
+ State string
+ CodeVerifier string
+ ProxyURL string
+ CreatedAt time.Time
+ AuthType string
+ Provider string
+ RedirectURI string
+ ClientID string
+ ClientSecret string
+ Region string
+ StartURL string
+}
+
+type SessionStore struct {
+ mu sync.RWMutex
+ data map[string]*AuthSession
+ setCount uint64
+}
+
+func NewSessionStore() *SessionStore {
+ return &SessionStore{data: make(map[string]*AuthSession)}
+}
+
+func (s *SessionStore) Get(id string) (*AuthSession, bool) {
+ now := time.Now()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ session, ok := s.data[id]
+ if ok && sessionExpired(session, now) {
+ delete(s.data, id)
+ return nil, false
+ }
+ return session, ok
+}
+
+func (s *SessionStore) Set(id string, session *AuthSession) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.setCount++
+ if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
+ s.pruneExpiredLocked(time.Now())
+ }
+ s.data[id] = session
+}
+
+func (s *SessionStore) Delete(id string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.data, id)
+}
+
+func (s *SessionStore) pruneExpiredLocked(now time.Time) {
+ for id, session := range s.data {
+ if sessionExpired(session, now) {
+ delete(s.data, id)
+ }
+ }
+}
+
+func sessionExpired(session *AuthSession, now time.Time) bool {
+ if session == nil {
+ return true
+ }
+ if session.CreatedAt.IsZero() {
+ return true
+ }
+ return now.After(session.CreatedAt.Add(sessionTTL))
+}
+
+type TokenData struct {
+ AccessToken string `json:"accessToken"`
+ RefreshToken string `json:"refreshToken"`
+ ProfileArn string `json:"profileArn,omitempty"`
+ ExpiresAt string `json:"expiresAt,omitempty"`
+ AuthMethod string `json:"authMethod,omitempty"`
+ Provider string `json:"provider,omitempty"`
+ ClientID string `json:"clientId,omitempty"`
+ ClientSecret string `json:"clientSecret,omitempty"`
+ ClientIDHash string `json:"clientIdHash,omitempty"`
+ Email string `json:"email,omitempty"`
+ StartURL string `json:"startUrl,omitempty"`
+ Region string `json:"region,omitempty"`
+}
+
+type socialTokenResponse struct {
+ AccessToken string `json:"accessToken"`
+ RefreshToken string `json:"refreshToken"`
+ ProfileArn string `json:"profileArn"`
+ ExpiresIn int `json:"expiresIn"`
+}
+
+type registerClientResponse struct {
+ ClientID string `json:"clientId"`
+ ClientSecret string `json:"clientSecret"`
+}
+
+type createTokenResponse struct {
+ AccessToken string `json:"accessToken"`
+ RefreshToken string `json:"refreshToken"`
+ ProfileArn string `json:"profileArn"`
+ ExpiresIn int `json:"expiresIn"`
+}
+
+type userInfoResponse struct {
+ Email string `json:"email"`
+}
+
+type deviceRegistration struct {
+ ClientID string `json:"clientId"`
+ ClientSecret string `json:"clientSecret"`
+}
+
+type RefreshTokenInvalidError struct {
+ StatusCode int
+ Body string
+}
+
+func (e *RefreshTokenInvalidError) Error() string {
+ if e == nil {
+ return ""
+ }
+ body := strings.TrimSpace(e.Body)
+ if body == "" {
+ return "kiro refresh token invalid (invalid_grant)"
+ }
+ return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
+}
+
+func GenerateSessionID() string {
+ return uuid.NewString()
+}
+
+func GenerateState() (string, error) {
+ return randomURLSafe(16)
+}
+
+func GenerateCodeVerifier() (string, error) {
+ return randomURLSafe(32)
+}
+
+func randomURLSafe(n int) (string, error) {
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(buf), nil
+}
+
+func GenerateCodeChallenge(verifier string) string {
+ sum := sha256.Sum256([]byte(verifier))
+ return base64.RawURLEncoding.EncodeToString(sum[:])
+}
+
+func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
+ params := url.Values{}
+ params.Set("state", state)
+ params.Set("code_challenge", codeChallenge)
+ params.Set("code_challenge_method", "S256")
+ params.Set("redirect_uri", redirectURI)
+ params.Set("redirect_from", "KiroIDE")
+ return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
+}
+
+func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
+ redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
+ if redirectURI == "" {
+ return ""
+ }
+ path := strings.TrimSpace(callbackPath)
+ if path == "" {
+ path = "/oauth/callback"
+ } else if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ fullRedirectURI := redirectURI + path
+ if option := strings.TrimSpace(loginOption); option != "" {
+ return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
+ }
+ return fullRedirectURI
+}
+
+func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
+ payload := map[string]string{
+ "code": code,
+ "code_verifier": codeVerifier,
+ "redirect_uri": redirectURI,
+ }
+ var resp socialTokenResponse
+ if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
+ return nil, err
+ }
+ expiresIn := resp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ return &TokenData{
+ AccessToken: resp.AccessToken,
+ RefreshToken: resp.RefreshToken,
+ ProfileArn: resp.ProfileArn,
+ ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
+ AuthMethod: "social",
+ Region: defaultIDCRegion,
+ }, nil
+}
+
+func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
+ payload := map[string]string{
+ "refreshToken": refreshToken,
+ }
+ var resp socialTokenResponse
+ accountKey := BuildAccountKey("", "", refreshToken, "", 0)
+ if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
+ return nil, err
+ }
+ expiresIn := resp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ return &TokenData{
+ AccessToken: resp.AccessToken,
+ RefreshToken: resp.RefreshToken,
+ ProfileArn: resp.ProfileArn,
+ ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
+ AuthMethod: "social",
+ Provider: provider,
+ Region: defaultIDCRegion,
+ }, nil
+}
+
+func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
+ if region == "" {
+ region = defaultIDCRegion
+ }
+ payload := map[string]any{
+ "clientName": "Kiro IDE",
+ "clientType": "public",
+ "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
+ "grantTypes": []string{"authorization_code", "refresh_token"},
+ "redirectUris": []string{redirectURI},
+ "issuerUrl": issuerURL,
+ }
+ var resp registerClientResponse
+ headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
+ if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
+ if region == "" {
+ region = defaultIDCRegion
+ }
+ params := url.Values{}
+ params.Set("response_type", "code")
+ params.Set("client_id", clientID)
+ params.Set("redirect_uri", redirectURI)
+ params.Set("scopes", strings.Join([]string{
+ "codewhisperer:completions",
+ "codewhisperer:analysis",
+ "codewhisperer:conversations",
+ "codewhisperer:transformations",
+ "codewhisperer:taskassist",
+ }, " "))
+ params.Set("state", state)
+ params.Set("code_challenge", codeChallenge)
+ params.Set("code_challenge_method", "S256")
+ return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
+}
+
+func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
+ if region == "" {
+ region = defaultIDCRegion
+ }
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "code": code,
+ "codeVerifier": codeVerifier,
+ "redirectUri": redirectURI,
+ "grantType": "authorization_code",
+ }
+ var resp createTokenResponse
+ accountKey := BuildAccountKey(clientID, "", "", "", 0)
+ headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
+ if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
+ return nil, err
+ }
+ expiresIn := resp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ token := &TokenData{
+ AccessToken: resp.AccessToken,
+ RefreshToken: resp.RefreshToken,
+ ProfileArn: resp.ProfileArn,
+ ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
+ AuthMethod: "idc",
+ Provider: "AWS",
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ StartURL: startURL,
+ Region: region,
+ }
+ token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
+ return token, nil
+}
+
+func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
+ if region == "" {
+ region = defaultIDCRegion
+ }
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "refreshToken": refreshToken,
+ "grantType": "refresh_token",
+ }
+ var resp createTokenResponse
+ accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
+ headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
+ if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
+ return nil, err
+ }
+ expiresIn := resp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ token := &TokenData{
+ AccessToken: resp.AccessToken,
+ RefreshToken: resp.RefreshToken,
+ ProfileArn: resp.ProfileArn,
+ ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
+ AuthMethod: "idc",
+ Provider: "AWS",
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ StartURL: startURL,
+ Region: region,
+ }
+ token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
+ return token, nil
+}
+
+func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
+ if strings.TrimSpace(accessToken) == "" {
+ return ""
+ }
+ var resp userInfoResponse
+ headers := map[string]string{
+ "Authorization": "Bearer " + accessToken,
+ }
+ if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
+ return ""
+ }
+ return strings.TrimSpace(resp.Email)
+}
+
+func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
+ var token TokenData
+ if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
+ return nil, fmt.Errorf("failed to parse kiro token: %w", err)
+ }
+ token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
+ if strings.TrimSpace(token.AccessToken) == "" {
+ return nil, fmt.Errorf("access token is empty")
+ }
+ if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
+ var reg deviceRegistration
+ if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil {
+ return nil, fmt.Errorf("failed to parse device registration: %w", err)
+ }
+ if reg.ClientID != "" {
+ token.ClientID = reg.ClientID
+ }
+ if reg.ClientSecret != "" {
+ token.ClientSecret = reg.ClientSecret
+ }
+ }
+ return &token, nil
+}
+
+func getOIDCEndpoint(region string) string {
+ if strings.TrimSpace(oidcEndpointOverride) != "" {
+ return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
+ }
+ if region == "" {
+ region = defaultIDCRegion
+ }
+ return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
+}
+
+func oidcHeaders(accountKey, machineID string) map[string]string {
+ headers := BuildOIDCHeaders(accountKey, machineID)
+ if headers["amz-sdk-invocation-id"] == "" {
+ headers["amz-sdk-invocation-id"] = uuid.NewString()
+ }
+ if headers["amz-sdk-request"] == "" {
+ headers["amz-sdk-request"] = "attempt=1; max=4"
+ }
+ return headers
+}
+
+func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
+ client, err := newHTTPClient(proxyURL)
+ if err != nil {
+ return err
+ }
+
+ var body io.Reader
+ if payload != nil {
+ encoded, err := json.Marshal(payload)
+ if err != nil {
+ return err
+ }
+ body = bytes.NewReader(encoded)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
+ if err != nil {
+ return err
+ }
+
+ if payload != nil {
+ req.Header.Set("Content-Type", "application/json")
+ }
+ for key, value := range extraHeaders {
+ req.Header.Set(key, value)
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ bodyText := strings.TrimSpace(string(respBody))
+ if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
+ return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
+ }
+ return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
+ }
+ if out == nil || len(respBody) == 0 {
+ return nil
+ }
+ return json.Unmarshal(respBody, out)
+}
+
+func newHTTPClient(rawProxyURL string) (*http.Client, error) {
+ _, parsed, err := proxyurl.Parse(rawProxyURL)
+ if err != nil {
+ return nil, err
+ }
+ transport := &http.Transport{}
+ if parsed != nil {
+ transport.Proxy = http.ProxyURL(parsed)
+ }
+ return &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: transport,
+ }, nil
+}
diff --git a/backend/internal/pkg/kiro/oauth_invalid_grant_test.go b/backend/internal/pkg/kiro/oauth_invalid_grant_test.go
new file mode 100644
index 00000000..6803ae6d
--- /dev/null
+++ b/backend/internal/pkg/kiro/oauth_invalid_grant_test.go
@@ -0,0 +1,105 @@
+package kiro
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "/refreshToken", r.URL.Path)
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
+ }))
+ defer server.Close()
+
+ previous := socialAuthEndpointURL
+ socialAuthEndpointURL = server.URL
+ t.Cleanup(func() { socialAuthEndpointURL = previous })
+
+ _, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
+ require.Error(t, err)
+
+ var invalid *RefreshTokenInvalidError
+ require.True(t, errors.As(err, &invalid))
+ require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
+ require.Contains(t, invalid.Body, "invalid_grant")
+}
+
+func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "/token", r.URL.Path)
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
+ }))
+ defer server.Close()
+
+ previous := oidcEndpointOverride
+ oidcEndpointOverride = server.URL
+ t.Cleanup(func() { oidcEndpointOverride = previous })
+
+ _, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
+ require.Error(t, err)
+
+ var invalid *RefreshTokenInvalidError
+ require.True(t, errors.As(err, &invalid))
+ require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
+ require.Contains(t, invalid.Body, "invalid_grant")
+}
+
+func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
+ const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
+ default:
+ t.Fatalf("unexpected path: %s", r.URL.Path)
+ }
+ }))
+ defer server.Close()
+
+ previous := oidcEndpointOverride
+ oidcEndpointOverride = server.URL
+ t.Cleanup(func() { oidcEndpointOverride = previous })
+
+ token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
+ require.NoError(t, err)
+ require.Equal(t, profileArn, token.ProfileArn)
+ require.Equal(t, "kiro@example.com", token.Email)
+}
+
+func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
+ const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
+ default:
+ t.Fatalf("unexpected path: %s", r.URL.Path)
+ }
+ }))
+ defer server.Close()
+
+ previous := oidcEndpointOverride
+ oidcEndpointOverride = server.URL
+ t.Cleanup(func() { oidcEndpointOverride = previous })
+
+ token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
+ require.NoError(t, err)
+ require.Equal(t, profileArn, token.ProfileArn)
+ require.Equal(t, "kiro@example.com", token.Email)
+}
diff --git a/backend/internal/pkg/kiro/oauth_test.go b/backend/internal/pkg/kiro/oauth_test.go
new file mode 100644
index 00000000..b6b9b52d
--- /dev/null
+++ b/backend/internal/pkg/kiro/oauth_test.go
@@ -0,0 +1,56 @@
+//go:build unit
+
+package kiro
+
+import (
+ "fmt"
+ "testing"
+ "time"
+)
+
+func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
+ got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
+ want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
+ if got != want {
+ t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
+ }
+}
+
+func TestBuildSocialTokenRedirectURI(t *testing.T) {
+ got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
+ want := "http://localhost:49153/oauth/callback?login_option=github"
+ if got != want {
+ t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
+ }
+}
+
+func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
+ store := NewSessionStore()
+ store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
+
+ session, ok := store.Get("expired")
+ if ok || session != nil {
+ t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
+ }
+ if _, exists := store.data["expired"]; exists {
+ t.Fatalf("expired session should be deleted from the store")
+ }
+}
+
+func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
+ store := NewSessionStore()
+ now := time.Now()
+ for i := 0; i < sessionCleanupMin; i++ {
+ store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
+ }
+ store.setCount = sessionCleanupEvery - 1
+
+ store.Set("fresh", &AuthSession{CreatedAt: now})
+
+ if len(store.data) != 1 {
+ t.Fatalf("store size = %d, want 1", len(store.data))
+ }
+ if _, ok := store.data["fresh"]; !ok {
+ t.Fatalf("fresh session should remain after pruning")
+ }
+}
diff --git a/backend/internal/pkg/kiro/translator.go b/backend/internal/pkg/kiro/translator.go
new file mode 100644
index 00000000..2d26745c
--- /dev/null
+++ b/backend/internal/pkg/kiro/translator.go
@@ -0,0 +1,2730 @@
+package kiro
+
+import (
+ "bufio"
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+
+ "github.com/google/uuid"
+ "github.com/tidwall/gjson"
+)
+
+const (
+ kiroMaxToolDescLen = 10237
+ kiroMaxToolNameLen = 63
+ thinkingStartTag = ""
+ thinkingEndTag = ""
+ embeddedToolCallPrefix = "[Called "
+ minFrameSize = 16
+ maxEventMsgSize = 10 << 20
+ writeToolDescriptionSuffix = "IMPORTANT: If the content to write exceeds 150 lines, write only the first 50 lines with this tool, then append the remaining content using Edit calls in chunks of no more than 50 lines. Use a unique placeholder if needed. Do not write the whole file in one call."
+ editToolDescriptionSuffix = "IMPORTANT: If new content exceeds 50 lines, split it into multiple Edit calls, replacing or appending no more than 50 lines per call. If appending, use a unique placeholder and remove it in the final chunk."
+ systemChunkedWritePolicy = "When Write or Edit tools include chunking limits, comply silently and complete the operation through multiple tool calls when needed."
+)
+
+var (
+ trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`)
+ requiredToolFields = map[string][][]string{
+ "write": {{"file_path", "path"}, {"content"}},
+ "write_to_file": {{"path"}, {"content"}},
+ "fswrite": {{"path"}, {"content"}},
+ "create_file": {{"path"}, {"content"}},
+ "edit_file": {{"path"}},
+ "apply_diff": {{"path"}, {"diff"}},
+ "str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
+ "bash": {{"cmd", "command"}},
+ "execute": {{"command"}},
+ "run_command": {{"command"}},
+ }
+)
+
+type Usage struct {
+ InputTokens int
+ OutputTokens int
+ TotalTokens int
+ CacheReadInputTokens int
+}
+
+type StreamResult struct {
+ Usage Usage
+ StopReason string
+ FirstDeltaDur *time.Duration
+}
+
+type ParseResult struct {
+ ResponseBody []byte
+ Usage Usage
+ StopReason string
+}
+
+type KiroRequestContext struct {
+ ToolNameMap map[string]string
+ ThinkingEnabled bool
+}
+
+type KiroBuildResult struct {
+ Payload []byte
+ Context KiroRequestContext
+}
+
+type KiroPayload struct {
+ ConversationState KiroConversationState `json:"conversationState"`
+ ProfileArn string `json:"profileArn,omitempty"`
+ InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
+}
+
+type KiroInferenceConfig struct {
+ MaxTokens int `json:"maxTokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+}
+
+type thinkingDirective struct {
+ Mode string
+ BudgetTokens int
+ Effort string
+}
+
+type KiroConversationState struct {
+ AgentContinuationID string `json:"agentContinuationId,omitempty"`
+ AgentTaskType string `json:"agentTaskType,omitempty"`
+ ChatTriggerType string `json:"chatTriggerType"`
+ ConversationID string `json:"conversationId"`
+ CurrentMessage KiroCurrentMessage `json:"currentMessage"`
+ History []KiroHistoryMessage `json:"history,omitempty"`
+}
+
+type KiroCurrentMessage struct {
+ UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
+}
+
+type KiroHistoryMessage struct {
+ UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
+ AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
+}
+
+type KiroImage struct {
+ Format string `json:"format"`
+ Source KiroImageSource `json:"source"`
+}
+
+type KiroImageSource struct {
+ Bytes string `json:"bytes"`
+}
+
+type KiroUserInputMessage struct {
+ Content string `json:"content"`
+ ModelID string `json:"modelId"`
+ Origin string `json:"origin"`
+ Images []KiroImage `json:"images,omitempty"`
+ UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
+}
+
+type KiroUserInputMessageContext struct {
+ ToolResults []KiroToolResult `json:"toolResults,omitempty"`
+ Tools []KiroToolWrapper `json:"tools,omitempty"`
+}
+
+type KiroToolResult struct {
+ Content []KiroTextContent `json:"content"`
+ Status string `json:"status"`
+ ToolUseID string `json:"toolUseId"`
+}
+
+type KiroTextContent struct {
+ Text string `json:"text"`
+}
+
+type KiroToolWrapper struct {
+ ToolSpecification KiroToolSpecification `json:"toolSpecification"`
+}
+
+type KiroToolSpecification struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema KiroInputSchema `json:"inputSchema"`
+}
+
+type KiroInputSchema struct {
+ JSON interface{} `json:"json"`
+}
+
+type KiroAssistantResponseMessage struct {
+ Content string `json:"content"`
+ ToolUses []KiroToolUse `json:"toolUses,omitempty"`
+}
+
+type KiroToolUse struct {
+ ToolUseID string `json:"toolUseId"`
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
+ IsTruncated bool `json:"-"`
+ TruncatedRaw string `json:"-"`
+}
+
+type toolUseState struct {
+ ToolUseID string
+ Name string
+ InputBuffer strings.Builder
+}
+
+type eventStreamMessage struct {
+ EventType string
+ Payload []byte
+}
+
+func MapModel(model string) string {
+ switch strings.TrimSpace(strings.ToLower(model)) {
+ case "claude-opus-4-6", "claude-opus-4-6-thinking", "claude-opus-4.6":
+ return "claude-opus-4.6"
+ case "claude-sonnet-4-6", "claude-sonnet-4-6-thinking", "claude-sonnet-4.6":
+ return "claude-sonnet-4.6"
+ case "claude-opus-4-5-20251101", "claude-opus-4-5-20251101-thinking", "claude-opus-4.5":
+ return "claude-opus-4.5"
+ case "claude-sonnet-4-5-20250929", "claude-sonnet-4-5-20250929-thinking", "claude-sonnet-4.5":
+ return "claude-sonnet-4.5"
+ case "claude-haiku-4-5-20251001", "claude-haiku-4-5-20251001-thinking", "claude-haiku-4.5":
+ return "claude-haiku-4.5"
+ default:
+ return ""
+ }
+}
+
+func normalizeModelAlias(model string) string {
+ base := strings.TrimSpace(strings.ToLower(model))
+ for {
+ next := strings.TrimSuffix(base, "-thinking")
+ if next == base {
+ return next
+ }
+ base = next
+ }
+}
+
+func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, headers http.Header) ([]byte, error) {
+ result, err := BuildKiroPayloadWithContext(claudeBody, modelID, profileArn, origin, headers)
+ if err != nil {
+ return nil, err
+ }
+ return result.Payload, nil
+}
+
+func BuildKiroPayloadWithContext(claudeBody []byte, modelID, profileArn, origin string, headers http.Header) (*KiroBuildResult, error) {
+ const kiroMaxOutputTokens = 32000
+ requestCtx := KiroRequestContext{ToolNameMap: map[string]string{}}
+ var maxTokens int64
+ if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
+ maxTokens = mt.Int()
+ if maxTokens == -1 {
+ maxTokens = kiroMaxOutputTokens
+ }
+ }
+
+ var temperature float64
+ var hasTemperature bool
+ if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
+ temperature = temp.Float()
+ hasTemperature = true
+ }
+
+ var topP float64
+ var hasTopP bool
+ if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() {
+ topP = tp.Float()
+ hasTopP = true
+ }
+
+ messages := gjson.GetBytes(claudeBody, "messages")
+ thinking := deriveThinkingDirective(claudeBody, headers)
+ requestCtx.ThinkingEnabled = thinking != nil
+ toolChoiceHint := extractClaudeToolChoiceHint(claudeBody, &requestCtx)
+ systemPrompt := buildInjectedSystemPrompt(extractSystemPrompt(claudeBody), thinking, toolChoiceHint)
+
+ history, currentUserMsg, currentToolResults := processMessages(messages, modelID, normalizeOrigin(origin), &requestCtx)
+ history = prependSystemHistory(history, systemPrompt, modelID, normalizeOrigin(origin))
+ var tools gjson.Result
+ if !isToolChoiceNone(claudeBody) {
+ tools = gjson.GetBytes(claudeBody, "tools")
+ }
+ kiroTools := convertClaudeToolsToKiro(tools, &requestCtx)
+ currentToolResults, orphanedToolUseIDs := validateToolPairing(history, currentToolResults)
+ removeOrphanedToolUses(history, orphanedToolUseIDs)
+ kiroTools = appendMissingPlaceholderTools(kiroTools, collectHistoryToolNames(history))
+ if currentUserMsg != nil {
+ currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, currentToolResults)
+ currentToolResults = deduplicateToolResults(currentToolResults)
+ if len(kiroTools) > 0 || len(currentToolResults) > 0 {
+ currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
+ Tools: kiroTools,
+ ToolResults: currentToolResults,
+ }
+ }
+ }
+
+ var currentMessage KiroCurrentMessage
+ if currentUserMsg != nil {
+ currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
+ } else {
+ currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
+ Content: buildFinalContent("", nil),
+ ModelID: modelID,
+ Origin: normalizeOrigin(origin),
+ }}
+ }
+
+ var inferenceConfig *KiroInferenceConfig
+ if maxTokens > 0 || hasTemperature || hasTopP {
+ inferenceConfig = &KiroInferenceConfig{}
+ if maxTokens > 0 {
+ inferenceConfig.MaxTokens = int(maxTokens)
+ }
+ if hasTemperature {
+ inferenceConfig.Temperature = temperature
+ }
+ if hasTopP {
+ inferenceConfig.TopP = topP
+ }
+ }
+
+ conversationID := extractMetadataFromMessages(messages, "conversationId")
+ continuationID := extractMetadataFromMessages(messages, "continuationId")
+ if conversationID == "" {
+ conversationID = uuid.NewString()
+ }
+
+ payload := KiroPayload{
+ ConversationState: KiroConversationState{
+ AgentTaskType: "vibe",
+ ChatTriggerType: "MANUAL",
+ ConversationID: conversationID,
+ CurrentMessage: currentMessage,
+ History: history,
+ },
+ ProfileArn: profileArn,
+ InferenceConfig: inferenceConfig,
+ }
+ if continuationID != "" {
+ payload.ConversationState.AgentContinuationID = continuationID
+ }
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+ return &KiroBuildResult{Payload: payloadBytes, Context: requestCtx}, nil
+}
+
+func ParseNonStreamingEventStream(body io.Reader, model string) (*ParseResult, error) {
+ return ParseNonStreamingEventStreamWithContext(body, model, KiroRequestContext{})
+}
+
+func ParseNonStreamingEventStreamWithContext(body io.Reader, model string, requestCtx KiroRequestContext) (*ParseResult, error) {
+ content, toolUses, usage, stopReason, err := parseEventStream(body)
+ if err != nil {
+ return nil, err
+ }
+ return &ParseResult{
+ ResponseBody: buildClaudeResponse(content, toolUses, model, usage, stopReason, requestCtx),
+ Usage: usage,
+ StopReason: stopReason,
+ }, nil
+}
+
+func StreamEventStreamAsAnthropic(ctx context.Context, body io.Reader, w io.Writer, model string, inputTokens int) (*StreamResult, error) {
+ return StreamEventStreamAsAnthropicWithContext(ctx, body, w, model, inputTokens, KiroRequestContext{})
+}
+
+func StreamEventStreamAsAnthropicWithContext(ctx context.Context, body io.Reader, w io.Writer, model string, inputTokens int, requestCtx KiroRequestContext) (*StreamResult, error) {
+ reader := bufio.NewReader(body)
+ start := time.Now()
+ var firstDelta *time.Duration
+ usage := Usage{InputTokens: inputTokens}
+ contentBlockIndex := -1
+ thinkingBlockIndex := -1
+ messageStartSent := false
+ textBlockOpen := false
+ thinkingBlockOpen := false
+ processedIDs := make(map[string]bool)
+ emittedToolContents := make(map[string]bool)
+ streamingToolBlockIndices := make(map[string]int)
+ streamingToolStarted := make(map[string]bool)
+ streamingToolStopped := make(map[string]bool)
+ currentStreamingToolID := ""
+ pendingAssistantText := ""
+ pendingLeadingWhitespace := ""
+ stopReason := ""
+ thinkingBuffer := ""
+ inThinkingBlock := false
+ stripThinkingLeadingNewline := false
+ sawNonThinkingBlock := false
+
+ writeEvent := func(event string, data any) error {
+ payload, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ _, err = io.WriteString(w, "event: "+event+"\ndata: "+string(payload)+"\n\n")
+ return err
+ }
+ ensureMessageStart := func() error {
+ if messageStartSent {
+ return nil
+ }
+ if err := writeEvent("message_start", map[string]any{
+ "type": "message_start",
+ "message": map[string]any{
+ "id": "msg_" + uuid.NewString()[:24],
+ "type": "message",
+ "role": "assistant",
+ "content": []any{},
+ "model": model,
+ "stop_reason": nil,
+ "stop_sequence": nil,
+ "usage": map[string]any{
+ "input_tokens": usage.InputTokens,
+ "output_tokens": 0,
+ },
+ },
+ }); err != nil {
+ return err
+ }
+ messageStartSent = true
+ return nil
+ }
+
+ closeText := func() error {
+ if !textBlockOpen {
+ return nil
+ }
+ textBlockOpen = false
+ return writeEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": contentBlockIndex})
+ }
+ closeThinking := func() error {
+ if !thinkingBlockOpen {
+ return nil
+ }
+ thinkingBlockOpen = false
+ return writeEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": thinkingBlockIndex})
+ }
+ closeStreamingTool := func(toolUseID string) error {
+ if toolUseID == "" || !streamingToolStarted[toolUseID] || streamingToolStopped[toolUseID] {
+ return nil
+ }
+ streamingToolStopped[toolUseID] = true
+ if currentStreamingToolID == toolUseID {
+ currentStreamingToolID = ""
+ }
+ return writeEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": streamingToolBlockIndices[toolUseID]})
+ }
+ closeOpenStreamingTool := func() error {
+ return closeStreamingTool(currentStreamingToolID)
+ }
+ startStreamingToolUse := func(toolUseID, name string) error {
+ if toolUseID == "" || name == "" || streamingToolStopped[toolUseID] {
+ return nil
+ }
+ sawNonThinkingBlock = true
+ if currentStreamingToolID != "" && currentStreamingToolID != toolUseID {
+ if err := closeOpenStreamingTool(); err != nil {
+ return err
+ }
+ }
+ if stopReason == "" {
+ stopReason = "tool_use"
+ }
+ if err := ensureMessageStart(); err != nil {
+ return err
+ }
+ if firstDelta == nil {
+ delta := time.Since(start)
+ firstDelta = &delta
+ }
+ if err := closeThinking(); err != nil {
+ return err
+ }
+ if err := closeText(); err != nil {
+ return err
+ }
+ blockIndex, ok := streamingToolBlockIndices[toolUseID]
+ if !ok {
+ contentBlockIndex++
+ blockIndex = contentBlockIndex
+ streamingToolBlockIndices[toolUseID] = blockIndex
+ }
+ currentStreamingToolID = toolUseID
+ if streamingToolStarted[toolUseID] {
+ return nil
+ }
+ streamingToolStarted[toolUseID] = true
+ return writeEvent("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": blockIndex,
+ "content_block": map[string]any{
+ "type": "tool_use",
+ "id": toolUseID,
+ "name": restoreResponseToolName(name, requestCtx),
+ "input": map[string]any{},
+ },
+ })
+ }
+ emitStreamingToolInput := func(toolUseID, name, fragment string) error {
+ if fragment == "" {
+ return nil
+ }
+ if err := startStreamingToolUse(toolUseID, name); err != nil {
+ return err
+ }
+ if toolUseID == "" || !streamingToolStarted[toolUseID] || streamingToolStopped[toolUseID] {
+ return nil
+ }
+ return writeEvent("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": streamingToolBlockIndices[toolUseID],
+ "delta": map[string]any{
+ "type": "input_json_delta",
+ "partial_json": fragment,
+ },
+ })
+ }
+ processStreamingToolUseEvent := func(event map[string]interface{}) error {
+ tu := nestedEvent(event, "toolUseEvent")
+ toolUseID := getString(tu, "toolUseId")
+ name := getString(tu, "name")
+ if err := startStreamingToolUse(toolUseID, name); err != nil {
+ return err
+ }
+ if inputRaw, ok := tu["input"]; ok {
+ switch v := inputRaw.(type) {
+ case string:
+ if err := emitStreamingToolInput(toolUseID, name, v); err != nil {
+ return err
+ }
+ case map[string]interface{}:
+ encoded, err := json.Marshal(v)
+ if err != nil {
+ return err
+ }
+ if err := emitStreamingToolInput(toolUseID, name, string(encoded)); err != nil {
+ return err
+ }
+ }
+ }
+ isStop, _ := tu["stop"].(bool)
+ if isStop {
+ processedIDs[toolUseID] = true
+ if stopReason == "" {
+ stopReason = "tool_use"
+ }
+ return closeStreamingTool(toolUseID)
+ }
+ return nil
+ }
+ emitTextDelta := func(text string, allowWhitespace bool) error {
+ if text == "" || (!allowWhitespace && strings.TrimSpace(text) == "") {
+ return nil
+ }
+ if err := closeOpenStreamingTool(); err != nil {
+ return err
+ }
+ if !textBlockOpen && !allowWhitespace {
+ if pendingLeadingWhitespace != "" {
+ text = strings.TrimLeftFunc(pendingLeadingWhitespace+text, unicode.IsSpace)
+ pendingLeadingWhitespace = ""
+ if text == "" {
+ return nil
+ }
+ }
+ }
+ if err := ensureMessageStart(); err != nil {
+ return err
+ }
+ sawNonThinkingBlock = true
+ if firstDelta == nil {
+ delta := time.Since(start)
+ firstDelta = &delta
+ }
+ if err := closeThinking(); err != nil {
+ return err
+ }
+ if !textBlockOpen {
+ contentBlockIndex++
+ textBlockOpen = true
+ if err := writeEvent("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": contentBlockIndex,
+ "content_block": map[string]any{
+ "type": "text",
+ "text": "",
+ },
+ }); err != nil {
+ return err
+ }
+ }
+ return writeEvent("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": contentBlockIndex,
+ "delta": map[string]any{
+ "type": "text_delta",
+ "text": text,
+ },
+ })
+ }
+ emitToolUse := func(tool KiroToolUse) error {
+ if !shouldEmitToolUse(tool, emittedToolContents) {
+ return nil
+ }
+ sawNonThinkingBlock = true
+ if err := closeOpenStreamingTool(); err != nil {
+ return err
+ }
+ if err := ensureMessageStart(); err != nil {
+ return err
+ }
+ if err := closeText(); err != nil {
+ return err
+ }
+ if err := closeThinking(); err != nil {
+ return err
+ }
+ contentBlockIndex++
+ if err := writeEvent("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": contentBlockIndex,
+ "content_block": map[string]any{
+ "type": "tool_use",
+ "id": tool.ToolUseID,
+ "name": restoreResponseToolName(tool.Name, requestCtx),
+ "input": map[string]any{},
+ },
+ }); err != nil {
+ return err
+ }
+ inputJSON, _ := json.Marshal(tool.Input)
+ if err := writeEvent("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": contentBlockIndex,
+ "delta": map[string]any{
+ "type": "input_json_delta",
+ "partial_json": string(inputJSON),
+ },
+ }); err != nil {
+ return err
+ }
+ return writeEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": contentBlockIndex})
+ }
+ flushPendingAssistantText := func() error {
+ text, embeddedTools, pending := drainEmbeddedToolText(pendingAssistantText)
+ pendingAssistantText = pending
+ if err := emitTextDelta(text, false); err != nil {
+ return err
+ }
+ for _, tool := range embeddedTools {
+ if err := emitToolUse(tool); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ emitPlainAssistantText := func(text string) error {
+ if text == "" {
+ return nil
+ }
+ pendingAssistantText += text
+ return flushPendingAssistantText()
+ }
+ startThinkingBlock := func() error {
+ if err := closeOpenStreamingTool(); err != nil {
+ return err
+ }
+ if err := closeText(); err != nil {
+ return err
+ }
+ if err := ensureMessageStart(); err != nil {
+ return err
+ }
+ if firstDelta == nil {
+ delta := time.Since(start)
+ firstDelta = &delta
+ }
+ if thinkingBlockOpen {
+ return nil
+ }
+ contentBlockIndex++
+ thinkingBlockIndex = contentBlockIndex
+ thinkingBlockOpen = true
+ return writeEvent("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": thinkingBlockIndex,
+ "content_block": map[string]any{
+ "type": "thinking",
+ "thinking": "",
+ },
+ })
+ }
+ emitThinkingDelta := func(text string) error {
+ if !thinkingBlockOpen {
+ if err := startThinkingBlock(); err != nil {
+ return err
+ }
+ }
+ return writeEvent("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": thinkingBlockIndex,
+ "delta": map[string]any{
+ "type": "thinking_delta",
+ "thinking": text,
+ },
+ })
+ }
+ finishThinkingBlock := func() error {
+ if err := emitThinkingDelta(""); err != nil {
+ return err
+ }
+ return closeThinking()
+ }
+ processThinkingTaggedText := func(text string) error {
+ if text == "" {
+ return nil
+ }
+ thinkingBuffer += text
+ for {
+ if !inThinkingBlock {
+ startPos := findRealThinkingStartTag(thinkingBuffer, 0)
+ if startPos != -1 {
+ before := thinkingBuffer[:startPos]
+ if strings.TrimSpace(before) != "" {
+ if err := emitPlainAssistantText(before); err != nil {
+ return err
+ }
+ }
+ inThinkingBlock = true
+ stripThinkingLeadingNewline = true
+ thinkingBuffer = thinkingBuffer[startPos+len(thinkingStartTag):]
+ if err := startThinkingBlock(); err != nil {
+ return err
+ }
+ continue
+ }
+ safeLen := safeThinkingStreamFlushLen(thinkingBuffer, len(thinkingStartTag))
+ if safeLen > 0 {
+ safeText := thinkingBuffer[:safeLen]
+ if strings.TrimSpace(safeText) != "" {
+ if err := emitPlainAssistantText(safeText); err != nil {
+ return err
+ }
+ thinkingBuffer = thinkingBuffer[safeLen:]
+ }
+ }
+ break
+ }
+ if stripThinkingLeadingNewline {
+ if strings.HasPrefix(thinkingBuffer, "\n") {
+ thinkingBuffer = thinkingBuffer[1:]
+ stripThinkingLeadingNewline = false
+ } else if thinkingBuffer != "" {
+ stripThinkingLeadingNewline = false
+ }
+ }
+ endPos := findStreamThinkingEndTagStrict(thinkingBuffer, 0)
+ if endPos != -1 {
+ if thinkingText := thinkingBuffer[:endPos]; thinkingText != "" {
+ if err := emitThinkingDelta(thinkingText); err != nil {
+ return err
+ }
+ }
+ inThinkingBlock = false
+ if err := finishThinkingBlock(); err != nil {
+ return err
+ }
+ thinkingBuffer = thinkingBuffer[endPos+len(thinkingEndTag)+len("\n\n"):]
+ continue
+ }
+ safeLen := safeThinkingStreamFlushLen(thinkingBuffer, len(thinkingEndTag)+len("\n\n"))
+ if safeLen > 0 {
+ if err := emitThinkingDelta(thinkingBuffer[:safeLen]); err != nil {
+ return err
+ }
+ thinkingBuffer = thinkingBuffer[safeLen:]
+ }
+ break
+ }
+ return nil
+ }
+ flushThinkingAtBoundary := func() error {
+ if !requestCtx.ThinkingEnabled || thinkingBuffer == "" {
+ return nil
+ }
+ if inThinkingBlock {
+ endPos := findStreamThinkingEndTagAtBufferEnd(thinkingBuffer, 0)
+ if endPos != -1 {
+ if thinkingText := thinkingBuffer[:endPos]; thinkingText != "" {
+ if err := emitThinkingDelta(thinkingText); err != nil {
+ return err
+ }
+ }
+ afterPos := endPos + len(thinkingEndTag)
+ remaining := strings.TrimLeftFunc(thinkingBuffer[afterPos:], unicode.IsSpace)
+ thinkingBuffer = ""
+ inThinkingBlock = false
+ if err := finishThinkingBlock(); err != nil {
+ return err
+ }
+ return emitPlainAssistantText(remaining)
+ }
+ if err := emitThinkingDelta(thinkingBuffer); err != nil {
+ return err
+ }
+ thinkingBuffer = ""
+ inThinkingBlock = false
+ return finishThinkingBlock()
+ }
+ remaining := thinkingBuffer
+ thinkingBuffer = ""
+ return emitPlainAssistantText(remaining)
+ }
+ flushThinkingAtEOF := func() error {
+ if !requestCtx.ThinkingEnabled {
+ return nil
+ }
+ return flushThinkingAtBoundary()
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ msg, err := readEventStreamMessage(reader)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, err
+ }
+ if msg == nil || len(msg.Payload) == 0 {
+ continue
+ }
+
+ var event map[string]interface{}
+ if err := json.Unmarshal(msg.Payload, &event); err != nil {
+ continue
+ }
+
+ if sr := readStopReason(event); sr != "" {
+ stopReason = sr
+ }
+
+ switch msg.EventType {
+ case "assistantResponseEvent":
+ assistant := nestedEvent(event, "assistantResponseEvent")
+ if sr := readStopReason(assistant); sr != "" {
+ stopReason = sr
+ }
+ text := getString(assistant, "content")
+ if text == "" {
+ text = getString(event, "content")
+ }
+ if text != "" {
+ if requestCtx.ThinkingEnabled {
+ if err := processThinkingTaggedText(text); err != nil {
+ return nil, err
+ }
+ } else {
+ pendingAssistantText += text
+ if err := flushPendingAssistantText(); err != nil {
+ return nil, err
+ }
+ }
+ }
+ for _, tool := range readToolUses(assistant, event) {
+ if processedIDs[tool.ToolUseID] {
+ continue
+ }
+ processedIDs[tool.ToolUseID] = true
+ if err := flushThinkingAtBoundary(); err != nil {
+ return nil, err
+ }
+ if err := emitToolUse(tool); err != nil {
+ return nil, err
+ }
+ }
+ case "reasoningContentEvent":
+ reasoning := nestedEvent(event, "reasoningContentEvent")
+ text := getString(reasoning, "text")
+ if text == "" {
+ text = getString(event, "text")
+ }
+ if text == "" {
+ continue
+ }
+ if requestCtx.ThinkingEnabled {
+ wrapped := thinkingStartTag + text + thinkingEndTag + "\n\n"
+ if err := processThinkingTaggedText(wrapped); err != nil {
+ return nil, err
+ }
+ }
+ case "toolUseEvent":
+ if err := flushThinkingAtBoundary(); err != nil {
+ return nil, err
+ }
+ if err := processStreamingToolUseEvent(event); err != nil {
+ return nil, err
+ }
+ case "messageMetadataEvent", "metadataEvent", "supplementaryWebLinksEvent", "usageEvent", "messageStopEvent", "message_stop":
+ updateUsageFromEvent(&usage, msg.EventType, event)
+ default:
+ updateUsageFromEvent(&usage, msg.EventType, event)
+ }
+ }
+
+ if err := closeOpenStreamingTool(); err != nil {
+ return nil, err
+ }
+ if err := flushThinkingAtEOF(); err != nil {
+ return nil, err
+ }
+ if err := flushPendingAssistantText(); err != nil {
+ return nil, err
+ }
+ if requestCtx.ThinkingEnabled && thinkingBlockIndex != -1 && !sawNonThinkingBlock {
+ stopReason = "max_tokens"
+ if err := emitTextDelta(" ", true); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := closeText(); err != nil {
+ return nil, err
+ }
+ if err := closeThinking(); err != nil {
+ return nil, err
+ }
+ if usage.TotalTokens == 0 {
+ usage.TotalTokens = usage.InputTokens + usage.OutputTokens
+ }
+ if stopReason == "" {
+ if len(emittedToolContents) > 0 {
+ stopReason = "tool_use"
+ } else {
+ stopReason = "end_turn"
+ }
+ }
+ if err := ensureMessageStart(); err != nil {
+ return nil, err
+ }
+ if err := writeEvent("message_delta", map[string]any{
+ "type": "message_delta",
+ "delta": map[string]any{
+ "stop_reason": stopReason,
+ "stop_sequence": nil,
+ },
+ "usage": map[string]any{
+ "input_tokens": usage.InputTokens,
+ "output_tokens": usage.OutputTokens,
+ "cache_read_input_tokens": usage.CacheReadInputTokens,
+ "cache_creation_input_tokens": 0,
+ },
+ }); err != nil {
+ return nil, err
+ }
+ if err := writeEvent("message_stop", map[string]any{"type": "message_stop"}); err != nil {
+ return nil, err
+ }
+
+ return &StreamResult{
+ Usage: usage,
+ StopReason: stopReason,
+ FirstDeltaDur: firstDelta,
+ }, nil
+}
+
+func extractSystemPrompt(claudeBody []byte) string {
+ systemField := gjson.GetBytes(claudeBody, "system")
+ if systemField.IsArray() {
+ var sb strings.Builder
+ for _, block := range systemField.Array() {
+ if block.Get("type").String() == "text" {
+ _, _ = sb.WriteString(block.Get("text").String())
+ } else if block.Type == gjson.String {
+ _, _ = sb.WriteString(block.String())
+ }
+ }
+ return sb.String()
+ }
+ return systemField.String()
+}
+
+func isThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
+ return deriveThinkingDirective(body, headers) != nil
+}
+
+func deriveThinkingDirective(body []byte, headers http.Header) *thinkingDirective {
+ if override := thinkingDirectiveFromModel(gjson.GetBytes(body, "model").String()); override != nil {
+ return override
+ }
+ switch thinkingType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "thinking.type").String())); thinkingType {
+ case "adaptive":
+ effort := strings.TrimSpace(gjson.GetBytes(body, "output_config.effort").String())
+ if effort == "" {
+ effort = "high"
+ }
+ budget := int(gjson.GetBytes(body, "thinking.budget_tokens").Int())
+ if budget <= 0 {
+ budget = 20000
+ }
+ return &thinkingDirective{Mode: "adaptive", BudgetTokens: budget, Effort: effort}
+ case "enabled":
+ budget := int(gjson.GetBytes(body, "thinking.budget_tokens").Int())
+ if budget <= 0 {
+ budget = 16000
+ }
+ return &thinkingDirective{Mode: "enabled", BudgetTokens: budget}
+ }
+ if headers != nil {
+ if beta := headers.Get("Anthropic-Beta"); strings.Contains(beta, "interleaved-thinking") {
+ return &thinkingDirective{Mode: "enabled", BudgetTokens: 16000}
+ }
+ }
+ if effort := gjson.GetBytes(body, "reasoning_effort").String(); effort != "" && effort != "none" {
+ return &thinkingDirective{Mode: "enabled", BudgetTokens: 16000}
+ }
+ model := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "model").String()))
+ if strings.Contains(model, "-reason") {
+ return &thinkingDirective{Mode: "enabled", BudgetTokens: 16000}
+ }
+ return nil
+}
+
+func thinkingDirectiveFromModel(model string) *thinkingDirective {
+ model = strings.ToLower(strings.TrimSpace(model))
+ if !strings.Contains(model, "thinking") {
+ return nil
+ }
+
+ switch normalizeModelAlias(model) {
+ case "claude-opus-4-6", "claude-opus-4.6":
+ return &thinkingDirective{
+ Mode: "adaptive",
+ BudgetTokens: 20000,
+ Effort: "high",
+ }
+ default:
+ return &thinkingDirective{
+ Mode: "enabled",
+ BudgetTokens: 20000,
+ }
+ }
+}
+
+func buildInjectedSystemPrompt(systemPrompt string, thinking *thinkingDirective, toolChoiceHint string) string {
+ systemPrompt = strings.TrimSpace(systemPrompt)
+ timestampContext := fmt.Sprintf("[Context: Current time is %s]", time.Now().Format("2006-01-02 15:04:05 MST"))
+ if systemPrompt == "" {
+ systemPrompt = timestampContext
+ } else {
+ systemPrompt = timestampContext + "\n\n" + systemPrompt
+ }
+ if toolChoiceHint != "" {
+ if systemPrompt != "" {
+ systemPrompt += "\n"
+ }
+ systemPrompt += toolChoiceHint
+ }
+ if !strings.Contains(systemPrompt, systemChunkedWritePolicy) {
+ systemPrompt += "\n" + systemChunkedWritePolicy
+ }
+ if thinking != nil {
+ switch thinking.Mode {
+ case "adaptive":
+ effort := strings.TrimSpace(thinking.Effort)
+ if effort == "" {
+ effort = "high"
+ }
+ thinkingPrefix := "adaptive\n" + effort + ""
+ return thinkingPrefix + "\n\n" + systemPrompt
+ default:
+ budget := thinking.BudgetTokens
+ if budget <= 0 {
+ budget = 16000
+ }
+ thinkingPrefix := "enabled\n" + strconv.Itoa(budget) + ""
+ return thinkingPrefix + "\n\n" + systemPrompt
+ }
+ }
+ return systemPrompt
+}
+
+func extractClaudeToolChoiceHint(claudeBody []byte, requestCtx *KiroRequestContext) string {
+ toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
+ if !toolChoice.Exists() {
+ return ""
+ }
+
+ if toolChoice.Type == gjson.String {
+ switch strings.ToLower(strings.TrimSpace(toolChoice.String())) {
+ case "none":
+ return "[INSTRUCTION: Do not use any tools. Respond with text only.]"
+ case "auto", "":
+ return ""
+ }
+ }
+
+ switch strings.ToLower(strings.TrimSpace(toolChoice.Get("type").String())) {
+ case "any":
+ return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
+ case "tool":
+ toolName := mapKiroToolName(toolChoice.Get("name").String(), requestCtx)
+ if toolName != "" {
+ return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
+ }
+ case "none":
+ return "[INSTRUCTION: Do not use any tools. Respond with text only.]"
+ }
+
+ return ""
+}
+
+func isToolChoiceNone(claudeBody []byte) bool {
+ toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
+ if !toolChoice.Exists() {
+ return false
+ }
+ if toolChoice.Type == gjson.String {
+ return strings.EqualFold(strings.TrimSpace(toolChoice.String()), "none")
+ }
+ return strings.EqualFold(strings.TrimSpace(toolChoice.Get("type").String()), "none")
+}
+
+func kiroToolNameAlias(name string) string {
+ return mapKiroToolName(name, nil)
+}
+
+func prependSystemHistory(history []KiroHistoryMessage, systemPrompt, modelID, origin string) []KiroHistoryMessage {
+ systemPrompt = strings.TrimSpace(systemPrompt)
+ if systemPrompt == "" {
+ return history
+ }
+
+ prefix := []KiroHistoryMessage{
+ {
+ UserInputMessage: &KiroUserInputMessage{
+ Content: systemPrompt,
+ ModelID: modelID,
+ Origin: origin,
+ },
+ },
+ {
+ AssistantResponseMessage: &KiroAssistantResponseMessage{
+ Content: "I will follow these instructions.",
+ },
+ },
+ }
+
+ return append(prefix, history...)
+}
+
+func normalizeOrigin(origin string) string {
+ switch origin {
+ case "KIRO_CLI", "AMAZON_Q":
+ return "CLI"
+ case "KIRO_AI_EDITOR", "KIRO_IDE", "":
+ return "AI_EDITOR"
+ default:
+ return origin
+ }
+}
+
+func extractMetadataFromMessages(messages gjson.Result, key string) string {
+ arr := messages.Array()
+ for i := len(arr) - 1; i >= 0; i-- {
+ if val := arr[i].Get("additional_kwargs." + key); val.Exists() && val.String() != "" {
+ return val.String()
+ }
+ }
+ return ""
+}
+
+func convertClaudeToolsToKiro(tools gjson.Result, requestCtx *KiroRequestContext) []KiroToolWrapper {
+ if !tools.IsArray() {
+ return nil
+ }
+ var out []KiroToolWrapper
+ for _, tool := range tools.Array() {
+ originalName := tool.Get("name").String()
+ if strings.TrimSpace(originalName) == "" {
+ originalName = tool.Get("type").String()
+ }
+ isWebSearch := strings.TrimSpace(originalName) == "web_search"
+ name := mapKiroToolName(originalName, requestCtx)
+ description := strings.TrimSpace(tool.Get("description").String())
+ if isWebSearch {
+ if cached := GetCachedWebSearchDescription(); cached != "" {
+ description = cached
+ } else {
+ description = remoteWebSearchDescription
+ }
+ }
+ if description == "" {
+ description = "Tool: " + name
+ }
+ description = appendChunkedToolDescription(originalName, description)
+ description = truncateKiroToolDescription(description)
+ inputSchema := normalizeKiroJSONSchema(tool.Get("input_schema").Value())
+ out = append(out, KiroToolWrapper{
+ ToolSpecification: KiroToolSpecification{
+ Name: name,
+ Description: description,
+ InputSchema: KiroInputSchema{JSON: inputSchema},
+ },
+ })
+ }
+ return out
+}
+
+func appendChunkedToolDescription(name, description string) string {
+ suffix := chunkedToolDescriptionSuffix(name)
+ if suffix == "" {
+ return description
+ }
+ if strings.Contains(description, suffix) {
+ description = strings.Replace(description, suffix, "", 1)
+ }
+ if strings.TrimSpace(description) == "" {
+ return suffix
+ }
+ base := strings.TrimRight(description, "\n")
+ joined := base + "\n" + suffix
+ if len(joined) <= kiroMaxToolDescLen {
+ return joined
+ }
+ const truncationMarker = "... (description truncated)"
+ baseLimit := kiroMaxToolDescLen - len(suffix) - 1 - len(truncationMarker)
+ if baseLimit <= 0 {
+ return truncateKiroToolDescription(joined)
+ }
+ return truncateUTF8(base, baseLimit) + truncationMarker + "\n" + suffix
+}
+
+func chunkedToolDescriptionSuffix(name string) string {
+ switch strings.ToLower(strings.TrimSpace(name)) {
+ case "write", "write_to_file", "fswrite", "create_file":
+ return writeToolDescriptionSuffix
+ case "edit", "edit_file", "str_replace_editor", "apply_diff":
+ return editToolDescriptionSuffix
+ default:
+ return ""
+ }
+}
+
+func truncateKiroToolDescription(description string) string {
+ if len(description) <= kiroMaxToolDescLen {
+ return description
+ }
+ return truncateUTF8(description, kiroMaxToolDescLen-30) + "... (description truncated)"
+}
+
+func truncateUTF8(s string, limit int) string {
+ if limit <= 0 {
+ return ""
+ }
+ if len(s) <= limit {
+ return s
+ }
+ for limit > 0 && !utf8.RuneStart(s[limit]) {
+ limit--
+ }
+ return s[:limit]
+}
+
+func shortenToolNameIfNeeded(name string) string {
+ name = strings.TrimSpace(name)
+ if len(name) <= kiroMaxToolNameLen {
+ return name
+ }
+ sum := sha256.Sum256([]byte(name))
+ suffix := fmt.Sprintf("%x", sum[:])[:8]
+ prefixLen := kiroMaxToolNameLen - 1 - len(suffix)
+ prefix := name
+ if len(prefix) > prefixLen {
+ prefix = prefix[:prefixLen]
+ for len(prefix) > 0 && !utf8.ValidString(prefix) {
+ prefix = prefix[:len(prefix)-1]
+ }
+ }
+ return prefix + "_" + suffix
+}
+
+func mapKiroToolName(name string, requestCtx *KiroRequestContext) string {
+ name = strings.TrimSpace(name)
+ if name == "" {
+ return ""
+ }
+ if name == "web_search" {
+ return "remote_web_search"
+ }
+ short := shortenToolNameIfNeeded(name)
+ if short != name && requestCtx != nil {
+ if requestCtx.ToolNameMap == nil {
+ requestCtx.ToolNameMap = make(map[string]string)
+ }
+ requestCtx.ToolNameMap[short] = name
+ }
+ return short
+}
+
+func normalizeKiroJSONSchema(schema any) any {
+ return normalizeKiroJSONSchemaValue(schema, true)
+}
+
+func normalizeKiroJSONSchemaValue(schema any, enforceObjectKeywords bool) any {
+ obj, ok := schema.(map[string]interface{})
+ if !ok || obj == nil {
+ return defaultKiroJSONSchema()
+ }
+ normalized := make(map[string]interface{}, len(obj)+4)
+ for key, value := range obj {
+ normalized[key] = normalizeSchemaChild(key, value)
+ }
+ if typ, ok := normalized["type"].(string); !ok || strings.TrimSpace(typ) == "" {
+ normalized["type"] = "object"
+ }
+ typ, _ := normalized["type"].(string)
+ needsObjectKeywords := enforceObjectKeywords ||
+ strings.TrimSpace(typ) == "object" ||
+ hasSchemaKey(normalized, "properties") ||
+ hasSchemaKey(normalized, "required") ||
+ hasSchemaKey(normalized, "additionalProperties")
+ if needsObjectKeywords {
+ properties, ok := normalized["properties"].(map[string]interface{})
+ if !ok || properties == nil {
+ normalized["properties"] = map[string]interface{}{}
+ } else {
+ for key, value := range properties {
+ properties[key] = normalizeKiroJSONSchemaValue(value, false)
+ }
+ normalized["properties"] = properties
+ }
+ normalized["required"] = normalizeSchemaRequired(normalized["required"])
+ switch additional := normalized["additionalProperties"].(type) {
+ case bool:
+ case map[string]interface{}:
+ normalized["additionalProperties"] = normalizeKiroJSONSchemaValue(additional, false)
+ default:
+ normalized["additionalProperties"] = true
+ }
+ }
+ return normalized
+}
+
+func hasSchemaKey(schema map[string]interface{}, key string) bool {
+ _, ok := schema[key]
+ return ok
+}
+
+func defaultKiroJSONSchema() map[string]interface{} {
+ return map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{},
+ "required": []interface{}{},
+ "additionalProperties": true,
+ }
+}
+
+func normalizeSchemaRequired(value interface{}) []interface{} {
+ arr, ok := value.([]interface{})
+ if !ok {
+ return []interface{}{}
+ }
+ out := make([]interface{}, 0, len(arr))
+ for _, item := range arr {
+ if s, ok := item.(string); ok {
+ out = append(out, s)
+ }
+ }
+ return out
+}
+
+func normalizeSchemaChild(key string, value interface{}) interface{} {
+ switch key {
+ case "items", "not":
+ if obj, ok := value.(map[string]interface{}); ok {
+ return normalizeKiroJSONSchemaValue(obj, false)
+ }
+ if arr, ok := value.([]interface{}); ok {
+ out := make([]interface{}, 0, len(arr))
+ for _, item := range arr {
+ out = append(out, normalizeKiroJSONSchemaValue(item, false))
+ }
+ return out
+ }
+ case "oneOf", "anyOf", "allOf":
+ if arr, ok := value.([]interface{}); ok {
+ out := make([]interface{}, 0, len(arr))
+ for _, item := range arr {
+ out = append(out, normalizeKiroJSONSchemaValue(item, false))
+ }
+ return out
+ }
+ }
+ return value
+}
+
+func processMessages(messages gjson.Result, modelID, origin string, requestCtx *KiroRequestContext) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
+ messagesArray := mergeAdjacentMessages(messages.Array())
+ if len(messagesArray) > 0 && messagesArray[0].Get("role").String() == "assistant" {
+ messagesArray = append([]gjson.Result{gjson.Parse(`{"role":"user","content":"."}`)}, messagesArray...)
+ }
+
+ var history []KiroHistoryMessage
+ var currentUserMsg *KiroUserInputMessage
+ var currentToolResults []KiroToolResult
+
+ for i, msg := range messagesArray {
+ role := msg.Get("role").String()
+ last := i == len(messagesArray)-1
+ switch role {
+ case "user":
+ userMsg, toolResults := buildUserMessageStruct(msg, modelID, origin)
+ if strings.TrimSpace(userMsg.Content) == "" {
+ if len(toolResults) > 0 {
+ userMsg.Content = "Tool results provided."
+ } else {
+ userMsg.Content = "Continue"
+ }
+ }
+ if last {
+ currentUserMsg = &userMsg
+ currentToolResults = toolResults
+ } else {
+ if len(toolResults) > 0 {
+ userMsg.UserInputMessageContext = &KiroUserInputMessageContext{ToolResults: toolResults}
+ }
+ history = append(history, KiroHistoryMessage{UserInputMessage: &userMsg})
+ }
+ case "assistant":
+ assistantMsg := buildAssistantMessageStruct(msg, requestCtx)
+ if last {
+ history = append(history, KiroHistoryMessage{AssistantResponseMessage: &assistantMsg})
+ currentUserMsg = &KiroUserInputMessage{
+ Content: "Continue",
+ ModelID: modelID,
+ Origin: origin,
+ }
+ } else {
+ history = append(history, KiroHistoryMessage{AssistantResponseMessage: &assistantMsg})
+ }
+ }
+ }
+
+ return history, currentUserMsg, currentToolResults
+}
+
+func validateToolPairing(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroToolResult, map[string]bool) {
+ allToolUseIDs := make(map[string]bool)
+ pairedToolUseIDs := make(map[string]bool)
+ for _, h := range history {
+ if h.AssistantResponseMessage != nil {
+ for _, tu := range h.AssistantResponseMessage.ToolUses {
+ allToolUseIDs[tu.ToolUseID] = true
+ }
+ }
+ if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
+ for _, tr := range h.UserInputMessage.UserInputMessageContext.ToolResults {
+ pairedToolUseIDs[tr.ToolUseID] = true
+ }
+ }
+ }
+
+ filtered := currentToolResults[:0]
+ for _, tr := range currentToolResults {
+ if allToolUseIDs[tr.ToolUseID] && !pairedToolUseIDs[tr.ToolUseID] {
+ filtered = append(filtered, tr)
+ pairedToolUseIDs[tr.ToolUseID] = true
+ }
+ }
+ orphaned := make(map[string]bool)
+ for toolUseID := range allToolUseIDs {
+ if !pairedToolUseIDs[toolUseID] {
+ orphaned[toolUseID] = true
+ }
+ }
+ return filtered, orphaned
+}
+
+func removeOrphanedToolUses(history []KiroHistoryMessage, orphaned map[string]bool) {
+ if len(orphaned) == 0 {
+ return
+ }
+ for i := range history {
+ msg := history[i].AssistantResponseMessage
+ if msg == nil || len(msg.ToolUses) == 0 {
+ continue
+ }
+ filtered := msg.ToolUses[:0]
+ for _, toolUse := range msg.ToolUses {
+ if !orphaned[toolUse.ToolUseID] {
+ filtered = append(filtered, toolUse)
+ }
+ }
+ msg.ToolUses = filtered
+ }
+}
+
+func collectHistoryToolNames(history []KiroHistoryMessage) []string {
+ seen := make(map[string]bool)
+ var names []string
+ for _, h := range history {
+ if h.AssistantResponseMessage == nil {
+ continue
+ }
+ for _, tu := range h.AssistantResponseMessage.ToolUses {
+ name := strings.TrimSpace(tu.Name)
+ if name == "" {
+ continue
+ }
+ key := strings.ToLower(name)
+ if seen[key] {
+ continue
+ }
+ seen[key] = true
+ names = append(names, name)
+ }
+ }
+ return names
+}
+
+func appendMissingPlaceholderTools(tools []KiroToolWrapper, historyToolNames []string) []KiroToolWrapper {
+ if len(historyToolNames) == 0 {
+ return tools
+ }
+ seen := make(map[string]bool)
+ for _, tool := range tools {
+ seen[strings.ToLower(strings.TrimSpace(tool.ToolSpecification.Name))] = true
+ }
+ for _, name := range historyToolNames {
+ key := strings.ToLower(strings.TrimSpace(name))
+ if key == "" || seen[key] {
+ continue
+ }
+ seen[key] = true
+ tools = append(tools, KiroToolWrapper{
+ ToolSpecification: KiroToolSpecification{
+ Name: name,
+ Description: "Tool used in conversation history",
+ InputSchema: KiroInputSchema{JSON: normalizeKiroJSONSchema(nil)},
+ },
+ })
+ }
+ return tools
+}
+
+func buildFinalContent(content string, toolResults []KiroToolResult) string {
+ if strings.TrimSpace(content) == "" {
+ if len(toolResults) > 0 {
+ return "Tool results provided."
+ }
+ return "Continue"
+ }
+ return content
+}
+
+func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
+ seen := make(map[string]bool)
+ out := make([]KiroToolResult, 0, len(toolResults))
+ for _, tr := range toolResults {
+ if seen[tr.ToolUseID] {
+ continue
+ }
+ seen[tr.ToolUseID] = true
+ out = append(out, tr)
+ }
+ return out
+}
+
+func buildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
+ content := msg.Get("content")
+ var contentBuilder strings.Builder
+ var toolResults []KiroToolResult
+ var images []KiroImage
+ seenToolUseIDs := make(map[string]bool)
+
+ if content.IsArray() {
+ for _, part := range content.Array() {
+ switch part.Get("type").String() {
+ case "text":
+ _, _ = contentBuilder.WriteString(part.Get("text").String())
+ case "image":
+ mediaType := part.Get("source.media_type").String()
+ data := part.Get("source.data").String()
+ format := ""
+ if idx := strings.LastIndex(mediaType, "/"); idx != -1 {
+ format = mediaType[idx+1:]
+ }
+ if format != "" && data != "" {
+ images = append(images, KiroImage{
+ Format: format,
+ Source: KiroImageSource{Bytes: data},
+ })
+ }
+ case "tool_result":
+ toolUseID := part.Get("tool_use_id").String()
+ if toolUseID == "" || seenToolUseIDs[toolUseID] {
+ continue
+ }
+ seenToolUseIDs[toolUseID] = true
+ status := "success"
+ if part.Get("is_error").Bool() {
+ status = "error"
+ }
+ textContents := []KiroTextContent{{Text: "Tool use was cancelled by the user"}}
+ resultContent := part.Get("content")
+ if resultContent.IsArray() {
+ textContents = textContents[:0]
+ for _, item := range resultContent.Array() {
+ if item.Get("type").String() == "text" {
+ textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
+ } else if item.Type == gjson.String {
+ textContents = append(textContents, KiroTextContent{Text: item.String()})
+ }
+ }
+ } else if resultContent.Type == gjson.String {
+ textContents = []KiroTextContent{{Text: resultContent.String()}}
+ }
+ toolResults = append(toolResults, KiroToolResult{
+ ToolUseID: toolUseID,
+ Content: textContents,
+ Status: status,
+ })
+ }
+ }
+ } else {
+ _, _ = contentBuilder.WriteString(content.String())
+ }
+
+ userMsg := KiroUserInputMessage{
+ Content: contentBuilder.String(),
+ ModelID: modelID,
+ Origin: origin,
+ }
+ if len(images) > 0 {
+ userMsg.Images = images
+ }
+ return userMsg, toolResults
+}
+
+func buildAssistantMessageStruct(msg gjson.Result, requestCtx *KiroRequestContext) KiroAssistantResponseMessage {
+ content := msg.Get("content")
+ var contentBuilder strings.Builder
+ var toolUses []KiroToolUse
+
+ if content.IsArray() {
+ for _, part := range content.Array() {
+ switch part.Get("type").String() {
+ case "text":
+ _, _ = contentBuilder.WriteString(part.Get("text").String())
+ case "tool_use":
+ toolName := mapKiroToolName(part.Get("name").String(), requestCtx)
+ input := map[string]interface{}{}
+ toolInput := part.Get("input")
+ if toolInput.IsObject() {
+ toolInput.ForEach(func(key, value gjson.Result) bool {
+ input[key.String()] = value.Value()
+ return true
+ })
+ }
+ toolUses = append(toolUses, KiroToolUse{
+ ToolUseID: part.Get("id").String(),
+ Name: toolName,
+ Input: input,
+ })
+ }
+ }
+ } else {
+ _, _ = contentBuilder.WriteString(content.String())
+ }
+
+ finalContent := contentBuilder.String()
+ if strings.TrimSpace(finalContent) == "" {
+ finalContent = " "
+ }
+ return KiroAssistantResponseMessage{
+ Content: finalContent,
+ ToolUses: toolUses,
+ }
+}
+
+func mergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
+ if len(messages) <= 1 {
+ return messages
+ }
+ var merged []gjson.Result
+ for _, msg := range messages {
+ if len(merged) == 0 {
+ merged = append(merged, msg)
+ continue
+ }
+ lastMsg := merged[len(merged)-1]
+ role := msg.Get("role").String()
+ lastRole := lastMsg.Get("role").String()
+ if role == "tool" || lastRole == "tool" || role != lastRole {
+ merged = append(merged, msg)
+ continue
+ }
+ mergedMsg := map[string]interface{}{
+ "role": role,
+ "content": json.RawMessage(mergeMessageContent(lastMsg, msg)),
+ }
+ encoded, _ := json.Marshal(mergedMsg)
+ merged[len(merged)-1] = gjson.ParseBytes(encoded)
+ }
+ return merged
+}
+
+func mergeMessageContent(msg1, msg2 gjson.Result) string {
+ var blocks1, blocks2 []map[string]interface{}
+ content1 := msg1.Get("content")
+ content2 := msg2.Get("content")
+ if content1.IsArray() {
+ for _, block := range content1.Array() {
+ blocks1 = append(blocks1, blockToMap(block))
+ }
+ } else if content1.Type == gjson.String {
+ blocks1 = append(blocks1, map[string]interface{}{"type": "text", "text": content1.String()})
+ }
+ if content2.IsArray() {
+ for _, block := range content2.Array() {
+ blocks2 = append(blocks2, blockToMap(block))
+ }
+ } else if content2.Type == gjson.String {
+ blocks2 = append(blocks2, map[string]interface{}{"type": "text", "text": content2.String()})
+ }
+ if len(blocks1) > 0 && len(blocks2) > 0 && blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" {
+ leftText, leftOK := blocks1[len(blocks1)-1]["text"].(string)
+ rightText, rightOK := blocks2[0]["text"].(string)
+ if leftOK && rightOK {
+ blocks1[len(blocks1)-1]["text"] = leftText + "\n\n" + rightText
+ blocks2 = blocks2[1:]
+ }
+ }
+ allBlocks := append(blocks1, blocks2...)
+ result, _ := json.Marshal(allBlocks)
+ return string(result)
+}
+
+func blockToMap(block gjson.Result) map[string]interface{} {
+ result := make(map[string]interface{})
+ block.ForEach(func(key, value gjson.Result) bool {
+ if value.IsObject() {
+ result[key.String()] = blockToMap(value)
+ } else if value.IsArray() {
+ var arr []interface{}
+ for _, item := range value.Array() {
+ if item.IsObject() {
+ arr = append(arr, blockToMap(item))
+ } else {
+ arr = append(arr, item.Value())
+ }
+ }
+ result[key.String()] = arr
+ } else {
+ result[key.String()] = value.Value()
+ }
+ return true
+ })
+ return result
+}
+
+func parseEventStream(body io.Reader) (string, []KiroToolUse, Usage, string, error) {
+ reader := bufio.NewReader(body)
+ var content strings.Builder
+ var toolUses []KiroToolUse
+ var usage Usage
+ stopReason := ""
+ processedIDs := make(map[string]bool)
+ var currentTool *toolUseState
+
+ for {
+ msg, err := readEventStreamMessage(reader)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return "", nil, usage, stopReason, err
+ }
+ if msg == nil || len(msg.Payload) == 0 {
+ continue
+ }
+
+ var event map[string]interface{}
+ if err := json.Unmarshal(msg.Payload, &event); err != nil {
+ continue
+ }
+ if sr := readStopReason(event); sr != "" {
+ stopReason = sr
+ }
+ switch msg.EventType {
+ case "assistantResponseEvent":
+ assistant := nestedEvent(event, "assistantResponseEvent")
+ if text := getString(assistant, "content"); text != "" {
+ _, _ = content.WriteString(text)
+ } else if text := getString(event, "content"); text != "" {
+ _, _ = content.WriteString(text)
+ }
+ if sr := readStopReason(assistant); sr != "" {
+ stopReason = sr
+ }
+ for _, tool := range readToolUses(assistant, event) {
+ if processedIDs[tool.ToolUseID] {
+ continue
+ }
+ processedIDs[tool.ToolUseID] = true
+ toolUses = append(toolUses, tool)
+ }
+ case "toolUseEvent":
+ completed, next := processToolUseEvent(event, currentTool, processedIDs)
+ currentTool = next
+ toolUses = append(toolUses, completed...)
+ case "reasoningContentEvent":
+ reasoning := nestedEvent(event, "reasoningContentEvent")
+ text := getString(reasoning, "text")
+ if text == "" {
+ text = getString(event, "text")
+ }
+ if text != "" {
+ _, _ = content.WriteString(thinkingStartTag)
+ _, _ = content.WriteString(text)
+ _, _ = content.WriteString(thinkingEndTag)
+ }
+ default:
+ updateUsageFromEvent(&usage, msg.EventType, event)
+ }
+ }
+
+ if currentTool != nil && currentTool.ToolUseID != "" && !processedIDs[currentTool.ToolUseID] {
+ completed, _ := processToolUseEvent(map[string]interface{}{
+ "toolUseEvent": map[string]interface{}{
+ "toolUseId": currentTool.ToolUseID,
+ "name": currentTool.Name,
+ "stop": true,
+ "input": currentTool.InputBuffer.String(),
+ },
+ }, currentTool, processedIDs)
+ toolUses = append(toolUses, completed...)
+ }
+ cleanText, embeddedToolUses, _ := drainEmbeddedToolText(content.String())
+ toolUses = append(toolUses, embeddedToolUses...)
+ toolUses = deduplicateToolUses(toolUses)
+
+ if usage.TotalTokens == 0 {
+ usage.TotalTokens = usage.InputTokens + usage.OutputTokens
+ }
+ if stopReason == "" {
+ if hasUsableToolUses(toolUses) {
+ stopReason = "tool_use"
+ } else {
+ stopReason = "end_turn"
+ }
+ }
+ return cleanText, toolUses, usage, stopReason, nil
+}
+
+func buildClaudeResponse(content string, toolUses []KiroToolUse, model string, usage Usage, stopReason string, requestCtx KiroRequestContext) []byte {
+ var blocks []map[string]interface{}
+ blocks = append(blocks, extractThinkingBlocks(content)...)
+ usableTools := 0
+ for _, tool := range toolUses {
+ if tool.IsTruncated {
+ continue
+ }
+ usableTools++
+ blocks = append(blocks, map[string]interface{}{
+ "type": "tool_use",
+ "id": tool.ToolUseID,
+ "name": restoreResponseToolName(tool.Name, requestCtx),
+ "input": tool.Input,
+ })
+ }
+ pureThinking := hasThinkingBlocksOnly(blocks) && usableTools == 0
+ if pureThinking {
+ blocks = append(blocks, map[string]interface{}{"type": "text", "text": ""})
+ stopReason = "max_tokens"
+ }
+ if len(blocks) == 0 {
+ blocks = append(blocks, map[string]interface{}{"type": "text", "text": ""})
+ }
+ if stopReason == "" {
+ if usableTools > 0 {
+ stopReason = "tool_use"
+ } else {
+ stopReason = "end_turn"
+ }
+ }
+ response := map[string]interface{}{
+ "id": "msg_" + uuid.NewString()[:24],
+ "type": "message",
+ "role": "assistant",
+ "model": model,
+ "content": blocks,
+ "stop_reason": stopReason,
+ "usage": map[string]interface{}{
+ "input_tokens": usage.InputTokens,
+ "output_tokens": usage.OutputTokens,
+ "cache_read_input_tokens": usage.CacheReadInputTokens,
+ },
+ }
+ result, _ := json.Marshal(response)
+ return result
+}
+
+func restoreResponseToolName(name string, requestCtx KiroRequestContext) string {
+ name = strings.TrimSpace(name)
+ if requestCtx.ToolNameMap == nil {
+ return name
+ }
+ if original := strings.TrimSpace(requestCtx.ToolNameMap[name]); original != "" {
+ return original
+ }
+ return name
+}
+
+func hasThinkingBlocksOnly(blocks []map[string]interface{}) bool {
+ if len(blocks) == 0 {
+ return false
+ }
+ hasThinking := false
+ for _, block := range blocks {
+ blockType, _ := block["type"].(string)
+ switch blockType {
+ case "thinking":
+ hasThinking = true
+ case "text":
+ return false
+ default:
+ return false
+ }
+ }
+ return hasThinking
+}
+
+func extractThinkingBlocks(content string) []map[string]interface{} {
+ if content == "" {
+ return nil
+ }
+ if findRealThinkingStartTag(content, 0) == -1 {
+ return []map[string]interface{}{{"type": "text", "text": content}}
+ }
+ var blocks []map[string]interface{}
+ pos := 0
+ for pos < len(content) {
+ start := findRealThinkingStartTag(content, pos)
+ if start == -1 {
+ if text := content[pos:]; strings.TrimSpace(text) != "" {
+ blocks = append(blocks, map[string]interface{}{"type": "text", "text": text})
+ }
+ break
+ }
+ end := findRealThinkingEndTag(content, start+len(thinkingStartTag))
+ if end == -1 {
+ if text := content[pos:]; strings.TrimSpace(text) != "" {
+ blocks = append(blocks, map[string]interface{}{"type": "text", "text": text})
+ }
+ break
+ }
+ if text := content[pos:start]; strings.TrimSpace(text) != "" {
+ blocks = append(blocks, map[string]interface{}{"type": "text", "text": text})
+ }
+ thinking := strings.TrimPrefix(content[start+len(thinkingStartTag):end], "\n")
+ if strings.TrimSpace(thinking) != "" {
+ blocks = append(blocks, map[string]interface{}{
+ "type": "thinking",
+ "thinking": thinking,
+ "signature": thinkingSignature(thinking),
+ })
+ }
+ pos = end + len(thinkingEndTag)
+ if strings.HasPrefix(content[pos:], "\n\n") {
+ pos += len("\n\n")
+ }
+ }
+ if len(blocks) == 0 {
+ blocks = append(blocks, map[string]interface{}{"type": "text", "text": ""})
+ }
+ return blocks
+}
+
+func findRealThinkingStartTag(content string, from int) int {
+ return findRealThinkingTag(content, thinkingStartTag, from, false)
+}
+
+func findRealThinkingEndTag(content string, from int) int {
+ searchFrom := from
+ for {
+ pos := findRealThinkingTag(content, thinkingEndTag, searchFrom, true)
+ if pos == -1 {
+ return -1
+ }
+ after := pos + len(thinkingEndTag)
+ if strings.HasPrefix(content[after:], "\n\n") || strings.TrimSpace(content[after:]) == "" {
+ return pos
+ }
+ searchFrom = pos + 1
+ }
+}
+
+func findStreamThinkingEndTagStrict(content string, from int) int {
+ searchFrom := from
+ for {
+ pos := findRealThinkingTag(content, thinkingEndTag, searchFrom, true)
+ if pos == -1 {
+ return -1
+ }
+ after := pos + len(thinkingEndTag)
+ if strings.HasPrefix(content[after:], "\n\n") {
+ return pos
+ }
+ searchFrom = pos + 1
+ }
+}
+
+func findStreamThinkingEndTagAtBufferEnd(content string, from int) int {
+ searchFrom := from
+ for {
+ pos := findRealThinkingTag(content, thinkingEndTag, searchFrom, true)
+ if pos == -1 {
+ return -1
+ }
+ after := pos + len(thinkingEndTag)
+ if strings.TrimSpace(content[after:]) == "" {
+ return pos
+ }
+ searchFrom = pos + 1
+ }
+}
+
+func safeThinkingStreamFlushLen(content string, keepBytes int) int {
+ if keepBytes <= 0 || len(content) <= keepBytes {
+ return 0
+ }
+ pos := len(content) - keepBytes
+ for pos > 0 && !utf8.ValidString(content[:pos]) {
+ pos--
+ }
+ for pos > 0 && !utf8.RuneStart(content[pos]) {
+ pos--
+ }
+ return pos
+}
+
+func findRealThinkingTag(content, tag string, from int, allowEndBoundary bool) int {
+ if from < 0 {
+ from = 0
+ }
+ searchFrom := from
+ for searchFrom < len(content) {
+ rel := strings.Index(content[searchFrom:], tag)
+ if rel == -1 {
+ return -1
+ }
+ pos := searchFrom + rel
+ after := pos + len(tag)
+ if !isThinkingTagQuoted(content, pos, after) &&
+ !isInsideMarkdownFence(content, pos) &&
+ !isLineBlockQuote(content, pos) &&
+ (!allowEndBoundary || after <= len(content)) {
+ return pos
+ }
+ searchFrom = pos + 1
+ }
+ return -1
+}
+
+func isThinkingTagQuoted(content string, start, after int) bool {
+ if start > 0 && isThinkingQuoteChar(content[start-1]) {
+ return true
+ }
+ return after < len(content) && isThinkingQuoteChar(content[after])
+}
+
+func isThinkingQuoteChar(ch byte) bool {
+ switch ch {
+ case '`', '"', '\'', '\\':
+ return true
+ default:
+ return false
+ }
+}
+
+func isInsideMarkdownFence(content string, pos int) bool {
+ inFence := false
+ lineStart := 0
+ for lineStart < pos {
+ lineEnd := strings.IndexByte(content[lineStart:], '\n')
+ if lineEnd == -1 {
+ lineEnd = len(content)
+ } else {
+ lineEnd += lineStart
+ }
+ line := strings.TrimSpace(content[lineStart:lineEnd])
+ if strings.HasPrefix(line, "```") || strings.HasPrefix(line, "~~~") {
+ inFence = !inFence
+ }
+ lineStart = lineEnd + 1
+ }
+ return inFence
+}
+
+func isLineBlockQuote(content string, pos int) bool {
+ lineStart := strings.LastIndexByte(content[:pos], '\n') + 1
+ return strings.HasPrefix(strings.TrimLeftFunc(content[lineStart:pos], unicode.IsSpace), ">")
+}
+
+func thinkingSignature(content string) string {
+ if content == "" {
+ return ""
+ }
+ sum := sha256.Sum256([]byte(content))
+ return base64.StdEncoding.EncodeToString(sum[:])
+}
+
+func readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, error) {
+ prelude := make([]byte, 12)
+ _, err := io.ReadFull(reader, prelude)
+ if err != nil {
+ return nil, err
+ }
+ totalLength := binary.BigEndian.Uint32(prelude[0:4])
+ headersLength := binary.BigEndian.Uint32(prelude[4:8])
+ if totalLength < minFrameSize || totalLength > maxEventMsgSize {
+ return nil, fmt.Errorf("invalid kiro eventstream frame length: %d", totalLength)
+ }
+ if headersLength > totalLength-16 {
+ return nil, fmt.Errorf("invalid kiro eventstream headers length: %d", headersLength)
+ }
+ remaining := make([]byte, totalLength-12)
+ if _, err := io.ReadFull(reader, remaining); err != nil {
+ return nil, err
+ }
+ eventType := extractEventType(remaining[:headersLength])
+ payloadStart := headersLength
+ payloadEnd := uint32(len(remaining)) - 4
+ if payloadStart >= payloadEnd {
+ return &eventStreamMessage{EventType: eventType}, nil
+ }
+ return &eventStreamMessage{
+ EventType: eventType,
+ Payload: remaining[payloadStart:payloadEnd],
+ }, nil
+}
+
+func extractEventType(headers []byte) string {
+ offset := 0
+ for offset < len(headers) {
+ nameLen := int(headers[offset])
+ offset++
+ if offset+nameLen > len(headers) {
+ break
+ }
+ name := string(headers[offset : offset+nameLen])
+ offset += nameLen
+ if offset >= len(headers) {
+ break
+ }
+ valueType := headers[offset]
+ offset++
+ if valueType == 7 {
+ if offset+2 > len(headers) {
+ break
+ }
+ valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2]))
+ offset += 2
+ if offset+valueLen > len(headers) {
+ break
+ }
+ value := string(headers[offset : offset+valueLen])
+ offset += valueLen
+ if name == ":event-type" {
+ return value
+ }
+ continue
+ }
+ next, ok := skipHeaderValue(headers, offset, valueType)
+ if !ok {
+ break
+ }
+ offset = next
+ }
+ return ""
+}
+
+func skipHeaderValue(headers []byte, offset int, valueType byte) (int, bool) {
+ switch valueType {
+ case 0, 1:
+ return offset, true
+ case 2:
+ if offset+1 > len(headers) {
+ return offset, false
+ }
+ return offset + 1, true
+ case 3:
+ if offset+2 > len(headers) {
+ return offset, false
+ }
+ return offset + 2, true
+ case 4:
+ if offset+4 > len(headers) {
+ return offset, false
+ }
+ return offset + 4, true
+ case 5, 8:
+ if offset+8 > len(headers) {
+ return offset, false
+ }
+ return offset + 8, true
+ case 6:
+ if offset+2 > len(headers) {
+ return offset, false
+ }
+ length := int(binary.BigEndian.Uint16(headers[offset : offset+2]))
+ offset += 2
+ if offset+length > len(headers) {
+ return offset, false
+ }
+ return offset + length, true
+ case 9:
+ if offset+16 > len(headers) {
+ return offset, false
+ }
+ return offset + 16, true
+ default:
+ return offset, false
+ }
+}
+
+func processToolUseEvent(event map[string]interface{}, currentTool *toolUseState, processedIDs map[string]bool) ([]KiroToolUse, *toolUseState) {
+ tu := nestedEvent(event, "toolUseEvent")
+ toolUseID := getString(tu, "toolUseId")
+ name := getString(tu, "name")
+ isStop, _ := tu["stop"].(bool)
+
+ var inputFragment string
+ var inputMap map[string]interface{}
+ if inputRaw, ok := tu["input"]; ok {
+ switch v := inputRaw.(type) {
+ case string:
+ inputFragment = v
+ case map[string]interface{}:
+ inputMap = v
+ }
+ }
+
+ if toolUseID != "" && name != "" {
+ if currentTool == nil || currentTool.ToolUseID != toolUseID {
+ if processedIDs[toolUseID] {
+ return nil, currentTool
+ }
+ currentTool = &toolUseState{ToolUseID: toolUseID, Name: name}
+ }
+ }
+ if currentTool != nil && inputFragment != "" {
+ _, _ = currentTool.InputBuffer.WriteString(inputFragment)
+ }
+ if currentTool != nil && inputMap != nil {
+ currentTool.InputBuffer.Reset()
+ encoded, _ := json.Marshal(inputMap)
+ _, _ = currentTool.InputBuffer.Write(encoded)
+ }
+ if !isStop || currentTool == nil {
+ return nil, currentTool
+ }
+ processedIDs[currentTool.ToolUseID] = true
+ return []KiroToolUse{finalizeRawToolUse(currentTool.ToolUseID, currentTool.Name, currentTool.InputBuffer.String())}, nil
+}
+
+func repairJSON(input string) string {
+ str := strings.TrimSpace(input)
+ if str == "" {
+ return "{}"
+ }
+ var parsed interface{}
+ if err := json.Unmarshal([]byte(str), &parsed); err == nil {
+ return str
+ }
+ str = escapeControlCharsInStrings(str)
+ str = trailingCommaPattern.ReplaceAllString(str, "$1")
+ openBraces, openBrackets, inString := jsonBalance(str)
+ if inString {
+ str += `"`
+ openBraces, openBrackets, _ = jsonBalance(str)
+ }
+ if openBraces > 0 {
+ str += strings.Repeat("}", openBraces)
+ }
+ if openBrackets > 0 {
+ str += strings.Repeat("]", openBrackets)
+ }
+ if err := json.Unmarshal([]byte(str), &parsed); err != nil {
+ return strings.TrimSpace(input)
+ }
+ return str
+}
+
+func escapeControlCharsInStrings(input string) string {
+ var out strings.Builder
+ inString := false
+ escape := false
+ for i := 0; i < len(input); i++ {
+ ch := input[i]
+ if escape {
+ _ = out.WriteByte(ch)
+ escape = false
+ continue
+ }
+ if ch == '\\' {
+ _ = out.WriteByte(ch)
+ escape = true
+ continue
+ }
+ if ch == '"' {
+ inString = !inString
+ _ = out.WriteByte(ch)
+ continue
+ }
+ if inString {
+ switch ch {
+ case '\n':
+ _, _ = out.WriteString("\\n")
+ continue
+ case '\r':
+ _, _ = out.WriteString("\\r")
+ continue
+ case '\t':
+ _, _ = out.WriteString("\\t")
+ continue
+ }
+ }
+ _ = out.WriteByte(ch)
+ }
+ return out.String()
+}
+
+func jsonBalance(input string) (openBraces int, openBrackets int, inString bool) {
+ escape := false
+ for i := 0; i < len(input); i++ {
+ ch := input[i]
+ if escape {
+ escape = false
+ continue
+ }
+ if ch == '\\' {
+ escape = true
+ continue
+ }
+ if ch == '"' {
+ inString = !inString
+ continue
+ }
+ if inString {
+ continue
+ }
+ switch ch {
+ case '{':
+ openBraces++
+ case '}':
+ openBraces--
+ case '[':
+ openBrackets++
+ case ']':
+ openBrackets--
+ }
+ }
+ return openBraces, openBrackets, inString
+}
+
+func finalizeRawToolUse(toolUseID, name, rawInput string) KiroToolUse {
+ tool := KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: normalizeResponseToolName(name),
+ Input: map[string]interface{}{},
+ }
+ rawInput = strings.TrimSpace(rawInput)
+ tool.TruncatedRaw = rawInput
+ repaired := repairJSON(rawInput)
+ if strings.TrimSpace(repaired) != "" {
+ _ = json.Unmarshal([]byte(repaired), &tool.Input)
+ }
+ tool.IsTruncated = isTruncatedToolUse(tool.Name, rawInput, tool.Input)
+ return tool
+}
+
+func finalizeStructuredToolUse(toolUseID, name string, input map[string]interface{}) KiroToolUse {
+ if input == nil {
+ input = map[string]interface{}{}
+ }
+ tool := KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: normalizeResponseToolName(name),
+ Input: input,
+ }
+ tool.IsTruncated = hasMissingRequiredFields(tool.Name, tool.Input)
+ return tool
+}
+
+func normalizeResponseToolName(name string) string {
+ name = strings.TrimSpace(name)
+ if name == "web_search" {
+ return "remote_web_search"
+ }
+ return name
+}
+
+func shouldEmitToolUse(tool KiroToolUse, emittedToolContents map[string]bool) bool {
+ if tool.IsTruncated {
+ return false
+ }
+ key := toolUseContentKey(tool)
+ if key == "" {
+ return false
+ }
+ if emittedToolContents[key] {
+ return false
+ }
+ emittedToolContents[key] = true
+ return true
+}
+
+func hasUsableToolUses(toolUses []KiroToolUse) bool {
+ for _, tool := range toolUses {
+ if !tool.IsTruncated {
+ return true
+ }
+ }
+ return false
+}
+
+func deduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
+ seenIDs := make(map[string]bool)
+ seenContent := make(map[string]bool)
+ out := make([]KiroToolUse, 0, len(toolUses))
+ for _, tool := range toolUses {
+ if tool.ToolUseID != "" {
+ if seenIDs[tool.ToolUseID] {
+ continue
+ }
+ seenIDs[tool.ToolUseID] = true
+ }
+ key := toolUseContentKey(tool)
+ if key != "" && seenContent[key] {
+ continue
+ }
+ if key != "" {
+ seenContent[key] = true
+ }
+ out = append(out, tool)
+ }
+ return out
+}
+
+func toolUseContentKey(tool KiroToolUse) string {
+ name := strings.TrimSpace(tool.Name)
+ if name == "" {
+ return ""
+ }
+ inputJSON, _ := json.Marshal(tool.Input)
+ return name + ":" + string(inputJSON)
+}
+
+func drainEmbeddedToolText(text string) (cleanText string, toolUses []KiroToolUse, pending string) {
+ complete, pending := splitCompleteEmbeddedToolText(text)
+ if strings.TrimSpace(complete) == "" {
+ return "", nil, pending
+ }
+ cleanText, toolUses = parseEmbeddedToolCalls(complete)
+ return cleanText, deduplicateToolUses(toolUses), pending
+}
+
+func splitCompleteEmbeddedToolText(text string) (complete string, pending string) {
+ searchFrom := 0
+ for {
+ idx := strings.Index(text[searchFrom:], embeddedToolCallPrefix)
+ if idx == -1 {
+ return text, ""
+ }
+ idx += searchFrom
+ _, _, end, ok := parseEmbeddedToolCallAt(text, idx)
+ if !ok {
+ return text[:idx], text[idx:]
+ }
+ searchFrom = end
+ }
+}
+
+func parseEmbeddedToolCalls(text string) (string, []KiroToolUse) {
+ if !strings.Contains(text, embeddedToolCallPrefix) {
+ return text, nil
+ }
+ var (
+ builder strings.Builder
+ toolUses []KiroToolUse
+ index int
+ )
+ for index < len(text) {
+ start := strings.Index(text[index:], embeddedToolCallPrefix)
+ if start == -1 {
+ builder.WriteString(text[index:])
+ break
+ }
+ start += index
+ builder.WriteString(text[index:start])
+ tool, _, end, ok := parseEmbeddedToolCallAt(text, start)
+ if !ok {
+ builder.WriteString(text[start:])
+ break
+ }
+ toolUses = append(toolUses, tool)
+ index = end
+ }
+ return builder.String(), toolUses
+}
+
+func parseEmbeddedToolCallAt(text string, start int) (KiroToolUse, int, int, bool) {
+ if start < 0 || start >= len(text) || !strings.HasPrefix(text[start:], embeddedToolCallPrefix) {
+ return KiroToolUse{}, 0, 0, false
+ }
+ pos := start + len(embeddedToolCallPrefix)
+ argsMarker := " with args:"
+ argsIndex := strings.Index(text[pos:], argsMarker)
+ if argsIndex == -1 {
+ return KiroToolUse{}, 0, 0, false
+ }
+ argsIndex += pos
+ toolName := strings.TrimSpace(text[pos:argsIndex])
+ if toolName == "" {
+ return KiroToolUse{}, 0, 0, false
+ }
+ jsonStart := argsIndex + len(argsMarker)
+ for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t' || text[jsonStart] == '\n') {
+ jsonStart++
+ }
+ if jsonStart >= len(text) || text[jsonStart] != '{' {
+ return KiroToolUse{}, 0, 0, false
+ }
+ jsonEnd := findMatchingJSONBracket(text, jsonStart)
+ if jsonEnd == -1 {
+ return KiroToolUse{}, 0, 0, false
+ }
+ end := jsonEnd + 1
+ for end < len(text) && text[end] != ']' {
+ end++
+ }
+ if end >= len(text) {
+ return KiroToolUse{}, 0, 0, false
+ }
+ rawJSON := text[jsonStart : jsonEnd+1]
+ tool := finalizeRawToolUse("toolu_"+GenerateToolUseID(), toolName, rawJSON)
+ return tool, start, end + 1, true
+}
+
+func findMatchingJSONBracket(text string, start int) int {
+ depth := 0
+ inString := false
+ escape := false
+ for i := start; i < len(text); i++ {
+ ch := text[i]
+ if escape {
+ escape = false
+ continue
+ }
+ if ch == '\\' {
+ escape = true
+ continue
+ }
+ if ch == '"' {
+ inString = !inString
+ continue
+ }
+ if inString {
+ continue
+ }
+ switch ch {
+ case '{':
+ depth++
+ case '}':
+ depth--
+ if depth == 0 {
+ return i
+ }
+ }
+ }
+ return -1
+}
+
+func isTruncatedToolUse(name, rawInput string, input map[string]interface{}) bool {
+ rawInput = strings.TrimSpace(rawInput)
+ if rawInput == "" {
+ return hasToolRequirements(name)
+ }
+ if looksLikeTruncatedJSON(rawInput) {
+ return true
+ }
+ return hasMissingRequiredFields(name, input)
+}
+
+func looksLikeTruncatedJSON(raw string) bool {
+ raw = strings.TrimSpace(raw)
+ if raw == "" || raw[0] != '{' {
+ return false
+ }
+ openBraces, openBrackets, inString := jsonBalance(raw)
+ if openBraces > 0 || openBrackets > 0 || inString {
+ return true
+ }
+ last := raw[len(raw)-1]
+ return last == ':' || last == ','
+}
+
+func hasToolRequirements(name string) bool {
+ _, ok := requiredToolFields[strings.ToLower(strings.TrimSpace(name))]
+ return ok
+}
+
+func hasMissingRequiredFields(name string, input map[string]interface{}) bool {
+ groups, ok := requiredToolFields[strings.ToLower(strings.TrimSpace(name))]
+ if !ok {
+ return false
+ }
+ for _, group := range groups {
+ matched := false
+ for _, candidate := range group {
+ if _, exists := input[candidate]; exists {
+ matched = true
+ break
+ }
+ }
+ if !matched {
+ return true
+ }
+ }
+ return false
+}
+
+func updateUsageFromEvent(usage *Usage, eventType string, event map[string]interface{}) {
+ if usage == nil {
+ return
+ }
+ meta := nestedEvent(event, eventType)
+ if len(meta) == 0 {
+ meta = event
+ }
+ if tokenUsage, ok := meta["tokenUsage"].(map[string]interface{}); ok {
+ if value, ok := toInt(tokenUsage["uncachedInputTokens"]); ok {
+ usage.InputTokens = value
+ }
+ if value, ok := toInt(tokenUsage["outputTokens"]); ok {
+ usage.OutputTokens = value
+ }
+ if value, ok := toInt(tokenUsage["totalTokens"]); ok {
+ usage.TotalTokens = value
+ }
+ if value, ok := toInt(tokenUsage["cacheReadInputTokens"]); ok {
+ usage.CacheReadInputTokens = value
+ if usage.InputTokens == 0 {
+ usage.InputTokens = value
+ } else {
+ usage.InputTokens += value
+ }
+ }
+ }
+ if value, ok := toInt(event["inputTokens"]); ok && value > 0 {
+ usage.InputTokens = value
+ }
+ if value, ok := toInt(event["outputTokens"]); ok && value > 0 {
+ usage.OutputTokens = value
+ }
+ if value, ok := toInt(event["totalTokens"]); ok && value > 0 {
+ usage.TotalTokens = value
+ }
+ if value, ok := toInt(meta["inputTokens"]); ok && value > 0 {
+ usage.InputTokens = value
+ }
+ if value, ok := toInt(meta["outputTokens"]); ok && value > 0 {
+ usage.OutputTokens = value
+ }
+ if value, ok := toInt(meta["totalTokens"]); ok && value > 0 {
+ usage.TotalTokens = value
+ }
+}
+
+func readToolUses(primary, fallback map[string]interface{}) []KiroToolUse {
+ var raw []interface{}
+ if value, ok := primary["toolUses"].([]interface{}); ok {
+ raw = value
+ } else if value, ok := fallback["toolUses"].([]interface{}); ok {
+ raw = value
+ }
+ if len(raw) == 0 {
+ return nil
+ }
+ out := make([]KiroToolUse, 0, len(raw))
+ for _, item := range raw {
+ tool, ok := item.(map[string]interface{})
+ if !ok {
+ continue
+ }
+ input := map[string]interface{}{}
+ if value, ok := tool["input"].(map[string]interface{}); ok {
+ input = value
+ }
+ out = append(out, finalizeStructuredToolUse(getString(tool, "toolUseId"), getString(tool, "name"), input))
+ }
+ return out
+}
+
+func nestedEvent(event map[string]interface{}, key string) map[string]interface{} {
+ if nested, ok := event[key].(map[string]interface{}); ok {
+ return nested
+ }
+ return event
+}
+
+func getString(m map[string]interface{}, key string) string {
+ if value, ok := m[key].(string); ok {
+ return value
+ }
+ return ""
+}
+
+func readStopReason(m map[string]interface{}) string {
+ if stop := getString(m, "stop_reason"); stop != "" {
+ return stop
+ }
+ return getString(m, "stopReason")
+}
+
+func toInt(value interface{}) (int, bool) {
+ switch v := value.(type) {
+ case float64:
+ return int(v), true
+ case int:
+ return v, true
+ case int64:
+ return int(v), true
+ case json.Number:
+ n, err := v.Int64()
+ return int(n), err == nil
+ default:
+ return 0, false
+ }
+}
diff --git a/backend/internal/pkg/kiro/translator_test.go b/backend/internal/pkg/kiro/translator_test.go
new file mode 100644
index 00000000..04424ee6
--- /dev/null
+++ b/backend/internal/pkg/kiro/translator_test.go
@@ -0,0 +1,1222 @@
+package kiro
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestBuildRuntimeUserAgentStable(t *testing.T) {
+ key := BuildAccountKey("client-id", "", "", "", 1)
+ machineID := BuildMachineID("refresh-token", "", "")
+ ua1 := BuildRuntimeUserAgent(key, machineID)
+ ua2 := BuildRuntimeUserAgent(key, machineID)
+ amzUA := BuildRuntimeAmzUserAgent(key, machineID)
+
+ require.Equal(t, ua1, ua2)
+ require.Contains(t, ua1, "KiroIDE-")
+ require.Contains(t, amzUA, "KiroIDE-")
+ require.Contains(t, ua1, "KiroIDE-0.11.")
+ require.Contains(t, ua1, "aws-sdk-js/1.0.34")
+ require.Contains(t, ua1, "md/nodejs#22.22.0")
+ require.Contains(t, ua1, machineID)
+ require.Contains(t, amzUA, machineID)
+}
+
+func TestBuildKiroPayloadBasic(t *testing.T) {
+ SetCachedWebSearchDescription("")
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "system":"You are a test system prompt.",
+ "messages":[{"role":"user","content":"hello kiro"}],
+ "tools":[{"name":"web_search","description":"", "input_schema":{"type":"object","properties":{"query":{"type":"string"}}}}]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "arn:aws:codewhisperer:us-east-1:123456789012:profile/test", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ require.Equal(t, "claude-sonnet-4.5", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
+ require.Equal(t, "AI_EDITOR", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.origin").String())
+ require.Equal(t, "remote_web_search", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.name").String())
+ require.Equal(t, remoteWebSearchDescription, gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.description").String())
+ require.Equal(t, "hello kiro", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.content").String())
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "[Context: Current time is ")
+ require.Contains(t, systemContent, "You are a test system prompt.")
+ require.Equal(t, "I will follow these instructions.", gjson.GetBytes(payload, "conversationState.history.1.assistantResponseMessage.content").String())
+}
+
+func TestBuildKiroPayloadWebSearchUsesCachedDescription(t *testing.T) {
+ SetCachedWebSearchDescription("cached web search description")
+ t.Cleanup(func() { SetCachedWebSearchDescription("") })
+
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello kiro"}],
+ "tools":[{"name":"web_search","description":"caller description", "input_schema":{"type":"object","properties":{"query":{"type":"string"}}}}]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+ require.Equal(t, "remote_web_search", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.name").String())
+ require.Equal(t, "cached web search description", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.description").String())
+}
+
+func TestBuildKiroPayloadAppendsChunkedWritePolicyToWriteAndEditTools(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello"}],
+ "tools":[
+ {"name":"Write","description":"write file", "input_schema":{"type":"object"}},
+ {"name":"Edit","description":"edit file", "input_schema":{"type":"object"}},
+ {"name":"read_file","description":"read file", "input_schema":{"type":"object"}}
+ ]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ tools := gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools").Array()
+ require.Len(t, tools, 3)
+ require.Contains(t, tools[0].Get("toolSpecification.description").String(), writeToolDescriptionSuffix)
+ require.Contains(t, tools[1].Get("toolSpecification.description").String(), editToolDescriptionSuffix)
+ require.NotContains(t, tools[2].Get("toolSpecification.description").String(), "chunks of no more than 50 lines")
+}
+
+func TestBuildKiroPayloadChunkedWritePolicyIsIdempotentAndTruncated(t *testing.T) {
+ longDescription := strings.Repeat("long description ", 900) + "\n" + writeToolDescriptionSuffix
+ body := []byte(fmt.Sprintf(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello"}],
+ "tools":[{"name":"write_to_file","description":%q, "input_schema":{"type":"object"}}]
+ }`, longDescription))
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ description := gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.description").String()
+ require.LessOrEqual(t, len(description), kiroMaxToolDescLen)
+ require.Equal(t, 1, strings.Count(description, writeToolDescriptionSuffix))
+ require.Contains(t, description, writeToolDescriptionSuffix)
+}
+
+func TestBuildKiroPayloadInjectsChunkedWritePolicyIntoSystemPrompt(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "system":"Follow user instructions.",
+ "thinking":{"type":"enabled","budget_tokens":2048},
+ "messages":[{"role":"user","content":"hello"}]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "enabled")
+ require.Contains(t, systemContent, "Follow user instructions.")
+ require.Contains(t, systemContent, systemChunkedWritePolicy)
+ require.Equal(t, 1, strings.Count(systemContent, systemChunkedWritePolicy))
+}
+
+func TestBuildKiroPayloadInjectsThinkingIntoHistory(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "thinking":{"type":"enabled","budget_tokens":2048},
+ "messages":[{"role":"user","content":"hello kiro"}]
+ }`)
+
+ headers := http.Header{}
+ headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", headers)
+ require.NoError(t, err)
+
+ require.Equal(t, "hello kiro", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.content").String())
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "enabled\n2048")
+ require.Contains(t, systemContent, "[Context: Current time is ")
+ require.Equal(t, "I will follow these instructions.", gjson.GetBytes(payload, "conversationState.history.1.assistantResponseMessage.content").String())
+}
+
+func TestBuildKiroPayloadInjectsAdaptiveThinkingForOpus46ThinkingModel(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-opus-4-6-thinking",
+ "messages":[{"role":"user","content":"hello kiro"}]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-opus-4.6", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "adaptive\nhigh")
+ require.Contains(t, systemContent, "[Context: Current time is ")
+}
+
+func TestBuildKiroPayloadInjectsThinkingForThinkingAliasModel(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5-20250929-thinking",
+ "messages":[{"role":"user","content":"hello kiro"}]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "enabled\n20000")
+}
+
+func TestBuildKiroPayloadHeaderOnlyThinking(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello kiro"}]
+ }`)
+
+ headers := http.Header{}
+ headers.Set("Anthropic-Beta", "oauth-2025-04-20,interleaved-thinking-2025-05-14")
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", headers)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "enabled\n16000")
+}
+
+func TestBuildKiroPayloadInjectsToolChoiceHints(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello kiro"}],
+ "tools":[{"name":"web_search","description":"search", "input_schema":{"type":"object","properties":{"query":{"type":"string"}}}}],
+ "tool_choice":{"type":"tool","name":"web_search"}
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "MUST use the tool named 'remote_web_search'")
+}
+
+func TestBuildKiroPayloadInjectsRequiredToolChoiceHint(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello kiro"}],
+ "tools":[{"name":"web_search","description":"search", "input_schema":{"type":"object","properties":{"query":{"type":"string"}}}}],
+ "tool_choice":{"type":"any"}
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "MUST use at least one of the available tools")
+}
+
+func TestBuildKiroPayloadToolChoiceNoneOmitsTools(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello kiro"}],
+ "tools":[{"name":"web_search","description":"search", "input_schema":{"type":"object","properties":{"query":{"type":"string"}}}}],
+ "tool_choice":{"type":"none"}
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+
+ systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
+ require.Contains(t, systemContent, "Do not use any tools. Respond with text only.")
+ require.False(t, gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools").Exists())
+}
+
+func TestParseNonStreamingEventStream(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "hello from kiro",
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "messageMetadataEvent", map[string]any{
+ "messageMetadataEvent": map[string]any{
+ "tokenUsage": map[string]any{
+ "uncachedInputTokens": 12,
+ "outputTokens": 7,
+ "cacheReadInputTokens": 3,
+ "totalTokens": 22,
+ },
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "messageStopEvent", map[string]any{
+ "messageStopEvent": map[string]any{
+ "stop_reason": "end_turn",
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+ require.Equal(t, 15, result.Usage.InputTokens)
+ require.Equal(t, 7, result.Usage.OutputTokens)
+ require.Equal(t, 22, result.Usage.TotalTokens)
+
+ var response map[string]any
+ require.NoError(t, json.Unmarshal(result.ResponseBody, &response))
+ require.Equal(t, "end_turn", response["stop_reason"])
+ content, _ := response["content"].([]any)
+ require.NotEmpty(t, content)
+ first, _ := content[0].(map[string]any)
+ require.Equal(t, "text", first["type"])
+ firstText, ok := first["text"].(string)
+ require.True(t, ok)
+ require.True(t, strings.Contains(firstText, "hello from kiro"))
+}
+
+func TestExtractThinkingBlocksIgnoresLiteralTags(t *testing.T) {
+ content := strings.Join([]string{
+ "Use `` literally.",
+ "Quote \"\" and ''.",
+ "> quoted",
+ "```",
+ "code",
+ "```",
+ }, "\n")
+
+ blocks := extractThinkingBlocks(content)
+ require.Len(t, blocks, 1)
+ require.Equal(t, "text", blocks[0]["type"])
+ require.Equal(t, content, blocks[0]["text"])
+}
+
+func TestExtractThinkingBlocksParsesRealTags(t *testing.T) {
+ blocks := extractThinkingBlocks("\nreason\n\nfinal text")
+
+ require.Len(t, blocks, 2)
+ require.Equal(t, "thinking", blocks[0]["type"])
+ require.Equal(t, "reason", blocks[0]["thinking"])
+ require.NotEmpty(t, blocks[0]["signature"])
+ require.Equal(t, "text", blocks[1]["type"])
+ require.Equal(t, "final text", blocks[1]["text"])
+}
+
+func TestParseNonStreamingEventStreamPureThinkingFallback(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "reason only",
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "max_tokens", gjson.GetBytes(result.ResponseBody, "stop_reason").String())
+
+ content := gjson.GetBytes(result.ResponseBody, "content").Array()
+ require.Len(t, content, 2)
+ require.Equal(t, "thinking", content[0].Get("type").String())
+ require.Equal(t, "reason only", content[0].Get("thinking").String())
+ require.Equal(t, "text", content[1].Get("type").String())
+ require.Equal(t, "", content[1].Get("text").String())
+}
+
+func TestParseNonStreamingEventStreamThinkingWithTextKeepsEndTurn(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "reason\n\nfinal",
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", gjson.GetBytes(result.ResponseBody, "stop_reason").String())
+ require.Equal(t, "thinking", gjson.GetBytes(result.ResponseBody, "content.0.type").String())
+ require.Equal(t, "text", gjson.GetBytes(result.ResponseBody, "content.1.type").String())
+ require.Equal(t, "final", gjson.GetBytes(result.ResponseBody, "content.1.text").String())
+}
+
+func TestParseNonStreamingEventStreamThinkingWithToolUseKeepsToolUseStopReason(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "reason only",
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_search",
+ "name": "remote_web_search",
+ "input": `{"query":"golang"}`,
+ "stop": true,
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "tool_use", gjson.GetBytes(result.ResponseBody, "stop_reason").String())
+ require.Equal(t, "thinking", gjson.GetBytes(result.ResponseBody, "content.0.type").String())
+ require.Equal(t, "tool_use", gjson.GetBytes(result.ResponseBody, "content.1.type").String())
+ require.False(t, gjson.GetBytes(result.ResponseBody, "content.2.text").Exists())
+}
+
+func TestParseNonStreamingEventStreamExtractsEmbeddedToolCall(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": `Before [Called web_search with args: {"query":"golang concurrency"}] After`,
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "tool_use", result.StopReason)
+ require.NotContains(t, string(result.ResponseBody), "[Called")
+
+ content := gjson.GetBytes(result.ResponseBody, "content").Array()
+ require.Len(t, content, 2)
+ require.Equal(t, "text", content[0].Get("type").String())
+ require.Equal(t, "Before After", content[0].Get("text").String())
+ require.Equal(t, "tool_use", content[1].Get("type").String())
+ require.Equal(t, "remote_web_search", content[1].Get("name").String())
+ require.Equal(t, "golang concurrency", content[1].Get("input.query").String())
+}
+
+func TestParseNonStreamingEventStreamDeduplicatesToolUsesByContent(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "toolUses": []map[string]any{
+ {
+ "toolUseId": "toolu_first",
+ "name": "remote_web_search",
+ "input": map[string]any{
+ "query": "golang",
+ },
+ },
+ },
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_second",
+ "name": "remote_web_search",
+ "input": map[string]any{
+ "query": "golang",
+ },
+ "stop": true,
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+
+ content := gjson.GetBytes(result.ResponseBody, "content").Array()
+ toolUseCount := 0
+ for _, block := range content {
+ if block.Get("type").String() == "tool_use" {
+ toolUseCount++
+ }
+ }
+ require.Equal(t, 1, toolUseCount)
+}
+
+func TestParseNonStreamingEventStreamSkipsTruncatedToolUse(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_truncated",
+ "name": "write_to_file",
+ "input": `{"path":"main.go","content":"package main`,
+ "stop": true,
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+
+ content := gjson.GetBytes(result.ResponseBody, "content").Array()
+ require.Len(t, content, 1)
+ require.Equal(t, "text", content[0].Get("type").String())
+ require.NotContains(t, string(result.ResponseBody), `"type":"tool_use"`)
+}
+
+func TestParseNonStreamingEventStreamDropsIncompleteEmbeddedToolTail(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": `Before [Called web_search with args: {"query":"golang`,
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+ require.NotContains(t, string(result.ResponseBody), "[Called")
+ require.Equal(t, "Before ", gjson.GetBytes(result.ResponseBody, "content.0.text").String())
+}
+
+func TestParseNonStreamingEventStreamThinkingOnlyResponse(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "reasoningContentEvent", map[string]any{
+ "reasoningContentEvent": map[string]any{
+ "text": "I should think first.",
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
+ require.NoError(t, err)
+ require.Equal(t, "max_tokens", gjson.GetBytes(result.ResponseBody, "stop_reason").String())
+ require.Equal(t, "thinking", gjson.GetBytes(result.ResponseBody, "content.0.type").String())
+ require.Equal(t, "I should think first.", gjson.GetBytes(result.ResponseBody, "content.0.thinking").String())
+ require.Equal(t, "text", gjson.GetBytes(result.ResponseBody, "content.1.type").String())
+ require.Equal(t, "", gjson.GetBytes(result.ResponseBody, "content.1.text").String())
+}
+
+func TestStreamEventStreamAsAnthropicExtractsEmbeddedToolCall(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": `Before [Called web_search with args: {"query":"gol`,
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": `ang"}] After`,
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "tool_use", result.StopReason)
+
+ output := out.String()
+ require.NotContains(t, output, "[Called")
+ require.Contains(t, output, `"text":"Before "`)
+ require.Contains(t, output, `"text":" After"`)
+ require.Contains(t, output, `"name":"remote_web_search"`)
+ require.Contains(t, output, `"partial_json":"{\"query\":\"golang\"}"`)
+}
+
+func TestStreamEventStreamAsAnthropicSkipsLeadingWhitespaceOnlyChunk(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "\n",
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "Hello from Kiro",
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+
+ output := out.String()
+ require.Contains(t, output, `"text":"Hello from Kiro"`)
+ require.NotContains(t, output, `"delta":{"text":"\n","type":"text_delta"}`)
+ require.NotContains(t, output, `"delta":{"text":"","type":"text_delta"}`)
+}
+
+func TestStreamEventStreamAsAnthropicSkipsTrailingWhitespaceOnlyChunk(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "Hello from Kiro",
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "\n",
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "\n\n",
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+
+ output := out.String()
+ require.Contains(t, output, `"text":"Hello from Kiro"`)
+ require.NotContains(t, output, `"text":"\n"`)
+ require.NotContains(t, output, `"text":"\n\n"`)
+}
+
+func TestStreamEventStreamAsAnthropicDelaysMessageStartUntilContent(t *testing.T) {
+ pr, pw := io.Pipe()
+ var out bytes.Buffer
+ errCh := make(chan error, 1)
+
+ go func() {
+ _, err := StreamEventStreamAsAnthropic(context.Background(), pr, &out, "claude-sonnet-4-5", 9)
+ errCh <- err
+ }()
+
+ _, err := pw.Write(buildEventStreamFrame(t, "messageMetadataEvent", map[string]any{
+ "messageMetadataEvent": map[string]any{
+ "tokenUsage": map[string]any{
+ "uncachedInputTokens": 9,
+ },
+ },
+ }))
+ require.NoError(t, err)
+
+ time.Sleep(50 * time.Millisecond)
+ require.Empty(t, out.String())
+
+ _, err = pw.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_delayed",
+ "name": "remote_web_search",
+ "input": map[string]any{
+ "query": "golang",
+ },
+ "stop": true,
+ },
+ }))
+ require.NoError(t, err)
+ require.NoError(t, pw.Close())
+ require.NoError(t, <-errCh)
+
+ output := out.String()
+ require.Contains(t, output, "event: message_start")
+ require.Contains(t, output, `"name":"remote_web_search"`)
+ require.Contains(t, output, `"partial_json":"{\"query\":\"golang\"}`)
+ messageStartIdx := strings.Index(output, "event: message_start")
+ toolUseIdx := strings.Index(output, `"name":"remote_web_search"`)
+ require.NotEqual(t, -1, messageStartIdx)
+ require.NotEqual(t, -1, toolUseIdx)
+ require.Less(t, messageStartIdx, toolUseIdx)
+}
+
+func TestStreamEventStreamAsAnthropicStreamsToolUseFragments(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_stream",
+ "name": "write_file",
+ "input": `{"path":"/tmp/a.txt",`,
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_stream",
+ "name": "write_file",
+ "input": `"content":"hello"}`,
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_stream",
+ "name": "write_file",
+ "stop": true,
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "tool_use", result.StopReason)
+
+ output := out.String()
+ require.Equal(t, 1, strings.Count(output, `"id":"toolu_stream"`))
+ require.Contains(t, output, `"partial_json":"{\"path\":\"/tmp/a.txt\","`)
+ require.Contains(t, output, `"partial_json":"\"content\":\"hello\"}"`)
+ require.Contains(t, output, `event: content_block_stop`)
+}
+
+func TestStreamEventStreamAsAnthropicStreamsIncompleteToolUseFragment(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_incomplete",
+ "name": "write_file",
+ "input": `{"path":`,
+ "stop": true,
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "tool_use", result.StopReason)
+ require.Contains(t, out.String(), `"partial_json":"{\"path\":"`)
+}
+
+func TestStreamEventStreamAsAnthropicStopsPreviousToolWhenIDChanges(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_one",
+ "name": "write_file",
+ "input": `{"path":"a"}`,
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_two",
+ "name": "read_file",
+ "input": `{"path":"b"}`,
+ "stop": true,
+ },
+ }))
+
+ var out bytes.Buffer
+ _, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+
+ output := out.String()
+ firstStart := strings.Index(output, `"id":"toolu_one"`)
+ firstStop := strings.Index(output[firstStart:], `event: content_block_stop`)
+ secondStart := strings.Index(output, `"id":"toolu_two"`)
+ require.NotEqual(t, -1, firstStart)
+ require.NotEqual(t, -1, firstStop)
+ require.NotEqual(t, -1, secondStart)
+ require.Less(t, firstStart+firstStop, secondStart)
+}
+
+func TestStreamEventStreamAsAnthropicClosesToolBeforeText(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_before_text",
+ "name": "write_file",
+ "input": `{"path":"a"}`,
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "done",
+ },
+ }))
+
+ var out bytes.Buffer
+ _, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+
+ output := out.String()
+ toolStart := strings.Index(output, `"id":"toolu_before_text"`)
+ toolStop := strings.Index(output[toolStart:], `event: content_block_stop`)
+ textDelta := strings.Index(output, `"text":"done"`)
+ require.NotEqual(t, -1, toolStart)
+ require.NotEqual(t, -1, toolStop)
+ require.NotEqual(t, -1, textDelta)
+ require.Less(t, toolStart+toolStop, textDelta)
+}
+
+func TestStreamEventStreamAsAnthropicClosesThinkingBeforeTool(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "reasoningContentEvent", map[string]any{
+ "reasoningContentEvent": map[string]any{
+ "text": "thinking first",
+ },
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_after_thinking",
+ "name": "write_file",
+ "input": `{"path":"a"}`,
+ "stop": true,
+ },
+ }))
+
+ var out bytes.Buffer
+ _, err := StreamEventStreamAsAnthropicWithContext(context.Background(), stream, &out, "claude-sonnet-4-5", 9, KiroRequestContext{ThinkingEnabled: true})
+ require.NoError(t, err)
+
+ output := out.String()
+ thinkingDelta := strings.Index(output, `"thinking":"thinking first"`)
+ toolStart := strings.Index(output, `"id":"toolu_after_thinking"`)
+ require.NotEqual(t, -1, thinkingDelta)
+ thinkingStop := strings.Index(output[thinkingDelta:], `event: content_block_stop`)
+ require.NotEqual(t, -1, thinkingStop)
+ require.NotEqual(t, -1, toolStart)
+ require.Less(t, thinkingDelta+thinkingStop, toolStart)
+}
+
+func TestStreamEventStreamAsAnthropicClosesOpenToolAtEOF(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_eof",
+ "name": "write_file",
+ "input": `{"path":"a"}`,
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "tool_use", result.StopReason)
+ require.Contains(t, out.String(), `event: content_block_stop`)
+}
+
+func TestStreamEventStreamAsAnthropicStreamsToolUseMapInput(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_map",
+ "name": "remote_web_search",
+ "input": map[string]any{
+ "query": "golang",
+ },
+ "stop": true,
+ },
+ }))
+
+ var out bytes.Buffer
+ _, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Contains(t, out.String(), `"partial_json":"{\"query\":\"golang\"}"`)
+}
+
+func TestStreamEventStreamAsAnthropicIgnoresPingFrames(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "ping", map[string]any{}))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "Hello after ping",
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+ require.Contains(t, out.String(), `"text":"Hello after ping"`)
+}
+
+func TestStreamEventStreamAsAnthropicThinkingOnlyResponse(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "reasoningContentEvent", map[string]any{
+ "reasoningContentEvent": map[string]any{
+ "text": "I should think first.",
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropicWithContext(context.Background(), stream, &out, "claude-sonnet-4-5", 9, KiroRequestContext{ThinkingEnabled: true})
+ require.NoError(t, err)
+ require.Equal(t, "max_tokens", result.StopReason)
+
+ output := out.String()
+ require.Contains(t, output, `"type":"thinking"`)
+ require.Contains(t, output, `"type":"thinking_delta"`)
+ require.Contains(t, output, `"thinking":"I should think first."`)
+ require.Contains(t, output, `"text":" "`)
+ require.Contains(t, output, `event: message_delta`)
+ require.Contains(t, output, `event: message_stop`)
+}
+
+func TestStreamEventStreamAsAnthropicParsesMultipleReasoningEventsWhenEnabled(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "reasoningContentEvent", map[string]any{
+ "reasoningContentEvent": map[string]any{"text": "first thought"},
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "reasoningContentEvent", map[string]any{
+ "reasoningContentEvent": map[string]any{"text": "second thought"},
+ }))
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{"content": "final"},
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropicWithContext(context.Background(), stream, &out, "claude-sonnet-4-5", 9, KiroRequestContext{ThinkingEnabled: true})
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+
+ output := out.String()
+ require.Contains(t, output, `"thinking":"first thought"`)
+ require.Contains(t, output, `"thinking":"second thought"`)
+ require.Contains(t, output, `"text":"final"`)
+}
+
+func TestStreamEventStreamAsAnthropicParsesTaggedThinkingWhenEnabled(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "\nreason\n\nfinal",
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropicWithContext(context.Background(), stream, &out, "claude-sonnet-4-5", 9, KiroRequestContext{ThinkingEnabled: true})
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+
+ output := out.String()
+ thinkingDelta := strings.Index(output, `"thinking":"reason"`)
+ textDelta := strings.Index(output, `"text":"final"`)
+ require.NotEqual(t, -1, thinkingDelta)
+ require.NotEqual(t, -1, textDelta)
+ require.Less(t, thinkingDelta, textDelta)
+ require.NotContains(t, output, `\u003c/thinking\u003e`)
+}
+
+func TestStreamEventStreamAsAnthropicBuffersSplitThinkingTags(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ for _, chunk := range []string{"\n\n\nrea", "son", "\n\nfinal"} {
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{"content": chunk},
+ }))
+ }
+
+ var out bytes.Buffer
+ _, err := StreamEventStreamAsAnthropicWithContext(context.Background(), stream, &out, "claude-sonnet-4-5", 9, KiroRequestContext{ThinkingEnabled: true})
+ require.NoError(t, err)
+
+ output := out.String()
+ thinkingStart := strings.Index(output, `"type":"thinking"`)
+ textDelta := strings.Index(output, `"text":"final"`)
+ require.NotEqual(t, -1, thinkingStart)
+ require.NotEqual(t, -1, textDelta)
+ require.Less(t, thinkingStart, textDelta)
+ require.NotContains(t, output, `\u003cthink`)
+ require.NotContains(t, output, `\u003c/thinking\u003e`)
+ require.NotContains(t, output, `"text":"\n\n"`)
+}
+
+func TestStreamEventStreamAsAnthropicTreatsThinkingTagsAsTextWhenDisabled(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
+ "assistantResponseEvent": map[string]any{
+ "content": "reason\n\nfinal",
+ },
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+
+ output := out.String()
+ require.Contains(t, output, `\u003cthinking\u003ereason\u003c/thinking\u003e`)
+ require.NotContains(t, output, `"type":"thinking_delta"`)
+}
+
+func TestStreamEventStreamAsAnthropicIgnoresReasoningContentWhenThinkingDisabled(t *testing.T) {
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "reasoningContentEvent", map[string]any{
+ "reasoningContentEvent": map[string]any{"text": "hidden reasoning"},
+ }))
+
+ var out bytes.Buffer
+ result, err := StreamEventStreamAsAnthropic(context.Background(), stream, &out, "claude-sonnet-4-5", 9)
+ require.NoError(t, err)
+ require.Equal(t, "end_turn", result.StopReason)
+ require.NotContains(t, out.String(), "hidden reasoning")
+ require.NotContains(t, out.String(), `"type":"thinking"`)
+}
+
+func TestBuildAssistantMessageStructUsesSpacePlaceholderForToolOnly(t *testing.T) {
+ msg := gjson.Parse(`{
+ "role":"assistant",
+ "content":[
+ {"type":"tool_use","id":"toolu_01ABC","name":"read_file","input":{"path":"/tmp/test.txt"}}
+ ]
+ }`)
+
+ result := buildAssistantMessageStruct(msg, nil)
+ require.Equal(t, " ", result.Content)
+ require.Len(t, result.ToolUses, 1)
+ require.Equal(t, "read_file", result.ToolUses[0].Name)
+ require.Equal(t, "/tmp/test.txt", result.ToolUses[0].Input["path"])
+}
+
+func TestBuildKiroPayloadAddsPlaceholderToolForHistoryToolUse(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[
+ {"role":"assistant","content":[{"type":"tool_use","id":"toolu_01","name":"read_file","input":{"path":"/tmp/a.txt"}}]},
+ {"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_01","content":"ok"},{"type":"text","text":"continue"}]}
+ ]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+ tools := gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools").Array()
+ require.Len(t, tools, 1)
+ require.Equal(t, "read_file", tools[0].Get("toolSpecification.name").String())
+ require.Equal(t, "Tool used in conversation history", tools[0].Get("toolSpecification.description").String())
+ require.Equal(t, "object", tools[0].Get("toolSpecification.inputSchema.json.type").String())
+}
+
+func TestBuildKiroPayloadNormalizesToolJSONSchema(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":"hello"}],
+ "tools":[{
+ "name":"bad_schema",
+ "description":"bad schema",
+ "input_schema":{
+ "properties":null,
+ "required":null,
+ "additionalProperties":"sometimes",
+ "items":{"properties":null,"required":[1,"ok"],"additionalProperties":7}
+ }
+ }]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+ schema := gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.inputSchema.json")
+ require.Equal(t, "object", schema.Get("type").String())
+ require.True(t, schema.Get("properties").IsObject())
+ require.True(t, schema.Get("required").IsArray())
+ require.Len(t, schema.Get("required").Array(), 0)
+ require.True(t, schema.Get("additionalProperties").Bool())
+ require.Equal(t, "object", schema.Get("items.type").String())
+ require.Equal(t, "ok", schema.Get("items.required.0").String())
+ require.True(t, schema.Get("items.additionalProperties").Bool())
+}
+
+func TestBuildKiroPayloadFiltersCurrentOrphanToolResult(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"missing","content":"orphaned"}]}]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+ require.False(t, gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.toolResults").Exists())
+}
+
+func TestBuildKiroPayloadRemovesHistoryOrphanToolUse(t *testing.T) {
+ body := []byte(`{
+ "model":"claude-sonnet-4-5",
+ "messages":[
+ {"role":"assistant","content":[{"type":"tool_use","id":"toolu_orphan","name":"read_file","input":{"path":"/tmp/a.txt"}}]},
+ {"role":"user","content":"continue"}
+ ]
+ }`)
+
+ payload, err := BuildKiroPayload(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+ history := gjson.GetBytes(payload, "conversationState.history").Array()
+ foundAssistantWithoutToolUses := false
+ for _, msg := range history {
+ if msg.Get("assistantResponseMessage").Exists() && msg.Get("assistantResponseMessage.content").String() == " " {
+ foundAssistantWithoutToolUses = true
+ require.False(t, msg.Get("assistantResponseMessage.toolUses").Exists())
+ }
+ }
+ require.True(t, foundAssistantWithoutToolUses)
+ require.False(t, gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools").Exists())
+}
+
+func TestMergeAdjacentMessagesUsesDoubleNewline(t *testing.T) {
+ messages := gjson.Parse(`[
+ {"role":"user","content":"first"},
+ {"role":"user","content":"second"}
+ ]`).Array()
+
+ merged := mergeAdjacentMessages(messages)
+ require.Len(t, merged, 1)
+ require.Equal(t, "first\n\nsecond", merged[0].Get("content.0.text").String())
+}
+
+func TestLongToolNamesUseHashSuffixAndDoNotCollide(t *testing.T) {
+ nameA := strings.Repeat("tool_prefix_", 8) + "alpha"
+ nameB := strings.Repeat("tool_prefix_", 8) + "bravo"
+ shortA := shortenToolNameIfNeeded(nameA)
+ shortB := shortenToolNameIfNeeded(nameB)
+
+ require.Len(t, shortA, kiroMaxToolNameLen)
+ require.Len(t, shortB, kiroMaxToolNameLen)
+ require.NotEqual(t, shortA, shortB)
+ require.Regexp(t, `_[0-9a-f]{8}$`, shortA)
+ require.Regexp(t, `_[0-9a-f]{8}$`, shortB)
+}
+
+func TestBuildKiroPayloadMapsLongToolNameConsistently(t *testing.T) {
+ longName := strings.Repeat("mcp__very_long_server__", 4) + "read_file"
+ body := []byte(fmt.Sprintf(`{
+ "model":"claude-sonnet-4-5",
+ "system":"Follow tool choice.",
+ "tool_choice":{"type":"tool","name":%q},
+ "messages":[
+ {"role":"assistant","content":[{"type":"tool_use","id":"toolu_01","name":%q,"input":{"path":"/tmp/a.txt"}}]},
+ {"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_01","content":"ok"},{"type":"text","text":"continue"}]}
+ ],
+ "tools":[{"name":%q,"description":"read","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]
+ }`, longName, longName, longName))
+
+ result, err := BuildKiroPayloadWithContext(body, "claude-sonnet-4.5", "", "AI_EDITOR", nil)
+ require.NoError(t, err)
+ require.Len(t, result.Context.ToolNameMap, 1)
+ var shortName string
+ for short, original := range result.Context.ToolNameMap {
+ shortName = short
+ require.Equal(t, longName, original)
+ }
+ require.NotEmpty(t, shortName)
+ require.Equal(t, shortName, gjson.GetBytes(result.Payload, "conversationState.currentMessage.userInputMessage.userInputMessageContext.tools.0.toolSpecification.name").String())
+ require.Contains(t, gjson.GetBytes(result.Payload, "conversationState.history.0.userInputMessage.content").String(), "MUST use the tool named '"+shortName+"'")
+
+ found := false
+ for _, msg := range gjson.GetBytes(result.Payload, "conversationState.history").Array() {
+ for _, toolUse := range msg.Get("assistantResponseMessage.toolUses").Array() {
+ if toolUse.Get("toolUseId").String() == "toolu_01" {
+ found = true
+ require.Equal(t, shortName, toolUse.Get("name").String())
+ }
+ }
+ }
+ require.True(t, found)
+}
+
+func TestParseNonStreamingEventStreamRestoresShortToolName(t *testing.T) {
+ longName := strings.Repeat("long_tool_name_", 6)
+ shortName := shortenToolNameIfNeeded(longName)
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_long",
+ "name": shortName,
+ "input": `{"path":"/tmp/a.txt"}`,
+ "stop": true,
+ },
+ }))
+
+ result, err := ParseNonStreamingEventStreamWithContext(stream, "claude-sonnet-4-5", KiroRequestContext{
+ ToolNameMap: map[string]string{shortName: longName},
+ })
+ require.NoError(t, err)
+ require.Equal(t, longName, gjson.GetBytes(result.ResponseBody, "content.0.name").String())
+}
+
+func TestStreamEventStreamAsAnthropicRestoresShortToolName(t *testing.T) {
+ longName := strings.Repeat("long_tool_name_", 6)
+ shortName := shortenToolNameIfNeeded(longName)
+ stream := bytes.NewBuffer(nil)
+ _, _ = stream.Write(buildEventStreamFrame(t, "toolUseEvent", map[string]any{
+ "toolUseEvent": map[string]any{
+ "toolUseId": "toolu_long",
+ "name": shortName,
+ "input": `{"path":"/tmp/a.txt"}`,
+ "stop": true,
+ },
+ }))
+
+ var out bytes.Buffer
+ _, err := StreamEventStreamAsAnthropicWithContext(context.Background(), stream, &out, "claude-sonnet-4-5", 1, KiroRequestContext{
+ ToolNameMap: map[string]string{shortName: longName},
+ })
+ require.NoError(t, err)
+ require.Contains(t, out.String(), `"name":"`+longName+`"`)
+ require.NotContains(t, out.String(), `"name":"`+shortName+`"`)
+}
+
+func TestRepairJSONKeepsStringBracesWhileRepairingTrailingComma(t *testing.T) {
+ raw := `{"key":"value with {nested}",}`
+ repaired := repairJSON(raw)
+
+ var parsed map[string]string
+ require.NoError(t, json.Unmarshal([]byte(repaired), &parsed))
+ require.Equal(t, "value with {nested}", parsed["key"])
+}
+
+func TestMapModel_MatchesKiroReferenceMapping(t *testing.T) {
+ t.Parallel()
+
+ cases := map[string]string{
+ "claude-sonnet-4-6": "claude-sonnet-4.6",
+ "claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
+ "claude-sonnet-4.6": "claude-sonnet-4.6",
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
+ "claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
+ "claude-sonnet-4.5": "claude-sonnet-4.5",
+ "claude-opus-4-6": "claude-opus-4.6",
+ "claude-opus-4-6-thinking": "claude-opus-4.6",
+ "claude-opus-4.6": "claude-opus-4.6",
+ "claude-opus-4-5-20251101": "claude-opus-4.5",
+ "claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
+ "claude-opus-4.5": "claude-opus-4.5",
+ "claude-haiku-4-5-20251001": "claude-haiku-4.5",
+ "claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
+ "claude-haiku-4.5": "claude-haiku-4.5",
+ }
+
+ for input, want := range cases {
+ if got := MapModel(input); got != want {
+ t.Fatalf("MapModel(%q) = %q, want %q", input, got, want)
+ }
+ }
+
+ rejected := []string{
+ "claude-sonnet-4-6-chat",
+ " claude-sonnet-4-6-thinking-chat ",
+ "claude-sonnet-4-6-agentic",
+ " claude-sonnet-4-6-thinking-agentic ",
+ "claude-3-5-sonnet-20241022",
+ "claude-opus-4-20250514",
+ "claude-sonnet-4",
+ "claude-opus-4-5",
+ "claude-sonnet-4-5",
+ "claude-haiku-4-5",
+ }
+ for _, input := range rejected {
+ if got := MapModel(input); got != "" {
+ t.Fatalf("MapModel(%q) = %q, want empty", input, got)
+ }
+ }
+}
+
+func TestMapModel_ReturnsEmptyForUnsupportedModels(t *testing.T) {
+ t.Parallel()
+
+ cases := []string{
+ "auto",
+ "gpt-4",
+ "gpt-4o",
+ "deepseek-3-2",
+ "minimax-m2-1",
+ "qwen3-coder-next",
+ }
+
+ for _, input := range cases {
+ if got := MapModel(input); got != "" {
+ t.Fatalf("MapModel(%q) = %q, want empty string", input, got)
+ }
+ }
+}
+
+func buildEventStreamFrame(t *testing.T, eventType string, payload any) []byte {
+ t.Helper()
+ payloadBytes, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ headers := bytes.NewBuffer(nil)
+ _ = headers.WriteByte(byte(len(":event-type")))
+ _, _ = headers.WriteString(":event-type")
+ _ = headers.WriteByte(7)
+ require.NoError(t, binary.Write(headers, binary.BigEndian, uint16(len(eventType))))
+ _, _ = headers.WriteString(eventType)
+
+ totalLength := uint32(12 + headers.Len() + len(payloadBytes) + 4)
+ frame := bytes.NewBuffer(nil)
+ require.NoError(t, binary.Write(frame, binary.BigEndian, totalLength))
+ require.NoError(t, binary.Write(frame, binary.BigEndian, uint32(headers.Len())))
+ require.NoError(t, binary.Write(frame, binary.BigEndian, uint32(0)))
+ _, _ = frame.Write(headers.Bytes())
+ _, _ = frame.Write(payloadBytes)
+ require.NoError(t, binary.Write(frame, binary.BigEndian, uint32(0)))
+ return frame.Bytes()
+}
diff --git a/backend/internal/pkg/kiro/websearch.go b/backend/internal/pkg/kiro/websearch.go
new file mode 100644
index 00000000..28987852
--- /dev/null
+++ b/backend/internal/pkg/kiro/websearch.go
@@ -0,0 +1,368 @@
+package kiro
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/tidwall/gjson"
+)
+
+const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
+const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
+
+var cachedWebSearchDescription atomic.Value // stores string
+
+type MCPRequest struct {
+ ID string `json:"id"`
+ JSONRPC string `json:"jsonrpc"`
+ Method string `json:"method"`
+ Params interface{} `json:"params,omitempty"`
+}
+
+type MCPResponse struct {
+ Result *struct {
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ } `json:"content"`
+ Tools []struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ } `json:"tools"`
+ } `json:"result,omitempty"`
+ Error *struct {
+ Code *int `json:"code,omitempty"`
+ Message *string `json:"message,omitempty"`
+ } `json:"error,omitempty"`
+}
+
+type WebSearchResults struct {
+ Results []WebSearchResult `json:"results"`
+}
+
+type WebSearchResult struct {
+ Title string `json:"title"`
+ URL string `json:"url"`
+ Snippet *string `json:"snippet,omitempty"`
+ PublishedDate *int64 `json:"publishedDate,omitempty"`
+ ID *string `json:"id,omitempty"`
+ Domain *string `json:"domain,omitempty"`
+ MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
+ PublicDomain *bool `json:"publicDomain,omitempty"`
+}
+
+type SearchIndicator struct {
+ ToolUseID string
+ Query string
+ Results *WebSearchResults
+}
+
+func GetCachedWebSearchDescription() string {
+ if v := cachedWebSearchDescription.Load(); v != nil {
+ return strings.TrimSpace(v.(string))
+ }
+ return ""
+}
+
+func SetCachedWebSearchDescription(desc string) {
+ cachedWebSearchDescription.Store(strings.TrimSpace(desc))
+}
+
+func BuildMcpEndpoint(region string) string {
+ if strings.TrimSpace(region) == "" {
+ region = "us-east-1"
+ }
+ return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
+}
+
+func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
+ if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
+ return nil
+ }
+ for _, item := range resp.Result.Content {
+ if item.Type != "" && item.Type != "text" {
+ continue
+ }
+ var results WebSearchResults
+ if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
+ return &results
+ }
+ }
+ return nil
+}
+
+func ExtractSearchQuery(body []byte) string {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() {
+ return ""
+ }
+ arr := messages.Array()
+ for i := len(arr) - 1; i >= 0; i-- {
+ msg := arr[i]
+ if msg.Get("role").String() != "user" {
+ continue
+ }
+ text := extractSearchText(msg.Get("content"))
+ const prefix = "Perform a web search for the query: "
+ text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
+ if text != "" {
+ return text
+ }
+ }
+ return ""
+}
+
+func extractSearchText(content gjson.Result) string {
+ if content.Type == gjson.String {
+ return content.String()
+ }
+ if !content.IsArray() {
+ return ""
+ }
+ for _, block := range content.Array() {
+ if block.Get("type").String() == "text" {
+ if text := strings.TrimSpace(block.Get("text").String()); text != "" {
+ return text
+ }
+ }
+ }
+ return ""
+}
+
+func GenerateToolUseID() string {
+ return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
+}
+
+func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
+ var payload map[string]interface{}
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return body, err
+ }
+ rawTools, ok := payload["tools"].([]interface{})
+ if !ok {
+ return body, nil
+ }
+
+ replaced := make([]interface{}, 0, len(rawTools))
+ for _, rawTool := range rawTools {
+ tool, ok := rawTool.(map[string]interface{})
+ if !ok {
+ replaced = append(replaced, rawTool)
+ continue
+ }
+ name := getInterfaceString(tool["name"])
+ toolType := getInterfaceString(tool["type"])
+ if !isWebSearchToolName(name, toolType) {
+ replaced = append(replaced, rawTool)
+ continue
+ }
+ replaced = append(replaced, map[string]interface{}{
+ "name": "web_search",
+ "description": minimalWebSearchDescription,
+ "input_schema": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "query": map[string]interface{}{
+ "type": "string",
+ "description": "The search query to execute",
+ },
+ },
+ "required": []string{"query"},
+ "additionalProperties": false,
+ },
+ })
+ }
+
+ payload["tools"] = replaced
+ updated, err := json.Marshal(payload)
+ if err != nil {
+ return body, err
+ }
+ return updated, nil
+}
+
+func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
+ var payload map[string]interface{}
+ if err := json.Unmarshal(claudePayload, &payload); err != nil {
+ return claudePayload, fmt.Errorf("parse claude payload: %w", err)
+ }
+
+ rawMessages, ok := payload["messages"].([]interface{})
+ if !ok {
+ return claudePayload, fmt.Errorf("claude payload missing messages array")
+ }
+
+ assistantMsg := map[string]interface{}{
+ "role": "assistant",
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": toolUseID,
+ "name": "web_search",
+ "input": map[string]interface{}{"query": query},
+ },
+ },
+ }
+
+ userContent := []interface{}{
+ map[string]interface{}{
+ "type": "tool_result",
+ "tool_use_id": toolUseID,
+ "content": formatToolResultText(results),
+ },
+ }
+ if guidance := searchGuidanceText(); guidance != "" {
+ userContent = append(userContent, map[string]interface{}{
+ "type": "text",
+ "text": guidance,
+ })
+ }
+ userMsg := map[string]interface{}{
+ "role": "user",
+ "content": userContent,
+ }
+
+ rawMessages = append(rawMessages, assistantMsg, userMsg)
+ payload["messages"] = rawMessages
+ updated, err := json.Marshal(payload)
+ if err != nil {
+ return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
+ }
+ return updated, nil
+}
+
+func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
+ if len(searches) == 0 {
+ return responsePayload, nil
+ }
+
+ var response map[string]interface{}
+ if err := json.Unmarshal(responsePayload, &response); err != nil {
+ return responsePayload, err
+ }
+ content, _ := response["content"].([]interface{})
+ updated := make([]interface{}, 0, len(searches)*2+len(content))
+ for _, search := range searches {
+ updated = append(updated, map[string]interface{}{
+ "type": "server_tool_use",
+ "id": search.ToolUseID,
+ "name": "web_search",
+ "input": map[string]interface{}{"query": search.Query},
+ })
+ updated = append(updated, map[string]interface{}{
+ "type": "web_search_tool_result",
+ "content": buildSearchResultContent(search.Results),
+ })
+ }
+ updated = append(updated, content...)
+ response["content"] = updated
+
+ encoded, err := json.Marshal(response)
+ if err != nil {
+ return responsePayload, err
+ }
+ return encoded, nil
+}
+
+func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
+ content := make([]map[string]interface{}, 0)
+ if results == nil {
+ return content
+ }
+ for _, result := range results.Results {
+ snippet := ""
+ if result.Snippet != nil {
+ snippet = strings.TrimSpace(*result.Snippet)
+ }
+ content = append(content, map[string]interface{}{
+ "type": "web_search_result",
+ "title": result.Title,
+ "url": result.URL,
+ "encrypted_content": snippet,
+ "page_age": nil,
+ })
+ }
+ return content
+}
+
+func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
+ content := gjson.GetBytes(responsePayload, "content")
+ if !content.IsArray() {
+ return "", "", false
+ }
+ for _, block := range content.Array() {
+ if block.Get("type").String() != "tool_use" {
+ continue
+ }
+ name := block.Get("name").String()
+ if !isWebSearchToolName(name, "") {
+ continue
+ }
+ query = strings.TrimSpace(block.Get("input.query").String())
+ if query == "" {
+ continue
+ }
+ return block.Get("id").String(), query, true
+ }
+ return "", "", false
+}
+
+func isWebSearchToolName(name, toolType string) bool {
+ name = strings.ToLower(strings.TrimSpace(name))
+ toolType = strings.ToLower(strings.TrimSpace(toolType))
+ if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
+ return true
+ }
+ switch name {
+ case "web_search", "web_search_20250305", "google_search", "remote_web_search":
+ return true
+ default:
+ return false
+ }
+}
+
+func getInterfaceString(v interface{}) string {
+ if v == nil {
+ return ""
+ }
+ switch val := v.(type) {
+ case string:
+ return strings.TrimSpace(val)
+ default:
+ return strings.TrimSpace(fmt.Sprint(val))
+ }
+}
+
+func formatToolResultText(results *WebSearchResults) string {
+ if results == nil || len(results.Results) == 0 {
+ return "No search results found."
+ }
+ payload, err := json.MarshalIndent(results.Results, "", " ")
+ if err != nil {
+ return "Found search results, but failed to format them."
+ }
+ return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
+}
+
+func searchGuidanceText() string {
+ now := time.Now()
+ return fmt.Sprintf(`
+Current date: %s (%s)
+
+IMPORTANT: Evaluate the search results above carefully. If the results are:
+- Mostly spam, SEO junk, or unrelated websites
+- Missing actual information about the query topic
+- Outdated or not matching the requested time frame
+
+Then you MUST use the web_search tool again with a refined query. Try:
+- Rephrasing in English for better coverage
+- Using more specific keywords
+- Adding date context
+
+Do NOT apologize for bad results without first attempting a re-search.
+`, now.Format("January 2, 2006"), now.Format("Monday"))
+}
diff --git a/backend/internal/pkg/kiro/websearch_stream.go b/backend/internal/pkg/kiro/websearch_stream.go
new file mode 100644
index 00000000..2a6a910a
--- /dev/null
+++ b/backend/internal/pkg/kiro/websearch_stream.go
@@ -0,0 +1,297 @@
+package kiro
+
+import (
+ "encoding/json"
+ "strings"
+)
+
+type BufferedStreamResult struct {
+ StopReason string
+ WebSearchQuery string
+ WebSearchToolUseID string
+ HasWebSearchToolUse bool
+ WebSearchToolUseIndex int
+}
+
+func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
+ searchContent := make([]map[string]interface{}, 0)
+ if results != nil {
+ for _, result := range results.Results {
+ snippet := ""
+ if result.Snippet != nil {
+ snippet = strings.TrimSpace(*result.Snippet)
+ }
+ searchContent = append(searchContent, map[string]interface{}{
+ "type": "web_search_result",
+ "title": result.Title,
+ "url": result.URL,
+ "encrypted_content": snippet,
+ "page_age": nil,
+ })
+ }
+ }
+
+ inputJSON, _ := json.Marshal(map[string]string{"query": query})
+
+ events := []map[string]interface{}{
+ {
+ "type": "content_block_start",
+ "index": startIndex,
+ "content_block": map[string]interface{}{
+ "type": "server_tool_use",
+ "id": toolUseID,
+ "name": "web_search",
+ "input": map[string]interface{}{},
+ },
+ },
+ {
+ "type": "content_block_delta",
+ "index": startIndex,
+ "delta": map[string]interface{}{
+ "type": "input_json_delta",
+ "partial_json": string(inputJSON),
+ },
+ },
+ {
+ "type": "content_block_stop",
+ "index": startIndex,
+ },
+ {
+ "type": "content_block_start",
+ "index": startIndex + 1,
+ "content_block": map[string]interface{}{
+ "type": "web_search_tool_result",
+ "content": searchContent,
+ },
+ },
+ {
+ "type": "content_block_stop",
+ "index": startIndex + 1,
+ },
+ }
+
+ result := make([][]byte, 0, len(events))
+ for _, event := range events {
+ eventType, _ := event["type"].(string)
+ payload, _ := json.Marshal(event)
+ result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
+ }
+ return result
+}
+
+func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
+ result := BufferedStreamResult{WebSearchToolUseIndex: -1}
+ var currentToolName string
+ currentToolIndex := -1
+ var toolInputBuilder strings.Builder
+
+ for _, chunk := range chunks {
+ lines := strings.Split(string(chunk), "\n")
+ for _, line := range lines {
+ if !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
+ if payload == "" || payload == "[DONE]" {
+ continue
+ }
+
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(payload), &event); err != nil {
+ continue
+ }
+
+ switch eventType, _ := event["type"].(string); eventType {
+ case "message_delta":
+ if delta, ok := event["delta"].(map[string]interface{}); ok {
+ if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
+ result.StopReason = stopReason
+ }
+ }
+ case "content_block_start":
+ contentBlock, ok := event["content_block"].(map[string]interface{})
+ if !ok {
+ continue
+ }
+ blockType, _ := contentBlock["type"].(string)
+ if blockType != "tool_use" {
+ continue
+ }
+ currentToolName, _ = contentBlock["name"].(string)
+ currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
+ if idx, ok := event["index"].(float64); ok {
+ currentToolIndex = int(idx)
+ }
+ if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
+ result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
+ }
+ toolInputBuilder.Reset()
+ case "content_block_delta":
+ if currentToolName == "" {
+ continue
+ }
+ delta, ok := event["delta"].(map[string]interface{})
+ if !ok {
+ continue
+ }
+ deltaType, _ := delta["type"].(string)
+ if deltaType != "input_json_delta" {
+ continue
+ }
+ if partialJSON, ok := delta["partial_json"].(string); ok {
+ toolInputBuilder.WriteString(partialJSON)
+ }
+ case "content_block_stop":
+ if !isWebSearchToolName(currentToolName, "") {
+ currentToolName = ""
+ currentToolIndex = -1
+ toolInputBuilder.Reset()
+ continue
+ }
+ result.HasWebSearchToolUse = true
+ result.WebSearchToolUseIndex = currentToolIndex
+ var input map[string]string
+ if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
+ result.WebSearchQuery = strings.TrimSpace(input["query"])
+ }
+ currentToolName = ""
+ currentToolIndex = -1
+ toolInputBuilder.Reset()
+ }
+ }
+ }
+
+ return result
+}
+
+func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
+ filtered := make([][]byte, 0, len(chunks))
+ for _, chunk := range chunks {
+ adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
+ if shouldForward {
+ filtered = append(filtered, adjusted)
+ }
+ }
+ return filtered
+}
+
+func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
+ return filterSSEChunk(chunk, -1, offset)
+}
+
+func MaxContentBlockIndex(chunks [][]byte) int {
+ maxIndex := -1
+ for _, chunk := range chunks {
+ lines := strings.Split(string(chunk), "\n")
+ for _, line := range lines {
+ if !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
+ if payload == "" || payload == "[DONE]" {
+ continue
+ }
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(payload), &event); err != nil {
+ continue
+ }
+ switch eventType, _ := event["type"].(string); eventType {
+ case "content_block_start", "content_block_delta", "content_block_stop":
+ if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
+ maxIndex = int(idx)
+ }
+ }
+ }
+ }
+ return maxIndex
+}
+
+func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
+ lines := strings.Split(string(chunk), "\n")
+ var builder strings.Builder
+ hasContent := false
+
+ for i := 0; i < len(lines); i++ {
+ line := lines[i]
+ if strings.HasPrefix(line, "event: ") {
+ if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
+ payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
+ if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
+ i++
+ continue
+ }
+ }
+ builder.WriteString(line + "\n")
+ hasContent = true
+ continue
+ }
+
+ if strings.HasPrefix(line, "data: ") {
+ payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
+ if payload == "[DONE]" {
+ continue
+ }
+ if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
+ continue
+ }
+ adjusted := adjustEventPayload(payload, indexOffset)
+ if adjusted == "" {
+ continue
+ }
+ builder.WriteString("data: " + adjusted + "\n")
+ hasContent = true
+ continue
+ }
+
+ builder.WriteString(line + "\n")
+ if strings.TrimSpace(line) != "" {
+ hasContent = true
+ }
+ }
+
+ if !hasContent {
+ return nil, false
+ }
+ return []byte(builder.String()), true
+}
+
+func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
+ if payload == "" {
+ return false
+ }
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(payload), &event); err != nil {
+ return false
+ }
+ eventType, _ := event["type"].(string)
+ if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
+ return true
+ }
+ if webSearchToolUseIndex < 0 {
+ return false
+ }
+ if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
+ return true
+ }
+ return false
+}
+
+func adjustEventPayload(payload string, indexOffset int) string {
+ if payload == "" || indexOffset == 0 {
+ return payload
+ }
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(payload), &event); err != nil {
+ return payload
+ }
+ switch eventType, _ := event["type"].(string); eventType {
+ case "content_block_start", "content_block_delta", "content_block_stop":
+ if idx, ok := event["index"].(float64); ok {
+ event["index"] = int(idx) + indexOffset
+ if adjusted, err := json.Marshal(event); err == nil {
+ return string(adjusted)
+ }
+ }
+ }
+ return payload
+}
diff --git a/backend/internal/pkg/kiro/websearch_stream_test.go b/backend/internal/pkg/kiro/websearch_stream_test.go
new file mode 100644
index 00000000..1b48f803
--- /dev/null
+++ b/backend/internal/pkg/kiro/websearch_stream_test.go
@@ -0,0 +1,73 @@
+package kiro
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
+ snippet := "result snippet"
+ events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
+ Results: []WebSearchResult{
+ {Title: "Go", URL: "https://go.dev", Snippet: &snippet},
+ },
+ }, 0)
+
+ require.Len(t, events, 5)
+ require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
+ require.Contains(t, string(events[0]), `"input":{}`)
+ require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
+ require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
+ require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
+ require.NotContains(t, string(events[3]), `"tool_use_id"`)
+ require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
+}
+
+func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
+ chunks := [][]byte{
+ []byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
+ []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
+ []byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
+ []byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
+ []byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
+ }
+
+ result := AnalyzeBufferedStream(chunks)
+ require.True(t, result.HasWebSearchToolUse)
+ require.Equal(t, "golang concurrency", result.WebSearchQuery)
+ require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
+ require.Equal(t, 1, result.WebSearchToolUseIndex)
+ require.Equal(t, "tool_use", result.StopReason)
+}
+
+func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
+ chunks := [][]byte{
+ []byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
+ []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
+ []byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
+ []byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
+ []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
+ []byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
+ []byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
+ []byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
+ }
+
+ filtered := FilterChunksForClient(chunks, 1, 2)
+ require.NotEmpty(t, filtered)
+ joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
+ require.NotContains(t, joined, `"type":"message_start"`)
+ require.NotContains(t, joined, `"type":"message_delta"`)
+ require.NotContains(t, joined, `"name":"web_search"`)
+ require.Contains(t, joined, `"index":2`)
+ require.Equal(t, 2, MaxContentBlockIndex(filtered))
+}
+
+func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
+ _, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
+ require.False(t, shouldForward)
+
+ adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
+ require.True(t, shouldForward)
+ require.Contains(t, string(adjusted), `"index":3`)
+}
diff --git a/backend/internal/pkg/kiro/websearch_test.go b/backend/internal/pkg/kiro/websearch_test.go
new file mode 100644
index 00000000..394e9d9f
--- /dev/null
+++ b/backend/internal/pkg/kiro/websearch_test.go
@@ -0,0 +1,138 @@
+package kiro
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
+ body := []byte(`{
+ "tools":[{"type":"web_search_20250305","description":"old"}],
+ "messages":[{"role":"user","content":"golang"}]
+ }`)
+
+ updated, err := ReplaceWebSearchToolDescription(body)
+ require.NoError(t, err)
+ require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
+ require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
+ require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
+ require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
+ require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
+ require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
+}
+
+func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
+ body := []byte(`{
+ "messages":[{"role":"user","content":"what is golang"}]
+ }`)
+ results := &WebSearchResults{
+ Results: []WebSearchResult{
+ {Title: "Go", URL: "https://go.dev"},
+ },
+ }
+
+ updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
+ require.NoError(t, err)
+ require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
+ require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
+ require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
+ require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
+ require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
+ require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
+ require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
+ require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "")
+}
+
+func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
+ response := []byte(`{
+ "content":[
+ {"type":"text","text":"let me search"},
+ {"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
+ ]
+ }`)
+
+ toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
+ require.True(t, ok)
+ require.Equal(t, "srvtoolu_next", toolUseID)
+ require.Equal(t, "golang concurrency", query)
+}
+
+func TestInjectSearchIndicatorsInResponse(t *testing.T) {
+ response := []byte(`{
+ "id":"msg_1",
+ "type":"message",
+ "role":"assistant",
+ "model":"kiro",
+ "content":[{"type":"text","text":"final"}],
+ "stop_reason":"end_turn",
+ "usage":{"input_tokens":1,"output_tokens":1}
+ }`)
+
+ snippet := "result snippet"
+ updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
+ {
+ ToolUseID: "srvtoolu_test",
+ Query: "golang",
+ Results: &WebSearchResults{
+ Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(updated, &decoded))
+ require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
+ require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
+ require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
+ require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
+ require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
+ require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
+ require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
+ require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
+}
+
+func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
+ resp := &MCPResponse{
+ Result: &struct {
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ } `json:"content"`
+ Tools []struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ } `json:"tools"`
+ }{
+ Content: []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ }{
+ {
+ Type: "text",
+ Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
+ },
+ },
+ },
+ }
+
+ results := ParseSearchResults(resp)
+ require.NotNil(t, results)
+ require.Len(t, results.Results, 1)
+ require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
+ require.Equal(t, "doc-1", *results.Results[0].ID)
+ require.Equal(t, "go.dev", *results.Results[0].Domain)
+ require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
+ require.True(t, *results.Results[0].PublicDomain)
+}
+
+func TestSearchGuidanceText_IsStructured(t *testing.T) {
+ guidance := searchGuidanceText()
+ require.Contains(t, guidance, "")
+ require.Contains(t, guidance, "Current date:")
+ require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
+ require.Contains(t, guidance, "Rephrasing in English for better coverage")
+}
diff --git a/backend/internal/pkg/kirocooldown/store.go b/backend/internal/pkg/kirocooldown/store.go
new file mode 100644
index 00000000..b2a9980d
--- /dev/null
+++ b/backend/internal/pkg/kirocooldown/store.go
@@ -0,0 +1,479 @@
+package kirocooldown
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "math/rand"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ MinRequestInterval = time.Second
+ MaxRequestInterval = 2 * time.Second
+
+ CooldownReason429 = "rate_limit_exceeded"
+ CooldownReasonSuspended = "account_suspended"
+
+ ShortCooldown = time.Minute
+ MaxCooldown = 5 * time.Minute
+ LongCooldown = 24 * time.Hour
+
+ redisTimeout = 3 * time.Second
+ activeTTL = 10 * time.Second
+ stateTTL = 25 * time.Hour
+ keyPrefix = "kiro:cooldown:"
+)
+
+var (
+ ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
+
+ reserveRequestScript = redis.NewScript(`
+local t = redis.call('TIME')
+local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
+local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
+local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
+local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
+local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
+local interval_ms = tonumber(ARGV[1])
+local active_ttl_ms = tonumber(ARGV[2])
+local state_ttl_ms = tonumber(ARGV[3])
+
+if cooldown_until_ms > now_ms then
+ return {1, cooldown_until_ms - now_ms, cooldown_reason}
+end
+
+if cooldown_until_ms > 0 then
+ redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
+end
+
+local next_slot_ms = now_ms
+if last_request_ms > 0 then
+ local candidate_ms = last_request_ms + interval_ms
+ if candidate_ms > now_ms then
+ next_slot_ms = candidate_ms
+ end
+end
+
+redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
+if fail_count > 0 or cooldown_until_ms > now_ms then
+ redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
+else
+ redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
+end
+return {0, next_slot_ms - now_ms, ''}
+`)
+
+ mark429Script = redis.NewScript(`
+local t = redis.call('TIME')
+local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
+local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
+local short_cooldown_ms = tonumber(ARGV[1])
+local max_cooldown_ms = tonumber(ARGV[2])
+local state_ttl_ms = tonumber(ARGV[3])
+local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
+if cooldown_ms > max_cooldown_ms then
+ cooldown_ms = max_cooldown_ms
+end
+redis.call('HSET', KEYS[1],
+ 'fail_count', fail_count,
+ 'cooldown_until_ms', now_ms + cooldown_ms,
+ 'cooldown_reason', ARGV[4]
+)
+redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
+return cooldown_ms
+`)
+
+ markSuccessScript = redis.NewScript(`
+redis.call('HSET', KEYS[1],
+ 'fail_count', 0,
+ 'cooldown_until_ms', 0,
+ 'cooldown_reason', ''
+)
+redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
+return 1
+`)
+
+ markSuspendedScript = redis.NewScript(`
+local t = redis.call('TIME')
+local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
+local cooldown_ms = tonumber(ARGV[1])
+local state_ttl_ms = tonumber(ARGV[2])
+redis.call('HSET', KEYS[1],
+ 'fail_count', 0,
+ 'cooldown_until_ms', now_ms + cooldown_ms,
+ 'cooldown_reason', ARGV[3]
+)
+redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
+return cooldown_ms
+`)
+)
+
+type Error struct {
+ remaining time.Duration
+ reason string
+}
+
+type State struct {
+ Active bool
+ Reason string
+ CooldownUntil time.Time
+ Remaining time.Duration
+ FailCount int
+}
+
+func NewError(remaining time.Duration, reason string) error {
+ return &Error{remaining: remaining, reason: reason}
+}
+
+func (e *Error) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.reason == "" {
+ return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
+ }
+ return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
+}
+
+func Calculate429Cooldown(retryCount int) time.Duration {
+ if retryCount < 0 {
+ retryCount = 0
+ }
+ cooldown := ShortCooldown * time.Duration(1< MaxCooldown {
+ return MaxCooldown
+ }
+ return cooldown
+}
+
+type Store struct {
+ client *redis.Client
+ rngMu sync.Mutex
+ rng *rand.Rand
+}
+
+func NewStore(client *redis.Client) *Store {
+ return &Store{
+ client: client,
+ rng: rand.New(rand.NewSource(time.Now().UnixNano())),
+ }
+}
+
+func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
+ if err := s.validate(); err != nil {
+ return 0, err
+ }
+ cacheCtx, cancel := withRedisTimeout(ctx)
+ defer cancel()
+
+ values, err := reserveRequestScript.Run(
+ cacheCtx,
+ s.client,
+ []string{RedisKey(tokenKey)},
+ s.nextInterval().Milliseconds(),
+ activeTTL.Milliseconds(),
+ stateTTL.Milliseconds(),
+ ).Result()
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
+ }
+ parts, ok := values.([]interface{})
+ if !ok || len(parts) != 3 {
+ return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
+ }
+ state, err := luaInt64(parts[0])
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
+ }
+ waitMS, err := luaInt64(parts[1])
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
+ }
+ reason, err := luaString(parts[2])
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
+ }
+ if state == 1 {
+ return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
+ }
+ if waitMS <= 0 {
+ return 0, nil
+ }
+ return time.Duration(waitMS) * time.Millisecond, nil
+}
+
+func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
+ if err := s.validate(); err != nil {
+ return err
+ }
+ cacheCtx, cancel := withRedisTimeout(ctx)
+ defer cancel()
+ if err := markSuccessScript.Run(
+ cacheCtx,
+ s.client,
+ []string{RedisKey(tokenKey)},
+ activeTTL.Milliseconds(),
+ ).Err(); err != nil {
+ return fmt.Errorf("kiro cooldown mark success: %w", err)
+ }
+ return nil
+}
+
+func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
+ if err := s.validate(); err != nil {
+ return 0, err
+ }
+ cacheCtx, cancel := withRedisTimeout(ctx)
+ defer cancel()
+ result, err := mark429Script.Run(
+ cacheCtx,
+ s.client,
+ []string{RedisKey(tokenKey)},
+ ShortCooldown.Milliseconds(),
+ MaxCooldown.Milliseconds(),
+ stateTTL.Milliseconds(),
+ CooldownReason429,
+ ).Result()
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
+ }
+ cooldownMS, err := luaInt64(result)
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
+ }
+ return time.Duration(cooldownMS) * time.Millisecond, nil
+}
+
+func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
+ if err := s.validate(); err != nil {
+ return 0, err
+ }
+ cacheCtx, cancel := withRedisTimeout(ctx)
+ defer cancel()
+ result, err := markSuspendedScript.Run(
+ cacheCtx,
+ s.client,
+ []string{RedisKey(tokenKey)},
+ LongCooldown.Milliseconds(),
+ stateTTL.Milliseconds(),
+ CooldownReasonSuspended,
+ ).Result()
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
+ }
+ cooldownMS, err := luaInt64(result)
+ if err != nil {
+ return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
+ }
+ return time.Duration(cooldownMS) * time.Millisecond, nil
+}
+
+func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
+ if err := s.validate(); err != nil {
+ return nil, err
+ }
+ cacheCtx, cancel := withRedisTimeout(ctx)
+ defer cancel()
+
+ values, err := s.client.HMGet(
+ cacheCtx,
+ RedisKey(tokenKey),
+ "cooldown_until_ms",
+ "cooldown_reason",
+ "fail_count",
+ ).Result()
+ if err != nil {
+ return nil, fmt.Errorf("kiro cooldown get state: %w", err)
+ }
+ if len(values) != 3 {
+ return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
+ }
+
+ cooldownUntilMS, err := luaInt64(values[0])
+ if err != nil && values[0] != nil {
+ return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
+ }
+ reason, err := luaString(values[1])
+ if err != nil {
+ return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
+ }
+ failCount, err := luaInt64(values[2])
+ if err != nil && values[2] != nil {
+ return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
+ }
+ if cooldownUntilMS <= 0 {
+ return nil, nil
+ }
+
+ cooldownUntil := time.UnixMilli(cooldownUntilMS)
+ remaining := time.Until(cooldownUntil)
+ if remaining <= 0 {
+ return nil, nil
+ }
+
+ return &State{
+ Active: true,
+ Reason: reason,
+ CooldownUntil: cooldownUntil,
+ Remaining: remaining,
+ FailCount: int(failCount),
+ }, nil
+}
+
+func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
+ if err := s.validate(); err != nil {
+ return false, err
+ }
+ uniqueKeys := make([]string, 0, len(tokenKeys))
+ seen := make(map[string]struct{}, len(tokenKeys))
+ for _, tokenKey := range tokenKeys {
+ tokenKey = strings.TrimSpace(tokenKey)
+ if tokenKey == "" {
+ continue
+ }
+ redisKey := RedisKey(tokenKey)
+ if _, ok := seen[redisKey]; ok {
+ continue
+ }
+ seen[redisKey] = struct{}{}
+ uniqueKeys = append(uniqueKeys, redisKey)
+ }
+ if len(uniqueKeys) == 0 {
+ return false, nil
+ }
+
+ cacheCtx, cancel := withRedisTimeout(ctx)
+ defer cancel()
+
+ type candidate struct {
+ redisKey string
+ cooldownUntilMS int64
+ failCount int64
+ }
+ now := time.Now().UnixMilli()
+ var best *candidate
+
+ pipe := s.client.Pipeline()
+ cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
+ for _, redisKey := range uniqueKeys {
+ cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
+ }
+ if _, err := pipe.Exec(cacheCtx); err != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
+ }
+
+ for i, cmd := range cmds {
+ values, err := cmd.Result()
+ if err != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
+ }
+ if len(values) != 3 {
+ return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
+ }
+ cooldownUntilMS, err := luaInt64(values[0])
+ if err != nil && values[0] != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
+ }
+ reason, err := luaString(values[1])
+ if err != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
+ }
+ failCount, err := luaInt64(values[2])
+ if err != nil && values[2] != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
+ }
+ if cooldownUntilMS <= now || reason != CooldownReason429 {
+ continue
+ }
+ current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
+ if best == nil ||
+ current.cooldownUntilMS < best.cooldownUntilMS ||
+ (current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
+ best = current
+ }
+ }
+ if best == nil {
+ return false, nil
+ }
+
+ if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
+ }
+ if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
+ return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
+ }
+ return true, nil
+}
+
+func RedisKey(tokenKey string) string {
+ sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
+ digest := hex.EncodeToString(sum[:])
+ return keyPrefix + "{" + digest + "}"
+}
+
+func ActiveTTL() time.Duration {
+ return activeTTL
+}
+
+func StateTTL() time.Duration {
+ return stateTTL
+}
+
+func (s *Store) validate() error {
+ if s == nil || s.client == nil {
+ return ErrStoreUnavailable
+ }
+ return nil
+}
+
+func (s *Store) nextInterval() time.Duration {
+ s.rngMu.Lock()
+ defer s.rngMu.Unlock()
+ if MaxRequestInterval <= MinRequestInterval {
+ return MinRequestInterval
+ }
+ return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
+}
+
+func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return context.WithTimeout(ctx, redisTimeout)
+}
+
+func luaInt64(v any) (int64, error) {
+ switch n := v.(type) {
+ case int64:
+ return n, nil
+ case int:
+ return int64(n), nil
+ case string:
+ return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
+ case []byte:
+ return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
+ default:
+ return 0, fmt.Errorf("unsupported lua numeric type %T", v)
+ }
+}
+
+func luaString(v any) (string, error) {
+ switch s := v.(type) {
+ case string:
+ return s, nil
+ case []byte:
+ return string(s), nil
+ case nil:
+ return "", nil
+ default:
+ return "", fmt.Errorf("unsupported lua string type %T", v)
+ }
+}
diff --git a/backend/internal/pkg/kirocooldown/store_test.go b/backend/internal/pkg/kirocooldown/store_test.go
new file mode 100644
index 00000000..aabae5e2
--- /dev/null
+++ b/backend/internal/pkg/kirocooldown/store_test.go
@@ -0,0 +1,32 @@
+package kirocooldown
+
+import (
+ "context"
+ "testing"
+
+ "github.com/redis/go-redis/v9"
+)
+
+func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
+ store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
+
+ cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
+ if err != nil {
+ t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
+ }
+ if cleared {
+ t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
+ }
+}
+
+func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
+ store := NewStore(nil)
+
+ cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
+ if err == nil {
+ t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
+ }
+ if cleared {
+ t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
+ }
+}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 1c786f50..adcf1874 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -41,6 +41,9 @@ func RegisterAdminRoutes(
// Antigravity OAuth
registerAntigravityOAuthRoutes(admin, h)
+ // Kiro OAuth / IDC
+ registerKiroOAuthRoutes(admin, h)
+
// 代理管理
registerProxyRoutes(admin, h)
@@ -295,6 +298,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Antigravity 默认模型映射
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
+ accounts.GET("/kiro/default-model-mapping", h.Admin.Account.GetKiroDefaultModelMapping)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
@@ -347,6 +351,17 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
}
}
+func registerKiroOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ kiro := admin.Group("/kiro")
+ {
+ kiro.POST("/oauth/auth-url", h.Admin.KiroOAuth.GenerateAuthURL)
+ kiro.POST("/oauth/idc-auth-url", h.Admin.KiroOAuth.GenerateIDCAuthURL)
+ kiro.POST("/oauth/exchange-code", h.Admin.KiroOAuth.ExchangeCode)
+ kiro.POST("/oauth/refresh-token", h.Admin.KiroOAuth.RefreshToken)
+ kiro.POST("/oauth/import-token", h.Admin.KiroOAuth.ImportToken)
+ }
+}
+
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies := admin.Group("/proxies")
{
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index cd06ffa3..375f1cf1 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -48,6 +48,13 @@ type Account struct {
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
+ KiroQuotaState string
+ KiroQuotaReason string
+ KiroQuotaResetAt *time.Time
+ KiroRuntimeState string
+ KiroRuntimeReason string
+ KiroRuntimeResetAt *time.Time
+
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
@@ -164,6 +171,10 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
+func (a *Account) IsKiro() bool {
+ return a.Platform == PlatformKiro
+}
+
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
@@ -478,17 +489,17 @@ func (a *Account) GetModelMapping() map[string]string {
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
- // Antigravity 平台使用默认映射
- if a.Platform == domain.PlatformAntigravity {
- return domain.DefaultAntigravityModelMapping
+ // 部分平台在未显式配置 model_mapping 时仍应使用默认映射,
+ // 以限制可调度/可转发的模型集合。
+ if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
+ return defaults
}
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil
}
if len(rawMapping) == 0 {
- // Antigravity 平台使用默认映射
- if a.Platform == domain.PlatformAntigravity {
- return domain.DefaultAntigravityModelMapping
+ if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
+ return defaults
}
return nil
}
@@ -510,13 +521,23 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
return result
}
- // Antigravity 平台使用默认映射
- if a.Platform == domain.PlatformAntigravity {
- return domain.DefaultAntigravityModelMapping
+ if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
+ return defaults
}
return nil
}
+func defaultModelMappingForPlatform(platform string) map[string]string {
+ switch platform {
+ case domain.PlatformAntigravity:
+ return domain.DefaultAntigravityModelMapping
+ case domain.PlatformKiro:
+ return domain.DefaultKiroModelMapping
+ default:
+ return nil
+ }
+}
+
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
@@ -608,8 +629,8 @@ func resolveRequestedModelInMapping(mapping map[string]string, requestedModel st
return matchWildcardMappingResult(mapping, requestedModel)
}
-// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
-// 如果未配置 mapping,返回 true(允许所有模型)
+// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)。
+// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时也会先回退到默认映射。
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
@@ -622,8 +643,8 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
}
-// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
-// 如果未配置 mapping,返回原始模型名
+// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)。
+// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时返回默认映射结果。
func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel
@@ -725,6 +746,9 @@ func (a *Account) GetBaseURL() string {
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
+ if a.Platform == PlatformKiro {
+ return ""
+ }
return "https://api.anthropic.com"
}
if a.Platform == PlatformAntigravity {
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index 3189a729..0d8a304b 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -180,7 +180,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
if err != nil {
return nil, err
}
- if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
+ if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
}
}
@@ -296,7 +296,7 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if err != nil {
return nil, err
}
- if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
+ if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
}
}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 391e7475..088d35c9 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
@@ -65,6 +66,7 @@ type AccountTestService struct {
accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider
claudeTokenProvider *ClaudeTokenProvider
+ kiroTokenProvider *KiroTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
@@ -76,6 +78,7 @@ func NewAccountTestService(
accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider,
claudeTokenProvider *ClaudeTokenProvider,
+ kiroTokenProvider *KiroTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
@@ -85,6 +88,7 @@ func NewAccountTestService(
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
claudeTokenProvider: claudeTokenProvider,
+ kiroTokenProvider: kiroTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
@@ -191,6 +195,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt)
}
+ if account.IsKiro() && account.Type == AccountTypeOAuth {
+ return s.testKiroAccountConnection(c, account, modelID)
+ }
+
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -239,6 +247,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
}
baseURL := account.GetBaseURL()
+ if baseURL == "" && account.Platform == PlatformKiro {
+ return s.sendErrorAndEnd(c, "Kiro API Key accounts require a Base URL")
+ }
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
@@ -387,6 +398,149 @@ func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Con
return s.processClaudeStream(c, resp.Body)
}
+func (s *AccountTestService) testKiroAccountConnection(c *gin.Context, account *Account, modelID string) error {
+ ctx := c.Request.Context()
+
+ testModelID := strings.TrimSpace(modelID)
+ if testModelID == "" {
+ testModelID = "claude-sonnet-4-6"
+ }
+ if mappedModel := account.GetMappedModel(testModelID); strings.TrimSpace(mappedModel) != "" {
+ testModelID = mappedModel
+ }
+
+ if account.Type != AccountTypeOAuth {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Kiro account type: %s", account.Type))
+ }
+
+ if s.kiroTokenProvider == nil {
+ return s.sendErrorAndEnd(c, "Kiro token provider not configured")
+ }
+
+ accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get Kiro access token: %s", err.Error()))
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ payload, err := createTestPayload(testModelID)
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create test payload")
+ }
+ payloadBytes, _ := json.Marshal(payload)
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
+
+ resp, err := s.executeKiroTestUpstream(ctx, account, payloadBytes, testModelID, accessToken)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ defer func() { _ = resp.Body.Close() }()
+ body, _ := io.ReadAll(resp.Body)
+ return s.sendErrorAndEnd(c, formatKiroTestError(resp.StatusCode, body, testModelID, account))
+ }
+
+ pr, pw := io.Pipe()
+ go func() {
+ defer func() { _ = resp.Body.Close() }()
+ _, streamErr := kiropkg.StreamEventStreamAsAnthropic(ctx, resp.Body, pw, testModelID, estimateKiroInputTokens(payloadBytes))
+ if streamErr != nil {
+ _ = pw.CloseWithError(streamErr)
+ return
+ }
+ _ = pw.Close()
+ }()
+
+ return s.processClaudeStream(c, pr)
+}
+
+func formatKiroTestError(statusCode int, body []byte, requestedModel string, account *Account) string {
+ return fmt.Sprintf("API returned %d: %s", statusCode, string(body))
+}
+
+func (s *AccountTestService) executeKiroTestUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string) (*http.Response, error) {
+ modelID := kiropkg.MapModel(mappedModel)
+ currentToken := token
+ buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
+ if err != nil {
+ return nil, err
+ }
+ payload := buildResult.Payload
+
+ endpoints := buildKiroEndpoints(account)
+ proxyURL := kiroProxyURL(account)
+ tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
+ accountKey := buildKiroAccountKey(account)
+ maxRetries := 2
+ for idx, endpoint := range endpoints {
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
+ if idx+1 < len(endpoints) {
+ _ = resp.Body.Close()
+ break
+ }
+ return resp, nil
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
+ respBody, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ return nil, readErr
+ }
+
+ if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
+ refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
+ if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
+ currentToken = refreshedToken
+ accountKey = buildKiroAccountKey(account)
+ buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
+ if err != nil {
+ return nil, err
+ }
+ payload = buildResult.Payload
+ continue
+ }
+ }
+
+ resetHTTPResponseBody(resp, respBody)
+ return resp, nil
+ }
+
+ if resp.StatusCode == http.StatusBadRequest {
+ respBody, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ return nil, readErr
+ }
+ resetHTTPResponseBody(resp, respBody)
+ return resp, nil
+ }
+
+ return resp, nil
+ }
+ }
+
+ return nil, fmt.Errorf("kiro upstream endpoints exhausted")
+}
+
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
diff --git a/backend/internal/service/account_test_service_kiro_apikey_fallback_test.go b/backend/internal/service/account_test_service_kiro_apikey_fallback_test.go
new file mode 100644
index 00000000..d492efba
--- /dev/null
+++ b/backend/internal/service/account_test_service_kiro_apikey_fallback_test.go
@@ -0,0 +1,84 @@
+//go:build unit
+
+package service
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+)
+
+func TestAccountTestService_KiroAPIKeyUsesGenericAnthropicCompatiblePath(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 19,
+ Name: "kiro-apikey-test",
+ Platform: PlatformKiro,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "base_url": "https://kiro-upstream.example.com",
+ "api_key": "kiro-api-key",
+ "model_mapping": map[string]any{
+ "claude-sonnet-4-6": "claude-sonnet-4-6",
+ },
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"invalid api key"}}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Len(t, upstream.requests, 1)
+
+ req := upstream.requests[0]
+ require.Equal(t, "kiro-upstream.example.com", req.URL.Host)
+ require.Equal(t, "/v1/messages", req.URL.Path)
+ require.Equal(t, "kiro-api-key", req.Header.Get("x-api-key"))
+ require.Empty(t, req.Header.Get("Authorization"))
+ require.Equal(t, claude.APIKeyBetaHeader, req.Header.Get("anthropic-beta"))
+}
+
+func TestAccountTestService_KiroAPIKeyWithoutBaseURLErrors(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 20,
+ Name: "kiro-apikey-missing-base-url",
+ Platform: PlatformKiro,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "kiro-api-key",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: &queuedHTTPUpstream{},
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "Base URL")
+}
diff --git a/backend/internal/service/account_test_service_kiro_test.go b/backend/internal/service/account_test_service_kiro_test.go
new file mode 100644
index 00000000..4039cfa2
--- /dev/null
+++ b/backend/internal/service/account_test_service_kiro_test.go
@@ -0,0 +1,317 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountTestService_KiroUsesKiroUpstreamInsteadOfAnthropic(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 1,
+ Name: "kiro-test",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "kiro-access-token",
+ "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/TESTSOCIAL",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{1: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"Invalid bearer token"}}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
+ httpUpstream: upstream,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "gpt-4o", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Len(t, upstream.requests, 1)
+
+ req := upstream.requests[0]
+ require.Equal(t, "q.us-east-1.amazonaws.com", req.URL.Host)
+ require.Equal(t, "/generateAssistantResponse", req.URL.Path)
+ require.Equal(t, "Bearer kiro-access-token", req.Header.Get("Authorization"))
+ require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
+ require.Empty(t, req.Header.Get("anthropic-version"))
+ require.NotContains(t, req.URL.Host, "api.anthropic.com")
+}
+
+func TestAccountTestService_Kiro429DoesNotFallbackToCodeWhispererEndpoint(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 2,
+ Name: "kiro-fallback",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "kiro-access-token",
+ "api_region": "us-west-2",
+ "region": "us-west-2",
+ "profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTFALLBACK",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{2: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusTooManyRequests, `{"message":"slow down"}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
+ httpUpstream: upstream,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Len(t, upstream.requests, 1)
+
+ require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
+ require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
+ require.Contains(t, err.Error(), "API returned 429")
+}
+
+func TestAccountTestService_KiroIDCWithoutProfileArnOmitsProfileArnAndUsesDefaultRuntimeRegion(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 5,
+ Name: "kiro-idc-default-region",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "kiro-access-token",
+ "auth_method": "idc",
+ "provider": "AWS",
+ "region": "ap-northeast-2",
+ "start_url": "https://d-example.awsapps.com/start",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{5: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
+ httpUpstream: upstream,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Len(t, upstream.requests, 1)
+ require.Equal(t, "q.us-east-1.amazonaws.com", upstream.requests[0].URL.Host)
+ body, readErr := io.ReadAll(upstream.requests[0].Body)
+ require.NoError(t, readErr)
+ require.NotContains(t, string(body), `"profileArn":`)
+}
+
+func TestAccountTestService_KiroInvalidModelErrorPassthrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 6,
+ Name: "kiro-invalid-model",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "kiro-access-token",
+ "profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTINVALIDMODEL",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
+ httpUpstream: upstream,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Equal(t, `API returned 400: {"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`, err.Error())
+}
+
+func TestAccountTestService_KiroInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 7,
+ Name: "kiro-invalid-model-refresh",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "kiro-access-token",
+ "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{7: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
+ httpUpstream: upstream,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "API returned 400")
+ require.Len(t, upstream.requests, 1)
+
+ firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
+ require.NoError(t, readErr)
+ require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
+ require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
+}
+
+func TestAccountTestService_KiroPreferredEndpointIsIgnored(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ account := &Account{
+ ID: 6,
+ Name: "kiro-preferred-endpoint",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "kiro-access-token",
+ "api_region": "us-west-2",
+ "profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/PREFERRED",
+ "preferred_endpoint": "codewhisperer",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
+ },
+ }
+ svc := &AccountTestService{
+ accountRepo: repo,
+ kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
+ httpUpstream: upstream,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
+ require.Error(t, err)
+ require.Len(t, upstream.requests, 1)
+ require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
+ require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
+}
+
+func TestBuildKiroPayloadForAccount_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
+ account := &Account{
+ ID: 3,
+ Name: "kiro-builder-id",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "auth_method": "idc",
+ "provider": "BuilderId",
+ "region": "us-east-1",
+ "client_id": "builder-client-id",
+ },
+ }
+
+ testPayload, err := createTestPayload("claude-sonnet-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(testPayload)
+ require.NoError(t, err)
+
+ kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
+ require.NoError(t, err)
+ require.NotContains(t, string(kiroPayload), `"profileArn":`)
+}
+
+func TestBuildKiroPayloadForAccount_KiroBuilderIDUsesCredentialProfileArn(t *testing.T) {
+ account := &Account{
+ ID: 33,
+ Name: "kiro-builder-id-cached",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "auth_method": "builder-id",
+ "provider": "BuilderId",
+ "region": "us-east-1",
+ "client_id": "builder-client-id",
+ "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED",
+ },
+ }
+
+ testPayload, err := createTestPayload("claude-sonnet-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(testPayload)
+ require.NoError(t, err)
+
+ kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
+ require.NoError(t, err)
+ require.Contains(t, string(kiroPayload), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED"`)
+}
+
+func TestBuildKiroPayloadForAccount_KiroEnterpriseIDCOmitsMissingProfileArn(t *testing.T) {
+ account := &Account{
+ ID: 4,
+ Name: "kiro-enterprise-idc",
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "auth_method": "idc",
+ "provider": "AWS",
+ "region": "us-east-1",
+ "client_id": "enterprise-client-id",
+ "start_url": "https://d-example.awsapps.com/start",
+ },
+ }
+
+ testPayload, err := createTestPayload("claude-sonnet-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(testPayload)
+ require.NoError(t, err)
+
+ kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
+ require.NoError(t, err)
+ require.NotContains(t, string(kiroPayload), `"profileArn":`)
+}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index 68ba8f8c..684495bd 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -103,10 +103,17 @@ type antigravityUsageCache struct {
timestamp time.Time
}
+// kiroUsageCache 缓存 Kiro 额度快照
+type kiroUsageCache struct {
+ usageInfo *UsageInfo
+ timestamp time.Time
+}
+
const (
apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
+ kiroUsageErrorTTL = 1 * time.Minute // Kiro 错误缓存 TTL(可恢复错误)
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
@@ -118,8 +125,10 @@ type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
+ kiroUsageCache sync.Map // accountID -> *kiroUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
+ kiroUsageFlight singleflight.Group // 防止同一 Kiro 账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
}
@@ -176,6 +185,23 @@ type AICredit struct {
MinimumBalance float64 `json:"minimum_balance,omitempty"`
}
+// KiroCreditProgress 表示 Kiro 主额度或 Bonus 的用量进度。
+type KiroCreditProgress struct {
+ CurrentUsage float64 `json:"current_usage"`
+ UsageLimit float64 `json:"usage_limit"`
+ PercentageUsed float64 `json:"percentage_used"`
+ DaysRemaining int `json:"days_remaining,omitempty"`
+ ExpiryDate *time.Time `json:"expiry_date,omitempty"`
+}
+
+// KiroOverageInfo 表示 Kiro 账号的 overage 状态。
+type KiroOverageInfo struct {
+ CurrentOverages float64 `json:"current_overages"`
+ OverageCharges float64 `json:"overage_charges"`
+ CurrencyCode string `json:"currency_code,omitempty"`
+ CurrencySymbol string `json:"currency_symbol,omitempty"`
+}
+
// UsageInfo 账号使用量信息
type UsageInfo struct {
Source string `json:"source,omitempty"` // "passive" or "active"
@@ -203,6 +229,21 @@ type UsageInfo struct {
// Antigravity AI Credits 余额
AICredits []AICredit `json:"ai_credits,omitempty"`
+ // Kiro Credits 额度与 overage 信息
+ KiroSubscriptionName string `json:"kiro_subscription_name,omitempty"`
+ KiroSubscriptionType string `json:"kiro_subscription_type,omitempty"`
+ KiroResetAt *time.Time `json:"kiro_reset_at,omitempty"`
+ KiroOveragesEnabled bool `json:"kiro_overages_enabled,omitempty"`
+ KiroCredit *KiroCreditProgress `json:"kiro_credit,omitempty"`
+ KiroBonus *KiroCreditProgress `json:"kiro_bonus,omitempty"`
+ KiroOverage *KiroOverageInfo `json:"kiro_overage,omitempty"`
+ KiroQuotaState string `json:"kiro_quota_state,omitempty"`
+ KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
+ KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
+ KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
+ KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
+ KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
+
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
@@ -266,6 +307,7 @@ type AccountUsageService struct {
cache *UsageCache
identityCache IdentityCache
tlsFPProfileService *TLSFingerprintProfileService
+ kiroCooldownStore KiroCooldownStore
}
// NewAccountUsageService 创建AccountUsageService实例
@@ -291,6 +333,13 @@ func NewAccountUsageService(
}
}
+func (s *AccountUsageService) SetKiroCooldownStore(store KiroCooldownStore) *AccountUsageService {
+ if s != nil {
+ s.kiroCooldownStore = store
+ }
+ return s
+}
+
// GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
@@ -317,6 +366,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return usage, err
}
+ if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
+ return s.getKiroUsage(ctx, account, "active", false)
+ }
+
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity {
usage, err := s.getAntigravityUsage(ctx, account)
@@ -425,6 +478,13 @@ func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int
return nil, fmt.Errorf("get account failed: %w", err)
}
+ if account.Platform == PlatformKiro {
+ if account.Type != AccountTypeOAuth {
+ return nil, fmt.Errorf("passive usage only supported for Kiro OAuth accounts")
+ }
+ return s.getKiroUsage(ctx, account, "passive", false)
+ }
+
if !account.IsAnthropicOAuthOrSetupToken() {
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
}
diff --git a/backend/internal/service/account_usage_service_kiro_apikey_test.go b/backend/internal/service/account_usage_service_kiro_apikey_test.go
new file mode 100644
index 00000000..d3f9c023
--- /dev/null
+++ b/backend/internal/service/account_usage_service_kiro_apikey_test.go
@@ -0,0 +1,40 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountUsageService_GetUsage_KiroAPIKeyUnsupported(t *testing.T) {
+ account := &Account{
+ ID: 9101,
+ Platform: PlatformKiro,
+ Type: AccountTypeAPIKey,
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
+ svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
+
+ usage, err := svc.GetUsage(context.Background(), account.ID)
+ require.Nil(t, usage)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "does not support usage query")
+}
+
+func TestAccountUsageService_GetPassiveUsage_KiroAPIKeyUnsupported(t *testing.T) {
+ account := &Account{
+ ID: 9102,
+ Platform: PlatformKiro,
+ Type: AccountTypeAPIKey,
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
+ svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
+
+ usage, err := svc.GetPassiveUsage(context.Background(), account.ID)
+ require.Nil(t, usage)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "Kiro OAuth")
+}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index d966c684..919e507f 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -1448,7 +1448,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
// require_oauth_only: 过滤掉 apikey 类型账号
- if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
+ if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
@@ -1728,7 +1728,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
// require_oauth_only: 过滤掉 apikey 类型账号
- if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
+ if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index bb32540b..1b9bb9d5 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -37,6 +37,7 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
+ PlatformKiro = domain.PlatformKiro
)
// Account type constants
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 074013c3..3d7b1c10 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -27,6 +27,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
@@ -56,6 +57,7 @@ const (
defaultModelsListCacheTTL = 15 * time.Second
postUsageBillingTimeout = 15 * time.Second
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
+ defaultKiroStreamKeepalive = 25 * time.Second
)
const (
@@ -70,6 +72,7 @@ const (
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
+type kiroCooldownRecoveryAttemptedKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
@@ -78,6 +81,7 @@ type accountWithLoad struct {
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
+var kiroCooldownRecoveryAttemptedKey = kiroCooldownRecoveryAttemptedKeyType{}
var (
windowCostPrefetchCacheHitTotal atomic.Int64
@@ -554,6 +558,8 @@ type GatewayService struct {
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
+ kiroTokenProvider *KiroTokenProvider
+ kiroCooldownStore KiroCooldownStore
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
userGroupRateResolver *userGroupRateResolver
@@ -592,6 +598,8 @@ func NewGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
+ kiroTokenProvider *KiroTokenProvider,
+ kiroCooldownStore KiroCooldownStore,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
digestStore *DigestSessionStore,
@@ -624,6 +632,8 @@ func NewGatewayService(
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
+ kiroTokenProvider: kiroTokenProvider,
+ kiroCooldownStore: kiroCooldownStore,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
@@ -1969,6 +1979,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if len(candidates) == 0 {
+ if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, useMixed) {
+ retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
+ return s.SelectAccountWithLoadAwareness(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID)
+ }
return nil, ErrNoAvailableAccounts
}
@@ -2348,14 +2362,91 @@ func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool
if account == nil {
return false
}
- return account.IsSchedulable()
+ if !account.IsSchedulable() {
+ return false
+ }
+ return s.isKiroRuntimeSchedulable(context.Background(), account)
}
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
if account == nil {
return false
}
- return account.IsSchedulableForModelWithContext(ctx, requestedModel)
+ if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ return false
+ }
+ return s.isKiroRuntimeSchedulable(ctx, account)
+}
+
+func (s *GatewayService) isKiroRuntimeSchedulable(ctx context.Context, account *Account) bool {
+ if account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth || s == nil || s.kiroCooldownStore == nil {
+ return true
+ }
+ state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(account))
+ if err != nil {
+ return true
+ }
+ return state == nil || !state.Active
+}
+
+func (s *GatewayService) tryRecoverKiroCooldownPool(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) bool {
+ if s == nil || s.kiroCooldownStore == nil || ctx.Value(kiroCooldownRecoveryAttemptedKey) == true {
+ return false
+ }
+ tokenKeys := s.kiroTransientCooldownRecoveryKeys(ctx, accounts, requestedModel, excludedIDs, allowMixedScheduling)
+ if len(tokenKeys) == 0 {
+ return false
+ }
+ cleared, err := s.kiroCooldownStore.ClearEarliestTransientCooldown(ctx, tokenKeys)
+ if err != nil {
+ logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery failed: %v", err)
+ return false
+ }
+ if cleared {
+ logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery cleared one transient cooldown")
+ }
+ return cleared
+}
+
+func (s *GatewayService) kiroTransientCooldownRecoveryKeys(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) []string {
+ tokenKeys := make([]string, 0, len(accounts))
+ eligible := 0
+ for i := range accounts {
+ acc := &accounts[i]
+ if acc == nil || acc.Platform != PlatformKiro || acc.Type != AccountTypeOAuth {
+ if allowMixedScheduling {
+ continue
+ }
+ return nil
+ }
+ if _, excluded := excludedIDs[acc.ID]; excluded {
+ continue
+ }
+ if !acc.IsSchedulable() {
+ continue
+ }
+ if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
+ continue
+ }
+ if !s.isAccountSchedulableForQuota(acc) ||
+ !s.isAccountSchedulableForWindowCost(ctx, acc, false) ||
+ !s.isAccountSchedulableForRPM(ctx, acc, false) {
+ continue
+ }
+ eligible++
+ state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc))
+ if err != nil || state == nil || !state.Active {
+ return nil
+ }
+ if state.Reason != kirocooldown.CooldownReason429 {
+ return nil
+ }
+ tokenKeys = append(tokenKeys, buildKiroAccountKey(acc))
+ }
+ if eligible == 0 || len(tokenKeys) != eligible {
+ return nil
+ }
+ return tokenKeys
}
// isAccountInGroup checks if the account belongs to the specified group.
@@ -3234,6 +3325,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if selected == nil {
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
+ if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, false) {
+ retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
+ return s.selectAccountForModelWithPlatform(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, platform)
+ }
if requestedModel != "" {
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
}
@@ -3613,6 +3708,17 @@ func (s *GatewayService) diagnoseSelectionFailure(
if _, excluded := excludedIDs[acc.ID]; excluded {
return selectionFailureDiagnosis{Category: "excluded"}
}
+ if !acc.IsSchedulable() {
+ return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
+ }
+ if acc.Platform == PlatformKiro && acc.Type == AccountTypeOAuth {
+ if state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc)); err == nil && state != nil && state.Active {
+ return selectionFailureDiagnosis{
+ Category: "unschedulable",
+ Detail: fmt.Sprintf("kiro_runtime_%s remaining=%s", state.Reason, state.Remaining.Truncate(time.Second)),
+ }
+ }
+ }
if !s.isAccountSchedulableForSelection(acc) {
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
@@ -3776,6 +3882,13 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
}
return accessToken, "oauth", nil
}
+ if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth && s.kiroTokenProvider != nil {
+ accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return "", "", err
+ }
+ return accessToken, "oauth", nil
+ }
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token")
@@ -4319,11 +4432,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
- // Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
- if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
- return s.handleWebSearchEmulation(ctx, c, account, parsed)
- }
-
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model
@@ -4347,6 +4455,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return s.forwardBedrock(ctx, c, account, parsed, startTime)
}
+ if account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
+ return s.forwardKiroMessages(ctx, c, account, parsed, startTime)
+ }
+
+ // Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
+ if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
+ return s.handleWebSearchEmulation(ctx, c, account, parsed)
+ }
+
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
@@ -4439,7 +4556,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel := reqModel
mappingSource := ""
- if account.Type == AccountTypeAPIKey {
+ if account.Platform == PlatformKiro {
+ if next := account.GetMappedModel(reqModel); next != "" && next != reqModel {
+ mappedModel = next
+ mappingSource = "account"
+ }
+ } else if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
mappingSource = "account"
@@ -5938,6 +6060,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
+ if baseURL == "" && account.Platform == PlatformKiro {
+ return nil, fmt.Errorf("kiro api key account requires base_url")
+ }
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
@@ -7199,10 +7324,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
- keepaliveInterval := time.Duration(0)
- if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
- keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
- }
+ keepaliveInterval := s.streamKeepaliveIntervalForAccount(account)
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
@@ -8241,6 +8363,9 @@ type recordUsageOpts struct {
// 长上下文计费(仅 Gemini 路径需要)
LongContextThreshold int
LongContextMultiplier float64
+
+ // Kiro 账号在上游返回 auto 等无法定价模型时使用保守计费兜底。
+ IsKiroAccount bool
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -8377,6 +8502,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
}
// 计算费用
+ opts.IsKiroAccount = account != nil && account.Platform == PlatformKiro
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式
@@ -8454,6 +8580,28 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
+const kiroConservativeFallbackBillingModel = "claude-opus-4-6"
+
+func shouldUseKiroConservativeBillingFallback(result *ForwardResult, billingModel string, opts *recordUsageOpts) bool {
+ if result == nil {
+ return false
+ }
+
+ return opts != nil && opts.IsKiroAccount
+}
+
+func (s *GatewayService) calculateKiroConservativeTokenCost(tokens UsageTokens, multiplier float64) *CostBreakdown {
+ if s == nil || s.billingService == nil {
+ return nil
+ }
+ cost, err := s.billingService.CalculateCost(kiroConservativeFallbackBillingModel, tokens, multiplier)
+ if err != nil {
+ logger.LegacyPrintf("service.gateway", "Calculate conservative Kiro fallback cost failed: %v", err)
+ return nil
+ }
+ return cost
+}
+
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -8557,6 +8705,12 @@ func (s *GatewayService) calculateTokenCost(
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
+ if shouldUseKiroConservativeBillingFallback(result, billingModel, opts) {
+ if fallback := s.calculateKiroConservativeTokenCost(tokens, multiplier); fallback != nil {
+ logger.LegacyPrintf("service.gateway", "Using conservative Kiro fallback pricing for model=%s", billingModel)
+ return fallback
+ }
+ }
return &CostBreakdown{ActualCost: 0}
}
return cost
@@ -9444,6 +9598,19 @@ func reconcileCachedTokens(usage map[string]any) bool {
return true
}
+func (s *GatewayService) streamKeepaliveIntervalForAccount(account *Account) time.Duration {
+ if account != nil && account.Platform == PlatformKiro {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.KiroStreamKeepaliveInterval > 0 {
+ return time.Duration(s.cfg.Gateway.KiroStreamKeepaliveInterval) * time.Second
+ }
+ return defaultKiroStreamKeepalive
+ }
+ if s != nil && s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
+ return time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
+ }
+ return 0
+}
+
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
// initDebugGatewayBodyFile 初始化网关调试日志文件。
diff --git a/backend/internal/service/kiro_error_classifier.go b/backend/internal/service/kiro_error_classifier.go
new file mode 100644
index 00000000..15874603
--- /dev/null
+++ b/backend/internal/service/kiro_error_classifier.go
@@ -0,0 +1,222 @@
+package service
+
+import (
+ "errors"
+ "net"
+ "net/http"
+ "strings"
+
+ "github.com/tidwall/gjson"
+)
+
+const (
+ kiroErrorAuthError = "auth_error"
+ kiroErrorMonthlyRequest = "monthly_request_count"
+ kiroErrorProfileError = "profile_error"
+ kiroErrorQuotaExhausted = "quota_exhausted"
+ kiroErrorOverageExhausted = "overage_exhausted"
+ kiroErrorRateLimited = "rate_limited"
+ kiroErrorSuspended = "suspended"
+ kiroErrorUsageForbidden = "usage_forbidden"
+ kiroErrorUpstreamTransient = "upstream_transient"
+ kiroErrorBadRequestSchema = "bad_request_schema"
+ kiroErrorBadRequestToolPairing = "bad_request_tool_pairing"
+ kiroErrorBadRequestInvalidModel = "bad_request_invalid_model"
+ kiroErrorBadRequestAuth = "bad_request_auth"
+ kiroErrorBadRequestQuota = "bad_request_quota"
+ kiroErrorBadRequestUnknown = "bad_request_unknown"
+ kiroErrorRefreshTokenInvalid = "refresh_token_invalid"
+
+ kiroQuotaStateNormal = "normal"
+ kiroQuotaStateOverageActive = "overage_active"
+ kiroQuotaStateCreditsExhausted = "credits_exhausted"
+ kiroQuotaStateOverageExhausted = "overage_exhausted"
+)
+
+type kiroErrorClassification struct {
+ Category string
+ StatusCode int
+ Message string
+}
+
+func classifyKiroHTTPError(statusCode int, body string) kiroErrorClassification {
+ trimmed := strings.TrimSpace(body)
+ lower := strings.ToLower(trimmed)
+
+ switch {
+ case statusCode == http.StatusUnauthorized:
+ return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
+ case statusCode == http.StatusPaymentRequired && looksLikeKiroMonthlyRequestCountError(trimmed):
+ return kiroErrorClassification{Category: kiroErrorMonthlyRequest, StatusCode: statusCode, Message: trimmed}
+ case statusCode == http.StatusForbidden && isKiroSuspendedBody([]byte(trimmed)):
+ return kiroErrorClassification{Category: kiroErrorSuspended, StatusCode: statusCode, Message: trimmed}
+ case looksLikeKiroProfileError(lower):
+ return kiroErrorClassification{Category: kiroErrorProfileError, StatusCode: statusCode, Message: trimmed}
+ case statusCode == http.StatusBadRequest:
+ return classifyKiroBadRequest(trimmed, lower)
+ case statusCode == http.StatusForbidden && isKiroTokenErrorBody([]byte(trimmed)):
+ return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
+ case looksLikeKiroOverageExhaustedError(lower):
+ return kiroErrorClassification{Category: kiroErrorOverageExhausted, StatusCode: statusCode, Message: trimmed}
+ case looksLikeKiroQuotaExhaustedError(lower):
+ return kiroErrorClassification{Category: kiroErrorQuotaExhausted, StatusCode: statusCode, Message: trimmed}
+ case statusCode == http.StatusTooManyRequests:
+ return kiroErrorClassification{Category: kiroErrorRateLimited, StatusCode: statusCode, Message: trimmed}
+ case statusCode == http.StatusForbidden:
+ return kiroErrorClassification{Category: kiroErrorUsageForbidden, StatusCode: statusCode, Message: trimmed}
+ case statusCode >= 500:
+ return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
+ default:
+ return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
+ }
+}
+
+func classifyKiroError(err error) kiroErrorClassification {
+ if err == nil {
+ return kiroErrorClassification{}
+ }
+
+ var httpErr *kiroUsageHTTPError
+ if errors.As(err, &httpErr) && httpErr != nil {
+ return classifyKiroHTTPError(httpErr.StatusCode, httpErr.Body)
+ }
+
+ errStr := strings.TrimSpace(err.Error())
+ lower := strings.ToLower(errStr)
+ switch {
+ case looksLikeKiroInvalidGrantError(lower):
+ return kiroErrorClassification{Category: kiroErrorRefreshTokenInvalid, Message: errStr}
+ case looksLikeKiroMonthlyRequestCountError(errStr):
+ return kiroErrorClassification{Category: kiroErrorMonthlyRequest, Message: errStr}
+ case looksLikeKiroProfileError(lower):
+ return kiroErrorClassification{Category: kiroErrorProfileError, Message: errStr}
+ case looksLikeKiroOverageExhaustedError(lower):
+ return kiroErrorClassification{Category: kiroErrorOverageExhausted, Message: errStr}
+ case looksLikeKiroQuotaExhaustedError(lower):
+ return kiroErrorClassification{Category: kiroErrorQuotaExhausted, Message: errStr}
+ case strings.Contains(lower, "context deadline exceeded"),
+ strings.Contains(lower, "timeout"),
+ isNetErr(err):
+ return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
+ default:
+ return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
+ }
+}
+
+func classifyKiroBadRequest(trimmed, lower string) kiroErrorClassification {
+ switch {
+ case looksLikeKiroBadRequestSchemaError(lower):
+ return kiroErrorClassification{Category: kiroErrorBadRequestSchema, StatusCode: http.StatusBadRequest, Message: trimmed}
+ case looksLikeKiroBadRequestToolPairingError(lower):
+ return kiroErrorClassification{Category: kiroErrorBadRequestToolPairing, StatusCode: http.StatusBadRequest, Message: trimmed}
+ case looksLikeKiroBadRequestInvalidModelError(lower):
+ return kiroErrorClassification{Category: kiroErrorBadRequestInvalidModel, StatusCode: http.StatusBadRequest, Message: trimmed}
+ case looksLikeKiroInvalidGrantError(lower) || looksLikeKiroBadRequestAuthError(lower):
+ return kiroErrorClassification{Category: kiroErrorBadRequestAuth, StatusCode: http.StatusBadRequest, Message: trimmed}
+ case looksLikeKiroQuotaExhaustedError(lower) || looksLikeKiroMonthlyRequestCountError(trimmed):
+ return kiroErrorClassification{Category: kiroErrorBadRequestQuota, StatusCode: http.StatusBadRequest, Message: trimmed}
+ default:
+ return kiroErrorClassification{Category: kiroErrorBadRequestUnknown, StatusCode: http.StatusBadRequest, Message: trimmed}
+ }
+}
+
+func looksLikeKiroBadRequestSchemaError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return strings.Contains(lower, "schema") ||
+ strings.Contains(lower, "inputschema") ||
+ strings.Contains(lower, "improperly formed request") ||
+ strings.Contains(lower, "additionalproperties") ||
+ (strings.Contains(lower, "properties") && strings.Contains(lower, "required"))
+}
+
+func looksLikeKiroBadRequestToolPairingError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return strings.Contains(lower, "tool_use") ||
+ strings.Contains(lower, "tool_result") ||
+ strings.Contains(lower, "tooluseid") ||
+ strings.Contains(lower, "toolresults") ||
+ strings.Contains(lower, "must be paired") ||
+ strings.Contains(lower, "missing tool result")
+}
+
+func looksLikeKiroBadRequestInvalidModelError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return strings.Contains(lower, "invalid model") ||
+ strings.Contains(lower, "invalid_model_id") ||
+ strings.Contains(lower, "model not supported") ||
+ strings.Contains(lower, "unsupportedmodel") ||
+ strings.Contains(lower, "modelid")
+}
+
+func looksLikeKiroBadRequestAuthError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return strings.Contains(lower, "invalid token") ||
+ strings.Contains(lower, "expired token") ||
+ strings.Contains(lower, "access token") ||
+ strings.Contains(lower, "refresh token")
+}
+
+func looksLikeKiroInvalidGrantError(lower string) bool {
+ return strings.Contains(lower, "invalid_grant")
+}
+
+func looksLikeKiroMonthlyRequestCountError(body string) bool {
+ trimmed := strings.TrimSpace(body)
+ if trimmed == "" {
+ return false
+ }
+ if strings.Contains(trimmed, "MONTHLY_REQUEST_COUNT") {
+ return true
+ }
+ if !gjson.Valid(trimmed) {
+ return false
+ }
+ return gjson.Get(trimmed, "reason").String() == "MONTHLY_REQUEST_COUNT" ||
+ gjson.Get(trimmed, "error.reason").String() == "MONTHLY_REQUEST_COUNT"
+}
+
+func looksLikeKiroProfileError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return (strings.Contains(lower, "profilearn") && strings.Contains(lower, "required")) ||
+ (strings.Contains(lower, "profile arn") && strings.Contains(lower, "required")) ||
+ (strings.Contains(lower, "profile") && strings.Contains(lower, "not found")) ||
+ (strings.Contains(lower, "invalid profile")) ||
+ (strings.Contains(lower, "listavailableprofiles"))
+}
+
+func looksLikeKiroQuotaExhaustedError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return (strings.Contains(lower, "credit") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "depleted"))) ||
+ (strings.Contains(lower, "quota") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "exceeded") || strings.Contains(lower, "depleted"))) ||
+ (strings.Contains(lower, "usage limit") && (strings.Contains(lower, "reached") || strings.Contains(lower, "exceeded"))) ||
+ (strings.Contains(lower, "resource has been exhausted"))
+}
+
+func looksLikeKiroOverageExhaustedError(lower string) bool {
+ if lower == "" {
+ return false
+ }
+ return strings.Contains(lower, "overage") &&
+ (strings.Contains(lower, "exhaust") ||
+ strings.Contains(lower, "disabled") ||
+ strings.Contains(lower, "not enabled") ||
+ strings.Contains(lower, "not allowed") ||
+ strings.Contains(lower, "limit"))
+}
+
+func isNetErr(err error) bool {
+ var netErr net.Error
+ return errors.As(err, &netErr)
+}
diff --git a/backend/internal/service/kiro_error_classifier_test.go b/backend/internal/service/kiro_error_classifier_test.go
new file mode 100644
index 00000000..5cb12b08
--- /dev/null
+++ b/backend/internal/service/kiro_error_classifier_test.go
@@ -0,0 +1,66 @@
+package service
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestClassifyKiroHTTPErrorBadRequestCategories(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ want string
+ }{
+ {
+ name: "schema",
+ body: `{"message":"Improperly formed request: inputSchema.properties must be an object"}`,
+ want: kiroErrorBadRequestSchema,
+ },
+ {
+ name: "tool pairing",
+ body: `{"message":"tool_use must be paired with a matching tool_result"}`,
+ want: kiroErrorBadRequestToolPairing,
+ },
+ {
+ name: "invalid model id",
+ body: `{"message":"invalid modelId: model not supported"}`,
+ want: kiroErrorBadRequestInvalidModel,
+ },
+ {
+ name: "invalid model upstream",
+ body: `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
+ want: kiroErrorBadRequestInvalidModel,
+ },
+ {
+ name: "invalid model reason",
+ body: `{"message":"model route unavailable","reason":"INVALID_MODEL_ID"}`,
+ want: kiroErrorBadRequestInvalidModel,
+ },
+ {
+ name: "auth",
+ body: `{"error":"invalid_grant","message":"Invalid refresh token provided"}`,
+ want: kiroErrorBadRequestAuth,
+ },
+ {
+ name: "quota",
+ body: `{"message":"resource has been exhausted"}`,
+ want: kiroErrorBadRequestQuota,
+ },
+ {
+ name: "unknown",
+ body: `{"message":"bad request"}`,
+ want: kiroErrorBadRequestUnknown,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ classification := classifyKiroHTTPError(http.StatusBadRequest, tt.body)
+ require.Equal(t, tt.want, classification.Category)
+ require.Equal(t, http.StatusBadRequest, classification.StatusCode)
+ require.Equal(t, tt.body, classification.Message)
+ })
+ }
+}
diff --git a/backend/internal/service/kiro_http_helpers.go b/backend/internal/service/kiro_http_helpers.go
new file mode 100644
index 00000000..351eb17a
--- /dev/null
+++ b/backend/internal/service/kiro_http_helpers.go
@@ -0,0 +1,180 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+ "github.com/google/uuid"
+)
+
+func buildKiroAccountKey(account *Account) string {
+ if account == nil {
+ return ""
+ }
+ return kiropkg.BuildAccountKey(
+ account.GetCredential("client_id"),
+ account.GetCredential("client_id_hash"),
+ account.GetCredential("refresh_token"),
+ account.GetCredential("profile_arn"),
+ account.ID,
+ )
+}
+
+func buildKiroMachineID(account *Account) string {
+ if account == nil {
+ return kiropkg.BuildMachineID("", "", "account:nil")
+ }
+ for _, key := range []string{"machine_id", "machineId"} {
+ if machineID, ok := kiropkg.NormalizeMachineID(account.GetCredential(key)); ok {
+ return machineID
+ }
+ }
+ fallbackKey := buildKiroMachineIDFallbackKey(account)
+ if account.Type == AccountTypeAPIKey {
+ return kiropkg.BuildMachineID("", firstKiroCredential(account, "kiro_api_key", "kiroApiKey", "api_key"), fallbackKey)
+ }
+ return kiropkg.BuildMachineID(account.GetCredential("refresh_token"), "", fallbackKey)
+}
+
+func firstKiroCredential(account *Account, keys ...string) string {
+ if account == nil {
+ return ""
+ }
+ for _, key := range keys {
+ if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
+ return value
+ }
+ }
+ return ""
+}
+
+func buildKiroMachineIDFallbackKey(account *Account) string {
+ if account == nil {
+ return "account:nil"
+ }
+ if account.ID > 0 {
+ return fmt.Sprintf("account:%d", account.ID)
+ }
+ for _, key := range []string{"client_id", "profile_arn"} {
+ if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
+ return key + ":" + value
+ }
+ }
+ if name := strings.TrimSpace(account.Name); name != "" {
+ return "name:" + name
+ }
+ return "account:unknown"
+}
+
+func buildKiroRequestID(resp *http.Response) string {
+ if resp == nil {
+ return ""
+ }
+ if requestID := strings.TrimSpace(resp.Header.Get("x-request-id")); requestID != "" {
+ return requestID
+ }
+ if requestID := strings.TrimSpace(resp.Header.Get("x-amzn-requestid")); requestID != "" {
+ return requestID
+ }
+ return strings.TrimSpace(resp.Header.Get("x-amz-request-id"))
+}
+
+func isKiroInvalidModelIDBody(respBody []byte) bool {
+ var payload struct {
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+ Error struct {
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+ } `json:"error"`
+ }
+ if json.Unmarshal(respBody, &payload) != nil {
+ return looksLikeKiroBadRequestInvalidModelError(strings.ToLower(string(respBody)))
+ }
+ return strings.EqualFold(strings.TrimSpace(payload.Reason), "INVALID_MODEL_ID") ||
+ strings.EqualFold(strings.TrimSpace(payload.Error.Reason), "INVALID_MODEL_ID") ||
+ looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Message)) ||
+ looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Error.Message))
+}
+
+func isKiroSuspendedBody(respBody []byte) bool {
+ body := string(respBody)
+ return strings.Contains(body, "SUSPENDED") || strings.Contains(body, "TEMPORARILY_SUSPENDED")
+}
+
+func isKiroTokenErrorBody(respBody []byte) bool {
+ lower := strings.ToLower(string(respBody))
+ return strings.Contains(lower, "token") ||
+ strings.Contains(lower, "expired") ||
+ strings.Contains(lower, "invalid") ||
+ strings.Contains(lower, "unauthorized")
+}
+
+func kiroProxyURL(account *Account) string {
+ if account != nil && account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+}
+
+func kiroAPIRegion(account *Account) string {
+ if account == nil {
+ return kiroDefaultRegion
+ }
+ region := strings.TrimSpace(account.GetCredential("api_region"))
+ if region == "" {
+ region = kiroDefaultRegion
+ }
+ return region
+}
+
+func applyKiroConditionalHeaders(req *http.Request, account *Account) {
+ if req == nil || account == nil {
+ return
+ }
+ if strings.EqualFold(strings.TrimSpace(account.GetCredential("auth_method")), "external_idp") {
+ req.Header.Set("TokenType", "EXTERNAL_IDP")
+ }
+ if strings.EqualFold(strings.TrimSpace(account.GetCredential("provider")), "Internal") {
+ req.Header.Set("redirect-for-internal", "true")
+ }
+}
+
+func resolveKiroPayloadProfileArn(account *Account) string {
+ if account == nil {
+ return ""
+ }
+ return strings.TrimSpace(account.GetCredential("profile_arn"))
+}
+
+func newKiroJSONRequest(ctx context.Context, endpointURL string, payload []byte, token, accountKey, machineID, amzTarget string, account *Account) (*http.Request, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "*/*")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
+ req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
+ req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
+ req.Header.Set("x-amzn-codewhisperer-optout", "true")
+ req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
+ req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
+ if amzTarget != "" {
+ req.Header.Set("X-Amz-Target", amzTarget)
+ }
+ if account != nil {
+ profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
+ if profileArn != "" {
+ req.Header.Set("x-amzn-kiro-profile-arn", profileArn)
+ }
+ }
+ applyKiroConditionalHeaders(req, account)
+ return req, nil
+}
diff --git a/backend/internal/service/kiro_http_helpers_test.go b/backend/internal/service/kiro_http_helpers_test.go
new file mode 100644
index 00000000..8bd2c508
--- /dev/null
+++ b/backend/internal/service/kiro_http_helpers_test.go
@@ -0,0 +1,208 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "net/http"
+ "strings"
+ "testing"
+
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
+ accountA := &Account{
+ ID: 99,
+ Credentials: map[string]any{
+ "access_token": "token-a",
+ },
+ }
+ accountB := &Account{
+ ID: 99,
+ Credentials: map[string]any{
+ "access_token": "token-b",
+ },
+ }
+
+ require.Equal(t, buildKiroAccountKey(accountA), buildKiroAccountKey(accountB))
+}
+
+func TestBuildKiroMachineIDPrefersExplicitCredential(t *testing.T) {
+ account := &Account{
+ ID: 101,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "machineId": "2582956e-cc88-4669-b546-07adbffcb894",
+ "refresh_token": "refresh-token",
+ },
+ }
+
+ require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", buildKiroMachineID(account))
+}
+
+func TestBuildKiroMachineIDDerivesFromRefreshToken(t *testing.T) {
+ account := &Account{
+ ID: 102,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "refresh-token",
+ },
+ }
+
+ require.Equal(t, kiropkg.BuildMachineID("refresh-token", "", "account:102"), buildKiroMachineID(account))
+}
+
+func TestBuildKiroMachineIDDerivesFromAPIKeyAccount(t *testing.T) {
+ account := &Account{
+ ID: 103,
+ Platform: PlatformKiro,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "kiroApiKey": "kiro-api-key",
+ },
+ }
+
+ require.Equal(t, kiropkg.BuildMachineID("", "kiro-api-key", "account:103"), buildKiroMachineID(account))
+}
+
+func TestNewKiroJSONRequestAddsConditionalHeaders(t *testing.T) {
+ account := &Account{
+ Credentials: map[string]any{
+ "auth_method": "external_idp",
+ "provider": "Internal",
+ "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER",
+ },
+ }
+
+ req, err := newKiroJSONRequest(
+ context.Background(),
+ "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
+ []byte(`{"ok":true}`),
+ "access-token",
+ "account-key",
+ buildKiroMachineID(account),
+ "",
+ account,
+ )
+ require.NoError(t, err)
+ require.Equal(t, "EXTERNAL_IDP", req.Header.Get("TokenType"))
+ require.Equal(t, "true", req.Header.Get("redirect-for-internal"))
+ require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER", req.Header.Get("x-amzn-kiro-profile-arn"))
+ require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
+ require.Equal(t, "true", req.Header.Get("x-amzn-codewhisperer-optout"))
+ require.Contains(t, req.Header.Get("User-Agent"), "aws-sdk-js/1.0.34")
+ require.Contains(t, req.Header.Get("User-Agent"), "md/nodejs#22.22.0")
+ require.Contains(t, req.Header.Get("User-Agent"), buildKiroMachineID(account))
+ require.Contains(t, req.Header.Get("X-Amz-User-Agent"), buildKiroMachineID(account))
+ require.True(t, strings.Contains(req.Header.Get("User-Agent"), "api/codewhispererstreaming#1.0.34"))
+ require.Empty(t, req.Header.Get("Anthropic-Beta"))
+}
+
+func TestIsKiroInvalidModelIDBodyRecognizesKnownForms(t *testing.T) {
+ tests := []string{
+ `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`,
+ `{"message":"Invalid model. Please select a different model to continue."}`,
+ `API Error: 400 {"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
+ }
+
+ for _, body := range tests {
+ require.True(t, isKiroInvalidModelIDBody([]byte(body)), body)
+ }
+}
+
+func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
+ account := &Account{
+ ID: 7,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/test",
+ },
+ }
+ body := []byte(`{
+ "model":"claude-sonnet-4-6",
+ "messages":[{"role":"user","content":"hello"}]
+ }`)
+ headers := http.Header{}
+ headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
+
+ payload, err := buildKiroPayloadForAccount(
+ context.Background(),
+ account,
+ body,
+ "claude-sonnet-4.6",
+ "kiro-access-token",
+ "claude-sonnet-4-6",
+ headers,
+ )
+ require.NoError(t, err)
+ require.NotContains(t, string(payload), "CHUNKED WRITE PROTOCOL")
+ require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
+}
+
+func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
+ account := &Account{
+ Credentials: map[string]any{
+ "api_region": "eu-west-1",
+ "profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
+ "region": "ap-northeast-1",
+ },
+ }
+
+ require.Equal(t, "eu-west-1", kiroAPIRegion(account))
+}
+
+func TestKiroAPIRegionIgnoresProfileARNRegionFallback(t *testing.T) {
+ account := &Account{
+ Credentials: map[string]any{
+ "profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
+ },
+ }
+
+ require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
+}
+
+func TestKiroAPIRegionIgnoresOIDCRegionFallback(t *testing.T) {
+ account := &Account{
+ Credentials: map[string]any{
+ "region": "ap-northeast-2",
+ },
+ }
+
+ require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
+}
+
+func TestBuildKiroEndpointsUsesOnlyAmazonQEndpoint(t *testing.T) {
+ account := &Account{
+ Credentials: map[string]any{
+ "api_region": "us-west-2",
+ "preferred_endpoint": "cw",
+ },
+ }
+
+ endpoints := buildKiroEndpoints(account)
+ require.Len(t, endpoints, 1)
+ require.Equal(t, "AmazonQ", endpoints[0].Name)
+ require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
+ require.Empty(t, endpoints[0].AmzTarget)
+}
+
+func TestBuildKiroEndpointsIgnoresPreferredEndpoint(t *testing.T) {
+ for _, preferred := range []string{"codewhisperer", "cw", "unknown"} {
+ account := &Account{
+ Credentials: map[string]any{
+ "api_region": "us-west-2",
+ "preferred_endpoint": preferred,
+ },
+ }
+
+ endpoints := buildKiroEndpoints(account)
+ require.Len(t, endpoints, 1)
+ require.Equal(t, "AmazonQ", endpoints[0].Name)
+ require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
+ }
+}
diff --git a/backend/internal/service/kiro_mapping_fallback_test.go b/backend/internal/service/kiro_mapping_fallback_test.go
new file mode 100644
index 00000000..075ea080
--- /dev/null
+++ b/backend/internal/service/kiro_mapping_fallback_test.go
@@ -0,0 +1,74 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountKiroDefaultMappingRestrictsUnsupportedModels(t *testing.T) {
+ account := &Account{Platform: PlatformKiro}
+
+ require.False(t, account.IsModelSupported("gpt-4o"))
+ require.False(t, account.IsModelSupported("kiro-gpt-4o"))
+ require.False(t, account.IsModelSupported("auto"))
+ require.Equal(t, "claude-sonnet-4.6", account.GetMappedModel("claude-sonnet-4-6"))
+}
+
+func TestGatewayServiceCalculateTokenCost_KiroAutoUsesConservativeFallback(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Default.RateMultiplier = 1.1
+
+ svc := NewGatewayService(
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ cfg,
+ nil,
+ nil,
+ NewBillingService(cfg, nil),
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ result := &ForwardResult{
+ Model: "auto",
+ UpstreamModel: "auto",
+ Usage: ClaudeUsage{
+ InputTokens: 20,
+ OutputTokens: 10,
+ },
+ }
+
+ expected, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
+ InputTokens: 20,
+ OutputTokens: 10,
+ }, 1.1)
+ require.NoError(t, err)
+
+ cost := svc.calculateTokenCost(context.Background(), result, &APIKey{}, "auto", 1.1, &recordUsageOpts{IsKiroAccount: true})
+ require.NotNil(t, cost)
+ require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-12)
+ require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-12)
+}
diff --git a/backend/internal/service/kiro_oauth_service.go b/backend/internal/service/kiro_oauth_service.go
new file mode 100644
index 00000000..1b86639a
--- /dev/null
+++ b/backend/internal/service/kiro_oauth_service.go
@@ -0,0 +1,369 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+)
+
+const (
+ // Kiro desktop social auth uses localhost loopback callbacks from a fixed
+ // allowlist. Use one of the bundled ports from the official client.
+ kiroSocialRedirectURI = "http://localhost:49153"
+ // AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
+ kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
+)
+
+type KiroOAuthService struct {
+ sessionStore *kiropkg.SessionStore
+ proxyRepo ProxyRepository
+}
+
+func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
+ return &KiroOAuthService{
+ sessionStore: kiropkg.NewSessionStore(),
+ proxyRepo: proxyRepo,
+ }
+}
+
+func (s *KiroOAuthService) Stop() {}
+
+type KiroAuthURLResult struct {
+ AuthURL string `json:"auth_url"`
+ SessionID string `json:"session_id"`
+ State string `json:"state"`
+}
+
+type KiroIDCAuthURLResult struct {
+ AuthURL string `json:"auth_url"`
+ SessionID string `json:"session_id"`
+ State string `json:"state"`
+ ClientID string `json:"client_id"`
+ Region string `json:"region"`
+ StartURL string `json:"start_url"`
+}
+
+type KiroTokenInfo struct {
+ AccessToken string `json:"access_token,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ ProfileArn string `json:"profile_arn,omitempty"`
+ ExpiresAt string `json:"expires_at,omitempty"`
+ AuthMethod string `json:"auth_method,omitempty"`
+ Provider string `json:"provider,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ ClientSecret string `json:"client_secret,omitempty"`
+ ClientIDHash string `json:"client_id_hash,omitempty"`
+ Email string `json:"email,omitempty"`
+ StartURL string `json:"start_url,omitempty"`
+ Region string `json:"region,omitempty"`
+}
+
+type KiroGenerateAuthURLInput struct {
+ ProxyID *int64
+ Provider string
+}
+
+type KiroExchangeCodeInput struct {
+ SessionID string
+ State string
+ Code string
+ CallbackPath string
+ LoginOption string
+ ProxyID *int64
+}
+
+type KiroGenerateIDCAuthURLInput struct {
+ ProxyID *int64
+ StartURL string
+ Region string
+}
+
+type KiroRefreshTokenInput struct {
+ RefreshToken string
+ AuthMethod string
+ Provider string
+ ClientID string
+ ClientSecret string
+ StartURL string
+ Region string
+ ProfileArn string
+ ProxyID *int64
+}
+
+type KiroImportTokenInput struct {
+ TokenJSON string
+ DeviceRegistrationJSON string
+}
+
+func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
+ provider := strings.TrimSpace(input.Provider)
+ if provider == "" {
+ provider = string(kiropkg.SocialProviderGoogle)
+ }
+ if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
+ return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
+ }
+ state, err := kiropkg.GenerateState()
+ if err != nil {
+ return nil, fmt.Errorf("generate state failed: %w", err)
+ }
+ codeVerifier, err := kiropkg.GenerateCodeVerifier()
+ if err != nil {
+ return nil, fmt.Errorf("generate code verifier failed: %w", err)
+ }
+ sessionID := kiropkg.GenerateSessionID()
+ proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
+ s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
+ State: state,
+ CodeVerifier: codeVerifier,
+ ProxyURL: proxyURL,
+ CreatedAt: time.Now(),
+ AuthType: "social",
+ Provider: provider,
+ RedirectURI: kiroSocialRedirectURI,
+ })
+ return &KiroAuthURLResult{
+ AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
+ SessionID: sessionID,
+ State: state,
+ }, nil
+}
+
+func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
+ session, ok := s.sessionStore.Get(input.SessionID)
+ if !ok {
+ return nil, fmt.Errorf("session not found or expired")
+ }
+ if strings.TrimSpace(input.State) == "" || input.State != session.State {
+ return nil, fmt.Errorf("state invalid")
+ }
+ proxyURL := session.ProxyURL
+ if input.ProxyID != nil {
+ proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
+ }
+
+ switch session.AuthType {
+ case "social":
+ token, err := kiropkg.CreateSocialToken(
+ ctx,
+ proxyURL,
+ input.Code,
+ session.CodeVerifier,
+ buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
+ )
+ if err != nil {
+ return nil, err
+ }
+ token.Provider = session.Provider
+ s.sessionStore.Delete(input.SessionID)
+ return toKiroTokenInfo(token), nil
+ case "idc":
+ token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
+ if err != nil {
+ return nil, err
+ }
+ s.sessionStore.Delete(input.SessionID)
+ return toKiroTokenInfo(token), nil
+ default:
+ return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
+ }
+}
+
+func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
+ option := strings.ToLower(strings.TrimSpace(loginOption))
+ if option == "" {
+ switch provider {
+ case string(kiropkg.SocialProviderGitHub):
+ option = "github"
+ case string(kiropkg.SocialProviderGoogle):
+ option = "google"
+ }
+ }
+ return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
+}
+
+func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
+ startURL := strings.TrimSpace(input.StartURL)
+ if startURL == "" {
+ startURL = kiropkg.BuilderIDStartURL
+ }
+ region := strings.TrimSpace(input.Region)
+ if region == "" {
+ region = "us-east-1"
+ }
+ state, err := kiropkg.GenerateState()
+ if err != nil {
+ return nil, fmt.Errorf("generate state failed: %w", err)
+ }
+ codeVerifier, err := kiropkg.GenerateCodeVerifier()
+ if err != nil {
+ return nil, fmt.Errorf("generate code verifier failed: %w", err)
+ }
+ proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
+ reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
+ if err != nil {
+ return nil, err
+ }
+ sessionID := kiropkg.GenerateSessionID()
+ s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
+ State: state,
+ CodeVerifier: codeVerifier,
+ ProxyURL: proxyURL,
+ CreatedAt: time.Now(),
+ AuthType: "idc",
+ RedirectURI: kiroIDCRedirectURI,
+ ClientID: reg.ClientID,
+ ClientSecret: reg.ClientSecret,
+ Region: region,
+ StartURL: startURL,
+ })
+ return &KiroIDCAuthURLResult{
+ AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
+ SessionID: sessionID,
+ State: state,
+ ClientID: reg.ClientID,
+ Region: region,
+ StartURL: startURL,
+ }, nil
+}
+
+func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
+ proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
+ authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
+ if authMethod == "" {
+ authMethod = "social"
+ }
+
+ var token *kiropkg.TokenData
+ var err error
+ switch authMethod {
+ case "idc":
+ token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
+ default:
+ token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if token.ProfileArn == "" {
+ token.ProfileArn = input.ProfileArn
+ }
+ if token.ClientID == "" {
+ token.ClientID = input.ClientID
+ }
+ if token.ClientSecret == "" {
+ token.ClientSecret = input.ClientSecret
+ }
+ if token.StartURL == "" {
+ token.StartURL = input.StartURL
+ }
+ if token.Region == "" {
+ token.Region = input.Region
+ }
+ return toKiroTokenInfo(token), nil
+}
+
+func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
+ if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
+ return nil, fmt.Errorf("not a kiro oauth account")
+ }
+ return s.RefreshToken(ctx, &KiroRefreshTokenInput{
+ RefreshToken: account.GetCredential("refresh_token"),
+ AuthMethod: account.GetCredential("auth_method"),
+ Provider: account.GetCredential("provider"),
+ ClientID: account.GetCredential("client_id"),
+ ClientSecret: account.GetCredential("client_secret"),
+ StartURL: account.GetCredential("start_url"),
+ Region: account.GetCredential("region"),
+ ProfileArn: account.GetCredential("profile_arn"),
+ ProxyID: account.ProxyID,
+ })
+}
+
+func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
+ token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
+ if err != nil {
+ return nil, err
+ }
+ return toKiroTokenInfo(token), nil
+}
+
+func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
+ if tokenInfo == nil {
+ return map[string]any{}
+ }
+
+ creds := map[string]any{}
+ if tokenInfo.AccessToken != "" {
+ creds["access_token"] = tokenInfo.AccessToken
+ }
+ if tokenInfo.RefreshToken != "" {
+ creds["refresh_token"] = tokenInfo.RefreshToken
+ }
+ if tokenInfo.ProfileArn != "" {
+ creds["profile_arn"] = tokenInfo.ProfileArn
+ }
+ if tokenInfo.ExpiresAt != "" {
+ creds["expires_at"] = tokenInfo.ExpiresAt
+ }
+ if tokenInfo.AuthMethod != "" {
+ creds["auth_method"] = tokenInfo.AuthMethod
+ }
+ if tokenInfo.Provider != "" {
+ creds["provider"] = tokenInfo.Provider
+ }
+ if tokenInfo.ClientID != "" {
+ creds["client_id"] = tokenInfo.ClientID
+ }
+ if tokenInfo.ClientSecret != "" {
+ creds["client_secret"] = tokenInfo.ClientSecret
+ }
+ if tokenInfo.ClientIDHash != "" {
+ creds["client_id_hash"] = tokenInfo.ClientIDHash
+ }
+ if tokenInfo.Email != "" {
+ creds["email"] = tokenInfo.Email
+ }
+ if tokenInfo.StartURL != "" {
+ creds["start_url"] = tokenInfo.StartURL
+ }
+ if tokenInfo.Region != "" {
+ creds["region"] = tokenInfo.Region
+ }
+
+ return creds
+}
+
+func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
+ if token == nil {
+ return nil
+ }
+ return &KiroTokenInfo{
+ AccessToken: token.AccessToken,
+ RefreshToken: token.RefreshToken,
+ ProfileArn: token.ProfileArn,
+ ExpiresAt: token.ExpiresAt,
+ AuthMethod: token.AuthMethod,
+ Provider: token.Provider,
+ ClientID: token.ClientID,
+ ClientSecret: token.ClientSecret,
+ ClientIDHash: token.ClientIDHash,
+ Email: token.Email,
+ StartURL: token.StartURL,
+ Region: token.Region,
+ }
+}
+
+func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
+ if proxyID == nil || s.proxyRepo == nil {
+ return "", nil
+ }
+ proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
+ if err != nil || proxy == nil {
+ return "", err
+ }
+ return proxy.URL(), nil
+}
diff --git a/backend/internal/service/kiro_oauth_service_test.go b/backend/internal/service/kiro_oauth_service_test.go
new file mode 100644
index 00000000..46358579
--- /dev/null
+++ b/backend/internal/service/kiro_oauth_service_test.go
@@ -0,0 +1,51 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+ "github.com/stretchr/testify/require"
+)
+
+func TestKiroIDCAuthRedirectURIUsesLoopbackIP(t *testing.T) {
+ require.Equal(t, "http://127.0.0.1:9876/oauth/callback", kiroIDCRedirectURI)
+}
+
+func TestKiroSocialAuthRedirectURIUsesLoopbackIP(t *testing.T) {
+ require.Equal(t, "http://localhost:49153", kiroSocialRedirectURI)
+}
+
+func TestBuildKiroSocialExchangeRedirectURIUsesProviderDefault(t *testing.T) {
+ require.Equal(
+ t,
+ "http://localhost:49153/oauth/callback?login_option=github",
+ buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "", ""),
+ )
+}
+
+func TestBuildKiroSocialExchangeRedirectURIPreservesParsedCallbackData(t *testing.T) {
+ require.Equal(
+ t,
+ "http://localhost:49153/signin/callback?login_option=google",
+ buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "/signin/callback", "google"),
+ )
+}
+
+func TestKiroOAuthService_ExchangeCodeRejectsExpiredSession(t *testing.T) {
+ svc := NewKiroOAuthService(nil)
+ svc.sessionStore.Set("expired-session", &kiropkg.AuthSession{
+ State: "expected-state",
+ CreatedAt: time.Now().Add(-11 * time.Minute),
+ })
+
+ _, err := svc.ExchangeCode(context.Background(), &KiroExchangeCodeInput{
+ SessionID: "expired-session",
+ State: "expected-state",
+ Code: "auth-code",
+ })
+ require.EqualError(t, err, "session not found or expired")
+}
diff --git a/backend/internal/service/kiro_runtime.go b/backend/internal/service/kiro_runtime.go
new file mode 100644
index 00000000..78ed30b2
--- /dev/null
+++ b/backend/internal/service/kiro_runtime.go
@@ -0,0 +1,724 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ mathrand "math/rand"
+ "net/http"
+ "strings"
+ "time"
+
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+ "go.uber.org/zap"
+)
+
+type kiroEndpointConfig struct {
+ URL string
+ AmzTarget string
+ Name string
+}
+
+const kiroInvalidModelTempUnschedDuration = time.Minute
+
+const (
+ kiroRetryBaseDelay = 200 * time.Millisecond
+ kiroRetryMaxDelay = 2 * time.Second
+)
+
+var kiroRetrySleep = sleepWithContext
+
+func kiroRetryBackoffDelay(attempt int) time.Duration {
+ if attempt < 0 {
+ attempt = 0
+ }
+ delay := kiroRetryBaseDelay * time.Duration(1< kiroRetryMaxDelay {
+ delay = kiroRetryMaxDelay
+ }
+ jitterMax := delay / 4
+ if jitterMax <= 0 {
+ return delay
+ }
+ return delay + time.Duration(mathrand.Int63n(int64(jitterMax)+1))
+}
+
+func sleepKiroRetry(ctx context.Context, attempt int) error {
+ return kiroRetrySleep(ctx, kiroRetryBackoffDelay(attempt))
+}
+
+func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) {
+ if account == nil || parsed == nil {
+ return nil, fmt.Errorf("kiro forward: missing account or request")
+ }
+
+ originalModel := parsed.Model
+ mappedModel := originalModel
+ if next := account.GetMappedModel(originalModel); next != "" {
+ mappedModel = next
+ }
+ body := parsed.Body
+ if mappedModel != originalModel {
+ body = s.replaceModelInBody(body, mappedModel)
+ }
+ logger.L().Debug("gateway forward_kiro_messages: request prepared",
+ zap.Int64("account_id", account.ID),
+ zap.String("auth_method", strings.TrimSpace(account.GetCredential("auth_method"))),
+ zap.String("requested_model", originalModel),
+ zap.String("mapped_model", mappedModel),
+ zap.Bool("has_profile_arn", strings.TrimSpace(account.GetCredential("profile_arn")) != ""),
+ )
+
+ if s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, body) {
+ parsedForEmulation := *parsed
+ parsedForEmulation.Body = body
+ return s.handleWebSearchEmulation(ctx, c, account, &parsedForEmulation)
+ }
+
+ if parsed.Stream {
+ resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
+ if err != nil {
+ var failoverErr *UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: failoverErr.StatusCode,
+ Kind: "failover",
+ Message: sanitizeUpstreamErrorMessage(err.Error()),
+ })
+ return nil, failoverErr
+ }
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ setOpsUpstreamError(c, 0, safeErr, "")
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: 0,
+ Kind: "request_error",
+ Message: safeErr,
+ })
+ c.JSON(http.StatusBadGateway, gin.H{
+ "type": "error",
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": "Upstream request failed",
+ },
+ })
+ return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode >= 400 {
+ return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
+ }
+ upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
+ streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel, false)
+ if err != nil {
+ return nil, err
+ }
+ if streamResult.usage == nil {
+ streamResult.usage = &ClaudeUsage{}
+ }
+ return &ForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: *streamResult.usage,
+ Model: originalModel,
+ UpstreamModel: upstreamModel,
+ Stream: true,
+ Duration: time.Since(startTime),
+ FirstTokenMs: streamResult.firstTokenMs,
+ ClientDisconnect: streamResult.clientDisconnect,
+ }, nil
+ }
+
+ token, tokenType, err := s.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ if tokenType != "oauth" {
+ return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
+ }
+ if isOnlyWebSearchToolInBody(body) {
+ webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
+ switch {
+ case errors.Is(webSearchErr, errKiroWebSearchFallback):
+ case webSearchErr == nil:
+ upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
+ c.Header("Content-Type", "application/json")
+ if webSearchResult.RequestID != "" {
+ c.Header("x-request-id", webSearchResult.RequestID)
+ }
+ c.Data(http.StatusOK, "application/json", webSearchResult.ResponseBody)
+ return &ForwardResult{
+ RequestID: webSearchResult.RequestID,
+ Usage: webSearchResult.Usage,
+ Model: originalModel,
+ UpstreamModel: upstreamModel,
+ Stream: false,
+ Duration: time.Since(startTime),
+ }, nil
+ default:
+ var httpErr *kiroWebSearchHTTPError
+ if errors.As(webSearchErr, &httpErr) && httpErr.Response != nil {
+ return nil, s.handleKiroHTTPError(ctx, httpErr.Response, c, account, mappedModel, body)
+ }
+ var failoverErr *UpstreamFailoverError
+ if errors.As(webSearchErr, &failoverErr) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: failoverErr.StatusCode,
+ Kind: "failover",
+ Message: sanitizeUpstreamErrorMessage(webSearchErr.Error()),
+ })
+ return nil, failoverErr
+ }
+ safeErr := sanitizeUpstreamErrorMessage(webSearchErr.Error())
+ c.JSON(http.StatusBadGateway, gin.H{
+ "type": "error",
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": "Upstream request failed",
+ },
+ })
+ return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
+ }
+ }
+
+ inputTokens := estimateKiroInputTokens(body)
+ resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
+ if err != nil {
+ var failoverErr *UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: failoverErr.StatusCode,
+ Kind: "failover",
+ Message: sanitizeUpstreamErrorMessage(err.Error()),
+ })
+ return nil, failoverErr
+ }
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ c.JSON(http.StatusBadGateway, gin.H{
+ "type": "error",
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": "Upstream request failed",
+ },
+ })
+ return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode >= 400 {
+ return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
+ }
+
+ parseResult, err := kiropkg.ParseNonStreamingEventStreamWithContext(resp.Body, mappedModel, requestCtx)
+ if err != nil {
+ c.JSON(http.StatusBadGateway, gin.H{
+ "type": "error",
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": "Failed to parse Kiro upstream response",
+ },
+ })
+ return nil, err
+ }
+
+ c.Header("Content-Type", "application/json")
+ if requestID := resp.Header.Get("x-request-id"); requestID != "" {
+ c.Header("x-request-id", requestID)
+ }
+ c.Data(http.StatusOK, "application/json", parseResult.ResponseBody)
+
+ upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
+
+ return &ForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
+ Model: originalModel,
+ UpstreamModel: upstreamModel,
+ Stream: false,
+ Duration: time.Since(startTime),
+ }, nil
+}
+
+func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
+ token, tokenType, err := s.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, 0, err
+ }
+ if tokenType != "oauth" {
+ return nil, 0, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
+ }
+
+ inputTokens := estimateKiroInputTokens(anthropicBody)
+ if isOnlyWebSearchToolInBody(anthropicBody) {
+ pr, pw := io.Pipe()
+ headers := make(http.Header)
+ headers.Set("Content-Type", "text/event-stream")
+ go func() {
+ streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
+ if streamErr != nil {
+ _ = pw.CloseWithError(streamErr)
+ return
+ }
+ _ = pw.Close()
+ }()
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: headers,
+ Body: pr,
+ }, inputTokens, nil
+ }
+
+ resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
+ if err != nil {
+ var failoverErr *UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ return nil, inputTokens, err
+ }
+ return nil, inputTokens, err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return resp, inputTokens, nil
+ }
+
+ pr, pw := io.Pipe()
+ wrappedHeaders := resp.Header.Clone()
+ wrappedHeaders.Set("Content-Type", "text/event-stream")
+ if requestID := buildKiroRequestID(resp); requestID != "" {
+ wrappedHeaders.Set("x-request-id", requestID)
+ }
+
+ go func() {
+ defer func() { _ = resp.Body.Close() }()
+ _, streamErr := kiropkg.StreamEventStreamAsAnthropicWithContext(ctx, resp.Body, pw, mappedModel, inputTokens, requestCtx)
+ if streamErr != nil {
+ _ = pw.CloseWithError(streamErr)
+ return
+ }
+ _ = pw.Close()
+ }()
+
+ return &http.Response{
+ StatusCode: resp.StatusCode,
+ Header: wrappedHeaders,
+ Body: pr,
+ }, inputTokens, nil
+}
+
+func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
+ var requestCtx kiropkg.KiroRequestContext
+ if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
+ if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
+ return nil, requestCtx, failoverErr
+ }
+ return nil, requestCtx, err
+ }
+
+ modelID := kiropkg.MapModel(mappedModel)
+ currentToken := token
+ buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
+ if err != nil {
+ return nil, requestCtx, err
+ }
+ payload := buildResult.Payload
+ requestCtx = buildResult.Context
+
+ endpoints := buildKiroEndpoints(account)
+ proxyURL := kiroProxyURL(account)
+ tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
+ accountKey := buildKiroAccountKey(account)
+ maxRetries := 2
+
+ for idx, endpoint := range endpoints {
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
+ if err != nil {
+ return nil, requestCtx, err
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
+ if err != nil {
+ if attempt < maxRetries {
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, requestCtx, sleepErr
+ }
+ continue
+ }
+ return nil, requestCtx, err
+ }
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ cooldown, err := s.markKiro429(ctx, accountKey)
+ if err != nil {
+ _ = resp.Body.Close()
+ return nil, requestCtx, err
+ }
+ if idx+1 < len(endpoints) {
+ _ = resp.Body.Close()
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, requestCtx, sleepErr
+ }
+ break
+ }
+ resp.Header.Set("x-kiro-cooldown", cooldown.String())
+ return resp, requestCtx, nil
+ }
+
+ if resp.StatusCode == http.StatusRequestTimeout || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
+ if attempt < maxRetries {
+ _ = resp.Body.Close()
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, requestCtx, sleepErr
+ }
+ continue
+ }
+ if idx+1 < len(endpoints) {
+ _ = resp.Body.Close()
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, requestCtx, sleepErr
+ }
+ break
+ }
+ return resp, requestCtx, nil
+ }
+
+ if resp.StatusCode == http.StatusPaymentRequired {
+ respBody, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ return nil, requestCtx, readErr
+ }
+ classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
+ if classification.Category == kiroErrorMonthlyRequest {
+ s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
+ }
+ return nil, requestCtx, &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ ResponseHeaders: resp.Header.Clone(),
+ }
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
+ respBody, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ return nil, requestCtx, readErr
+ }
+
+ if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
+ if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
+ return nil, requestCtx, err
+ }
+ resetHTTPResponseBody(resp, respBody)
+ return resp, requestCtx, nil
+ }
+
+ if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
+ refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
+ if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
+ currentToken = refreshedToken
+ accountKey = buildKiroAccountKey(account)
+ buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
+ if err != nil {
+ return nil, requestCtx, err
+ }
+ payload = buildResult.Payload
+ requestCtx = buildResult.Context
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, requestCtx, sleepErr
+ }
+ continue
+ }
+ if refreshErr != nil && isNonRetryableRefreshError(refreshErr) {
+ resetHTTPResponseBody(resp, respBody)
+ return resp, requestCtx, nil
+ }
+ }
+
+ if classifyKiroHTTPError(resp.StatusCode, string(respBody)).Category == kiroErrorAuthError {
+ s.markKiroAuthTemporarilyUnavailable(ctx, account, resp.StatusCode, string(respBody))
+ }
+
+ resetHTTPResponseBody(resp, respBody)
+ return resp, requestCtx, nil
+ }
+
+ if resp.StatusCode == http.StatusBadRequest {
+ respBody, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ return nil, requestCtx, readErr
+ }
+ classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
+ logKiroBadRequestClassification(classification, account, mappedModel, resp.Header, respBody)
+ resetHTTPResponseBody(resp, respBody)
+ return resp, requestCtx, nil
+ }
+
+ if resp.StatusCode >= 200 && resp.StatusCode < 300 {
+ if err := s.markKiroSuccess(ctx, accountKey); err != nil {
+ _ = resp.Body.Close()
+ return nil, requestCtx, err
+ }
+ }
+ return resp, requestCtx, nil
+ }
+ }
+ return nil, requestCtx, fmt.Errorf("kiro upstream endpoints exhausted")
+}
+
+func buildKiroEndpoints(account *Account) []kiroEndpointConfig {
+ region := kiroAPIRegion(account)
+ return []kiroEndpointConfig{
+ {
+ URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
+ Name: "AmazonQ",
+ },
+ }
+}
+
+func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) ([]byte, error) {
+ result, err := buildKiroPayloadForAccountWithRepo(ctx, nil, account, anthropicBody, modelID, token, requestModel, headers)
+ if err != nil {
+ return nil, err
+ }
+ return result.Payload, nil
+}
+
+func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
+ profileArn := resolveKiroPayloadProfileArn(account)
+ return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
+}
+
+func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
+ if s == nil || s.accountRepo == nil || account == nil {
+ return
+ }
+ until := time.Now().Add(10 * time.Minute)
+ reason := fmt.Sprintf("kiro auth failure (%d): %s", statusCode, strings.TrimSpace(body))
+ _ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason)
+}
+
+func (s *GatewayService) markKiroMonthlyRequestCountRateLimited(ctx context.Context, account *Account, body string) {
+ if s == nil || s.accountRepo == nil || account == nil {
+ return
+ }
+ resetAt := nextKiroMonthlyResetUTC(time.Now())
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
+ logger.L().Warn("kiro monthly request count rate-limit failed",
+ zap.Int64("account_id", account.ID),
+ zap.Time("reset_at", resetAt),
+ zap.Error(err),
+ )
+ return
+ }
+ reason := "kiro monthly request count exhausted (402): MONTHLY_REQUEST_COUNT"
+ if trimmed := strings.TrimSpace(body); trimmed != "" {
+ reason = fmt.Sprintf("%s body=%s", reason, truncateForLog([]byte(trimmed), 512))
+ }
+ logger.L().Warn("kiro monthly request count rate-limited",
+ zap.Int64("account_id", account.ID),
+ zap.Time("reset_at", resetAt),
+ zap.String("reason", reason),
+ )
+}
+
+func nextKiroMonthlyResetUTC(now time.Time) time.Time {
+ utc := now.UTC()
+ year, month, _ := utc.Date()
+ return time.Date(year, month+1, 1, 0, 0, 0, 0, time.UTC)
+}
+
+func resetHTTPResponseBody(resp *http.Response, body []byte) {
+ if resp == nil {
+ return
+ }
+ resp.Body = io.NopCloser(bytes.NewReader(body))
+ resp.ContentLength = int64(len(body))
+}
+
+func estimateKiroInputTokens(body []byte) int {
+ if len(body) == 0 {
+ return 0
+ }
+ if tokens := gjson.GetBytes(body, "metadata.input_tokens").Int(); tokens > 0 {
+ return int(tokens)
+ }
+ tokens := len(body) / 4
+ if tokens == 0 {
+ return 1
+ }
+ return tokens
+}
+
+func kiroUsageToClaude(usage kiropkg.Usage, fallbackInput int) ClaudeUsage {
+ inputTokens := usage.InputTokens
+ if inputTokens == 0 {
+ inputTokens = fallbackInput
+ }
+ return ClaudeUsage{
+ InputTokens: inputTokens,
+ OutputTokens: usage.OutputTokens,
+ CacheReadInputTokens: usage.CacheReadInputTokens,
+ }
+}
+
+func (s *GatewayService) markKiroInvalidModelRateLimited(ctx context.Context, account *Account, mappedModel string) {
+ if s == nil || s.accountRepo == nil || account == nil || account.Type != AccountTypeOAuth {
+ return
+ }
+ resetAt := time.Now().Add(kiroInvalidModelTempUnschedDuration)
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
+ logger.L().Warn("kiro invalid model rate-limit failed",
+ zap.Int64("account_id", account.ID),
+ zap.String("mapped_model", strings.TrimSpace(mappedModel)),
+ zap.Time("reset_at", resetAt),
+ zap.Error(err),
+ )
+ return
+ }
+ logger.L().Warn("kiro invalid model rate-limited",
+ zap.Int64("account_id", account.ID),
+ zap.String("mapped_model", strings.TrimSpace(mappedModel)),
+ zap.Time("reset_at", resetAt),
+ )
+}
+
+func (s *GatewayService) handleKiroHTTPError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, mappedModel string, requestBody []byte) error {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ if upstreamMsg == "" {
+ upstreamMsg = strings.TrimSpace(string(respBody))
+ }
+ classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
+ if resp.StatusCode == http.StatusBadRequest {
+ logKiroBadRequestClassification(classification, account, "", resp.Header, respBody)
+ }
+ if classification.Category == kiroErrorMonthlyRequest {
+ s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
+ }
+ if classification.Category == kiroErrorBadRequestInvalidModel && account != nil && account.Type == AccountTypeOAuth {
+ s.markKiroInvalidModelRateLimited(ctx, account, mappedModel)
+ event := s.buildKiroInvalidModelUpstreamEvent(account, resp, upstreamMsg, mappedModel, requestBody, c)
+ appendOpsUpstreamError(c, event)
+ return &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ ResponseHeaders: resp.Header.Clone(),
+ }
+ }
+
+ if resp.StatusCode == http.StatusPaymentRequired || s.shouldFailoverUpstreamError(resp.StatusCode) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ Kind: "failover",
+ Message: upstreamMsg,
+ })
+ if s.rateLimitService != nil {
+ s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
+ }
+ return &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ ResponseHeaders: resp.Header.Clone(),
+ }
+ }
+
+ setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, "")
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ Kind: "http_error",
+ Message: upstreamMsg,
+ })
+ c.JSON(mapUpstreamStatusCode(resp.StatusCode), gin.H{
+ "type": "error",
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": coalesceKiroErrorMessage(resp.StatusCode, upstreamMsg),
+ },
+ })
+ return fmt.Errorf("kiro upstream error: %d %s", resp.StatusCode, upstreamMsg)
+}
+
+func (s *GatewayService) buildKiroInvalidModelUpstreamEvent(account *Account, resp *http.Response, upstreamMsg, mappedModel string, requestBody []byte, c *gin.Context) OpsUpstreamErrorEvent {
+ _ = s
+ requestedModel := strings.TrimSpace(gjson.GetBytes(requestBody, "model").String())
+ hasTools := gjson.GetBytes(requestBody, "tools").Exists()
+ hasAdaptiveThinking := strings.EqualFold(strings.TrimSpace(gjson.GetBytes(requestBody, "thinking.type").String()), "adaptive")
+ hasContext1MBeta := false
+ if c != nil {
+ hasContext1MBeta = strings.Contains(c.GetHeader("Anthropic-Beta"), "context-1m")
+ }
+ return OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ Kind: "failover",
+ Message: upstreamMsg,
+ RequestedModel: requestedModel,
+ MappedModel: strings.TrimSpace(mappedModel),
+ KiroModelID: kiropkg.MapModel(mappedModel),
+ HasTools: hasTools,
+ HasAdaptiveThinking: hasAdaptiveThinking,
+ HasContext1MBeta: hasContext1MBeta,
+ }
+}
+
+func logKiroBadRequestClassification(classification kiroErrorClassification, account *Account, model string, headers http.Header, body []byte) {
+ if classification.StatusCode != http.StatusBadRequest {
+ return
+ }
+ var accountID int64
+ if account != nil {
+ accountID = account.ID
+ }
+ logger.L().Warn("kiro upstream bad request classified",
+ zap.String("category", classification.Category),
+ zap.Int("status", classification.StatusCode),
+ zap.Int64("account_id", accountID),
+ zap.String("model", strings.TrimSpace(model)),
+ zap.String("request_id", headers.Get("x-request-id")),
+ zap.String("body_excerpt", truncateForLog(body, 512)),
+ )
+}
+
+func coalesceKiroErrorMessage(statusCode int, upstreamMsg string) string {
+ if upstreamMsg != "" {
+ return upstreamMsg
+ }
+ switch statusCode {
+ case http.StatusTooManyRequests:
+ return "Rate limit exceeded"
+ case http.StatusForbidden:
+ return "Access denied"
+ case http.StatusUnauthorized:
+ return "Authentication failed"
+ default:
+ return "Upstream request failed"
+ }
+}
diff --git a/backend/internal/service/kiro_runtime_state.go b/backend/internal/service/kiro_runtime_state.go
new file mode 100644
index 00000000..68e1f745
--- /dev/null
+++ b/backend/internal/service/kiro_runtime_state.go
@@ -0,0 +1,99 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
+)
+
+var errKiroCooldownStoreUnavailable = errors.New("kiro cooldown store unavailable")
+
+type KiroCooldownStore interface {
+ ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error)
+ MarkSuccess(ctx context.Context, tokenKey string) error
+ Mark429(ctx context.Context, tokenKey string) (time.Duration, error)
+ MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error)
+ GetState(ctx context.Context, tokenKey string) (*kirocooldown.State, error)
+ ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error)
+}
+
+func asKiroCooldownFailoverError(err error) *UpstreamFailoverError {
+ if err == nil {
+ return nil
+ }
+ var cooldownErr *kirocooldown.Error
+ if !errors.As(err, &cooldownErr) {
+ return nil
+ }
+ return &UpstreamFailoverError{
+ StatusCode: http.StatusTooManyRequests,
+ ResponseBody: []byte(cooldownErr.Error()),
+ }
+}
+
+func (s *GatewayService) checkAndWaitKiroCooldown(ctx context.Context, tokenKey string) error {
+ if s == nil || s.kiroCooldownStore == nil {
+ return errKiroCooldownStoreUnavailable
+ }
+ waitFor, err := s.kiroCooldownStore.ReserveRequest(ctx, tokenKey)
+ if err != nil {
+ return err
+ }
+ if waitFor <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(waitFor)
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}
+
+func (s *GatewayService) markKiroSuccess(ctx context.Context, tokenKey string) error {
+ if s == nil || s.kiroCooldownStore == nil {
+ return errKiroCooldownStoreUnavailable
+ }
+ return s.kiroCooldownStore.MarkSuccess(ctx, tokenKey)
+}
+
+func (s *GatewayService) markKiro429(ctx context.Context, tokenKey string) (time.Duration, error) {
+ if s == nil || s.kiroCooldownStore == nil {
+ return 0, errKiroCooldownStoreUnavailable
+ }
+ return s.kiroCooldownStore.Mark429(ctx, tokenKey)
+}
+
+func (s *GatewayService) markKiroSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
+ if s == nil || s.kiroCooldownStore == nil {
+ return 0, errKiroCooldownStoreUnavailable
+ }
+ return s.kiroCooldownStore.MarkSuspended(ctx, tokenKey)
+}
+
+func (s *GatewayService) getKiroCooldownState(ctx context.Context, tokenKey string) (*kirocooldown.State, error) {
+ if s == nil || s.kiroCooldownStore == nil {
+ return nil, errKiroCooldownStoreUnavailable
+ }
+ return s.kiroCooldownStore.GetState(ctx, tokenKey)
+}
+
+func kiroRuntimeStateSnapshot(state *kirocooldown.State) (string, string, *time.Time) {
+ if state == nil || !state.Active {
+ return "", "", nil
+ }
+ resetAt := state.CooldownUntil
+ switch state.Reason {
+ case kirocooldown.CooldownReasonSuspended:
+ return "suspended", state.Reason, &resetAt
+ default:
+ return "cooldown", state.Reason, &resetAt
+ }
+}
diff --git a/backend/internal/service/kiro_runtime_state_integration_test.go b/backend/internal/service/kiro_runtime_state_integration_test.go
new file mode 100644
index 00000000..6a3f52a7
--- /dev/null
+++ b/backend/internal/service/kiro_runtime_state_integration_test.go
@@ -0,0 +1,192 @@
+//go:build integration
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/require"
+ tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
+)
+
+const kiroCooldownRedisImageTag = "redis:8.4-alpine"
+
+func TestRedisKiroCooldownStoreSharesCooldownAcrossInstances(t *testing.T) {
+ ctx := context.Background()
+ rdb := startKiroCooldownRedis(t, ctx)
+ storeA := kirocooldown.NewStore(rdb)
+ storeB := kirocooldown.NewStore(rdb)
+
+ cooldown, err := storeA.Mark429(ctx, "token-shared")
+ require.NoError(t, err)
+ require.Equal(t, time.Minute, cooldown)
+
+ wait, err := storeB.ReserveRequest(ctx, "token-shared")
+ require.Zero(t, wait)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), kirocooldown.CooldownReason429)
+
+ require.NoError(t, storeB.MarkSuccess(ctx, "token-shared"))
+
+ wait, err = storeA.ReserveRequest(ctx, "token-shared")
+ require.NoError(t, err)
+ require.GreaterOrEqual(t, wait, 0*time.Second)
+}
+
+func TestRedisKiroCooldownStoreSharesReservationAcrossInstances(t *testing.T) {
+ ctx := context.Background()
+ rdb := startKiroCooldownRedis(t, ctx)
+ storeA := kirocooldown.NewStore(rdb)
+ storeB := kirocooldown.NewStore(rdb)
+
+ wait, err := storeA.ReserveRequest(ctx, "token-rate")
+ require.NoError(t, err)
+ require.Zero(t, wait)
+
+ wait, err = storeB.ReserveRequest(ctx, "token-rate")
+ require.NoError(t, err)
+ require.Greater(t, wait, 0*time.Millisecond)
+ require.LessOrEqual(t, wait, kirocooldown.MaxRequestInterval)
+}
+
+func TestRedisKiroCooldownStoreSharesSuspendedStateAcrossInstances(t *testing.T) {
+ ctx := context.Background()
+ rdb := startKiroCooldownRedis(t, ctx)
+ storeA := kirocooldown.NewStore(rdb)
+ storeB := kirocooldown.NewStore(rdb)
+
+ cooldown, err := storeA.MarkSuspended(ctx, "token-suspended")
+ require.NoError(t, err)
+ require.Equal(t, kirocooldown.LongCooldown, cooldown)
+
+ wait, err := storeB.ReserveRequest(ctx, "token-suspended")
+ require.Zero(t, wait)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), kirocooldown.CooldownReasonSuspended)
+}
+
+func TestRedisKiroCooldownStoreSuspendedResetsFailCount(t *testing.T) {
+ ctx := context.Background()
+ rdb := startKiroCooldownRedis(t, ctx)
+ store := kirocooldown.NewStore(rdb)
+
+ _, err := store.Mark429(ctx, "token-reset")
+ require.NoError(t, err)
+ _, err = store.Mark429(ctx, "token-reset")
+ require.NoError(t, err)
+
+ cooldown, err := store.MarkSuspended(ctx, "token-reset")
+ require.NoError(t, err)
+ require.Equal(t, kirocooldown.LongCooldown, cooldown)
+
+ cooldown, err = store.Mark429(ctx, "token-reset")
+ require.NoError(t, err)
+ require.Equal(t, time.Minute, cooldown)
+}
+
+func TestRedisKiroCooldownStoreReserveDifferentTokenIgnoresOldCooldown(t *testing.T) {
+ ctx := context.Background()
+ rdb := startKiroCooldownRedis(t, ctx)
+ store := kirocooldown.NewStore(rdb)
+
+ _, err := store.Mark429(ctx, "token-old")
+ require.NoError(t, err)
+
+ wait, err := store.ReserveRequest(ctx, "token-new")
+ require.NoError(t, err)
+ require.Zero(t, wait)
+}
+
+func TestRedisKiroCooldownStoreUsesExpectedTTLs(t *testing.T) {
+ ctx := context.Background()
+ rdb := startKiroCooldownRedis(t, ctx)
+ store := kirocooldown.NewStore(rdb)
+
+ _, err := store.ReserveRequest(ctx, "token-ttl-active")
+ require.NoError(t, err)
+ activeTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-active")).Result()
+ require.NoError(t, err)
+ require.Greater(t, activeTTL, 0*time.Second)
+ require.LessOrEqual(t, activeTTL, kirocooldown.ActiveTTL())
+
+ _, err = store.MarkSuspended(ctx, "token-ttl-state")
+ require.NoError(t, err)
+ stateTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-state")).Result()
+ require.NoError(t, err)
+ require.Greater(t, stateTTL, 24*time.Hour)
+ require.LessOrEqual(t, stateTTL, kirocooldown.StateTTL())
+}
+
+func startKiroCooldownRedis(t *testing.T, ctx context.Context) *redis.Client {
+ t.Helper()
+ ensureKiroCooldownDockerAvailable(t)
+
+ redisContainer, err := tcredis.Run(ctx, kiroCooldownRedisImageTag)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ _ = redisContainer.Terminate(ctx)
+ })
+
+ host, err := redisContainer.Host(ctx)
+ require.NoError(t, err)
+ port, err := redisContainer.MappedPort(ctx, "6379/tcp")
+ require.NoError(t, err)
+
+ rdb := redis.NewClient(&redis.Options{
+ Addr: fmt.Sprintf("%s:%d", host, port.Int()),
+ DB: 0,
+ })
+ require.NoError(t, rdb.Ping(ctx).Err())
+ t.Cleanup(func() {
+ _ = rdb.Close()
+ })
+ return rdb
+}
+
+func ensureKiroCooldownDockerAvailable(t *testing.T) {
+ t.Helper()
+ if kiroCooldownDockerAvailable() {
+ return
+ }
+ t.Skip("Docker 未启用,跳过依赖 testcontainers 的 Kiro cooldown 集成测试")
+}
+
+func kiroCooldownDockerAvailable() bool {
+ if os.Getenv("DOCKER_HOST") != "" {
+ return true
+ }
+
+ socketCandidates := []string{
+ "/var/run/docker.sock",
+ filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
+ filepath.Join(kiroCooldownUserHomeDir(), ".docker", "run", "docker.sock"),
+ filepath.Join(kiroCooldownUserHomeDir(), ".docker", "desktop", "docker.sock"),
+ filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
+ }
+
+ for _, socket := range socketCandidates {
+ if socket == "" {
+ continue
+ }
+ if _, err := os.Stat(socket); err == nil {
+ return true
+ }
+ }
+ return false
+}
+
+func kiroCooldownUserHomeDir() string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return ""
+ }
+ return home
+}
diff --git a/backend/internal/service/kiro_runtime_state_test.go b/backend/internal/service/kiro_runtime_state_test.go
new file mode 100644
index 00000000..8eeba068
--- /dev/null
+++ b/backend/internal/service/kiro_runtime_state_test.go
@@ -0,0 +1,583 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type stubKiroCooldownStore struct {
+ reserveWait time.Duration
+ reserveErr error
+ successErr error
+ mark429TTL time.Duration
+ mark429Err error
+ suspendedTTL time.Duration
+ suspendedErr error
+ state *kirocooldown.State
+ stateErr error
+ clearCalled bool
+ clearKeys []string
+ clearResult bool
+ clearErr error
+}
+
+type recordingKiroTempUnschedRepo struct {
+ mockAccountRepoForGemini
+ called bool
+ id int64
+ until time.Time
+ reason string
+ rateCalled bool
+ rateID int64
+ rateLimitReset time.Time
+ rateLimitedCall int
+}
+
+func (r *recordingKiroTempUnschedRepo) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
+ r.called = true
+ r.id = id
+ r.until = until
+ r.reason = reason
+ return nil
+}
+
+func (r *recordingKiroTempUnschedRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
+ r.rateCalled = true
+ r.rateID = id
+ r.rateLimitReset = resetAt
+ r.rateLimitedCall++
+ return nil
+}
+
+type recordingKiroErrorRepo struct {
+ recordingKiroTempUnschedRepo
+ setErrorCalls int
+ errorID int64
+ errorMsg string
+}
+
+func (r *recordingKiroErrorRepo) SetError(_ context.Context, id int64, errorMsg string) error {
+ r.setErrorCalls++
+ r.errorID = id
+ r.errorMsg = errorMsg
+ return nil
+}
+
+func (s *stubKiroCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
+ return s.reserveWait, s.reserveErr
+}
+
+func (s *stubKiroCooldownStore) MarkSuccess(context.Context, string) error {
+ return s.successErr
+}
+
+func (s *stubKiroCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
+ return s.mark429TTL, s.mark429Err
+}
+
+func (s *stubKiroCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
+ return s.suspendedTTL, s.suspendedErr
+}
+
+func (s *stubKiroCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
+ if s.clearCalled && s.clearResult {
+ return nil, nil
+ }
+ return s.state, s.stateErr
+}
+
+func (s *stubKiroCooldownStore) ClearEarliestTransientCooldown(_ context.Context, tokenKeys []string) (bool, error) {
+ s.clearCalled = true
+ s.clearKeys = append([]string(nil), tokenKeys...)
+ return s.clearResult, s.clearErr
+}
+
+func TestCalculateKiro429Cooldown(t *testing.T) {
+ require.Equal(t, time.Minute, kirocooldown.Calculate429Cooldown(0))
+ require.Equal(t, 2*time.Minute, kirocooldown.Calculate429Cooldown(1))
+ require.Equal(t, 4*time.Minute, kirocooldown.Calculate429Cooldown(2))
+ require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(3))
+ require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(10))
+}
+
+func TestGatewayServiceCheckAndWaitKiroCooldownReturnsNilWithoutWait(t *testing.T) {
+ svc := &GatewayService{
+ kiroCooldownStore: &stubKiroCooldownStore{},
+ }
+
+ require.NoError(t, svc.checkAndWaitKiroCooldown(context.Background(), "token1"))
+}
+
+func TestGatewayServiceCheckAndWaitKiroCooldownPropagatesReserveError(t *testing.T) {
+ expected := errors.New("redis unavailable")
+ svc := &GatewayService{
+ kiroCooldownStore: &stubKiroCooldownStore{reserveErr: expected},
+ }
+
+ err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
+ require.ErrorIs(t, err, expected)
+}
+
+func TestGatewayServiceCheckAndWaitKiroCooldownRequiresStore(t *testing.T) {
+ svc := &GatewayService{}
+ err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
+ require.ErrorIs(t, err, errKiroCooldownStoreUnavailable)
+}
+
+func TestGatewayServiceCheckAndWaitKiroCooldownWaitsAndHonorsContext(t *testing.T) {
+ svc := &GatewayService{
+ kiroCooldownStore: &stubKiroCooldownStore{reserveWait: 200 * time.Millisecond},
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+
+ err := svc.checkAndWaitKiroCooldown(ctx, "token1")
+ require.ErrorIs(t, err, context.DeadlineExceeded)
+}
+
+func TestAsKiroCooldownFailoverError(t *testing.T) {
+ err := kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429)
+
+ var cooldownErr *kirocooldown.Error
+ require.ErrorAs(t, err, &cooldownErr)
+
+ failoverErr := asKiroCooldownFailoverError(err)
+ require.NotNil(t, failoverErr)
+ require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
+ require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
+ require.False(t, failoverErr.RetryableOnSameAccount)
+}
+
+func TestAsKiroCooldownFailoverErrorIgnoresNonCooldownErrors(t *testing.T) {
+ require.Nil(t, asKiroCooldownFailoverError(errors.New("redis unavailable")))
+}
+
+func TestGatewayServiceTryRecoverKiroCooldownPoolClearsOnlyTransientCooldown(t *testing.T) {
+ store := &stubKiroCooldownStore{
+ state: &kirocooldown.State{
+ Active: true,
+ Reason: kirocooldown.CooldownReason429,
+ CooldownUntil: time.Now().Add(time.Minute),
+ Remaining: time.Minute,
+ },
+ clearResult: true,
+ }
+ svc := &GatewayService{kiroCooldownStore: store}
+ accounts := []Account{
+ {
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ },
+ }
+
+ recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
+ require.True(t, recovered)
+ require.True(t, store.clearCalled)
+ require.Len(t, store.clearKeys, 1)
+ require.Equal(t, buildKiroAccountKey(&accounts[0]), store.clearKeys[0])
+}
+
+func TestGatewayServiceTryRecoverKiroCooldownPoolSkipsSuspended(t *testing.T) {
+ store := &stubKiroCooldownStore{
+ state: &kirocooldown.State{
+ Active: true,
+ Reason: kirocooldown.CooldownReasonSuspended,
+ CooldownUntil: time.Now().Add(time.Hour),
+ Remaining: time.Hour,
+ },
+ clearResult: true,
+ }
+ svc := &GatewayService{kiroCooldownStore: store}
+ accounts := []Account{
+ {
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ },
+ }
+
+ recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
+ require.False(t, recovered)
+ require.False(t, store.clearCalled)
+}
+
+func TestSelectAccountWithLoadAwarenessRecoversKiroCooldownPool(t *testing.T) {
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ account := Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ }
+ store := &stubKiroCooldownStore{
+ state: &kirocooldown.State{
+ Active: true,
+ Reason: kirocooldown.CooldownReason429,
+ CooldownUntil: time.Now().Add(time.Minute),
+ Remaining: time.Minute,
+ },
+ clearResult: true,
+ }
+ svc := &GatewayService{
+ accountRepo: &mockAccountRepoForGemini{accounts: []Account{account}},
+ concurrencyService: NewConcurrencyService(&mockConcurrencyCache{}),
+ cfg: cfg,
+ kiroCooldownStore: store,
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+ ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformKiro)
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "", nil, "", 0)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, account.ID, result.Account.ID)
+ require.True(t, store.clearCalled)
+ require.Equal(t, []string{buildKiroAccountKey(&account)}, store.clearKeys)
+}
+
+func TestClassifyKiroHTTPErrorMonthlyRequestCount(t *testing.T) {
+ tests := []string{
+ `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
+ `{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}`,
+ `API returned 402: {"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
+ }
+
+ for _, body := range tests {
+ classification := classifyKiroHTTPError(http.StatusPaymentRequired, body)
+ require.Equal(t, kiroErrorMonthlyRequest, classification.Category)
+ }
+}
+
+func TestClassifyKiroHTTPErrorPlain402IsTransient(t *testing.T) {
+ classification := classifyKiroHTTPError(http.StatusPaymentRequired, `{"message":"payment required"}`)
+ require.Equal(t, kiroErrorUpstreamTransient, classification.Category)
+}
+
+func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
+ svc := &GatewayService{
+ kiroCooldownStore: &stubKiroCooldownStore{
+ reserveErr: kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429),
+ },
+ }
+
+ _, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
+ require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
+ require.False(t, failoverErr.RetryableOnSameAccount)
+}
+
+func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
+ },
+ }
+ repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
+ },
+ }
+ svc := &GatewayService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ kiroCooldownStore: &stubKiroCooldownStore{},
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ payload, err := createTestPayload("claude-opus-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
+ require.NoError(t, err)
+ require.Equal(t, http.StatusBadRequest, resp.StatusCode)
+ require.Len(t, upstream.requests, 1)
+
+ firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
+ require.NoError(t, readErr)
+ require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
+ require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
+}
+
+func TestHandleKiroHTTPErrorOAuthInvalidModelRateLimitsAndFailovers(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+ c.Request.Header.Set("Anthropic-Beta", "context-1m-2025-08-07")
+
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Name: "kiro-oauth",
+ }
+ repo := &recordingKiroTempUnschedRepo{}
+ svc := &GatewayService{accountRepo: repo}
+ requestBody := []byte(`{"model":"claude-opus-4-7","tools":[{"name":"search"}],"thinking":{"type":"adaptive"}}`)
+ resp := newJSONResponse(http.StatusBadRequest, `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`)
+ resp.Header.Set("x-request-id", "req-invalid-model")
+
+ err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", requestBody)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "Invalid model")
+ require.False(t, failoverErr.RetryableOnSameAccount)
+
+ require.False(t, repo.called)
+ require.True(t, repo.rateCalled)
+ require.Equal(t, account.ID, repo.rateID)
+ require.WithinDuration(t, time.Now().Add(kiroInvalidModelTempUnschedDuration), repo.rateLimitReset, 5*time.Second)
+
+ rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
+ require.True(t, ok)
+ events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
+ require.True(t, ok)
+ require.Len(t, events, 1)
+ require.Equal(t, PlatformKiro, events[0].Platform)
+ require.Equal(t, account.ID, events[0].AccountID)
+ require.Equal(t, account.Name, events[0].AccountName)
+ require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
+ require.Equal(t, "req-invalid-model", events[0].UpstreamRequestID)
+ require.Equal(t, "failover", events[0].Kind)
+ require.Equal(t, "claude-opus-4-7", events[0].RequestedModel)
+ require.Equal(t, "claude-opus-4.6", events[0].MappedModel)
+ require.Equal(t, "claude-opus-4.6", events[0].KiroModelID)
+ require.True(t, events[0].HasTools)
+ require.True(t, events[0].HasAdaptiveThinking)
+ require.True(t, events[0].HasContext1MBeta)
+}
+
+func TestHandleKiroHTTPErrorAPIKeyInvalidModelDoesNotFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ account := &Account{
+ ID: 43,
+ Platform: PlatformKiro,
+ Type: AccountTypeAPIKey,
+ }
+ repo := &recordingKiroTempUnschedRepo{}
+ svc := &GatewayService{accountRepo: repo}
+ resp := newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model. Please select a different model to continue."}`)
+
+ err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", []byte(`{"model":"claude-opus-4-7"}`))
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.NotErrorAs(t, err, &failoverErr)
+ require.False(t, repo.called)
+ require.False(t, repo.rateCalled)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestNextKiroMonthlyResetUTC(t *testing.T) {
+ tests := []struct {
+ name string
+ now time.Time
+ want time.Time
+ }{
+ {
+ name: "middle of month",
+ now: time.Date(2026, time.April, 27, 10, 30, 45, 123, time.FixedZone("CST", 8*3600)),
+ want: time.Date(2026, time.May, 1, 0, 0, 0, 0, time.UTC),
+ },
+ {
+ name: "december rolls year",
+ now: time.Date(2026, time.December, 31, 23, 59, 59, 0, time.UTC),
+ want: time.Date(2027, time.January, 1, 0, 0, 0, 0, time.UTC),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, nextKiroMonthlyResetUTC(tt.now))
+ })
+ }
+}
+
+func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailovers(t *testing.T) {
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ }
+ repo := &recordingKiroTempUnschedRepo{}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusPaymentRequired, `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`),
+ },
+ }
+ svc := &GatewayService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ kiroCooldownStore: &stubKiroCooldownStore{},
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ payload, err := createTestPayload("claude-sonnet-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ _, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "MONTHLY_REQUEST_COUNT")
+ require.False(t, repo.called)
+ require.True(t, repo.rateCalled)
+ require.Equal(t, account.ID, repo.rateID)
+ require.Equal(t, nextKiroMonthlyResetUTC(time.Now()), repo.rateLimitReset)
+}
+
+func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) {
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ }
+ repo := &recordingKiroTempUnschedRepo{}
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusPaymentRequired, `{"message":"payment required"}`),
+ },
+ }
+ svc := &GatewayService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ kiroCooldownStore: &stubKiroCooldownStore{},
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ }
+
+ payload, err := createTestPayload("claude-sonnet-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ _, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
+ require.False(t, repo.called)
+ require.False(t, repo.rateCalled)
+}
+
+func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedule(t *testing.T) {
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "refresh_token": "old-refresh",
+ },
+ }
+ repo := &recordingKiroErrorRepo{
+ recordingKiroTempUnschedRepo: recordingKiroTempUnschedRepo{
+ mockAccountRepoForGemini: mockAccountRepoForGemini{
+ accountsByID: map[int64]*Account{account.ID: account},
+ },
+ },
+ }
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"message":"token expired"}`),
+ },
+ }
+ provider := NewKiroTokenProvider(repo, nil, nil)
+ provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
+ svc := &GatewayService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ kiroCooldownStore: &stubKiroCooldownStore{},
+ tlsFPProfileService: &TLSFingerprintProfileService{},
+ kiroTokenProvider: provider,
+ }
+
+ payload, err := createTestPayload("claude-sonnet-4-6")
+ require.NoError(t, err)
+ payloadBytes, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
+ require.NoError(t, err)
+ require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Equal(t, account.ID, repo.errorID)
+ require.Contains(t, repo.errorMsg, "invalid_grant")
+ require.False(t, repo.called, "non-retryable refresh errors should not mark temporary unschedulable")
+}
+
+func TestGatewayServiceIsAccountSchedulableForSelectionSkipsActiveKiroCooldown(t *testing.T) {
+ now := time.Now().Add(2 * time.Minute)
+ svc := &GatewayService{
+ kiroCooldownStore: &stubKiroCooldownStore{
+ state: &kirocooldown.State{
+ Active: true,
+ Reason: kirocooldown.CooldownReason429,
+ CooldownUntil: now,
+ Remaining: 2 * time.Minute,
+ },
+ },
+ }
+
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ }
+ require.False(t, svc.isAccountSchedulableForSelection(account))
+}
diff --git a/backend/internal/service/kiro_token_provider.go b/backend/internal/service/kiro_token_provider.go
new file mode 100644
index 00000000..56013f64
--- /dev/null
+++ b/backend/internal/service/kiro_token_provider.go
@@ -0,0 +1,221 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ kiroTokenRefreshSkew = 3 * time.Minute
+ kiroTokenCacheSkew = 5 * time.Minute
+)
+
+type KiroTokenCache = GeminiTokenCache
+
+type kiroAccountTokenRefresher interface {
+ RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error)
+ BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any
+}
+
+type KiroTokenProvider struct {
+ accountRepo AccountRepository
+ tokenCache KiroTokenCache
+ kiroOAuthService kiroAccountTokenRefresher
+ refreshAPI *OAuthRefreshAPI
+ executor OAuthRefreshExecutor
+ refreshPolicy ProviderRefreshPolicy
+}
+
+func NewKiroTokenProvider(
+ accountRepo AccountRepository,
+ tokenCache KiroTokenCache,
+ kiroOAuthService *KiroOAuthService,
+) *KiroTokenProvider {
+ return &KiroTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: tokenCache,
+ kiroOAuthService: kiroOAuthService,
+ refreshPolicy: GeminiProviderRefreshPolicy(),
+ }
+}
+
+func (p *KiroTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
+ p.refreshAPI = api
+ p.executor = executor
+}
+
+func (p *KiroTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
+ p.refreshPolicy = policy
+}
+
+func (p *KiroTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
+ return "", errors.New("not a kiro oauth account")
+ }
+
+ cacheKey := KiroTokenCacheKey(account)
+ if p.tokenCache != nil {
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+ }
+
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= kiroTokenRefreshSkew
+
+ if needsRefresh && p.refreshAPI != nil && p.executor != nil {
+ result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, kiroTokenRefreshSkew)
+ if err != nil {
+ if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
+ return "", err
+ }
+ } else if result.LockHeld {
+ if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
+ if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+ }
+ } else {
+ account = result.Account
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ } else if needsRefresh && p.tokenCache != nil {
+ locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if lockErr == nil && locked {
+ defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
+ }
+ }
+
+ accessToken := account.GetCredential("access_token")
+ if strings.TrimSpace(accessToken) == "" {
+ return "", errors.New("access_token not found in credentials")
+ }
+
+ if p.tokenCache != nil {
+ latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
+ if isStale && latestAccount != nil {
+ accessToken = latestAccount.GetCredential("access_token")
+ if strings.TrimSpace(accessToken) == "" {
+ return "", errors.New("access_token not found after version check")
+ }
+ } else {
+ ttl := 30 * time.Minute
+ if expiresAt != nil {
+ until := time.Until(*expiresAt)
+ switch {
+ case until > kiroTokenCacheSkew:
+ ttl = until - kiroTokenCacheSkew
+ case until > 0:
+ ttl = until
+ default:
+ ttl = time.Minute
+ }
+ }
+ _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
+ }
+ }
+
+ return accessToken, nil
+}
+
+func KiroTokenCacheKey(account *Account) string {
+ if account == nil {
+ return "kiro:account:0"
+ }
+ if clientIDHash := strings.TrimSpace(account.GetCredential("client_id_hash")); clientIDHash != "" {
+ return "kiro:" + clientIDHash
+ }
+ if clientID := strings.TrimSpace(account.GetCredential("client_id")); clientID != "" {
+ return "kiro:client:" + clientID
+ }
+ return "kiro:account:" + strconv.FormatInt(account.ID, 10)
+}
+
+func (p *KiroTokenProvider) ForceRefreshAccessToken(ctx context.Context, account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
+ return "", errors.New("not a kiro oauth account")
+ }
+ if p.kiroOAuthService == nil {
+ return "", errors.New("kiro oauth service is nil")
+ }
+
+ cacheKey := KiroTokenCacheKey(account)
+ lockHeld := false
+ if p.tokenCache != nil {
+ locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if lockErr == nil && locked {
+ lockHeld = true
+ defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
+ }
+ }
+
+ if p.accountRepo != nil {
+ if latestAccount, err := p.accountRepo.GetByID(ctx, account.ID); err == nil && latestAccount != nil {
+ account = latestAccount
+ }
+ }
+
+ tokenInfo, err := p.kiroOAuthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ if !lockHeld {
+ if latestAccount, stale := CheckTokenVersion(ctx, account, p.accountRepo); stale && latestAccount != nil {
+ account = latestAccount
+ if accessToken := strings.TrimSpace(account.GetCredential("access_token")); accessToken != "" {
+ _ = p.cacheAccessToken(ctx, account, accessToken)
+ return accessToken, nil
+ }
+ }
+ }
+ if isNonRetryableRefreshError(err) && p.accountRepo != nil {
+ errorMsg := "Token refresh failed (non-retryable): " + err.Error()
+ _ = p.accountRepo.SetError(ctx, account.ID, errorMsg)
+ }
+ return "", err
+ }
+
+ newCredentials := MergeCredentials(account.Credentials, p.kiroOAuthService.BuildAccountCredentials(tokenInfo))
+ newCredentials["_token_version"] = time.Now().UnixMilli()
+ if err := persistAccountCredentials(ctx, p.accountRepo, account, newCredentials); err != nil {
+ return "", err
+ }
+
+ accessToken := strings.TrimSpace(account.GetCredential("access_token"))
+ if accessToken == "" {
+ accessToken = strings.TrimSpace(tokenInfo.AccessToken)
+ }
+ if accessToken == "" {
+ return "", errors.New("access_token not found after kiro refresh")
+ }
+ if err := p.cacheAccessToken(ctx, account, accessToken); err != nil {
+ return "", err
+ }
+ return accessToken, nil
+}
+
+func (p *KiroTokenProvider) cacheAccessToken(ctx context.Context, account *Account, accessToken string) error {
+ if p.tokenCache == nil || account == nil || strings.TrimSpace(accessToken) == "" {
+ return nil
+ }
+ ttl := 30 * time.Minute
+ if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
+ until := time.Until(*expiresAt)
+ switch {
+ case until > kiroTokenCacheSkew:
+ ttl = until - kiroTokenCacheSkew
+ case until > 0:
+ ttl = until
+ default:
+ ttl = time.Minute
+ }
+ }
+ return p.tokenCache.SetAccessToken(ctx, KiroTokenCacheKey(account), accessToken, ttl)
+}
diff --git a/backend/internal/service/kiro_token_provider_test.go b/backend/internal/service/kiro_token_provider_test.go
new file mode 100644
index 00000000..0f6edb1d
--- /dev/null
+++ b/backend/internal/service/kiro_token_provider_test.go
@@ -0,0 +1,112 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type kiroTokenProviderRepo struct {
+ mockAccountRepoForGemini
+ setErrorCalls int
+ setErrorID int64
+ setErrorMsg string
+}
+
+func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
+ r.setErrorCalls++
+ r.setErrorID = id
+ r.setErrorMsg = errorMsg
+ return nil
+}
+
+type kiroTokenProviderSequenceRepo struct {
+ kiroTokenProviderRepo
+ accounts []*Account
+ reads int
+}
+
+func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
+ if len(r.accounts) == 0 {
+ return nil, errors.New("account not found")
+ }
+ idx := r.reads
+ if idx >= len(r.accounts) {
+ idx = len(r.accounts) - 1
+ }
+ r.reads++
+ return r.accounts[idx], nil
+}
+
+type stubKiroAccountTokenRefresher struct {
+ tokenInfo *KiroTokenInfo
+ err error
+}
+
+func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.tokenInfo, nil
+}
+
+func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
+ if tokenInfo == nil {
+ return nil
+ }
+ return map[string]any{
+ "access_token": tokenInfo.AccessToken,
+ "expires_at": tokenInfo.ExpiresAt,
+ }
+}
+
+func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
+ account := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{"refresh_token": "old-refresh"},
+ }
+ repo := &kiroTokenProviderRepo{
+ mockAccountRepoForGemini: mockAccountRepoForGemini{
+ accountsByID: map[int64]*Account{account.ID: account},
+ },
+ }
+ provider := NewKiroTokenProvider(repo, nil, nil)
+ provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
+
+ token, err := provider.ForceRefreshAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Empty(t, token)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Equal(t, account.ID, repo.setErrorID)
+ require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
+ require.Contains(t, repo.setErrorMsg, "invalid_grant")
+}
+
+func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
+ usedAccount := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{"refresh_token": "old-refresh"},
+ }
+ latestAccount := &Account{
+ ID: 42,
+ Platform: PlatformKiro,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
+ }
+ repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
+ provider := NewKiroTokenProvider(repo, nil, nil)
+ provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
+
+ token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
+ require.NoError(t, err)
+ require.Equal(t, "fresh-access", token)
+ require.Equal(t, 0, repo.setErrorCalls)
+}
diff --git a/backend/internal/service/kiro_token_refresher.go b/backend/internal/service/kiro_token_refresher.go
new file mode 100644
index 00000000..db0673d5
--- /dev/null
+++ b/backend/internal/service/kiro_token_refresher.go
@@ -0,0 +1,47 @@
+package service
+
+import (
+ "context"
+ "time"
+)
+
+const kiroRefreshWindow = 15 * time.Minute
+
+type KiroTokenRefresher struct {
+ kiroOAuthService *KiroOAuthService
+}
+
+func NewKiroTokenRefresher(kiroOAuthService *KiroOAuthService) *KiroTokenRefresher {
+ return &KiroTokenRefresher{
+ kiroOAuthService: kiroOAuthService,
+ }
+}
+
+func (r *KiroTokenRefresher) CacheKey(account *Account) string {
+ return KiroTokenCacheKey(account)
+}
+
+func (r *KiroTokenRefresher) CanRefresh(account *Account) bool {
+ return account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth
+}
+
+func (r *KiroTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
+ if !r.CanRefresh(account) {
+ return false
+ }
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ if expiresAt == nil {
+ return false
+ }
+ return time.Until(*expiresAt) <= kiroRefreshWindow
+}
+
+func (r *KiroTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ tokenInfo, err := r.kiroOAuthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+
+ newCredentials := r.kiroOAuthService.BuildAccountCredentials(tokenInfo)
+ return MergeCredentials(account.Credentials, newCredentials), nil
+}
diff --git a/backend/internal/service/kiro_usage_fetcher.go b/backend/internal/service/kiro_usage_fetcher.go
new file mode 100644
index 00000000..37a39355
--- /dev/null
+++ b/backend/internal/service/kiro_usage_fetcher.go
@@ -0,0 +1,608 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+ "github.com/google/uuid"
+)
+
+const (
+ kiroUsageOrigin = "AI_EDITOR"
+ kiroUsageResourceType = "AGENTIC_REQUEST"
+
+ kiroDefaultRegion = "us-east-1"
+)
+
+var resolveKiroRuntimeEndpoint = kiroRuntimeEndpoint
+
+type kiroUsageLimitsResponse struct {
+ NextDateReset any `json:"nextDateReset"`
+ OverageConfiguration kiroOverageConfiguration `json:"overageConfiguration"`
+ SubscriptionInfo kiroSubscriptionInfo `json:"subscriptionInfo"`
+ UsageBreakdownList []kiroUsageBreakdown `json:"usageBreakdownList"`
+}
+
+type kiroOverageConfiguration struct {
+ OverageStatus string `json:"overageStatus"`
+}
+
+type kiroSubscriptionInfo struct {
+ SubscriptionTitle string `json:"subscriptionTitle"`
+ Type string `json:"type"`
+}
+
+type kiroUsageBreakdown struct {
+ Currency string `json:"currency"`
+ CurrentOverages *float64 `json:"currentOverages"`
+ CurrentOveragesWithPrecision *float64 `json:"currentOveragesWithPrecision"`
+ CurrentUsage *float64 `json:"currentUsage"`
+ CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
+ DisplayName string `json:"displayName"`
+ DisplayNamePlural string `json:"displayNamePlural"`
+ FreeTrialInfo *kiroFreeTrialInfo `json:"freeTrialInfo"`
+ NextDateReset any `json:"nextDateReset"`
+ OverageCharges *float64 `json:"overageCharges"`
+ ResourceType string `json:"resourceType"`
+ UsageLimit *float64 `json:"usageLimit"`
+ UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
+}
+
+type kiroFreeTrialInfo struct {
+ CurrentUsage *float64 `json:"currentUsage"`
+ CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
+ FreeTrialExpiry any `json:"freeTrialExpiry"`
+ FreeTrialStatus string `json:"freeTrialStatus"`
+ UsageLimit *float64 `json:"usageLimit"`
+ UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
+}
+
+type kiroUsageHTTPError struct {
+ StatusCode int
+ Body string
+}
+
+func (e *kiroUsageHTTPError) Error() string {
+ if e == nil {
+ return "kiro usage request failed"
+ }
+ if strings.TrimSpace(e.Body) == "" {
+ return fmt.Sprintf("kiro usage request failed (status %d)", e.StatusCode)
+ }
+ return fmt.Sprintf("kiro usage request failed (status %d): %s", e.StatusCode, e.Body)
+}
+
+func (s *AccountUsageService) getKiroUsage(ctx context.Context, account *Account, source string, forceRefresh bool) (*UsageInfo, error) {
+ now := time.Now()
+ if account == nil {
+ return &UsageInfo{
+ Source: source,
+ UpdatedAt: &now,
+ Error: "account is nil",
+ ErrorCode: errorCodeNetworkError,
+ }, nil
+ }
+ if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
+ return &UsageInfo{
+ Source: source,
+ UpdatedAt: &now,
+ }, nil
+ }
+
+ cached, hasCached := s.getCachedKiroUsage(account.ID)
+ if hasCached && (cached.ErrorCode != "" || cached.Error != "") {
+ cached.Source = source
+ s.attachKiroRuntimeState(ctx, account, cached)
+ return cached, nil
+ }
+ if !forceRefresh && hasCached {
+ cached.Source = source
+ s.attachKiroRuntimeState(ctx, account, cached)
+ return cached, nil
+ }
+
+ flightKey := fmt.Sprintf("kiro-usage:%d", account.ID)
+ result, fetchErr, _ := s.cache.kiroUsageFlight.Do(flightKey, func() (any, error) {
+ if !forceRefresh {
+ if usage, ok := s.getCachedKiroUsage(account.ID); ok {
+ return usage, nil
+ }
+ }
+ usage, err := s.fetchAndCacheKiroUsage(ctx, account, source)
+ if err != nil {
+ return nil, err
+ }
+ return usage, nil
+ })
+ if fetchErr == nil {
+ if usage, ok := result.(*UsageInfo); ok && usage != nil {
+ usage.Source = source
+ s.attachKiroRuntimeState(ctx, account, usage)
+ if source == "active" {
+ s.tryClearRecoverableAccountError(ctx, account)
+ }
+ return usage, nil
+ }
+ }
+
+ degraded := buildKiroDegradedUsage(fetchErr)
+ degraded.Source = source
+ if hasCached {
+ cached.Error = degraded.Error
+ cached.ErrorCode = degraded.ErrorCode
+ cached.NeedsReauth = degraded.NeedsReauth
+ cached.KiroQuotaState = degraded.KiroQuotaState
+ cached.KiroQuotaReason = degraded.KiroQuotaReason
+ cached.KiroQuotaResetAt = degraded.KiroQuotaResetAt
+ cached.Source = source
+ s.attachKiroRuntimeState(ctx, account, cached)
+ return cached, nil
+ }
+ s.storeKiroUsageSnapshot(account.ID, degraded)
+ s.attachKiroRuntimeState(ctx, account, degraded)
+ return degraded, nil
+}
+
+func (s *AccountUsageService) fetchAndCacheKiroUsage(ctx context.Context, account *Account, source string) (*UsageInfo, error) {
+ token := strings.TrimSpace(account.GetCredential("access_token"))
+ if token == "" {
+ return nil, fmt.Errorf("no access token available")
+ }
+
+ region := kiroAPIRegion(account)
+ profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
+
+ resp, err := s.requestKiroUsageLimits(ctx, account, region, profileArn, token)
+ if err != nil {
+ return nil, err
+ }
+
+ usage := mapKiroUsageToInfo(resp)
+ usage.Source = source
+ s.storeKiroUsageSnapshot(account.ID, usage)
+ return usage, nil
+}
+
+func (s *AccountUsageService) storeKiroUsageSnapshot(accountID int64, usage *UsageInfo) {
+ if s == nil || s.cache == nil || accountID <= 0 || usage == nil {
+ return
+ }
+ now := time.Now()
+ if usage.UpdatedAt == nil {
+ usage.UpdatedAt = &now
+ }
+ s.cache.kiroUsageCache.Store(accountID, &kiroUsageCache{
+ usageInfo: cloneUsageInfo(usage),
+ timestamp: now,
+ })
+}
+
+func (s *AccountUsageService) getCachedKiroUsage(accountID int64) (*UsageInfo, bool) {
+ if s == nil || s.cache == nil || accountID <= 0 {
+ return nil, false
+ }
+ cached, ok := s.cache.kiroUsageCache.Load(accountID)
+ if !ok {
+ return nil, false
+ }
+ cache, ok := cached.(*kiroUsageCache)
+ if !ok || cache == nil || cache.usageInfo == nil {
+ return nil, false
+ }
+ if time.Since(cache.timestamp) >= kiroCacheTTL(cache.usageInfo) {
+ return nil, false
+ }
+ return cloneUsageInfo(cache.usageInfo), true
+}
+
+func kiroCacheTTL(info *UsageInfo) time.Duration {
+ if info == nil {
+ return kiroUsageErrorTTL
+ }
+ if info.ErrorCode != "" || info.Error != "" {
+ return kiroUsageErrorTTL
+ }
+ return apiCacheTTL
+}
+
+func cloneUsageInfo(info *UsageInfo) *UsageInfo {
+ if info == nil {
+ return nil
+ }
+ cloned := *info
+ return &cloned
+}
+
+func (s *AccountUsageService) requestKiroUsageLimits(ctx context.Context, account *Account, region, profileArn, token string) (*kiroUsageLimitsResponse, error) {
+ endpoint := resolveKiroRuntimeEndpoint(region)
+ reqURL, err := url.Parse(endpoint + "/getUsageLimits")
+ if err != nil {
+ return nil, fmt.Errorf("build kiro usage url failed: %w", err)
+ }
+ q := reqURL.Query()
+ q.Set("origin", kiroUsageOrigin)
+ if profileArn = strings.TrimSpace(profileArn); profileArn != "" {
+ q.Set("profileArn", profileArn)
+ }
+ q.Set("resourceType", kiroUsageResourceType)
+ reqURL.RawQuery = q.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("create kiro usage request failed: %w", err)
+ }
+ s.applyKiroRuntimeHeaders(req, account, token)
+
+ client, err := httpclient.GetClient(httpclient.Options{
+ ProxyURL: accountProxyURL(account),
+ Timeout: 30 * time.Second,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: isLoopbackEndpoint(endpoint),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("create kiro usage client failed: %w", err)
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("kiro usage request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("read kiro usage response failed: %w", err)
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
+ }
+
+ var parsed kiroUsageLimitsResponse
+ if err := json.Unmarshal(body, &parsed); err != nil {
+ return nil, fmt.Errorf("decode kiro usage response failed: %w", err)
+ }
+ return &parsed, nil
+}
+
+func (s *AccountUsageService) applyKiroRuntimeHeaders(req *http.Request, account *Account, token string) {
+ if req == nil {
+ return
+ }
+ accountKey := buildKiroAccountKey(account)
+ machineID := buildKiroMachineID(account)
+ req.Header.Set("Accept", "*/*")
+ req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
+ req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
+ req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
+ req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
+ req.Header.Set("x-amzn-codewhisperer-optout", "true")
+ req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
+ req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
+
+ if account == nil {
+ return
+ }
+ applyKiroConditionalHeaders(req, account)
+}
+
+func accountProxyURL(account *Account) string {
+ if account == nil || account.ProxyID == nil || account.Proxy == nil {
+ return ""
+ }
+ return account.Proxy.URL()
+}
+
+func kiroRuntimeEndpoint(region string) string {
+ region = strings.TrimSpace(region)
+ if region == "" {
+ region = kiroDefaultRegion
+ }
+ switch region {
+ case "us-east-1":
+ return "https://q.us-east-1.amazonaws.com"
+ case "eu-central-1":
+ return "https://q.eu-central-1.amazonaws.com"
+ case "us-gov-east-1":
+ return "https://q-fips.us-gov-east-1.amazonaws.com"
+ case "us-gov-west-1":
+ return "https://q-fips.us-gov-west-1.amazonaws.com"
+ case "us-iso-east-1":
+ return "https://q.us-iso-east-1.c2s.ic.gov"
+ case "us-isob-east-1":
+ return "https://q.us-isob-east-1.sc2s.sgov.gov"
+ case "us-isof-south-1":
+ return "https://q.us-isof-south-1.csp.hci.ic.gov"
+ case "us-isof-east-1":
+ return "https://q.us-isof-east-1.csp.hci.ic.gov"
+ default:
+ if strings.HasPrefix(region, "us-gov-") {
+ return "https://q-fips." + region + ".amazonaws.com"
+ }
+ return "https://q." + region + ".amazonaws.com"
+ }
+}
+
+func isLoopbackEndpoint(raw string) bool {
+ parsed, err := url.Parse(strings.TrimSpace(raw))
+ if err != nil {
+ return false
+ }
+ host := strings.TrimSpace(parsed.Hostname())
+ if host == "" {
+ return false
+ }
+ if strings.EqualFold(host, "localhost") {
+ return true
+ }
+ ip := net.ParseIP(host)
+ return ip != nil && ip.IsLoopback()
+}
+
+func mapKiroUsageToInfo(resp *kiroUsageLimitsResponse) *UsageInfo {
+ now := time.Now()
+ if resp == nil {
+ return &UsageInfo{UpdatedAt: &now}
+ }
+ info := &UsageInfo{
+ UpdatedAt: &now,
+ KiroSubscriptionName: strings.TrimSpace(resp.SubscriptionInfo.SubscriptionTitle),
+ KiroSubscriptionType: strings.TrimSpace(resp.SubscriptionInfo.Type),
+ KiroOveragesEnabled: strings.EqualFold(strings.TrimSpace(resp.OverageConfiguration.OverageStatus), "ENABLED"),
+ }
+
+ resetAt := parseKiroTimestamp(resp.NextDateReset)
+ if credit := selectKiroCreditBreakdown(resp.UsageBreakdownList); credit != nil {
+ if breakdownReset := parseKiroTimestamp(credit.NextDateReset); breakdownReset != nil {
+ resetAt = breakdownReset
+ }
+ info.KiroCredit = &KiroCreditProgress{
+ CurrentUsage: selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage),
+ UsageLimit: selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit),
+ PercentageUsed: percentageOrZero(selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage), selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit)),
+ }
+ info.KiroOverage = &KiroOverageInfo{
+ CurrentOverages: selectKiroFloat(credit.CurrentOveragesWithPrecision, credit.CurrentOverages),
+ OverageCharges: selectKiroFloat(credit.OverageCharges, nil),
+ CurrencyCode: strings.TrimSpace(credit.Currency),
+ CurrencySymbol: kiroCurrencySymbol(strings.TrimSpace(credit.Currency)),
+ }
+ if ft := credit.FreeTrialInfo; ft != nil && strings.EqualFold(strings.TrimSpace(ft.FreeTrialStatus), "ACTIVE") {
+ expiry := parseKiroTimestamp(ft.FreeTrialExpiry)
+ daysRemaining := 0
+ if expiry != nil {
+ daysRemaining = int(time.Until(*expiry).Hours() / 24)
+ if time.Until(*expiry)%(24*time.Hour) != 0 {
+ daysRemaining++
+ }
+ if daysRemaining < 0 {
+ daysRemaining = 0
+ }
+ }
+ current := selectKiroFloat(ft.CurrentUsageWithPrecision, ft.CurrentUsage)
+ limit := selectKiroFloat(ft.UsageLimitWithPrecision, ft.UsageLimit)
+ info.KiroBonus = &KiroCreditProgress{
+ CurrentUsage: current,
+ UsageLimit: limit,
+ PercentageUsed: percentageOrZero(current, limit),
+ DaysRemaining: daysRemaining,
+ ExpiryDate: expiry,
+ }
+ }
+ }
+ info.KiroResetAt = resetAt
+ setKiroQuotaStateFromUsage(info)
+ return info
+}
+
+func selectKiroCreditBreakdown(items []kiroUsageBreakdown) *kiroUsageBreakdown {
+ for i := range items {
+ if strings.EqualFold(strings.TrimSpace(items[i].ResourceType), "CREDIT") {
+ return &items[i]
+ }
+ }
+ if len(items) == 0 {
+ return nil
+ }
+ return &items[0]
+}
+
+func selectKiroFloat(preferred *float64, fallback *float64) float64 {
+ switch {
+ case preferred != nil:
+ return *preferred
+ case fallback != nil:
+ return *fallback
+ default:
+ return 0
+ }
+}
+
+func percentageOrZero(current, limit float64) float64 {
+ if limit <= 0 {
+ return 0
+ }
+ return current / limit * 100
+}
+
+func parseKiroTimestamp(raw any) *time.Time {
+ if raw == nil {
+ return nil
+ }
+ switch v := raw.(type) {
+ case string:
+ trimmed := strings.TrimSpace(v)
+ if trimmed == "" {
+ return nil
+ }
+ if parsed, err := time.Parse(time.RFC3339, trimmed); err == nil {
+ return &parsed
+ }
+ if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
+ return unixishToTime(i)
+ }
+ if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
+ return unixishFloatToTime(f)
+ }
+ case float64:
+ return unixishFloatToTime(v)
+ case int64:
+ return unixishToTime(v)
+ case int:
+ return unixishToTime(int64(v))
+ case json.Number:
+ if i, err := v.Int64(); err == nil {
+ return unixishToTime(i)
+ }
+ if f, err := v.Float64(); err == nil {
+ return unixishFloatToTime(f)
+ }
+ }
+ return nil
+}
+
+func unixishFloatToTime(v float64) *time.Time {
+ if v <= 0 {
+ return nil
+ }
+ if v >= 1e12 {
+ t := time.UnixMilli(int64(v))
+ return &t
+ }
+ t := time.Unix(int64(v), 0)
+ return &t
+}
+
+func unixishToTime(v int64) *time.Time {
+ if v <= 0 {
+ return nil
+ }
+ if v >= 1e12 {
+ t := time.UnixMilli(v)
+ return &t
+ }
+ t := time.Unix(v, 0)
+ return &t
+}
+
+func kiroCurrencySymbol(code string) string {
+ switch strings.ToUpper(strings.TrimSpace(code)) {
+ case "USD":
+ return "$"
+ default:
+ return ""
+ }
+}
+
+func buildKiroDegradedUsage(err error) *UsageInfo {
+ now := time.Now()
+ info := &UsageInfo{
+ UpdatedAt: &now,
+ Error: "usage API error",
+ ErrorCode: errorCodeNetworkError,
+ }
+ if err == nil {
+ return info
+ }
+
+ info.Error = fmt.Sprintf("usage API error: %v", err)
+
+ classification := classifyKiroError(err)
+ switch classification.Category {
+ case kiroErrorAuthError:
+ info.ErrorCode = errorCodeUnauthenticated
+ info.NeedsReauth = true
+ case kiroErrorRateLimited:
+ info.ErrorCode = errorCodeRateLimited
+ case kiroErrorQuotaExhausted:
+ info.ErrorCode = errorCodeNetworkError
+ info.KiroQuotaState = kiroQuotaStateCreditsExhausted
+ info.KiroQuotaReason = classification.Message
+ case kiroErrorOverageExhausted:
+ info.ErrorCode = errorCodeNetworkError
+ info.KiroQuotaState = kiroQuotaStateOverageExhausted
+ info.KiroQuotaReason = classification.Message
+ case kiroErrorSuspended, kiroErrorUsageForbidden, kiroErrorProfileError:
+ info.ErrorCode = errorCodeForbidden
+ default:
+ info.ErrorCode = errorCodeNetworkError
+ }
+ return info
+}
+
+func (s *AccountUsageService) attachKiroRuntimeState(ctx context.Context, account *Account, usage *UsageInfo) {
+ if s == nil || usage == nil || account == nil || account.Platform != PlatformKiro || s.kiroCooldownStore == nil {
+ return
+ }
+ usage.KiroRuntimeState = ""
+ usage.KiroRuntimeReason = ""
+ usage.KiroRuntimeResetAt = nil
+ state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
+ if err != nil || state == nil {
+ return
+ }
+ usage.KiroRuntimeState, usage.KiroRuntimeReason, usage.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
+}
+
+func (s *AccountUsageService) EnrichAccountWithKiroRuntimeState(ctx context.Context, account *Account) {
+ if s == nil || account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
+ return
+ }
+ account.KiroQuotaState = ""
+ account.KiroQuotaReason = ""
+ account.KiroQuotaResetAt = nil
+ account.KiroRuntimeState = ""
+ account.KiroRuntimeReason = ""
+ account.KiroRuntimeResetAt = nil
+ if usage, ok := s.getCachedKiroUsage(account.ID); ok {
+ account.KiroQuotaState = usage.KiroQuotaState
+ account.KiroQuotaReason = usage.KiroQuotaReason
+ account.KiroQuotaResetAt = usage.KiroQuotaResetAt
+ }
+ if s.kiroCooldownStore == nil {
+ return
+ }
+ state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
+ if err != nil || state == nil {
+ return
+ }
+ account.KiroRuntimeState, account.KiroRuntimeReason, account.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
+}
+
+func setKiroQuotaStateFromUsage(info *UsageInfo) {
+ if info == nil {
+ return
+ }
+ info.KiroQuotaState = ""
+ info.KiroQuotaReason = ""
+ info.KiroQuotaResetAt = nil
+
+ creditExhausted := false
+ if info.KiroCredit != nil && info.KiroCredit.UsageLimit > 0 {
+ creditExhausted = info.KiroCredit.CurrentUsage >= info.KiroCredit.UsageLimit
+ }
+ overageActive := info.KiroOverage != nil &&
+ (info.KiroOverage.CurrentOverages > 0 || info.KiroOverage.OverageCharges > 0)
+
+ switch {
+ case info.KiroOveragesEnabled && (overageActive || creditExhausted):
+ info.KiroQuotaState = kiroQuotaStateOverageActive
+ info.KiroQuotaReason = "overages_enabled"
+ info.KiroQuotaResetAt = info.KiroResetAt
+ case creditExhausted:
+ info.KiroQuotaState = kiroQuotaStateCreditsExhausted
+ info.KiroQuotaReason = "credits_exhausted"
+ info.KiroQuotaResetAt = info.KiroResetAt
+ default:
+ info.KiroQuotaState = kiroQuotaStateNormal
+ }
+}
diff --git a/backend/internal/service/kiro_websearch.go b/backend/internal/service/kiro_websearch.go
new file mode 100644
index 00000000..dc97e992
--- /dev/null
+++ b/backend/internal/service/kiro_websearch.go
@@ -0,0 +1,458 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "sync"
+
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
+)
+
+const kiroMaxWebSearchIterations = 5
+
+var (
+ errKiroWebSearchFallback = errors.New("kiro web search fallback")
+ kiroWebSearchDescCache sync.Map
+)
+
+type kiroWebSearchExecution struct {
+ ResponseBody []byte
+ Usage ClaudeUsage
+ RequestID string
+}
+
+type kiroWebSearchHTTPError struct {
+ Response *http.Response
+}
+
+type kiroStreamChunkCollector struct {
+ chunks [][]byte
+}
+
+func (e *kiroWebSearchHTTPError) Error() string {
+ if e == nil || e.Response == nil {
+ return "kiro web search http error"
+ }
+ return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
+}
+
+func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
+ if len(p) > 0 {
+ w.chunks = append(w.chunks, append([]byte(nil), p...))
+ }
+ return len(p), nil
+}
+
+func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
+ collector := &kiroStreamChunkCollector{}
+ result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
+ if err != nil {
+ return nil, nil, err
+ }
+ return collector.chunks, result, nil
+}
+
+func writeSSEChunks(w io.Writer, chunks [][]byte) error {
+ for _, chunk := range chunks {
+ if len(chunk) == 0 {
+ continue
+ }
+ if _, err := w.Write(chunk); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
+ if strings.TrimSpace(msgID) == "" {
+ msgID = "msg_" + kiropkg.GenerateToolUseID()
+ }
+ if strings.TrimSpace(model) == "" {
+ model = "kiro"
+ }
+ payload, err := json.Marshal(map[string]any{
+ "type": "message_start",
+ "message": map[string]any{
+ "id": msgID,
+ "type": "message",
+ "role": "assistant",
+ "model": model,
+ "content": []any{},
+ "stop_reason": nil,
+ "stop_sequence": nil,
+ "usage": map[string]any{
+ "input_tokens": inputTokens,
+ "output_tokens": 0,
+ "cache_creation_input_tokens": 0,
+ "cache_read_input_tokens": 0,
+ },
+ },
+ })
+ if err != nil {
+ return err
+ }
+ _, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
+ return err
+}
+
+func (s *GatewayService) streamKiroWebSearchAsAnthropic(
+ ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
+) error {
+ query := kiropkg.ExtractSearchQuery(anthropicBody)
+ if strings.TrimSpace(query) == "" {
+ return errKiroWebSearchFallback
+ }
+
+ currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
+ if err != nil {
+ currentBody = anthropicBody
+ }
+ currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
+ nextContentBlockIndex := 0
+
+ if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
+ return err
+ }
+
+ for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
+ s.prefetchKiroWebSearchDescription(ctx, account, token)
+
+ results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
+ if strings.TrimSpace(nextToken) != "" {
+ token = nextToken
+ }
+ if mcpErr != nil {
+ results = nil
+ }
+
+ if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
+ return err
+ }
+ nextContentBlockIndex += 2
+
+ currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
+ if err != nil {
+ return errKiroWebSearchFallback
+ }
+
+ resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return &kiroWebSearchHTTPError{Response: resp}
+ }
+
+ chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
+ defer func() { _ = resp.Body.Close() }()
+ return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
+ }()
+ if streamErr != nil {
+ return streamErr
+ }
+
+ analysis := kiropkg.AnalyzeBufferedStream(chunks)
+ if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
+ filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
+ if err := writeSSEChunks(w, filtered); err != nil {
+ return err
+ }
+ if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
+ nextContentBlockIndex = maxIndex + 1
+ }
+ query = analysis.WebSearchQuery
+ if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
+ currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
+ } else {
+ currentToolUseID = analysis.WebSearchToolUseID
+ }
+ continue
+ }
+
+ for _, chunk := range chunks {
+ adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
+ if !shouldForward {
+ continue
+ }
+ if _, err := w.Write(adjusted); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+
+ return fmt.Errorf("kiro web search exceeded max iterations")
+}
+
+func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
+ query := kiropkg.ExtractSearchQuery(anthropicBody)
+ if strings.TrimSpace(query) == "" {
+ return nil, errKiroWebSearchFallback
+ }
+
+ currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
+ if err != nil {
+ currentBody = anthropicBody
+ }
+
+ inputTokens := estimateKiroInputTokens(anthropicBody)
+ currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
+ searches := make([]kiropkg.SearchIndicator, 0, 2)
+ requestID := ""
+
+ for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
+ s.prefetchKiroWebSearchDescription(ctx, account, token)
+
+ results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
+ if strings.TrimSpace(nextToken) != "" {
+ token = nextToken
+ }
+ if mcpErr != nil {
+ results = nil
+ }
+ searches = append(searches, kiropkg.SearchIndicator{
+ ToolUseID: currentToolUseID,
+ Query: query,
+ Results: results,
+ })
+
+ currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
+ if err != nil {
+ return nil, errKiroWebSearchFallback
+ }
+
+ resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
+ if err != nil {
+ return nil, err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, &kiroWebSearchHTTPError{Response: resp}
+ }
+
+ parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
+ defer func() { _ = resp.Body.Close() }()
+ return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
+ }()
+ if parseErr != nil {
+ return nil, parseErr
+ }
+ if requestID == "" {
+ requestID = buildKiroRequestID(resp)
+ }
+
+ nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
+ if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
+ finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
+ if injectErr == nil {
+ parseResult.ResponseBody = finalBody
+ }
+ return &kiroWebSearchExecution{
+ ResponseBody: parseResult.ResponseBody,
+ Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
+ RequestID: requestID,
+ }, nil
+ }
+
+ query = nextQuery
+ if strings.TrimSpace(nextToolUseID) == "" {
+ nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
+ }
+ currentToolUseID = nextToolUseID
+ }
+
+ return nil, fmt.Errorf("kiro web search exceeded max iterations")
+}
+
+func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
+ endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
+ if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
+ if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
+ kiropkg.SetCachedWebSearchDescription(desc)
+ }
+ return
+ }
+
+ reqBody, _ := json.Marshal(kiropkg.MCPRequest{
+ ID: "tools_list",
+ JSONRPC: "2.0",
+ Method: "tools/list",
+ })
+ resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
+ if err != nil || resp == nil {
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode != http.StatusOK {
+ return
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return
+ }
+ var result kiropkg.MCPResponse
+ if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
+ return
+ }
+ for _, tool := range result.Result.Tools {
+ if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
+ kiroWebSearchDescCache.Store(endpoint, tool.Description)
+ kiropkg.SetCachedWebSearchDescription(tool.Description)
+ return
+ }
+ }
+}
+
+func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
+ reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
+ if err != nil {
+ return nil, token, err
+ }
+
+ endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
+ resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
+ if err != nil {
+ return nil, nextToken, err
+ }
+ if resp == nil {
+ return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, nextToken, err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
+ }
+
+ var parsed kiropkg.MCPResponse
+ if err := json.Unmarshal(body, &parsed); err != nil {
+ return nil, nextToken, err
+ }
+ if parsed.Error != nil {
+ msg := "unknown error"
+ if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
+ msg = strings.TrimSpace(*parsed.Error.Message)
+ }
+ code := 0
+ if parsed.Error.Code != nil {
+ code = *parsed.Error.Code
+ }
+ return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
+ }
+
+ return kiropkg.ParseSearchResults(&parsed), nextToken, nil
+}
+
+func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
+ return kiropkg.MCPRequest{
+ ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
+ JSONRPC: "2.0",
+ Method: "tools/call",
+ Params: map[string]interface{}{
+ "name": "web_search",
+ "arguments": map[string]interface{}{
+ "query": query,
+ "_meta": map[string]interface{}{
+ "_isValid": true,
+ "_activePath": []string{"query"},
+ "_completedPaths": [][]string{{"query"}},
+ },
+ },
+ },
+ }
+}
+
+func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
+ currentToken := token
+ accountKey := buildKiroAccountKey(account)
+ proxyURL := kiroProxyURL(account)
+ tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
+
+ for attempt := 0; attempt < 3; attempt++ {
+ if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
+ if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
+ return nil, currentToken, failoverErr
+ }
+ return nil, currentToken, err
+ }
+
+ req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
+ if err != nil {
+ return nil, currentToken, err
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
+ if err != nil {
+ return nil, currentToken, err
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
+ respBody, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ return nil, currentToken, readErr
+ }
+ if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
+ if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
+ return nil, currentToken, err
+ }
+ resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
+ return resp, currentToken, nil
+ }
+ if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
+ resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
+ return resp, currentToken, nil
+ }
+ if s.kiroTokenProvider == nil {
+ resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
+ return resp, currentToken, nil
+ }
+ refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
+ if refreshErr != nil {
+ resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
+ return resp, currentToken, nil
+ }
+ currentToken = refreshedToken
+ accountKey = buildKiroAccountKey(account)
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, currentToken, sleepErr
+ }
+ continue
+ }
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ if _, err := s.markKiro429(ctx, accountKey); err != nil {
+ _ = resp.Body.Close()
+ return nil, currentToken, err
+ }
+ }
+ if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
+ if attempt < 2 {
+ _ = resp.Body.Close()
+ if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
+ return nil, currentToken, sleepErr
+ }
+ continue
+ }
+ }
+ if resp.StatusCode >= 200 && resp.StatusCode < 300 {
+ if err := s.markKiroSuccess(ctx, accountKey); err != nil {
+ _ = resp.Body.Close()
+ return nil, currentToken, err
+ }
+ }
+
+ return resp, currentToken, nil
+ }
+
+ return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
+}
diff --git a/backend/internal/service/kiro_websearch_test.go b/backend/internal/service/kiro_websearch_test.go
new file mode 100644
index 00000000..26b7acaf
--- /dev/null
+++ b/backend/internal/service/kiro_websearch_test.go
@@ -0,0 +1,28 @@
+//go:build unit
+
+package service
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestBuildKiroWebSearchMCPRequest_UsesUnderscoredMetaKeys(t *testing.T) {
+ req := buildKiroWebSearchMCPRequest("golang concurrency")
+
+ body, err := json.Marshal(req)
+ require.NoError(t, err)
+
+ require.Equal(t, "tools/call", gjson.GetBytes(body, "method").String())
+ require.Equal(t, "web_search", gjson.GetBytes(body, "params.name").String())
+ require.Equal(t, "golang concurrency", gjson.GetBytes(body, "params.arguments.query").String())
+ require.True(t, gjson.GetBytes(body, "params.arguments._meta._isValid").Bool())
+ require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._activePath.0").String())
+ require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._completedPaths.0.0").String())
+ require.False(t, gjson.GetBytes(body, "params.arguments._meta.isValid").Exists())
+ require.False(t, gjson.GetBytes(body, "params.arguments._meta.activePath").Exists())
+ require.False(t, gjson.GetBytes(body, "params.arguments._meta.completedPaths").Exists())
+}
diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go
index 05d444e1..fb4c2ca5 100644
--- a/backend/internal/service/ops_upstream_context.go
+++ b/backend/internal/service/ops_upstream_context.go
@@ -89,6 +89,14 @@ type OpsUpstreamErrorEvent struct {
AccountID int64 `json:"account_id,omitempty"`
AccountName string `json:"account_name,omitempty"`
+ // Model diagnostics.
+ RequestedModel string `json:"requested_model,omitempty"`
+ MappedModel string `json:"mapped_model,omitempty"`
+ KiroModelID string `json:"kiro_model_id,omitempty"`
+ HasTools bool `json:"has_tools,omitempty"`
+ HasAdaptiveThinking bool `json:"has_adaptive_thinking,omitempty"`
+ HasContext1MBeta bool `json:"has_context_1m_beta,omitempty"`
+
// Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go
index 74c9edc3..93f07176 100644
--- a/backend/internal/service/token_cache_invalidator.go
+++ b/backend/internal/service/token_cache_invalidator.go
@@ -42,6 +42,9 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
+ case PlatformKiro:
+ keysToDelete = append(keysToDelete, KiroTokenCacheKey(account))
+ keysToDelete = append(keysToDelete, "kiro:"+accountIDKey)
case PlatformOpenAI:
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic:
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index 22f4aa29..1ff5edcc 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -10,6 +10,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
@@ -44,6 +45,7 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
+ kiroOAuthService *KiroOAuthService,
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
@@ -64,6 +66,7 @@ func NewTokenRefreshService(
claudeRefresher := NewClaudeTokenRefresher(oauthService)
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
+ kiroRefresher := NewKiroTokenRefresher(kiroOAuthService)
// 注册平台特定的刷新器(TokenRefresher 接口)
s.refreshers = []TokenRefresher{
@@ -71,6 +74,7 @@ func NewTokenRefreshService(
openAIRefresher,
geminiRefresher,
agRefresher,
+ kiroRefresher,
}
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
@@ -79,6 +83,7 @@ func NewTokenRefreshService(
openAIRefresher,
geminiRefresher,
agRefresher,
+ kiroRefresher,
}
return s
@@ -415,6 +420,10 @@ func isNonRetryableRefreshError(err error) bool {
if err == nil {
return false
}
+ var kiroInvalidGrant *kiropkg.RefreshTokenInvalidError
+ if errors.As(err, &kiroInvalidGrant) {
+ return true
+ }
msg := strings.ToLower(err.Error())
nonRetryable := []string{
"invalid_grant", // refresh_token 已失效
diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go
index 7cc8a713..cdea1aa3 100644
--- a/backend/internal/service/usage_log_helpers.go
+++ b/backend/internal/service/usage_log_helpers.go
@@ -22,9 +22,9 @@ func optionalNonEqualStringPtr(value, compare string) *string {
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
- return trimmed
+ return normalizeModelNameForPricing(trimmed)
}
- return strings.TrimSpace(upstreamModel)
+ return normalizeModelNameForPricing(upstreamModel)
}
func optionalInt64Ptr(v int64) *int64 {
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 8b50e478..6db44acf 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -8,6 +8,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
@@ -51,6 +52,7 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
+ kiroOAuthService *KiroOAuthService,
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
@@ -59,7 +61,7 @@ func ProvideTokenRefreshService(
proxyRepo ProxyRepository,
refreshAPI *OAuthRefreshAPI,
) *TokenRefreshService {
- svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
+ svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 OpenAI privacy opt-out 依赖
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
@@ -128,6 +130,23 @@ func ProvideAntigravityTokenProvider(
return p
}
+func ProvideKiroTokenProvider(
+ accountRepo AccountRepository,
+ tokenCache GeminiTokenCache,
+ kiroOAuthService *KiroOAuthService,
+ refreshAPI *OAuthRefreshAPI,
+) *KiroTokenProvider {
+ p := NewKiroTokenProvider(accountRepo, tokenCache, kiroOAuthService)
+ executor := NewKiroTokenRefresher(kiroOAuthService)
+ p.SetRefreshAPI(refreshAPI, executor)
+ p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
+ return p
+}
+
+func ProvideKiroCooldownStore(redisClient *redis.Client) KiroCooldownStore {
+ return kirocooldown.NewStore(redisClient)
+}
+
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
@@ -448,8 +467,11 @@ var ProviderSet = wire.NewSet(
NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService,
+ NewKiroOAuthService,
ProvideOAuthRefreshAPI,
ProvideGeminiTokenProvider,
+ ProvideKiroTokenProvider,
+ ProvideKiroCooldownStore,
NewGeminiMessagesCompatService,
ProvideAntigravityTokenProvider,
ProvideOpenAITokenProvider,