Merge remote-tracking branch 'pr/2131' into release/v0.1.133
# Conflicts: # backend/cmd/server/wire_gen.go # backend/internal/config/config.go # backend/internal/service/gateway_service.go # backend/internal/service/pricing_service.go # backend/internal/service/wire.go # deploy/config.example.yaml # frontend/src/views/admin/AccountsView.vue
This commit is contained in:
@@ -93,6 +93,7 @@ func provideCleanup(
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
kiroOAuth *service.KiroOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
@@ -216,6 +217,10 @@ func provideCleanup(
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"KiroOAuthService", func() error {
|
||||
kiroOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIWSPool", func() error {
|
||||
if openAIGateway != nil {
|
||||
openAIGateway.CloseOpenAIWSPool()
|
||||
|
||||
@@ -146,13 +146,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
|
||||
kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
@@ -166,6 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
|
||||
kiroOAuthHandler := admin.NewKiroOAuthHandler(kiroOAuthService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
@@ -179,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
kiroCooldownStore := service.ProvideKiroCooldownStore(redisClient)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@@ -236,7 +240,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
contentModerationHandler := admin.NewContentModerationHandler(contentModerationService)
|
||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, kiroOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -260,13 +264,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -316,6 +320,7 @@ func provideCleanup(
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
kiroOAuth *service.KiroOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
@@ -438,6 +443,10 @@ func provideCleanup(
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"KiroOAuthService", func() error {
|
||||
kiroOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIWSPool", func() error {
|
||||
if openAIGateway != nil {
|
||||
openAIGateway.CloseOpenAIWSPool()
|
||||
|
||||
@@ -36,6 +36,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
)
|
||||
@@ -72,6 +73,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
openAIOAuthSvc,
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // kiroOAuth
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
|
||||
@@ -682,6 +682,8 @@ type GatewayConfig struct {
|
||||
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
|
||||
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用
|
||||
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
|
||||
// KiroStreamKeepaliveInterval: Kiro 流式 keepalive 间隔(秒),0使用默认 25 秒
|
||||
KiroStreamKeepaliveInterval int `mapstructure:"kiro_stream_keepalive_interval"`
|
||||
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
|
||||
MaxLineSize int `mapstructure:"max_line_size"`
|
||||
|
||||
@@ -1752,6 +1754,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
|
||||
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.kiro_stream_keepalive_interval", 25)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
@@ -2369,6 +2372,13 @@ func (c *Config) Validate() error {
|
||||
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
|
||||
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
|
||||
}
|
||||
if c.Gateway.KiroStreamKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be non-negative")
|
||||
}
|
||||
if c.Gateway.KiroStreamKeepaliveInterval != 0 &&
|
||||
(c.Gateway.KiroStreamKeepaliveInterval < 5 || c.Gateway.KiroStreamKeepaliveInterval > 30) {
|
||||
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||
}
|
||||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformKiro = "kiro"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
@@ -117,6 +118,21 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultKiroModelMapping 是 Kiro 平台的默认模型映射。
|
||||
// 键为对外暴露/允许请求的模型名,值为实际发送到 Kiro 上游的模型名。
|
||||
var DefaultKiroModelMapping = map[string]string{
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4-6-thinking": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4.5",
|
||||
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
|
||||
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -24,3 +27,56 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expected := map[string]string{
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4-6-thinking": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4.5",
|
||||
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
|
||||
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
if len(DefaultKiroModelMapping) != len(expected) {
|
||||
t.Fatalf("expected %d Kiro mappings, got %d", len(expected), len(DefaultKiroModelMapping))
|
||||
}
|
||||
for model, want := range expected {
|
||||
if got := DefaultKiroModelMapping[model]; got != want {
|
||||
t.Fatalf("unexpected Kiro mapping for %q: got %q want %q", model, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range []string{
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
"deepseek-3-2",
|
||||
"minimax-m2-1",
|
||||
"qwen3-coder-next",
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6-chat",
|
||||
} {
|
||||
if _, ok := DefaultKiroModelMapping[model]; ok {
|
||||
t.Fatalf("did not expect %q to remain in DefaultKiroModelMapping", model)
|
||||
}
|
||||
}
|
||||
for model := range DefaultKiroModelMapping {
|
||||
if strings.HasSuffix(model, "-agentic") {
|
||||
t.Fatalf("did not expect agentic Kiro mapping %q", model)
|
||||
}
|
||||
if strings.HasSuffix(model, "-chat") {
|
||||
t.Fatalf("did not expect chat-only Kiro mapping %q", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -179,6 +180,9 @@ type AccountWithConcurrency struct {
|
||||
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||
|
||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||
if h.accountUsageService != nil {
|
||||
h.accountUsageService.EnrichAccountWithKiroRuntimeState(ctx, account)
|
||||
}
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(account),
|
||||
CurrentConcurrency: 0,
|
||||
@@ -351,6 +355,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
result := make([]AccountWithConcurrency, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if h.accountUsageService != nil {
|
||||
h.accountUsageService.EnrichAccountWithKiroRuntimeState(c.Request.Context(), acc)
|
||||
}
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(acc),
|
||||
CurrentConcurrency: concurrencyCounts[acc.ID],
|
||||
@@ -1953,6 +1960,18 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Kiro accounts
|
||||
if account.Platform == service.PlatformKiro {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, kiropkg.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, buildMappedKiroModels(mapping))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
@@ -1994,6 +2013,28 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
func buildMappedKiroModels(mapping map[string]string) []kiropkg.Model {
|
||||
models := make([]kiropkg.Model, 0, len(mapping))
|
||||
for requestedModel := range mapping {
|
||||
var found bool
|
||||
for _, dm := range kiropkg.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
models = append(models, kiropkg.Model{
|
||||
ID: requestedModel,
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||
@@ -2206,6 +2247,12 @@ func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
|
||||
response.Success(c, domain.DefaultAntigravityModelMapping)
|
||||
}
|
||||
|
||||
// GetKiroDefaultModelMapping 获取 Kiro 平台的默认模型映射
|
||||
// GET /api/v1/admin/accounts/kiro/default-model-mapping
|
||||
func (h *AccountHandler) GetKiroDefaultModelMapping(c *gin.Context) {
|
||||
response.Success(c, domain.DefaultKiroModelMapping)
|
||||
}
|
||||
|
||||
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
|
||||
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
|
||||
func sanitizeExtraBaseRPM(extra map[string]any) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "kiro-oauth",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 47,
|
||||
Name: "kiro-oauth-mapped",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "kiro-apikey",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 46,
|
||||
Name: "kiro-apikey-defaults",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -123,7 +123,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
|
||||
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(createReq))
|
||||
|
||||
updateReq := UpdateGroupRequest{Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type KiroOAuthHandler struct {
|
||||
kiroOAuthService *service.KiroOAuthService
|
||||
}
|
||||
|
||||
func NewKiroOAuthHandler(kiroOAuthService *service.KiroOAuthService) *KiroOAuthHandler {
|
||||
return &KiroOAuthHandler{kiroOAuthService: kiroOAuthService}
|
||||
}
|
||||
|
||||
type KiroGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req KiroGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.kiroOAuthService.GenerateAuthURL(c.Request.Context(), &service.KiroGenerateAuthURLInput{
|
||||
ProxyID: req.ProxyID,
|
||||
Provider: req.Provider,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "生成授权链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type KiroGenerateIDCAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
StartURL string `json:"start_url"`
|
||||
Region string `json:"region"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) GenerateIDCAuthURL(c *gin.Context) {
|
||||
var req KiroGenerateIDCAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.kiroOAuthService.GenerateIDCAuthURL(c.Request.Context(), &service.KiroGenerateIDCAuthURLInput{
|
||||
ProxyID: req.ProxyID,
|
||||
StartURL: req.StartURL,
|
||||
Region: req.Region,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "生成 IDC 授权链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type KiroExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
CallbackPath string `json:"callback_path"`
|
||||
LoginOption string `json:"login_option"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req KiroExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
tokenInfo, err := h.kiroOAuthService.ExchangeCode(c.Request.Context(), &service.KiroExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
CallbackPath: req.CallbackPath,
|
||||
LoginOption: req.LoginOption,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Token 交换失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
type KiroRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
Provider string `json:"provider"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
StartURL string `json:"start_url"`
|
||||
Region string `json:"region"`
|
||||
ProfileArn string `json:"profile_arn"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req KiroRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
tokenInfo, err := h.kiroOAuthService.RefreshToken(c.Request.Context(), &service.KiroRefreshTokenInput{
|
||||
RefreshToken: req.RefreshToken,
|
||||
AuthMethod: req.AuthMethod,
|
||||
Provider: req.Provider,
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
StartURL: req.StartURL,
|
||||
Region: req.Region,
|
||||
ProfileArn: req.ProfileArn,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "刷新 Kiro Token 失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
type KiroImportTokenRequest struct {
|
||||
TokenJSON string `json:"token_json" binding:"required"`
|
||||
DeviceRegistrationJSON string `json:"device_registration_json"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
|
||||
var req KiroImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
|
||||
TokenJSON: req.TokenJSON,
|
||||
DeviceRegistrationJSON: req.DeviceRegistrationJSON,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
@@ -230,6 +230,12 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
TempUnschedulableUntil: a.TempUnschedulableUntil,
|
||||
TempUnschedulableReason: a.TempUnschedulableReason,
|
||||
KiroQuotaState: a.KiroQuotaState,
|
||||
KiroQuotaReason: a.KiroQuotaReason,
|
||||
KiroQuotaResetAt: a.KiroQuotaResetAt,
|
||||
KiroRuntimeState: a.KiroRuntimeState,
|
||||
KiroRuntimeReason: a.KiroRuntimeReason,
|
||||
KiroRuntimeResetAt: a.KiroRuntimeResetAt,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
|
||||
@@ -183,6 +183,12 @@ type Account struct {
|
||||
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
|
||||
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
|
||||
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
|
||||
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
|
||||
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
|
||||
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
|
||||
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
|
||||
@@ -17,6 +17,7 @@ type AdminHandlers struct {
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
KiroOAuth *admin.KiroOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
|
||||
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
kiroOAuthHandler *admin.KiroOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
@@ -52,6 +53,7 @@ func ProvideAdminHandlers(
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
KiroOAuth: kiroOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
@@ -156,6 +158,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewKiroOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
|
||||
@@ -36,11 +36,12 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformKiro = "kiro"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RuntimeFingerprint struct {
|
||||
OIDCSDKVersion string
|
||||
RuntimeSDKVersion string
|
||||
StreamingSDKVersion string
|
||||
OSType string
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string
|
||||
}
|
||||
|
||||
type runtimeFingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*RuntimeFingerprint
|
||||
}
|
||||
|
||||
var (
|
||||
globalRuntimeFingerprintManager *runtimeFingerprintManager
|
||||
globalRuntimeFingerprintManagerOnce sync.Once
|
||||
|
||||
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
|
||||
runtimeSDKVersions = []string{"1.0.0"}
|
||||
streamingSDKVersions = []string{"1.0.34"}
|
||||
osTypes = []string{"darwin", "win32"}
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"24.6.0"},
|
||||
"win32": {"10.0.22631"},
|
||||
}
|
||||
nodeVersions = []string{"22.22.0"}
|
||||
kiroVersions = []string{
|
||||
"0.11.132", "0.11.131", "0.11.130",
|
||||
}
|
||||
)
|
||||
|
||||
func globalRuntimeFingerprints() *runtimeFingerprintManager {
|
||||
globalRuntimeFingerprintManagerOnce.Do(func() {
|
||||
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
|
||||
fingerprints: make(map[string]*RuntimeFingerprint),
|
||||
}
|
||||
})
|
||||
return globalRuntimeFingerprintManager
|
||||
}
|
||||
|
||||
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
|
||||
lookupKey := fingerprintLookupKey(accountKey, "runtime")
|
||||
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
|
||||
|
||||
m.mu.RLock()
|
||||
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||
m.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||
return fp
|
||||
}
|
||||
fp := generateRuntimeFingerprint(lookupKey, machineID)
|
||||
m.fingerprints[lookupKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
|
||||
hash := sha256.Sum256([]byte(accountKey))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
osType := goOSToNodePlatform(runtime.GOOS)
|
||||
if !containsString(osTypes, osType) {
|
||||
osType = osTypes[rng.Intn(len(osTypes))]
|
||||
}
|
||||
osVersionPool := osVersions[osType]
|
||||
if len(osVersionPool) == 0 {
|
||||
osVersionPool = osVersions["darwin"]
|
||||
}
|
||||
|
||||
return &RuntimeFingerprint{
|
||||
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||
OSType: osType,
|
||||
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
|
||||
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||
KiroHash: machineID,
|
||||
}
|
||||
}
|
||||
|
||||
func goOSToNodePlatform(goos string) string {
|
||||
switch strings.TrimSpace(goos) {
|
||||
case "windows":
|
||||
return "win32"
|
||||
default:
|
||||
return strings.TrimSpace(goos)
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(items []string, target string) bool {
|
||||
for _, item := range items {
|
||||
if item == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
|
||||
switch {
|
||||
case strings.TrimSpace(clientIDHash) != "":
|
||||
return clientIDHash
|
||||
case strings.TrimSpace(clientID) != "":
|
||||
return shortSHA(clientID)
|
||||
case strings.TrimSpace(refreshToken) != "":
|
||||
return shortSHA(refreshToken)
|
||||
case strings.TrimSpace(profileArn) != "":
|
||||
return shortSHA(profileArn)
|
||||
case accountID > 0:
|
||||
return shortSHA(fmt.Sprintf("account:%d", accountID))
|
||||
default:
|
||||
return shortSHA(uuid.NewString())
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeMachineID(machineID string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(machineID)
|
||||
if len(trimmed) == 64 && isHexString(trimmed) {
|
||||
return strings.ToLower(trimmed), true
|
||||
}
|
||||
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
|
||||
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
|
||||
normalized := strings.ToLower(withoutDashes)
|
||||
return normalized + normalized, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
|
||||
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
|
||||
return sha256Hex("KotlinNativeAPI/" + refreshToken)
|
||||
}
|
||||
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
|
||||
return sha256Hex("KiroAPIKey/" + apiKey)
|
||||
}
|
||||
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
|
||||
return sha256Hex("KiroFallback/" + fallbackKey)
|
||||
}
|
||||
return sha256Hex("KiroFallback/default")
|
||||
}
|
||||
|
||||
func shortSHA(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func sha256Hex(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func isHexString(value string) bool {
|
||||
for _, c := range value {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
|
||||
if normalized, ok := NormalizeMachineID(machineID); ok {
|
||||
return normalized
|
||||
}
|
||||
return BuildMachineID("", "", fallbackKey)
|
||||
}
|
||||
|
||||
func fingerprintLookupKey(accountKey, fallback string) string {
|
||||
key := strings.TrimSpace(accountKey)
|
||||
if key != "" {
|
||||
return key
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func BuildRuntimeUserAgent(accountKey, machineID string) string {
|
||||
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
|
||||
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
|
||||
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
|
||||
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
|
||||
"amz-sdk-invocation-id": uuid.NewString(),
|
||||
"amz-sdk-request": "attempt=1; max=4",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
|
||||
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
}
|
||||
}
|
||||
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := baseDelay << attempt
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
const jitterFactor = 0.3
|
||||
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
|
||||
return time.Duration(float64(delay) * jitter)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildLoginHeadersStable(t *testing.T) {
|
||||
headers1 := BuildLoginHeaders("", "")
|
||||
headers2 := BuildLoginHeaders("", "")
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
|
||||
require.Equal(t, "application/json", headers1["Content-Type"])
|
||||
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
|
||||
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
|
||||
}
|
||||
|
||||
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
|
||||
machineIDA := BuildMachineID("refresh-a", "", "")
|
||||
machineIDB := BuildMachineID("refresh-b", "", "")
|
||||
headers1 := BuildLoginHeaders("account-a", machineIDA)
|
||||
headers2 := BuildLoginHeaders("account-a", machineIDA)
|
||||
headers3 := BuildLoginHeaders("account-a", machineIDB)
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
|
||||
require.Contains(t, headers1["User-Agent"], machineIDA)
|
||||
}
|
||||
|
||||
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
|
||||
machineID := BuildMachineID("", "", "oidc-machine")
|
||||
headers1 := BuildOIDCHeaders("account-a", machineID)
|
||||
headers2 := BuildOIDCHeaders("account-a", machineID)
|
||||
headers3 := BuildOIDCHeaders("account-b", machineID)
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
|
||||
}
|
||||
|
||||
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
|
||||
key1 := BuildAccountKey("", "", "", "", 42)
|
||||
key2 := BuildAccountKey("", "", "", "", 42)
|
||||
key3 := BuildAccountKey("", "", "", "", 43)
|
||||
|
||||
require.Equal(t, key1, key2)
|
||||
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
|
||||
require.NotEqual(t, key1, key3)
|
||||
}
|
||||
|
||||
func TestBuildMachineID(t *testing.T) {
|
||||
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
|
||||
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
|
||||
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
|
||||
|
||||
fallback1 := BuildMachineID("", "", "account:1")
|
||||
fallback2 := BuildMachineID("", "", "account:1")
|
||||
fallback3 := BuildMachineID("", "", "account:2")
|
||||
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
|
||||
require.Equal(t, fallback1, fallback2)
|
||||
require.NotEqual(t, fallback1, fallback3)
|
||||
require.Len(t, fallback1, 64)
|
||||
}
|
||||
|
||||
func TestNormalizeMachineID(t *testing.T) {
|
||||
hex64 := strings.Repeat("A", 64)
|
||||
normalized, ok := NormalizeMachineID(hex64)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, strings.ToLower(hex64), normalized)
|
||||
|
||||
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
|
||||
|
||||
_, ok = NormalizeMachineID("not-a-machine-id")
|
||||
require.False(t, ok)
|
||||
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func expectedKiroMachineID(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package kiro
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
var DefaultModels = []Model{
|
||||
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
|
||||
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
|
||||
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
|
||||
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
|
||||
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
|
||||
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
|
||||
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
|
||||
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
|
||||
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
|
||||
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
|
||||
ids := make([]string, 0, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
|
||||
require.Equal(t, []string{
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-6-thinking",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6-thinking",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5-20251101-thinking",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5-20250929-thinking",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5-20251001-thinking",
|
||||
}, ids)
|
||||
|
||||
require.Contains(t, ids, "claude-sonnet-4-6")
|
||||
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
|
||||
require.NotContains(t, ids, "auto")
|
||||
require.NotContains(t, ids, "claude-sonnet-4")
|
||||
require.NotContains(t, ids, "gpt-4o")
|
||||
require.NotContains(t, ids, "deepseek-3-2")
|
||||
require.NotContains(t, ids, "minimax-m2-1")
|
||||
require.NotContains(t, ids, "qwen3-coder-next")
|
||||
require.NotContains(t, ids, "claude-opus-4-7")
|
||||
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
|
||||
for _, id := range ids {
|
||||
require.NotContains(t, id, "kiro-")
|
||||
require.NotContains(t, id, "-agentic")
|
||||
require.NotContains(t, id, "-chat")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
socialAuthPortalURL = "https://app.kiro.dev"
|
||||
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
defaultIDCRegion = "us-east-1"
|
||||
BuilderIDStartURL = "https://view.awsapps.com/start"
|
||||
sessionTTL = 10 * time.Minute
|
||||
sessionCleanupEvery = 32
|
||||
sessionCleanupMin = 32
|
||||
)
|
||||
|
||||
var (
|
||||
socialAuthEndpointURL = socialAuthEndpoint
|
||||
oidcEndpointOverride = ""
|
||||
)
|
||||
|
||||
type SocialProvider string
|
||||
|
||||
const (
|
||||
SocialProviderGoogle SocialProvider = "Google"
|
||||
SocialProviderGitHub SocialProvider = "Github"
|
||||
)
|
||||
|
||||
type AuthSession struct {
|
||||
State string
|
||||
CodeVerifier string
|
||||
ProxyURL string
|
||||
CreatedAt time.Time
|
||||
AuthType string
|
||||
Provider string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Region string
|
||||
StartURL string
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
data map[string]*AuthSession
|
||||
setCount uint64
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
return &SessionStore{data: make(map[string]*AuthSession)}
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
session, ok := s.data[id]
|
||||
if ok && sessionExpired(session, now) {
|
||||
delete(s.data, id)
|
||||
return nil, false
|
||||
}
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(id string, session *AuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.setCount++
|
||||
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
|
||||
s.pruneExpiredLocked(time.Now())
|
||||
}
|
||||
s.data[id] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.data, id)
|
||||
}
|
||||
|
||||
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
|
||||
for id, session := range s.data {
|
||||
if sessionExpired(session, now) {
|
||||
delete(s.data, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sessionExpired(session *AuthSession, now time.Time) bool {
|
||||
if session == nil {
|
||||
return true
|
||||
}
|
||||
if session.CreatedAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return now.After(session.CreatedAt.Add(sessionTTL))
|
||||
}
|
||||
|
||||
type TokenData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
ExpiresAt string `json:"expiresAt,omitempty"`
|
||||
AuthMethod string `json:"authMethod,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
type socialTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
type registerClientResponse struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
type createTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
type userInfoResponse struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type deviceRegistration struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
type RefreshTokenInvalidError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *RefreshTokenInvalidError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
body := strings.TrimSpace(e.Body)
|
||||
if body == "" {
|
||||
return "kiro refresh token invalid (invalid_grant)"
|
||||
}
|
||||
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
|
||||
}
|
||||
|
||||
func GenerateSessionID() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
return randomURLSafe(16)
|
||||
}
|
||||
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
return randomURLSafe(32)
|
||||
}
|
||||
|
||||
func randomURLSafe(n int) (string, error) {
|
||||
buf := make([]byte, n)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
|
||||
params := url.Values{}
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("redirect_from", "KiroIDE")
|
||||
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
|
||||
}
|
||||
|
||||
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
|
||||
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
|
||||
if redirectURI == "" {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSpace(callbackPath)
|
||||
if path == "" {
|
||||
path = "/oauth/callback"
|
||||
} else if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
fullRedirectURI := redirectURI + path
|
||||
if option := strings.TrimSpace(loginOption); option != "" {
|
||||
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
|
||||
}
|
||||
return fullRedirectURI
|
||||
}
|
||||
|
||||
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
|
||||
payload := map[string]string{
|
||||
"code": code,
|
||||
"code_verifier": codeVerifier,
|
||||
"redirect_uri": redirectURI,
|
||||
}
|
||||
var resp socialTokenResponse
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
return &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
|
||||
payload := map[string]string{
|
||||
"refreshToken": refreshToken,
|
||||
}
|
||||
var resp socialTokenResponse
|
||||
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
return &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: provider,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]any{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||
"redirectUris": []string{redirectURI},
|
||||
"issuerUrl": issuerURL,
|
||||
}
|
||||
var resp registerClientResponse
|
||||
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scopes", strings.Join([]string{
|
||||
"codewhisperer:completions",
|
||||
"codewhisperer:analysis",
|
||||
"codewhisperer:conversations",
|
||||
"codewhisperer:transformations",
|
||||
"codewhisperer:taskassist",
|
||||
}, " "))
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
|
||||
}
|
||||
|
||||
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"code": code,
|
||||
"codeVerifier": codeVerifier,
|
||||
"redirectUri": redirectURI,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
var resp createTokenResponse
|
||||
accountKey := BuildAccountKey(clientID, "", "", "", 0)
|
||||
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
token := &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}
|
||||
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"refreshToken": refreshToken,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
var resp createTokenResponse
|
||||
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
|
||||
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
token := &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}
|
||||
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return ""
|
||||
}
|
||||
var resp userInfoResponse
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
}
|
||||
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(resp.Email)
|
||||
}
|
||||
|
||||
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
|
||||
var token TokenData
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
|
||||
}
|
||||
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
||||
if strings.TrimSpace(token.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
|
||||
var reg deviceRegistration
|
||||
if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse device registration: %w", err)
|
||||
}
|
||||
if reg.ClientID != "" {
|
||||
token.ClientID = reg.ClientID
|
||||
}
|
||||
if reg.ClientSecret != "" {
|
||||
token.ClientSecret = reg.ClientSecret
|
||||
}
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func getOIDCEndpoint(region string) string {
|
||||
if strings.TrimSpace(oidcEndpointOverride) != "" {
|
||||
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||
}
|
||||
|
||||
func oidcHeaders(accountKey, machineID string) map[string]string {
|
||||
headers := BuildOIDCHeaders(accountKey, machineID)
|
||||
if headers["amz-sdk-invocation-id"] == "" {
|
||||
headers["amz-sdk-invocation-id"] = uuid.NewString()
|
||||
}
|
||||
if headers["amz-sdk-request"] == "" {
|
||||
headers["amz-sdk-request"] = "attempt=1; max=4"
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
|
||||
client, err := newHTTPClient(proxyURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(encoded)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if payload != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
for key, value := range extraHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyText := strings.TrimSpace(string(respBody))
|
||||
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
|
||||
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
|
||||
}
|
||||
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
|
||||
}
|
||||
if out == nil || len(respBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(respBody, out)
|
||||
}
|
||||
|
||||
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
|
||||
_, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport := &http.Transport{}
|
||||
if parsed != nil {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/refreshToken", r.URL.Path)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := socialAuthEndpointURL
|
||||
socialAuthEndpointURL = server.URL
|
||||
t.Cleanup(func() { socialAuthEndpointURL = previous })
|
||||
|
||||
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
|
||||
require.Error(t, err)
|
||||
|
||||
var invalid *RefreshTokenInvalidError
|
||||
require.True(t, errors.As(err, &invalid))
|
||||
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||
require.Contains(t, invalid.Body, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/token", r.URL.Path)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
|
||||
require.Error(t, err)
|
||||
|
||||
var invalid *RefreshTokenInvalidError
|
||||
require.True(t, errors.As(err, &invalid))
|
||||
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||
require.Contains(t, invalid.Body, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
|
||||
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, profileArn, token.ProfileArn)
|
||||
require.Equal(t, "kiro@example.com", token.Email)
|
||||
}
|
||||
|
||||
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
|
||||
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, profileArn, token.ProfileArn)
|
||||
require.Equal(t, "kiro@example.com", token.Email)
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build unit
|
||||
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
|
||||
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
|
||||
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
|
||||
if got != want {
|
||||
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSocialTokenRedirectURI(t *testing.T) {
|
||||
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
|
||||
want := "http://localhost:49153/oauth/callback?login_option=github"
|
||||
if got != want {
|
||||
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
|
||||
|
||||
session, ok := store.Get("expired")
|
||||
if ok || session != nil {
|
||||
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
|
||||
}
|
||||
if _, exists := store.data["expired"]; exists {
|
||||
t.Fatalf("expired session should be deleted from the store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
now := time.Now()
|
||||
for i := 0; i < sessionCleanupMin; i++ {
|
||||
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
|
||||
}
|
||||
store.setCount = sessionCleanupEvery - 1
|
||||
|
||||
store.Set("fresh", &AuthSession{CreatedAt: now})
|
||||
|
||||
if len(store.data) != 1 {
|
||||
t.Fatalf("store size = %d, want 1", len(store.data))
|
||||
}
|
||||
if _, ok := store.data["fresh"]; !ok {
|
||||
t.Fatalf("fresh session should remain after pruning")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,368 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
|
||||
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
|
||||
|
||||
var cachedWebSearchDescription atomic.Value // stores string
|
||||
|
||||
type MCPRequest struct {
|
||||
ID string `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type MCPResponse struct {
|
||||
Result *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result,omitempty"`
|
||||
Error *struct {
|
||||
Code *int `json:"code,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type WebSearchResults struct {
|
||||
Results []WebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
type WebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Snippet *string `json:"snippet,omitempty"`
|
||||
PublishedDate *int64 `json:"publishedDate,omitempty"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Domain *string `json:"domain,omitempty"`
|
||||
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
|
||||
PublicDomain *bool `json:"publicDomain,omitempty"`
|
||||
}
|
||||
|
||||
type SearchIndicator struct {
|
||||
ToolUseID string
|
||||
Query string
|
||||
Results *WebSearchResults
|
||||
}
|
||||
|
||||
func GetCachedWebSearchDescription() string {
|
||||
if v := cachedWebSearchDescription.Load(); v != nil {
|
||||
return strings.TrimSpace(v.(string))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SetCachedWebSearchDescription(desc string) {
|
||||
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
|
||||
}
|
||||
|
||||
func BuildMcpEndpoint(region string) string {
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
}
|
||||
|
||||
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
|
||||
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, item := range resp.Result.Content {
|
||||
if item.Type != "" && item.Type != "text" {
|
||||
continue
|
||||
}
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
|
||||
return &results
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExtractSearchQuery(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
arr := messages.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
msg := arr[i]
|
||||
if msg.Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
text := extractSearchText(msg.Get("content"))
|
||||
const prefix = "Perform a web search for the query: "
|
||||
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||
if text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractSearchText(content gjson.Result) string {
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
if !content.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func GenerateToolUseID() string {
|
||||
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
|
||||
}
|
||||
|
||||
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body, err
|
||||
}
|
||||
rawTools, ok := payload["tools"].([]interface{})
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
replaced := make([]interface{}, 0, len(rawTools))
|
||||
for _, rawTool := range rawTools {
|
||||
tool, ok := rawTool.(map[string]interface{})
|
||||
if !ok {
|
||||
replaced = append(replaced, rawTool)
|
||||
continue
|
||||
}
|
||||
name := getInterfaceString(tool["name"])
|
||||
toolType := getInterfaceString(tool["type"])
|
||||
if !isWebSearchToolName(name, toolType) {
|
||||
replaced = append(replaced, rawTool)
|
||||
continue
|
||||
}
|
||||
replaced = append(replaced, map[string]interface{}{
|
||||
"name": "web_search",
|
||||
"description": minimalWebSearchDescription,
|
||||
"input_schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The search query to execute",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
payload["tools"] = replaced
|
||||
updated, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
|
||||
}
|
||||
|
||||
rawMessages, ok := payload["messages"].([]interface{})
|
||||
if !ok {
|
||||
return claudePayload, fmt.Errorf("claude payload missing messages array")
|
||||
}
|
||||
|
||||
assistantMsg := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": query},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userContent := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": formatToolResultText(results),
|
||||
},
|
||||
}
|
||||
if guidance := searchGuidanceText(); guidance != "" {
|
||||
userContent = append(userContent, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": guidance,
|
||||
})
|
||||
}
|
||||
userMsg := map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": userContent,
|
||||
}
|
||||
|
||||
rawMessages = append(rawMessages, assistantMsg, userMsg)
|
||||
payload["messages"] = rawMessages
|
||||
updated, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
|
||||
if len(searches) == 0 {
|
||||
return responsePayload, nil
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(responsePayload, &response); err != nil {
|
||||
return responsePayload, err
|
||||
}
|
||||
content, _ := response["content"].([]interface{})
|
||||
updated := make([]interface{}, 0, len(searches)*2+len(content))
|
||||
for _, search := range searches {
|
||||
updated = append(updated, map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": search.ToolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": search.Query},
|
||||
})
|
||||
updated = append(updated, map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"content": buildSearchResultContent(search.Results),
|
||||
})
|
||||
}
|
||||
updated = append(updated, content...)
|
||||
response["content"] = updated
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return responsePayload, err
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
|
||||
content := make([]map[string]interface{}, 0)
|
||||
if results == nil {
|
||||
return content
|
||||
}
|
||||
for _, result := range results.Results {
|
||||
snippet := ""
|
||||
if result.Snippet != nil {
|
||||
snippet = strings.TrimSpace(*result.Snippet)
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": result.Title,
|
||||
"url": result.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
|
||||
content := gjson.GetBytes(responsePayload, "content")
|
||||
if !content.IsArray() {
|
||||
return "", "", false
|
||||
}
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() != "tool_use" {
|
||||
continue
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if !isWebSearchToolName(name, "") {
|
||||
continue
|
||||
}
|
||||
query = strings.TrimSpace(block.Get("input.query").String())
|
||||
if query == "" {
|
||||
continue
|
||||
}
|
||||
return block.Get("id").String(), query, true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func isWebSearchToolName(name, toolType string) bool {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
toolType = strings.ToLower(strings.TrimSpace(toolType))
|
||||
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
||||
return true
|
||||
}
|
||||
switch name {
|
||||
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getInterfaceString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(val)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(val))
|
||||
}
|
||||
}
|
||||
|
||||
func formatToolResultText(results *WebSearchResults) string {
|
||||
if results == nil || len(results.Results) == 0 {
|
||||
return "No search results found."
|
||||
}
|
||||
payload, err := json.MarshalIndent(results.Results, "", " ")
|
||||
if err != nil {
|
||||
return "Found search results, but failed to format them."
|
||||
}
|
||||
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
|
||||
}
|
||||
|
||||
func searchGuidanceText() string {
|
||||
now := time.Now()
|
||||
return fmt.Sprintf(`<search_guidance>
|
||||
Current date: %s (%s)
|
||||
|
||||
IMPORTANT: Evaluate the search results above carefully. If the results are:
|
||||
- Mostly spam, SEO junk, or unrelated websites
|
||||
- Missing actual information about the query topic
|
||||
- Outdated or not matching the requested time frame
|
||||
|
||||
Then you MUST use the web_search tool again with a refined query. Try:
|
||||
- Rephrasing in English for better coverage
|
||||
- Using more specific keywords
|
||||
- Adding date context
|
||||
|
||||
Do NOT apologize for bad results without first attempting a re-search.
|
||||
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BufferedStreamResult struct {
|
||||
StopReason string
|
||||
WebSearchQuery string
|
||||
WebSearchToolUseID string
|
||||
HasWebSearchToolUse bool
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if results != nil {
|
||||
for _, result := range results.Results {
|
||||
snippet := ""
|
||||
if result.Snippet != nil {
|
||||
snippet = strings.TrimSpace(*result.Snippet)
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": result.Title,
|
||||
"url": result.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
|
||||
events := []map[string]interface{}{
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": toolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
}
|
||||
|
||||
result := make([][]byte, 0, len(events))
|
||||
for _, event := range events {
|
||||
eventType, _ := event["type"].(string)
|
||||
payload, _ := json.Marshal(event)
|
||||
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
var currentToolName string
|
||||
currentToolIndex := -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "message_delta":
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
|
||||
result.StopReason = stopReason
|
||||
}
|
||||
}
|
||||
case "content_block_start":
|
||||
contentBlock, ok := event["content_block"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := contentBlock["type"].(string)
|
||||
if blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
currentToolName, _ = contentBlock["name"].(string)
|
||||
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
|
||||
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
case "content_block_delta":
|
||||
if currentToolName == "" {
|
||||
continue
|
||||
}
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
deltaType, _ := delta["type"].(string)
|
||||
if deltaType != "input_json_delta" {
|
||||
continue
|
||||
}
|
||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partialJSON)
|
||||
}
|
||||
case "content_block_stop":
|
||||
if !isWebSearchToolName(currentToolName, "") {
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
continue
|
||||
}
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
|
||||
result.WebSearchQuery = strings.TrimSpace(input["query"])
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
|
||||
filtered := make([][]byte, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
|
||||
if shouldForward {
|
||||
filtered = append(filtered, adjusted)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
return filterSSEChunk(chunk, -1, offset)
|
||||
}
|
||||
|
||||
func MaxContentBlockIndex(chunks [][]byte) int {
|
||||
maxIndex := -1
|
||||
for _, chunk := range chunks {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
|
||||
maxIndex = int(idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return maxIndex
|
||||
}
|
||||
|
||||
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
var builder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
|
||||
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
}
|
||||
builder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||
continue
|
||||
}
|
||||
adjusted := adjustEventPayload(payload, indexOffset)
|
||||
if adjusted == "" {
|
||||
continue
|
||||
}
|
||||
builder.WriteString("data: " + adjusted + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
return []byte(builder.String()), true
|
||||
}
|
||||
|
||||
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
|
||||
if payload == "" {
|
||||
return false
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return false
|
||||
}
|
||||
eventType, _ := event["type"].(string)
|
||||
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
|
||||
return true
|
||||
}
|
||||
if webSearchToolUseIndex < 0 {
|
||||
return false
|
||||
}
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func adjustEventPayload(payload string, indexOffset int) string {
|
||||
if payload == "" || indexOffset == 0 {
|
||||
return payload
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return payload
|
||||
}
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
if adjusted, err := json.Marshal(event); err == nil {
|
||||
return string(adjusted)
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
|
||||
snippet := "result snippet"
|
||||
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
|
||||
Results: []WebSearchResult{
|
||||
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
|
||||
},
|
||||
}, 0)
|
||||
|
||||
require.Len(t, events, 5)
|
||||
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
|
||||
require.Contains(t, string(events[0]), `"input":{}`)
|
||||
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
|
||||
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
|
||||
require.NotContains(t, string(events[3]), `"tool_use_id"`)
|
||||
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
|
||||
}
|
||||
|
||||
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||
}
|
||||
|
||||
result := AnalyzeBufferedStream(chunks)
|
||||
require.True(t, result.HasWebSearchToolUse)
|
||||
require.Equal(t, "golang concurrency", result.WebSearchQuery)
|
||||
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
|
||||
require.Equal(t, 1, result.WebSearchToolUseIndex)
|
||||
require.Equal(t, "tool_use", result.StopReason)
|
||||
}
|
||||
|
||||
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||
}
|
||||
|
||||
filtered := FilterChunksForClient(chunks, 1, 2)
|
||||
require.NotEmpty(t, filtered)
|
||||
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
|
||||
require.NotContains(t, joined, `"type":"message_start"`)
|
||||
require.NotContains(t, joined, `"type":"message_delta"`)
|
||||
require.NotContains(t, joined, `"name":"web_search"`)
|
||||
require.Contains(t, joined, `"index":2`)
|
||||
require.Equal(t, 2, MaxContentBlockIndex(filtered))
|
||||
}
|
||||
|
||||
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
|
||||
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
|
||||
require.False(t, shouldForward)
|
||||
|
||||
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
|
||||
require.True(t, shouldForward)
|
||||
require.Contains(t, string(adjusted), `"index":3`)
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools":[{"type":"web_search_20250305","description":"old"}],
|
||||
"messages":[{"role":"user","content":"golang"}]
|
||||
}`)
|
||||
|
||||
updated, err := ReplaceWebSearchToolDescription(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
|
||||
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
|
||||
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
|
||||
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
|
||||
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
|
||||
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
|
||||
}
|
||||
|
||||
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[{"role":"user","content":"what is golang"}]
|
||||
}`)
|
||||
results := &WebSearchResults{
|
||||
Results: []WebSearchResult{
|
||||
{Title: "Go", URL: "https://go.dev"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
|
||||
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
|
||||
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
|
||||
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
|
||||
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
|
||||
}
|
||||
|
||||
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
|
||||
response := []byte(`{
|
||||
"content":[
|
||||
{"type":"text","text":"let me search"},
|
||||
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
|
||||
]
|
||||
}`)
|
||||
|
||||
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "srvtoolu_next", toolUseID)
|
||||
require.Equal(t, "golang concurrency", query)
|
||||
}
|
||||
|
||||
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
|
||||
response := []byte(`{
|
||||
"id":"msg_1",
|
||||
"type":"message",
|
||||
"role":"assistant",
|
||||
"model":"kiro",
|
||||
"content":[{"type":"text","text":"final"}],
|
||||
"stop_reason":"end_turn",
|
||||
"usage":{"input_tokens":1,"output_tokens":1}
|
||||
}`)
|
||||
|
||||
snippet := "result snippet"
|
||||
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
|
||||
{
|
||||
ToolUseID: "srvtoolu_test",
|
||||
Query: "golang",
|
||||
Results: &WebSearchResults{
|
||||
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &decoded))
|
||||
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
|
||||
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
|
||||
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
|
||||
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
|
||||
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
|
||||
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
|
||||
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
|
||||
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
|
||||
}
|
||||
|
||||
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
|
||||
resp := &MCPResponse{
|
||||
Result: &struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
}{
|
||||
Content: []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}{
|
||||
{
|
||||
Type: "text",
|
||||
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ParseSearchResults(resp)
|
||||
require.NotNil(t, results)
|
||||
require.Len(t, results.Results, 1)
|
||||
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
|
||||
require.Equal(t, "doc-1", *results.Results[0].ID)
|
||||
require.Equal(t, "go.dev", *results.Results[0].Domain)
|
||||
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
|
||||
require.True(t, *results.Results[0].PublicDomain)
|
||||
}
|
||||
|
||||
func TestSearchGuidanceText_IsStructured(t *testing.T) {
|
||||
guidance := searchGuidanceText()
|
||||
require.Contains(t, guidance, "<search_guidance>")
|
||||
require.Contains(t, guidance, "Current date:")
|
||||
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
|
||||
require.Contains(t, guidance, "Rephrasing in English for better coverage")
|
||||
}
|
||||
@@ -0,0 +1,479 @@
|
||||
package kirocooldown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
MinRequestInterval = time.Second
|
||||
MaxRequestInterval = 2 * time.Second
|
||||
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
|
||||
ShortCooldown = time.Minute
|
||||
MaxCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
|
||||
redisTimeout = 3 * time.Second
|
||||
activeTTL = 10 * time.Second
|
||||
stateTTL = 25 * time.Hour
|
||||
keyPrefix = "kiro:cooldown:"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||
|
||||
reserveRequestScript = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
|
||||
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
|
||||
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
|
||||
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
|
||||
local interval_ms = tonumber(ARGV[1])
|
||||
local active_ttl_ms = tonumber(ARGV[2])
|
||||
local state_ttl_ms = tonumber(ARGV[3])
|
||||
|
||||
if cooldown_until_ms > now_ms then
|
||||
return {1, cooldown_until_ms - now_ms, cooldown_reason}
|
||||
end
|
||||
|
||||
if cooldown_until_ms > 0 then
|
||||
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
|
||||
end
|
||||
|
||||
local next_slot_ms = now_ms
|
||||
if last_request_ms > 0 then
|
||||
local candidate_ms = last_request_ms + interval_ms
|
||||
if candidate_ms > now_ms then
|
||||
next_slot_ms = candidate_ms
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
|
||||
if fail_count > 0 or cooldown_until_ms > now_ms then
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
else
|
||||
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
|
||||
end
|
||||
return {0, next_slot_ms - now_ms, ''}
|
||||
`)
|
||||
|
||||
mark429Script = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
|
||||
local short_cooldown_ms = tonumber(ARGV[1])
|
||||
local max_cooldown_ms = tonumber(ARGV[2])
|
||||
local state_ttl_ms = tonumber(ARGV[3])
|
||||
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
|
||||
if cooldown_ms > max_cooldown_ms then
|
||||
cooldown_ms = max_cooldown_ms
|
||||
end
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', fail_count,
|
||||
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||
'cooldown_reason', ARGV[4]
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
return cooldown_ms
|
||||
`)
|
||||
|
||||
markSuccessScript = redis.NewScript(`
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', 0,
|
||||
'cooldown_until_ms', 0,
|
||||
'cooldown_reason', ''
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
|
||||
return 1
|
||||
`)
|
||||
|
||||
markSuspendedScript = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local cooldown_ms = tonumber(ARGV[1])
|
||||
local state_ttl_ms = tonumber(ARGV[2])
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', 0,
|
||||
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||
'cooldown_reason', ARGV[3]
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
return cooldown_ms
|
||||
`)
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
remaining time.Duration
|
||||
reason string
|
||||
}
|
||||
|
||||
type State struct {
|
||||
Active bool
|
||||
Reason string
|
||||
CooldownUntil time.Time
|
||||
Remaining time.Duration
|
||||
FailCount int
|
||||
}
|
||||
|
||||
func NewError(remaining time.Duration, reason string) error {
|
||||
return &Error{remaining: remaining, reason: reason}
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.reason == "" {
|
||||
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
|
||||
}
|
||||
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
|
||||
}
|
||||
|
||||
func Calculate429Cooldown(retryCount int) time.Duration {
|
||||
if retryCount < 0 {
|
||||
retryCount = 0
|
||||
}
|
||||
cooldown := ShortCooldown * time.Duration(1<<retryCount)
|
||||
if cooldown > MaxCooldown {
|
||||
return MaxCooldown
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
rngMu sync.Mutex
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func NewStore(client *redis.Client) *Store {
|
||||
return &Store{
|
||||
client: client,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
values, err := reserveRequestScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
s.nextInterval().Milliseconds(),
|
||||
activeTTL.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
|
||||
}
|
||||
parts, ok := values.([]interface{})
|
||||
if !ok || len(parts) != 3 {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
|
||||
}
|
||||
state, err := luaInt64(parts[0])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
|
||||
}
|
||||
waitMS, err := luaInt64(parts[1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
|
||||
}
|
||||
reason, err := luaString(parts[2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
|
||||
}
|
||||
if state == 1 {
|
||||
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
|
||||
}
|
||||
if waitMS <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return time.Duration(waitMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
|
||||
if err := s.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
if err := markSuccessScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
activeTTL.Milliseconds(),
|
||||
).Err(); err != nil {
|
||||
return fmt.Errorf("kiro cooldown mark success: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
result, err := mark429Script.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
ShortCooldown.Milliseconds(),
|
||||
MaxCooldown.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
CooldownReason429,
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||
}
|
||||
cooldownMS, err := luaInt64(result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||
}
|
||||
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
result, err := markSuspendedScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
LongCooldown.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
CooldownReasonSuspended,
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||
}
|
||||
cooldownMS, err := luaInt64(result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||
}
|
||||
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
values, err := s.client.HMGet(
|
||||
cacheCtx,
|
||||
RedisKey(tokenKey),
|
||||
"cooldown_until_ms",
|
||||
"cooldown_reason",
|
||||
"fail_count",
|
||||
).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
|
||||
}
|
||||
if len(values) != 3 {
|
||||
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
|
||||
}
|
||||
|
||||
cooldownUntilMS, err := luaInt64(values[0])
|
||||
if err != nil && values[0] != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
|
||||
}
|
||||
reason, err := luaString(values[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
|
||||
}
|
||||
failCount, err := luaInt64(values[2])
|
||||
if err != nil && values[2] != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
|
||||
}
|
||||
if cooldownUntilMS <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cooldownUntil := time.UnixMilli(cooldownUntilMS)
|
||||
remaining := time.Until(cooldownUntil)
|
||||
if remaining <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &State{
|
||||
Active: true,
|
||||
Reason: reason,
|
||||
CooldownUntil: cooldownUntil,
|
||||
Remaining: remaining,
|
||||
FailCount: int(failCount),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
uniqueKeys := make([]string, 0, len(tokenKeys))
|
||||
seen := make(map[string]struct{}, len(tokenKeys))
|
||||
for _, tokenKey := range tokenKeys {
|
||||
tokenKey = strings.TrimSpace(tokenKey)
|
||||
if tokenKey == "" {
|
||||
continue
|
||||
}
|
||||
redisKey := RedisKey(tokenKey)
|
||||
if _, ok := seen[redisKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[redisKey] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, redisKey)
|
||||
}
|
||||
if len(uniqueKeys) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
type candidate struct {
|
||||
redisKey string
|
||||
cooldownUntilMS int64
|
||||
failCount int64
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
var best *candidate
|
||||
|
||||
pipe := s.client.Pipeline()
|
||||
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
|
||||
for _, redisKey := range uniqueKeys {
|
||||
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
|
||||
}
|
||||
if _, err := pipe.Exec(cacheCtx); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
|
||||
}
|
||||
|
||||
for i, cmd := range cmds {
|
||||
values, err := cmd.Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
|
||||
}
|
||||
if len(values) != 3 {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
|
||||
}
|
||||
cooldownUntilMS, err := luaInt64(values[0])
|
||||
if err != nil && values[0] != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
|
||||
}
|
||||
reason, err := luaString(values[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
|
||||
}
|
||||
failCount, err := luaInt64(values[2])
|
||||
if err != nil && values[2] != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
|
||||
}
|
||||
if cooldownUntilMS <= now || reason != CooldownReason429 {
|
||||
continue
|
||||
}
|
||||
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
|
||||
if best == nil ||
|
||||
current.cooldownUntilMS < best.cooldownUntilMS ||
|
||||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
|
||||
best = current
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
|
||||
}
|
||||
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func RedisKey(tokenKey string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
|
||||
digest := hex.EncodeToString(sum[:])
|
||||
return keyPrefix + "{" + digest + "}"
|
||||
}
|
||||
|
||||
func ActiveTTL() time.Duration {
|
||||
return activeTTL
|
||||
}
|
||||
|
||||
func StateTTL() time.Duration {
|
||||
return stateTTL
|
||||
}
|
||||
|
||||
func (s *Store) validate() error {
|
||||
if s == nil || s.client == nil {
|
||||
return ErrStoreUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) nextInterval() time.Duration {
|
||||
s.rngMu.Lock()
|
||||
defer s.rngMu.Unlock()
|
||||
if MaxRequestInterval <= MinRequestInterval {
|
||||
return MinRequestInterval
|
||||
}
|
||||
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
|
||||
}
|
||||
|
||||
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, redisTimeout)
|
||||
}
|
||||
|
||||
func luaInt64(v any) (int64, error) {
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return n, nil
|
||||
case int:
|
||||
return int64(n), nil
|
||||
case string:
|
||||
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
|
||||
case []byte:
|
||||
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func luaString(v any) (string, error) {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s, nil
|
||||
case []byte:
|
||||
return string(s), nil
|
||||
case nil:
|
||||
return "", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported lua string type %T", v)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package kirocooldown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
|
||||
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
|
||||
|
||||
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
|
||||
}
|
||||
if cleared {
|
||||
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
|
||||
store := NewStore(nil)
|
||||
|
||||
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
|
||||
if err == nil {
|
||||
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
|
||||
}
|
||||
if cleared {
|
||||
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
service.PlatformKiro: 1,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
|
||||
@@ -41,6 +41,9 @@ func RegisterAdminRoutes(
|
||||
// Antigravity OAuth
|
||||
registerAntigravityOAuthRoutes(admin, h)
|
||||
|
||||
// Kiro OAuth / IDC
|
||||
registerKiroOAuthRoutes(admin, h)
|
||||
|
||||
// 代理管理
|
||||
registerProxyRoutes(admin, h)
|
||||
|
||||
@@ -315,6 +318,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
accounts.GET("/kiro/default-model-mapping", h.Admin.Account.GetKiroDefaultModelMapping)
|
||||
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
@@ -367,6 +371,17 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
|
||||
}
|
||||
}
|
||||
|
||||
func registerKiroOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
kiro := admin.Group("/kiro")
|
||||
{
|
||||
kiro.POST("/oauth/auth-url", h.Admin.KiroOAuth.GenerateAuthURL)
|
||||
kiro.POST("/oauth/idc-auth-url", h.Admin.KiroOAuth.GenerateIDCAuthURL)
|
||||
kiro.POST("/oauth/exchange-code", h.Admin.KiroOAuth.ExchangeCode)
|
||||
kiro.POST("/oauth/refresh-token", h.Admin.KiroOAuth.RefreshToken)
|
||||
kiro.POST("/oauth/import-token", h.Admin.KiroOAuth.ImportToken)
|
||||
}
|
||||
}
|
||||
|
||||
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
|
||||
@@ -48,6 +48,13 @@ type Account struct {
|
||||
TempUnschedulableUntil *time.Time
|
||||
TempUnschedulableReason string
|
||||
|
||||
KiroQuotaState string
|
||||
KiroQuotaReason string
|
||||
KiroQuotaResetAt *time.Time
|
||||
KiroRuntimeState string
|
||||
KiroRuntimeReason string
|
||||
KiroRuntimeResetAt *time.Time
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
@@ -164,6 +171,10 @@ func (a *Account) IsGemini() bool {
|
||||
return a.Platform == PlatformGemini
|
||||
}
|
||||
|
||||
func (a *Account) IsKiro() bool {
|
||||
return a.Platform == PlatformKiro
|
||||
}
|
||||
|
||||
func (a *Account) GeminiOAuthType() string {
|
||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||
return ""
|
||||
@@ -478,17 +489,17 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
|
||||
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
||||
if a.Credentials == nil {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
// 部分平台在未显式配置 model_mapping 时仍应使用默认映射,
|
||||
// 以限制可调度/可转发的模型集合。
|
||||
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||
return defaults
|
||||
}
|
||||
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||
return nil
|
||||
}
|
||||
if len(rawMapping) == 0 {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||
return defaults
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -510,13 +521,23 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
||||
return result
|
||||
}
|
||||
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||
return defaults
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultModelMappingForPlatform(platform string) map[string]string {
|
||||
switch platform {
|
||||
case domain.PlatformAntigravity:
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
case domain.PlatformKiro:
|
||||
return domain.DefaultKiroModelMapping
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mapPtr(m map[string]any) uintptr {
|
||||
if m == nil {
|
||||
return 0
|
||||
@@ -608,8 +629,8 @@ func resolveRequestedModelInMapping(mapping map[string]string, requestedModel st
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)。
|
||||
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时也会先回退到默认映射。
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
@@ -622,8 +643,8 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
|
||||
}
|
||||
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)。
|
||||
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时返回默认映射结果。
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
@@ -725,6 +746,9 @@ func (a *Account) GetBaseURL() string {
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
if a.Platform == PlatformKiro {
|
||||
return ""
|
||||
}
|
||||
return "https://api.anthropic.com"
|
||||
}
|
||||
if a.Platform == PlatformAntigravity {
|
||||
|
||||
@@ -180,7 +180,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
|
||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||
}
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
|
||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
@@ -66,6 +67,7 @@ type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
kiroTokenProvider *KiroTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
@@ -77,6 +79,7 @@ func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
kiroTokenProvider *KiroTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
@@ -86,6 +89,7 @@ func NewAccountTestService(
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
kiroTokenProvider: kiroTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
@@ -192,6 +196,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.IsKiro() && account.Type == AccountTypeOAuth {
|
||||
return s.testKiroAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
@@ -240,6 +248,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" && account.Platform == PlatformKiro {
|
||||
return s.sendErrorAndEnd(c, "Kiro API Key accounts require a Base URL")
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
@@ -388,6 +399,149 @@ func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Con
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) testKiroAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
testModelID := strings.TrimSpace(modelID)
|
||||
if testModelID == "" {
|
||||
testModelID = "claude-sonnet-4-6"
|
||||
}
|
||||
if mappedModel := account.GetMappedModel(testModelID); strings.TrimSpace(mappedModel) != "" {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Kiro account type: %s", account.Type))
|
||||
}
|
||||
|
||||
if s.kiroTokenProvider == nil {
|
||||
return s.sendErrorAndEnd(c, "Kiro token provider not configured")
|
||||
}
|
||||
|
||||
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get Kiro access token: %s", err.Error()))
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
resp, err := s.executeKiroTestUpstream(ctx, account, payloadBytes, testModelID, accessToken)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, formatKiroTestError(resp.StatusCode, body, testModelID, account))
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
_, streamErr := kiropkg.StreamEventStreamAsAnthropic(ctx, resp.Body, pw, testModelID, estimateKiroInputTokens(payloadBytes))
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
return s.processClaudeStream(c, pr)
|
||||
}
|
||||
|
||||
func formatKiroTestError(statusCode int, body []byte, requestedModel string, account *Account) string {
|
||||
return fmt.Sprintf("API returned %d: %s", statusCode, string(body))
|
||||
}
|
||||
|
||||
func (s *AccountTestService) executeKiroTestUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string) (*http.Response, error) {
|
||||
modelID := kiropkg.MapModel(mappedModel)
|
||||
currentToken := token
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := buildResult.Payload
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
proxyURL := kiroProxyURL(account)
|
||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
maxRetries := 2
|
||||
for idx, endpoint := range endpoints {
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
|
||||
if idx+1 < len(endpoints) {
|
||||
_ = resp.Body.Close()
|
||||
break
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
|
||||
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = buildResult.Payload
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("kiro upstream endpoints exhausted")
|
||||
}
|
||||
|
||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
region := bedrockRuntimeRegion(account)
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
)
|
||||
|
||||
func TestAccountTestService_KiroAPIKeyUsesGenericAnthropicCompatiblePath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 19,
|
||||
Name: "kiro-apikey-test",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"base_url": "https://kiro-upstream.example.com",
|
||||
"api_key": "kiro-api-key",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"invalid api key"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
req := upstream.requests[0]
|
||||
require.Equal(t, "kiro-upstream.example.com", req.URL.Host)
|
||||
require.Equal(t, "/v1/messages", req.URL.Path)
|
||||
require.Equal(t, "kiro-api-key", req.Header.Get("x-api-key"))
|
||||
require.Empty(t, req.Header.Get("Authorization"))
|
||||
require.Equal(t, claude.APIKeyBetaHeader, req.Header.Get("anthropic-beta"))
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroAPIKeyWithoutBaseURLErrors(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Name: "kiro-apikey-missing-base-url",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "kiro-api-key",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: &queuedHTTPUpstream{},
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Base URL")
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountTestService_KiroUsesKiroUpstreamInsteadOfAnthropic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "kiro-test",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/TESTSOCIAL",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{1: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"Invalid bearer token"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "gpt-4o", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
req := upstream.requests[0]
|
||||
require.Equal(t, "q.us-east-1.amazonaws.com", req.URL.Host)
|
||||
require.Equal(t, "/generateAssistantResponse", req.URL.Path)
|
||||
require.Equal(t, "Bearer kiro-access-token", req.Header.Get("Authorization"))
|
||||
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Empty(t, req.Header.Get("anthropic-version"))
|
||||
require.NotContains(t, req.URL.Host, "api.anthropic.com")
|
||||
}
|
||||
|
||||
func TestAccountTestService_Kiro429DoesNotFallbackToCodeWhispererEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Name: "kiro-fallback",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"api_region": "us-west-2",
|
||||
"region": "us-west-2",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTFALLBACK",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{2: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusTooManyRequests, `{"message":"slow down"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
|
||||
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
|
||||
require.Contains(t, err.Error(), "API returned 429")
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroIDCWithoutProfileArnOmitsProfileArnAndUsesDefaultRuntimeRegion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "kiro-idc-default-region",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"auth_method": "idc",
|
||||
"provider": "AWS",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{5: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
require.Equal(t, "q.us-east-1.amazonaws.com", upstream.requests[0].URL.Host)
|
||||
body, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||
require.NoError(t, readErr)
|
||||
require.NotContains(t, string(body), `"profileArn":`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroInvalidModelErrorPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "kiro-invalid-model",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTINVALIDMODEL",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, `API returned 400: {"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`, err.Error())
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "kiro-invalid-model-refresh",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{7: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "API returned 400")
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||
require.NoError(t, readErr)
|
||||
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroPreferredEndpointIsIgnored(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "kiro-preferred-endpoint",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"api_region": "us-west-2",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/PREFERRED",
|
||||
"preferred_endpoint": "codewhisperer",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
|
||||
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccount_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "kiro-builder-id",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "idc",
|
||||
"provider": "BuilderId",
|
||||
"region": "us-east-1",
|
||||
"client_id": "builder-client-id",
|
||||
},
|
||||
}
|
||||
|
||||
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(testPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(kiroPayload), `"profileArn":`)
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccount_KiroBuilderIDUsesCredentialProfileArn(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 33,
|
||||
Name: "kiro-builder-id-cached",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "builder-id",
|
||||
"provider": "BuilderId",
|
||||
"region": "us-east-1",
|
||||
"client_id": "builder-client-id",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED",
|
||||
},
|
||||
}
|
||||
|
||||
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(testPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(kiroPayload), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED"`)
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccount_KiroEnterpriseIDCOmitsMissingProfileArn(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 4,
|
||||
Name: "kiro-enterprise-idc",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "idc",
|
||||
"provider": "AWS",
|
||||
"region": "us-east-1",
|
||||
"client_id": "enterprise-client-id",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
|
||||
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(testPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(kiroPayload), `"profileArn":`)
|
||||
}
|
||||
@@ -103,10 +103,17 @@ type antigravityUsageCache struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// kiroUsageCache 缓存 Kiro 额度快照
|
||||
type kiroUsageCache struct {
|
||||
usageInfo *UsageInfo
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
kiroUsageErrorTTL = 1 * time.Minute // Kiro 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -118,8 +125,10 @@ type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
kiroUsageCache sync.Map // accountID -> *kiroUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
kiroUsageFlight singleflight.Group // 防止同一 Kiro 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
@@ -176,6 +185,23 @@ type AICredit struct {
|
||||
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
||||
}
|
||||
|
||||
// KiroCreditProgress 表示 Kiro 主额度或 Bonus 的用量进度。
|
||||
type KiroCreditProgress struct {
|
||||
CurrentUsage float64 `json:"current_usage"`
|
||||
UsageLimit float64 `json:"usage_limit"`
|
||||
PercentageUsed float64 `json:"percentage_used"`
|
||||
DaysRemaining int `json:"days_remaining,omitempty"`
|
||||
ExpiryDate *time.Time `json:"expiry_date,omitempty"`
|
||||
}
|
||||
|
||||
// KiroOverageInfo 表示 Kiro 账号的 overage 状态。
|
||||
type KiroOverageInfo struct {
|
||||
CurrentOverages float64 `json:"current_overages"`
|
||||
OverageCharges float64 `json:"overage_charges"`
|
||||
CurrencyCode string `json:"currency_code,omitempty"`
|
||||
CurrencySymbol string `json:"currency_symbol,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
Source string `json:"source,omitempty"` // "passive" or "active"
|
||||
@@ -203,6 +229,21 @@ type UsageInfo struct {
|
||||
// Antigravity AI Credits 余额
|
||||
AICredits []AICredit `json:"ai_credits,omitempty"`
|
||||
|
||||
// Kiro Credits 额度与 overage 信息
|
||||
KiroSubscriptionName string `json:"kiro_subscription_name,omitempty"`
|
||||
KiroSubscriptionType string `json:"kiro_subscription_type,omitempty"`
|
||||
KiroResetAt *time.Time `json:"kiro_reset_at,omitempty"`
|
||||
KiroOveragesEnabled bool `json:"kiro_overages_enabled,omitempty"`
|
||||
KiroCredit *KiroCreditProgress `json:"kiro_credit,omitempty"`
|
||||
KiroBonus *KiroCreditProgress `json:"kiro_bonus,omitempty"`
|
||||
KiroOverage *KiroOverageInfo `json:"kiro_overage,omitempty"`
|
||||
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
|
||||
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
|
||||
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
|
||||
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
|
||||
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
|
||||
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
@@ -266,6 +307,7 @@ type AccountUsageService struct {
|
||||
cache *UsageCache
|
||||
identityCache IdentityCache
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
kiroCooldownStore KiroCooldownStore
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
@@ -291,6 +333,13 @@ func NewAccountUsageService(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) SetKiroCooldownStore(store KiroCooldownStore) *AccountUsageService {
|
||||
if s != nil {
|
||||
s.kiroCooldownStore = store
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// GetUsage 获取账号使用量
|
||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
@@ -317,6 +366,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return usage, err
|
||||
}
|
||||
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
return s.getKiroUsage(ctx, account, "active", false)
|
||||
}
|
||||
|
||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||
if account.Platform == PlatformAntigravity {
|
||||
usage, err := s.getAntigravityUsage(ctx, account)
|
||||
@@ -425,6 +478,13 @@ func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformKiro {
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("passive usage only supported for Kiro OAuth accounts")
|
||||
}
|
||||
return s.getKiroUsage(ctx, account, "passive", false)
|
||||
}
|
||||
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroAPIKeyUnsupported(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9101,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.Nil(t, usage)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "does not support usage query")
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetPassiveUsage_KiroAPIKeyUnsupported(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9102,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
usage, err := svc.GetPassiveUsage(context.Background(), account.ID)
|
||||
require.Nil(t, usage)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Kiro OAuth")
|
||||
}
|
||||
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
requestedModel: "any-model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping falls back to default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping rejects model outside default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "auto",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping uses default upstream mapping",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: "claude-sonnet-4.6",
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
|
||||
@@ -1716,7 +1716,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
}
|
||||
|
||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||
@@ -2008,7 +2008,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
|
||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||
|
||||
@@ -198,6 +198,13 @@ func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
|
||||
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
|
||||
|
||||
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
|
||||
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
|
||||
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
|
||||
|
||||
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
|
||||
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
|
||||
|
||||
// Gemini 3.1 Pro
|
||||
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-6, // $2 per MTok
|
||||
@@ -278,13 +285,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
|
||||
return s.fallbackPrices["claude-sonnet-4.6"]
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-sonnet-4.5"]
|
||||
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-haiku-4.5"]
|
||||
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
|
||||
@@ -39,6 +39,7 @@ const (
|
||||
PlatformOpenAI = domain.PlatformOpenAI
|
||||
PlatformGemini = domain.PlatformGemini
|
||||
PlatformAntigravity = domain.PlatformAntigravity
|
||||
PlatformKiro = domain.PlatformKiro
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
|
||||
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -105,44 +109,63 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -102,44 +106,63 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
@@ -56,6 +57,7 @@ const (
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
postUsageBillingTimeout = 15 * time.Second
|
||||
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
|
||||
defaultKiroStreamKeepalive = 25 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -70,6 +72,7 @@ const (
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
type kiroCooldownRecoveryAttemptedKeyType struct{}
|
||||
|
||||
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
|
||||
type accountWithLoad struct {
|
||||
@@ -78,6 +81,7 @@ type accountWithLoad struct {
|
||||
}
|
||||
|
||||
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
||||
var kiroCooldownRecoveryAttemptedKey = kiroCooldownRecoveryAttemptedKeyType{}
|
||||
|
||||
var (
|
||||
windowCostPrefetchCacheHitTotal atomic.Int64
|
||||
@@ -554,6 +558,8 @@ type GatewayService struct {
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
kiroTokenProvider *KiroTokenProvider
|
||||
kiroCooldownStore KiroCooldownStore
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
@@ -592,6 +598,8 @@ func NewGatewayService(
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
kiroTokenProvider *KiroTokenProvider,
|
||||
kiroCooldownStore KiroCooldownStore,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
rpmCache RPMCache,
|
||||
digestStore *DigestSessionStore,
|
||||
@@ -624,6 +632,8 @@ func NewGatewayService(
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
kiroTokenProvider: kiroTokenProvider,
|
||||
kiroCooldownStore: kiroCooldownStore,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
@@ -902,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
|
||||
injectMetadata bool
|
||||
metadataUserID string
|
||||
stripSystemCacheControl bool
|
||||
preserveToolChoice bool
|
||||
}
|
||||
|
||||
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
|
||||
@@ -1116,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
|
||||
if !gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
@@ -1967,6 +1984,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, useMixed) {
|
||||
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
|
||||
return s.SelectAccountWithLoadAwareness(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID)
|
||||
}
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
@@ -2346,14 +2367,91 @@ func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
return account.IsSchedulable()
|
||||
if !account.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
return s.isKiroRuntimeSchedulable(context.Background(), account)
|
||||
}
|
||||
|
||||
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
|
||||
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
return false
|
||||
}
|
||||
return s.isKiroRuntimeSchedulable(ctx, account)
|
||||
}
|
||||
|
||||
func (s *GatewayService) isKiroRuntimeSchedulable(ctx context.Context, account *Account) bool {
|
||||
if account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth || s == nil || s.kiroCooldownStore == nil {
|
||||
return true
|
||||
}
|
||||
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(account))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return state == nil || !state.Active
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryRecoverKiroCooldownPool(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) bool {
|
||||
if s == nil || s.kiroCooldownStore == nil || ctx.Value(kiroCooldownRecoveryAttemptedKey) == true {
|
||||
return false
|
||||
}
|
||||
tokenKeys := s.kiroTransientCooldownRecoveryKeys(ctx, accounts, requestedModel, excludedIDs, allowMixedScheduling)
|
||||
if len(tokenKeys) == 0 {
|
||||
return false
|
||||
}
|
||||
cleared, err := s.kiroCooldownStore.ClearEarliestTransientCooldown(ctx, tokenKeys)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery failed: %v", err)
|
||||
return false
|
||||
}
|
||||
if cleared {
|
||||
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery cleared one transient cooldown")
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
func (s *GatewayService) kiroTransientCooldownRecoveryKeys(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) []string {
|
||||
tokenKeys := make([]string, 0, len(accounts))
|
||||
eligible := 0
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc == nil || acc.Platform != PlatformKiro || acc.Type != AccountTypeOAuth {
|
||||
if allowMixedScheduling {
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) ||
|
||||
!s.isAccountSchedulableForWindowCost(ctx, acc, false) ||
|
||||
!s.isAccountSchedulableForRPM(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
eligible++
|
||||
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc))
|
||||
if err != nil || state == nil || !state.Active {
|
||||
return nil
|
||||
}
|
||||
if state.Reason != kirocooldown.CooldownReason429 {
|
||||
return nil
|
||||
}
|
||||
tokenKeys = append(tokenKeys, buildKiroAccountKey(acc))
|
||||
}
|
||||
if eligible == 0 || len(tokenKeys) != eligible {
|
||||
return nil
|
||||
}
|
||||
return tokenKeys
|
||||
}
|
||||
|
||||
// isAccountInGroup checks if the account belongs to the specified group.
|
||||
@@ -3232,6 +3330,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
|
||||
if selected == nil {
|
||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
||||
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, false) {
|
||||
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
|
||||
return s.selectAccountForModelWithPlatform(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||
}
|
||||
@@ -3611,6 +3713,17 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
return selectionFailureDiagnosis{Category: "excluded"}
|
||||
}
|
||||
if !acc.IsSchedulable() {
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||
}
|
||||
if acc.Platform == PlatformKiro && acc.Type == AccountTypeOAuth {
|
||||
if state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc)); err == nil && state != nil && state.Active {
|
||||
return selectionFailureDiagnosis{
|
||||
Category: "unschedulable",
|
||||
Detail: fmt.Sprintf("kiro_runtime_%s remaining=%s", state.Reason, state.Remaining.Truncate(time.Second)),
|
||||
}
|
||||
}
|
||||
}
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||
}
|
||||
@@ -3774,6 +3887,13 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth && s.kiroTokenProvider != nil {
|
||||
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
||||
accessToken := account.GetCredential("access_token")
|
||||
@@ -4343,11 +4463,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
@@ -4371,6 +4486,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return s.forwardBedrock(ctx, c, account, parsed, startTime)
|
||||
}
|
||||
|
||||
if account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
return s.forwardKiroMessages(ctx, c, account, parsed, startTime)
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||
if account.Platform == PlatformAnthropic && c != nil {
|
||||
@@ -4425,7 +4549,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
|
||||
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
|
||||
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{
|
||||
stripSystemCacheControl: !systemRewritten,
|
||||
preserveToolChoice: account.Platform == PlatformKiro,
|
||||
}
|
||||
if s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil && fp != nil {
|
||||
@@ -4462,7 +4589,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(reqModel); next != "" && next != reqModel {
|
||||
mappedModel = next
|
||||
mappingSource = "account"
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
mappingSource = "account"
|
||||
@@ -5967,6 +6099,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" && account.Platform == PlatformKiro {
|
||||
return nil, fmt.Errorf("kiro api key account requires base_url")
|
||||
}
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
@@ -7228,10 +7363,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
keepaliveInterval := s.streamKeepaliveIntervalForAccount(account)
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
@@ -8277,6 +8409,9 @@ type recordUsageOpts struct {
|
||||
// 长上下文计费(仅 Gemini 路径需要)
|
||||
LongContextThreshold int
|
||||
LongContextMultiplier float64
|
||||
|
||||
// Kiro 账号在上游返回 auto 等无法定价模型时使用保守计费兜底。
|
||||
IsKiroAccount bool
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -8414,6 +8549,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
opts.IsKiroAccount = account != nil && account.Platform == PlatformKiro
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
@@ -8492,6 +8628,28 @@ func (s *GatewayService) calculateRecordUsageCost(
|
||||
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
}
|
||||
|
||||
const kiroConservativeFallbackBillingModel = "claude-opus-4-6"
|
||||
|
||||
func shouldUseKiroConservativeBillingFallback(result *ForwardResult, billingModel string, opts *recordUsageOpts) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return opts != nil && opts.IsKiroAccount
|
||||
}
|
||||
|
||||
func (s *GatewayService) calculateKiroConservativeTokenCost(tokens UsageTokens, multiplier float64) *CostBreakdown {
|
||||
if s == nil || s.billingService == nil {
|
||||
return nil
|
||||
}
|
||||
cost, err := s.billingService.CalculateCost(kiroConservativeFallbackBillingModel, tokens, multiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate conservative Kiro fallback cost failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
return cost
|
||||
}
|
||||
|
||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||
@@ -8596,6 +8754,12 @@ func (s *GatewayService) calculateTokenCost(
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
if shouldUseKiroConservativeBillingFallback(result, billingModel, opts) {
|
||||
if fallback := s.calculateKiroConservativeTokenCost(tokens, multiplier); fallback != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Using conservative Kiro fallback pricing for model=%s", billingModel)
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
return &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
return cost
|
||||
@@ -8856,6 +9020,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射:
|
||||
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
|
||||
@@ -9486,6 +9654,19 @@ func reconcileCachedTokens(usage map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *GatewayService) streamKeepaliveIntervalForAccount(account *Account) time.Duration {
|
||||
if account != nil && account.Platform == PlatformKiro {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.KiroStreamKeepaliveInterval > 0 {
|
||||
return time.Duration(s.cfg.Gateway.KiroStreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
return defaultKiroStreamKeepalive
|
||||
}
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
return time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
|
||||
|
||||
// initDebugGatewayBodyFile 初始化网关调试日志文件。
|
||||
|
||||
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
|
||||
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
|
||||
// Anthropic API Key keeps the existing account-level override:
|
||||
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Kiro OAuth uses channel-level switch only.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
@@ -62,22 +64,37 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
return false
|
||||
}
|
||||
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
if account == nil {
|
||||
return false
|
||||
default: // "default" → follow channel config
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(account.Platform)
|
||||
}
|
||||
|
||||
switch {
|
||||
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
return false
|
||||
default:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
}
|
||||
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(platform)
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||
"message": map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
return flushSSEJSON(w, "message_start", evt)
|
||||
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
"name": toolNameWebSearch, "input": map[string]any{},
|
||||
},
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
|
||||
return err
|
||||
}
|
||||
inputJSON, err := json.Marshal(map[string]string{"query": query})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal query: %w", err)
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||
blocks := make([]map[string]string, 0, len(results))
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
|
||||
blocks := make([]map[string]any, 0, len(results))
|
||||
for _, r := range results {
|
||||
block := map[string]string{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
}
|
||||
if r.Snippet != "" {
|
||||
block["page_content"] = r.Snippet
|
||||
block := map[string]any{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
"encrypted_content": r.Snippet,
|
||||
"page_age": nil,
|
||||
}
|
||||
if r.PageAge != "" {
|
||||
block["page_age"] = r.PageAge
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +14,31 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"cache_creation_input_tokens":0`)
|
||||
require.Contains(t, body, `"cache_read_input_tokens":0`)
|
||||
}
|
||||
|
||||
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `event: content_block_start`)
|
||||
require.Contains(t, body, `"type":"server_tool_use"`)
|
||||
require.Contains(t, body, `"input":{}`)
|
||||
require.Contains(t, body, `event: content_block_delta`)
|
||||
require.Contains(t, body, `"type":"input_json_delta"`)
|
||||
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, body, `event: content_block_stop`)
|
||||
}
|
||||
|
||||
// --- isOnlyWebSearchToolInBody ---
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
|
||||
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||
// Second result has no PageAge
|
||||
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||
_, hasPageAge := blocks[1]["page_age"]
|
||||
require.False(t, hasPageAge)
|
||||
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
|
||||
require.Nil(t, blocks[1]["page_age"])
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
|
||||
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||
_, hasContent := blocks[0]["page_content"]
|
||||
require.False(t, hasContent)
|
||||
require.Equal(t, "", blocks[0]["encrypted_content"])
|
||||
require.Nil(t, blocks[0]["page_age"])
|
||||
}
|
||||
|
||||
// --- buildTextSummary ---
|
||||
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
func newKiroOAuthAccount() *Account {
|
||||
return &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
|
||||
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
|
||||
// nil channelService + default mode → returns false
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 11,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(77, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(77)
|
||||
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 12,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(78, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(78)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,623 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type kiroUsageCooldownStore struct {
|
||||
state *kirocooldown.State
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||
return s.state, s.err
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func kiroFloatPtr(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
|
||||
},
|
||||
}
|
||||
|
||||
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_kiro_billing_normalized",
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
|
||||
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
|
||||
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
|
||||
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
|
||||
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "*/*", r.Header.Get("Accept"))
|
||||
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
|
||||
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
|
||||
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
|
||||
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageConfiguration": {"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentOveragesWithPrecision":2,
|
||||
"currentUsageWithPrecision":125,
|
||||
"freeTrialInfo":{
|
||||
"currentUsageWithPrecision":25,
|
||||
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
|
||||
"freeTrialStatus":"ACTIVE",
|
||||
"usageLimitWithPrecision":500
|
||||
},
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageCharges":0.08,
|
||||
"resourceType":"CREDIT",
|
||||
"usageLimitWithPrecision":2000
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, "active", usage.Source)
|
||||
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
|
||||
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
|
||||
require.True(t, usage.KiroOveragesEnabled)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
|
||||
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
|
||||
require.NotNil(t, usage.KiroBonus)
|
||||
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
|
||||
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
|
||||
require.NotNil(t, usage.KiroOverage)
|
||||
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
|
||||
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
|
||||
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
|
||||
require.NotNil(t, usage.KiroResetAt)
|
||||
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
|
||||
require.NotNil(t, usage.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 702,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":300,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer successServer.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.NotNil(t, firstUsage.KiroCredit)
|
||||
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
|
||||
|
||||
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
|
||||
}))
|
||||
defer failingServer.Close()
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Empty(t, usage.Error)
|
||||
require.Empty(t, usage.ErrorCode)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 703,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "BuilderId",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 707,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":64,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 709,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"api_region": "eu-west-1",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":11,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{"eu-west-1"}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 710,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":7,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 704,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(90 * time.Second),
|
||||
Remaining: 90 * time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cooldown", usage.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
|
||||
require.NotNil(t, usage.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: `{"message":"profileArn is required for this request."}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeForbidden, info.ErrorCode)
|
||||
require.False(t, info.NeedsReauth)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: `{"message":"overage exhausted for this billing window"}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
|
||||
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
|
||||
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 708,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
requestCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
|
||||
|
||||
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, secondUsage)
|
||||
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
|
||||
require.Equal(t, 1, requestCount)
|
||||
}
|
||||
|
||||
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
|
||||
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
|
||||
NextDateReset: "2099-03-13T12:00:00Z",
|
||||
OverageConfiguration: kiroOverageConfiguration{
|
||||
OverageStatus: "DISABLED",
|
||||
},
|
||||
UsageBreakdownList: []kiroUsageBreakdown{
|
||||
{
|
||||
ResourceType: "CREDIT",
|
||||
CurrentUsageWithPrecision: kiroFloatPtr(2000),
|
||||
UsageLimitWithPrecision: kiroFloatPtr(2000),
|
||||
CurrentOveragesWithPrecision: kiroFloatPtr(0),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
|
||||
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
|
||||
require.NotNil(t, info.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
|
||||
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(2 * time.Minute),
|
||||
Remaining: 2 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
account := &Account{
|
||||
ID: 705,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
|
||||
require.Equal(t, "cooldown", account.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
|
||||
require.NotNil(t, account.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 706,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset":"2099-03-13T12:00:00Z",
|
||||
"overageConfiguration":{"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":2000,
|
||||
"currentOveragesWithPrecision":4,
|
||||
"overageCharges":0.2,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
_, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
target := &Account{
|
||||
ID: account.ID,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
|
||||
|
||||
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
|
||||
require.NotNil(t, target.KiroQuotaResetAt)
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
|
||||
account := Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformKiro,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
require.Empty(t, account.GetBaseURL())
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
|
||||
|
||||
require.Equal(t, 25*time.Second, got)
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamKeepaliveInterval: 10,
|
||||
KiroStreamKeepaliveInterval: 25,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
|
||||
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, nil)
|
||||
|
||||
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
upstreamModel string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "kiro claude sonnet 4.6 uses pricing key format",
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
upstreamModel: "claude-sonnet-4.6",
|
||||
want: "claude-sonnet-4-6",
|
||||
},
|
||||
{
|
||||
name: "falls back to upstream when requested model empty",
|
||||
requestedModel: "",
|
||||
upstreamModel: "claude-haiku-4-5",
|
||||
want: "claude-haiku-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_billing_normalized",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_auto_fallback_cost",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "auto",
|
||||
UpstreamModel: "auto",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 601, Quota: 100},
|
||||
User: &User{ID: 701},
|
||||
Account: &Account{ID: 801, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
kiroErrorAuthError = "auth_error"
|
||||
kiroErrorMonthlyRequest = "monthly_request_count"
|
||||
kiroErrorProfileError = "profile_error"
|
||||
kiroErrorQuotaExhausted = "quota_exhausted"
|
||||
kiroErrorOverageExhausted = "overage_exhausted"
|
||||
kiroErrorRateLimited = "rate_limited"
|
||||
kiroErrorSuspended = "suspended"
|
||||
kiroErrorUsageForbidden = "usage_forbidden"
|
||||
kiroErrorUpstreamTransient = "upstream_transient"
|
||||
kiroErrorBadRequestSchema = "bad_request_schema"
|
||||
kiroErrorBadRequestToolPairing = "bad_request_tool_pairing"
|
||||
kiroErrorBadRequestInvalidModel = "bad_request_invalid_model"
|
||||
kiroErrorBadRequestAuth = "bad_request_auth"
|
||||
kiroErrorBadRequestQuota = "bad_request_quota"
|
||||
kiroErrorBadRequestUnknown = "bad_request_unknown"
|
||||
kiroErrorRefreshTokenInvalid = "refresh_token_invalid"
|
||||
|
||||
kiroQuotaStateNormal = "normal"
|
||||
kiroQuotaStateOverageActive = "overage_active"
|
||||
kiroQuotaStateCreditsExhausted = "credits_exhausted"
|
||||
kiroQuotaStateOverageExhausted = "overage_exhausted"
|
||||
)
|
||||
|
||||
type kiroErrorClassification struct {
|
||||
Category string
|
||||
StatusCode int
|
||||
Message string
|
||||
}
|
||||
|
||||
func classifyKiroHTTPError(statusCode int, body string) kiroErrorClassification {
|
||||
trimmed := strings.TrimSpace(body)
|
||||
lower := strings.ToLower(trimmed)
|
||||
|
||||
switch {
|
||||
case statusCode == http.StatusUnauthorized:
|
||||
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusPaymentRequired && looksLikeKiroMonthlyRequestCountError(trimmed):
|
||||
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusForbidden && isKiroSuspendedBody([]byte(trimmed)):
|
||||
return kiroErrorClassification{Category: kiroErrorSuspended, StatusCode: statusCode, Message: trimmed}
|
||||
case looksLikeKiroProfileError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorProfileError, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusBadRequest:
|
||||
return classifyKiroBadRequest(trimmed, lower)
|
||||
case statusCode == http.StatusForbidden && isKiroTokenErrorBody([]byte(trimmed)):
|
||||
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
|
||||
case looksLikeKiroOverageExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorOverageExhausted, StatusCode: statusCode, Message: trimmed}
|
||||
case looksLikeKiroQuotaExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusTooManyRequests:
|
||||
return kiroErrorClassification{Category: kiroErrorRateLimited, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusForbidden:
|
||||
return kiroErrorClassification{Category: kiroErrorUsageForbidden, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode >= 500:
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
|
||||
default:
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
|
||||
}
|
||||
}
|
||||
|
||||
func classifyKiroError(err error) kiroErrorClassification {
|
||||
if err == nil {
|
||||
return kiroErrorClassification{}
|
||||
}
|
||||
|
||||
var httpErr *kiroUsageHTTPError
|
||||
if errors.As(err, &httpErr) && httpErr != nil {
|
||||
return classifyKiroHTTPError(httpErr.StatusCode, httpErr.Body)
|
||||
}
|
||||
|
||||
errStr := strings.TrimSpace(err.Error())
|
||||
lower := strings.ToLower(errStr)
|
||||
switch {
|
||||
case looksLikeKiroInvalidGrantError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorRefreshTokenInvalid, Message: errStr}
|
||||
case looksLikeKiroMonthlyRequestCountError(errStr):
|
||||
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, Message: errStr}
|
||||
case looksLikeKiroProfileError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorProfileError, Message: errStr}
|
||||
case looksLikeKiroOverageExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorOverageExhausted, Message: errStr}
|
||||
case looksLikeKiroQuotaExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, Message: errStr}
|
||||
case strings.Contains(lower, "context deadline exceeded"),
|
||||
strings.Contains(lower, "timeout"),
|
||||
isNetErr(err):
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
|
||||
default:
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
|
||||
}
|
||||
}
|
||||
|
||||
func classifyKiroBadRequest(trimmed, lower string) kiroErrorClassification {
|
||||
switch {
|
||||
case looksLikeKiroBadRequestSchemaError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestSchema, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroBadRequestToolPairingError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestToolPairing, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroBadRequestInvalidModelError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestInvalidModel, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroInvalidGrantError(lower) || looksLikeKiroBadRequestAuthError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestAuth, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroQuotaExhaustedError(lower) || looksLikeKiroMonthlyRequestCountError(trimmed):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestQuota, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
default:
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestUnknown, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestSchemaError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "schema") ||
|
||||
strings.Contains(lower, "inputschema") ||
|
||||
strings.Contains(lower, "improperly formed request") ||
|
||||
strings.Contains(lower, "additionalproperties") ||
|
||||
(strings.Contains(lower, "properties") && strings.Contains(lower, "required"))
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestToolPairingError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "tool_use") ||
|
||||
strings.Contains(lower, "tool_result") ||
|
||||
strings.Contains(lower, "tooluseid") ||
|
||||
strings.Contains(lower, "toolresults") ||
|
||||
strings.Contains(lower, "must be paired") ||
|
||||
strings.Contains(lower, "missing tool result")
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestInvalidModelError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "invalid model") ||
|
||||
strings.Contains(lower, "invalid_model_id") ||
|
||||
strings.Contains(lower, "model not supported") ||
|
||||
strings.Contains(lower, "unsupportedmodel") ||
|
||||
strings.Contains(lower, "modelid")
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestAuthError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "invalid token") ||
|
||||
strings.Contains(lower, "expired token") ||
|
||||
strings.Contains(lower, "access token") ||
|
||||
strings.Contains(lower, "refresh token")
|
||||
}
|
||||
|
||||
func looksLikeKiroInvalidGrantError(lower string) bool {
|
||||
return strings.Contains(lower, "invalid_grant")
|
||||
}
|
||||
|
||||
func looksLikeKiroMonthlyRequestCountError(body string) bool {
|
||||
trimmed := strings.TrimSpace(body)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, "MONTHLY_REQUEST_COUNT") {
|
||||
return true
|
||||
}
|
||||
if !gjson.Valid(trimmed) {
|
||||
return false
|
||||
}
|
||||
return gjson.Get(trimmed, "reason").String() == "MONTHLY_REQUEST_COUNT" ||
|
||||
gjson.Get(trimmed, "error.reason").String() == "MONTHLY_REQUEST_COUNT"
|
||||
}
|
||||
|
||||
func looksLikeKiroProfileError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return (strings.Contains(lower, "profilearn") && strings.Contains(lower, "required")) ||
|
||||
(strings.Contains(lower, "profile arn") && strings.Contains(lower, "required")) ||
|
||||
(strings.Contains(lower, "profile") && strings.Contains(lower, "not found")) ||
|
||||
(strings.Contains(lower, "invalid profile")) ||
|
||||
(strings.Contains(lower, "listavailableprofiles"))
|
||||
}
|
||||
|
||||
func looksLikeKiroQuotaExhaustedError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return (strings.Contains(lower, "credit") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "depleted"))) ||
|
||||
(strings.Contains(lower, "quota") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "exceeded") || strings.Contains(lower, "depleted"))) ||
|
||||
(strings.Contains(lower, "usage limit") && (strings.Contains(lower, "reached") || strings.Contains(lower, "exceeded"))) ||
|
||||
(strings.Contains(lower, "resource has been exhausted"))
|
||||
}
|
||||
|
||||
func looksLikeKiroOverageExhaustedError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "overage") &&
|
||||
(strings.Contains(lower, "exhaust") ||
|
||||
strings.Contains(lower, "disabled") ||
|
||||
strings.Contains(lower, "not enabled") ||
|
||||
strings.Contains(lower, "not allowed") ||
|
||||
strings.Contains(lower, "limit"))
|
||||
}
|
||||
|
||||
func isNetErr(err error) bool {
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyKiroHTTPErrorBadRequestCategories(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "schema",
|
||||
body: `{"message":"Improperly formed request: inputSchema.properties must be an object"}`,
|
||||
want: kiroErrorBadRequestSchema,
|
||||
},
|
||||
{
|
||||
name: "tool pairing",
|
||||
body: `{"message":"tool_use must be paired with a matching tool_result"}`,
|
||||
want: kiroErrorBadRequestToolPairing,
|
||||
},
|
||||
{
|
||||
name: "invalid model id",
|
||||
body: `{"message":"invalid modelId: model not supported"}`,
|
||||
want: kiroErrorBadRequestInvalidModel,
|
||||
},
|
||||
{
|
||||
name: "invalid model upstream",
|
||||
body: `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
|
||||
want: kiroErrorBadRequestInvalidModel,
|
||||
},
|
||||
{
|
||||
name: "invalid model reason",
|
||||
body: `{"message":"model route unavailable","reason":"INVALID_MODEL_ID"}`,
|
||||
want: kiroErrorBadRequestInvalidModel,
|
||||
},
|
||||
{
|
||||
name: "auth",
|
||||
body: `{"error":"invalid_grant","message":"Invalid refresh token provided"}`,
|
||||
want: kiroErrorBadRequestAuth,
|
||||
},
|
||||
{
|
||||
name: "quota",
|
||||
body: `{"message":"resource has been exhausted"}`,
|
||||
want: kiroErrorBadRequestQuota,
|
||||
},
|
||||
{
|
||||
name: "unknown",
|
||||
body: `{"message":"bad request"}`,
|
||||
want: kiroErrorBadRequestUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
classification := classifyKiroHTTPError(http.StatusBadRequest, tt.body)
|
||||
require.Equal(t, tt.want, classification.Category)
|
||||
require.Equal(t, http.StatusBadRequest, classification.StatusCode)
|
||||
require.Equal(t, tt.body, classification.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func buildKiroAccountKey(account *Account) string {
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
return kiropkg.BuildAccountKey(
|
||||
account.GetCredential("client_id"),
|
||||
account.GetCredential("client_id_hash"),
|
||||
account.GetCredential("refresh_token"),
|
||||
account.GetCredential("profile_arn"),
|
||||
account.ID,
|
||||
)
|
||||
}
|
||||
|
||||
func buildKiroMachineID(account *Account) string {
|
||||
if account == nil {
|
||||
return kiropkg.BuildMachineID("", "", "account:nil")
|
||||
}
|
||||
for _, key := range []string{"machine_id", "machineId"} {
|
||||
if machineID, ok := kiropkg.NormalizeMachineID(account.GetCredential(key)); ok {
|
||||
return machineID
|
||||
}
|
||||
}
|
||||
fallbackKey := buildKiroMachineIDFallbackKey(account)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
return kiropkg.BuildMachineID("", firstKiroCredential(account, "kiro_api_key", "kiroApiKey", "api_key"), fallbackKey)
|
||||
}
|
||||
return kiropkg.BuildMachineID(account.GetCredential("refresh_token"), "", fallbackKey)
|
||||
}
|
||||
|
||||
func firstKiroCredential(account *Account, keys ...string) string {
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
for _, key := range keys {
|
||||
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildKiroMachineIDFallbackKey(account *Account) string {
|
||||
if account == nil {
|
||||
return "account:nil"
|
||||
}
|
||||
if account.ID > 0 {
|
||||
return fmt.Sprintf("account:%d", account.ID)
|
||||
}
|
||||
for _, key := range []string{"client_id", "profile_arn"} {
|
||||
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
|
||||
return key + ":" + value
|
||||
}
|
||||
}
|
||||
if name := strings.TrimSpace(account.Name); name != "" {
|
||||
return "name:" + name
|
||||
}
|
||||
return "account:unknown"
|
||||
}
|
||||
|
||||
func buildKiroRequestID(resp *http.Response) string {
|
||||
if resp == nil {
|
||||
return ""
|
||||
}
|
||||
if requestID := strings.TrimSpace(resp.Header.Get("x-request-id")); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if requestID := strings.TrimSpace(resp.Header.Get("x-amzn-requestid")); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
return strings.TrimSpace(resp.Header.Get("x-amz-request-id"))
|
||||
}
|
||||
|
||||
func isKiroInvalidModelIDBody(respBody []byte) bool {
|
||||
var payload struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
Error struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(respBody, &payload) != nil {
|
||||
return looksLikeKiroBadRequestInvalidModelError(strings.ToLower(string(respBody)))
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(payload.Reason), "INVALID_MODEL_ID") ||
|
||||
strings.EqualFold(strings.TrimSpace(payload.Error.Reason), "INVALID_MODEL_ID") ||
|
||||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Message)) ||
|
||||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Error.Message))
|
||||
}
|
||||
|
||||
func isKiroSuspendedBody(respBody []byte) bool {
|
||||
body := string(respBody)
|
||||
return strings.Contains(body, "SUSPENDED") || strings.Contains(body, "TEMPORARILY_SUSPENDED")
|
||||
}
|
||||
|
||||
func isKiroTokenErrorBody(respBody []byte) bool {
|
||||
lower := strings.ToLower(string(respBody))
|
||||
return strings.Contains(lower, "token") ||
|
||||
strings.Contains(lower, "expired") ||
|
||||
strings.Contains(lower, "invalid") ||
|
||||
strings.Contains(lower, "unauthorized")
|
||||
}
|
||||
|
||||
func kiroProxyURL(account *Account) string {
|
||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func kiroAPIRegion(account *Account) string {
|
||||
if account == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
region := strings.TrimSpace(account.GetCredential("api_region"))
|
||||
if region == "" {
|
||||
region = kiroDefaultRegion
|
||||
}
|
||||
return region
|
||||
}
|
||||
|
||||
func applyKiroConditionalHeaders(req *http.Request, account *Account) {
|
||||
if req == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(account.GetCredential("auth_method")), "external_idp") {
|
||||
req.Header.Set("TokenType", "EXTERNAL_IDP")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(account.GetCredential("provider")), "Internal") {
|
||||
req.Header.Set("redirect-for-internal", "true")
|
||||
}
|
||||
}
|
||||
|
||||
func resolveKiroPayloadProfileArn(account *Account) string {
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||
}
|
||||
|
||||
func newKiroJSONRequest(ctx context.Context, endpointURL string, payload []byte, token, accountKey, machineID, amzTarget string, account *Account) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||
if amzTarget != "" {
|
||||
req.Header.Set("X-Amz-Target", amzTarget)
|
||||
}
|
||||
if account != nil {
|
||||
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||
if profileArn != "" {
|
||||
req.Header.Set("x-amzn-kiro-profile-arn", profileArn)
|
||||
}
|
||||
}
|
||||
applyKiroConditionalHeaders(req, account)
|
||||
return req, nil
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
|
||||
accountA := &Account{
|
||||
ID: 99,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-a",
|
||||
},
|
||||
}
|
||||
accountB := &Account{
|
||||
ID: 99,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-b",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, buildKiroAccountKey(accountA), buildKiroAccountKey(accountB))
|
||||
}
|
||||
|
||||
func TestBuildKiroMachineIDPrefersExplicitCredential(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"machineId": "2582956e-cc88-4669-b546-07adbffcb894",
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", buildKiroMachineID(account))
|
||||
}
|
||||
|
||||
func TestBuildKiroMachineIDDerivesFromRefreshToken(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiropkg.BuildMachineID("refresh-token", "", "account:102"), buildKiroMachineID(account))
|
||||
}
|
||||
|
||||
func TestBuildKiroMachineIDDerivesFromAPIKeyAccount(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"kiroApiKey": "kiro-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiropkg.BuildMachineID("", "kiro-api-key", "account:103"), buildKiroMachineID(account))
|
||||
}
|
||||
|
||||
func TestNewKiroJSONRequestAddsConditionalHeaders(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "external_idp",
|
||||
"provider": "Internal",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER",
|
||||
},
|
||||
}
|
||||
|
||||
req, err := newKiroJSONRequest(
|
||||
context.Background(),
|
||||
"https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
[]byte(`{"ok":true}`),
|
||||
"access-token",
|
||||
"account-key",
|
||||
buildKiroMachineID(account),
|
||||
"",
|
||||
account,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "EXTERNAL_IDP", req.Header.Get("TokenType"))
|
||||
require.Equal(t, "true", req.Header.Get("redirect-for-internal"))
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER", req.Header.Get("x-amzn-kiro-profile-arn"))
|
||||
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Equal(t, "true", req.Header.Get("x-amzn-codewhisperer-optout"))
|
||||
require.Contains(t, req.Header.Get("User-Agent"), "aws-sdk-js/1.0.34")
|
||||
require.Contains(t, req.Header.Get("User-Agent"), "md/nodejs#22.22.0")
|
||||
require.Contains(t, req.Header.Get("User-Agent"), buildKiroMachineID(account))
|
||||
require.Contains(t, req.Header.Get("X-Amz-User-Agent"), buildKiroMachineID(account))
|
||||
require.True(t, strings.Contains(req.Header.Get("User-Agent"), "api/codewhispererstreaming#1.0.34"))
|
||||
require.Empty(t, req.Header.Get("Anthropic-Beta"))
|
||||
}
|
||||
|
||||
func TestIsKiroInvalidModelIDBodyRecognizesKnownForms(t *testing.T) {
|
||||
tests := []string{
|
||||
`{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`,
|
||||
`{"message":"Invalid model. Please select a different model to continue."}`,
|
||||
`API Error: 400 {"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
|
||||
}
|
||||
|
||||
for _, body := range tests {
|
||||
require.True(t, isKiroInvalidModelIDBody([]byte(body)), body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/test",
|
||||
},
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-sonnet-4-6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
headers := http.Header{}
|
||||
headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-sonnet-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-sonnet-4-6",
|
||||
headers,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(payload), "CHUNKED WRITE PROTOCOL")
|
||||
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6-thinking",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
|
||||
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.NotContains(t, systemContent, "<thinking_mode>")
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"api_region": "eu-west-1",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
|
||||
"region": "ap-northeast-1",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, "eu-west-1", kiroAPIRegion(account))
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionIgnoresProfileARNRegionFallback(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionIgnoresOIDCRegionFallback(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"region": "ap-northeast-2",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
|
||||
}
|
||||
|
||||
func TestBuildKiroEndpointsUsesOnlyAmazonQEndpoint(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"api_region": "us-west-2",
|
||||
"preferred_endpoint": "cw",
|
||||
},
|
||||
}
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
require.Len(t, endpoints, 1)
|
||||
require.Equal(t, "AmazonQ", endpoints[0].Name)
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
|
||||
require.Empty(t, endpoints[0].AmzTarget)
|
||||
}
|
||||
|
||||
func TestBuildKiroEndpointsIgnoresPreferredEndpoint(t *testing.T) {
|
||||
for _, preferred := range []string{"codewhisperer", "cw", "unknown"} {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"api_region": "us-west-2",
|
||||
"preferred_endpoint": preferred,
|
||||
},
|
||||
}
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
require.Len(t, endpoints, 1)
|
||||
require.Equal(t, "AmazonQ", endpoints[0].Name)
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountKiroDefaultMappingRestrictsUnsupportedModels(t *testing.T) {
|
||||
account := &Account{Platform: PlatformKiro}
|
||||
|
||||
require.False(t, account.IsModelSupported("gpt-4o"))
|
||||
require.False(t, account.IsModelSupported("kiro-gpt-4o"))
|
||||
require.False(t, account.IsModelSupported("auto"))
|
||||
require.Equal(t, "claude-sonnet-4.6", account.GetMappedModel("claude-sonnet-4-6"))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCalculateTokenCost_KiroAutoUsesConservativeFallback(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
|
||||
svc := NewGatewayService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
result := &ForwardResult{
|
||||
Model: "auto",
|
||||
UpstreamModel: "auto",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
}
|
||||
|
||||
expected, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
cost := svc.calculateTokenCost(context.Background(), result, &APIKey{}, "auto", 1.1, &recordUsageOpts{IsKiroAccount: true})
|
||||
require.NotNil(t, cost)
|
||||
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-12)
|
||||
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-12)
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
)
|
||||
|
||||
const (
|
||||
// Kiro desktop social auth uses localhost loopback callbacks from a fixed
|
||||
// allowlist. Use one of the bundled ports from the official client.
|
||||
kiroSocialRedirectURI = "http://localhost:49153"
|
||||
// AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
|
||||
kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
|
||||
)
|
||||
|
||||
type KiroOAuthService struct {
|
||||
sessionStore *kiropkg.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
|
||||
return &KiroOAuthService{
|
||||
sessionStore: kiropkg.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) Stop() {}
|
||||
|
||||
type KiroAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
type KiroIDCAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
ClientID string `json:"client_id"`
|
||||
Region string `json:"region"`
|
||||
StartURL string `json:"start_url"`
|
||||
}
|
||||
|
||||
type KiroTokenInfo struct {
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ProfileArn string `json:"profile_arn,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDHash string `json:"client_id_hash,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
StartURL string `json:"start_url,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
type KiroGenerateAuthURLInput struct {
|
||||
ProxyID *int64
|
||||
Provider string
|
||||
}
|
||||
|
||||
type KiroExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
CallbackPath string
|
||||
LoginOption string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
type KiroGenerateIDCAuthURLInput struct {
|
||||
ProxyID *int64
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type KiroRefreshTokenInput struct {
|
||||
RefreshToken string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
StartURL string
|
||||
Region string
|
||||
ProfileArn string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
type KiroImportTokenInput struct {
|
||||
TokenJSON string
|
||||
DeviceRegistrationJSON string
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
|
||||
provider := strings.TrimSpace(input.Provider)
|
||||
if provider == "" {
|
||||
provider = string(kiropkg.SocialProviderGoogle)
|
||||
}
|
||||
if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
|
||||
return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
|
||||
}
|
||||
state, err := kiropkg.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate state failed: %w", err)
|
||||
}
|
||||
codeVerifier, err := kiropkg.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate code verifier failed: %w", err)
|
||||
}
|
||||
sessionID := kiropkg.GenerateSessionID()
|
||||
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
AuthType: "social",
|
||||
Provider: provider,
|
||||
RedirectURI: kiroSocialRedirectURI,
|
||||
})
|
||||
return &KiroAuthURLResult{
|
||||
AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
return nil, fmt.Errorf("state invalid")
|
||||
}
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
|
||||
}
|
||||
|
||||
switch session.AuthType {
|
||||
case "social":
|
||||
token, err := kiropkg.CreateSocialToken(
|
||||
ctx,
|
||||
proxyURL,
|
||||
input.Code,
|
||||
session.CodeVerifier,
|
||||
buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.Provider = session.Provider
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
return toKiroTokenInfo(token), nil
|
||||
case "idc":
|
||||
token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
return toKiroTokenInfo(token), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
|
||||
}
|
||||
}
|
||||
|
||||
func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
|
||||
option := strings.ToLower(strings.TrimSpace(loginOption))
|
||||
if option == "" {
|
||||
switch provider {
|
||||
case string(kiropkg.SocialProviderGitHub):
|
||||
option = "github"
|
||||
case string(kiropkg.SocialProviderGoogle):
|
||||
option = "google"
|
||||
}
|
||||
}
|
||||
return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
|
||||
startURL := strings.TrimSpace(input.StartURL)
|
||||
if startURL == "" {
|
||||
startURL = kiropkg.BuilderIDStartURL
|
||||
}
|
||||
region := strings.TrimSpace(input.Region)
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
state, err := kiropkg.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate state failed: %w", err)
|
||||
}
|
||||
codeVerifier, err := kiropkg.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate code verifier failed: %w", err)
|
||||
}
|
||||
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||
reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID := kiropkg.GenerateSessionID()
|
||||
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
AuthType: "idc",
|
||||
RedirectURI: kiroIDCRedirectURI,
|
||||
ClientID: reg.ClientID,
|
||||
ClientSecret: reg.ClientSecret,
|
||||
Region: region,
|
||||
StartURL: startURL,
|
||||
})
|
||||
return &KiroIDCAuthURLResult{
|
||||
AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
ClientID: reg.ClientID,
|
||||
Region: region,
|
||||
StartURL: startURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
|
||||
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||
authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
|
||||
if authMethod == "" {
|
||||
authMethod = "social"
|
||||
}
|
||||
|
||||
var token *kiropkg.TokenData
|
||||
var err error
|
||||
switch authMethod {
|
||||
case "idc":
|
||||
token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
|
||||
default:
|
||||
token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.ProfileArn == "" {
|
||||
token.ProfileArn = input.ProfileArn
|
||||
}
|
||||
if token.ClientID == "" {
|
||||
token.ClientID = input.ClientID
|
||||
}
|
||||
if token.ClientSecret == "" {
|
||||
token.ClientSecret = input.ClientSecret
|
||||
}
|
||||
if token.StartURL == "" {
|
||||
token.StartURL = input.StartURL
|
||||
}
|
||||
if token.Region == "" {
|
||||
token.Region = input.Region
|
||||
}
|
||||
return toKiroTokenInfo(token), nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("not a kiro oauth account")
|
||||
}
|
||||
return s.RefreshToken(ctx, &KiroRefreshTokenInput{
|
||||
RefreshToken: account.GetCredential("refresh_token"),
|
||||
AuthMethod: account.GetCredential("auth_method"),
|
||||
Provider: account.GetCredential("provider"),
|
||||
ClientID: account.GetCredential("client_id"),
|
||||
ClientSecret: account.GetCredential("client_secret"),
|
||||
StartURL: account.GetCredential("start_url"),
|
||||
Region: account.GetCredential("region"),
|
||||
ProfileArn: account.GetCredential("profile_arn"),
|
||||
ProxyID: account.ProxyID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
|
||||
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toKiroTokenInfo(token), nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
||||
if tokenInfo == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
creds := map[string]any{}
|
||||
if tokenInfo.AccessToken != "" {
|
||||
creds["access_token"] = tokenInfo.AccessToken
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.ProfileArn != "" {
|
||||
creds["profile_arn"] = tokenInfo.ProfileArn
|
||||
}
|
||||
if tokenInfo.ExpiresAt != "" {
|
||||
creds["expires_at"] = tokenInfo.ExpiresAt
|
||||
}
|
||||
if tokenInfo.AuthMethod != "" {
|
||||
creds["auth_method"] = tokenInfo.AuthMethod
|
||||
}
|
||||
if tokenInfo.Provider != "" {
|
||||
creds["provider"] = tokenInfo.Provider
|
||||
}
|
||||
if tokenInfo.ClientID != "" {
|
||||
creds["client_id"] = tokenInfo.ClientID
|
||||
}
|
||||
if tokenInfo.ClientSecret != "" {
|
||||
creds["client_secret"] = tokenInfo.ClientSecret
|
||||
}
|
||||
if tokenInfo.ClientIDHash != "" {
|
||||
creds["client_id_hash"] = tokenInfo.ClientIDHash
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.StartURL != "" {
|
||||
creds["start_url"] = tokenInfo.StartURL
|
||||
}
|
||||
if tokenInfo.Region != "" {
|
||||
creds["region"] = tokenInfo.Region
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
|
||||
func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
|
||||
if token == nil {
|
||||
return nil
|
||||
}
|
||||
return &KiroTokenInfo{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ProfileArn: token.ProfileArn,
|
||||
ExpiresAt: token.ExpiresAt,
|
||||
AuthMethod: token.AuthMethod,
|
||||
Provider: token.Provider,
|
||||
ClientID: token.ClientID,
|
||||
ClientSecret: token.ClientSecret,
|
||||
ClientIDHash: token.ClientIDHash,
|
||||
Email: token.Email,
|
||||
StartURL: token.StartURL,
|
||||
Region: token.Region,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||
if proxyID == nil || s.proxyRepo == nil {
|
||||
return "", nil
|
||||
}
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err != nil || proxy == nil {
|
||||
return "", err
|
||||
}
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKiroIDCAuthRedirectURIUsesLoopbackIP(t *testing.T) {
|
||||
require.Equal(t, "http://127.0.0.1:9876/oauth/callback", kiroIDCRedirectURI)
|
||||
}
|
||||
|
||||
func TestKiroSocialAuthRedirectURIUsesLoopbackIP(t *testing.T) {
|
||||
require.Equal(t, "http://localhost:49153", kiroSocialRedirectURI)
|
||||
}
|
||||
|
||||
func TestBuildKiroSocialExchangeRedirectURIUsesProviderDefault(t *testing.T) {
|
||||
require.Equal(
|
||||
t,
|
||||
"http://localhost:49153/oauth/callback?login_option=github",
|
||||
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "", ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestBuildKiroSocialExchangeRedirectURIPreservesParsedCallbackData(t *testing.T) {
|
||||
require.Equal(
|
||||
t,
|
||||
"http://localhost:49153/signin/callback?login_option=google",
|
||||
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "/signin/callback", "google"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestKiroOAuthService_ExchangeCodeRejectsExpiredSession(t *testing.T) {
|
||||
svc := NewKiroOAuthService(nil)
|
||||
svc.sessionStore.Set("expired-session", &kiropkg.AuthSession{
|
||||
State: "expected-state",
|
||||
CreatedAt: time.Now().Add(-11 * time.Minute),
|
||||
})
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &KiroExchangeCodeInput{
|
||||
SessionID: "expired-session",
|
||||
State: "expected-state",
|
||||
Code: "auth-code",
|
||||
})
|
||||
require.EqualError(t, err, "session not found or expired")
|
||||
}
|
||||
@@ -0,0 +1,740 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type kiroEndpointConfig struct {
|
||||
URL string
|
||||
AmzTarget string
|
||||
Name string
|
||||
}
|
||||
|
||||
const kiroInvalidModelTempUnschedDuration = time.Minute
|
||||
|
||||
const (
|
||||
kiroRetryBaseDelay = 200 * time.Millisecond
|
||||
kiroRetryMaxDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
var kiroRetrySleep = sleepWithContext
|
||||
|
||||
func kiroRetryBackoffDelay(attempt int) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := kiroRetryBaseDelay * time.Duration(1<<attempt)
|
||||
if delay > kiroRetryMaxDelay {
|
||||
delay = kiroRetryMaxDelay
|
||||
}
|
||||
jitterMax := delay / 4
|
||||
if jitterMax <= 0 {
|
||||
return delay
|
||||
}
|
||||
return delay + time.Duration(mathrand.Int63n(int64(jitterMax)+1))
|
||||
}
|
||||
|
||||
func sleepKiroRetry(ctx context.Context, attempt int) error {
|
||||
return kiroRetrySleep(ctx, kiroRetryBackoffDelay(attempt))
|
||||
}
|
||||
|
||||
func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) {
|
||||
if account == nil || parsed == nil {
|
||||
return nil, fmt.Errorf("kiro forward: missing account or request")
|
||||
}
|
||||
|
||||
originalModel := parsed.Model
|
||||
mappedModel := originalModel
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
body := parsed.Body
|
||||
if mappedModel != originalModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
}
|
||||
logger.L().Debug("gateway forward_kiro_messages: request prepared",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("auth_method", strings.TrimSpace(account.GetCredential("auth_method"))),
|
||||
zap.String("requested_model", originalModel),
|
||||
zap.String("mapped_model", mappedModel),
|
||||
zap.Bool("has_profile_arn", strings.TrimSpace(account.GetCredential("profile_arn")) != ""),
|
||||
)
|
||||
|
||||
if s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, body) {
|
||||
parsedForEmulation := *parsed
|
||||
parsedForEmulation.Body = body
|
||||
return s.handleWebSearchEmulation(ctx, c, account, &parsedForEmulation)
|
||||
}
|
||||
|
||||
if parsed.Stream {
|
||||
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: failoverErr.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||
})
|
||||
return nil, failoverErr
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
|
||||
}
|
||||
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamResult.usage == nil {
|
||||
streamResult.usage = &ClaudeUsage{}
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *streamResult.usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: streamResult.firstTokenMs,
|
||||
ClientDisconnect: streamResult.clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenType != "oauth" {
|
||||
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||
}
|
||||
if isOnlyWebSearchToolInBody(body) {
|
||||
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
switch {
|
||||
case errors.Is(webSearchErr, errKiroWebSearchFallback):
|
||||
case webSearchErr == nil:
|
||||
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||
c.Header("Content-Type", "application/json")
|
||||
if webSearchResult.RequestID != "" {
|
||||
c.Header("x-request-id", webSearchResult.RequestID)
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", webSearchResult.ResponseBody)
|
||||
return &ForwardResult{
|
||||
RequestID: webSearchResult.RequestID,
|
||||
Usage: webSearchResult.Usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
default:
|
||||
var httpErr *kiroWebSearchHTTPError
|
||||
if errors.As(webSearchErr, &httpErr) && httpErr.Response != nil {
|
||||
return nil, s.handleKiroHTTPError(ctx, httpErr.Response, c, account, mappedModel, body)
|
||||
}
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(webSearchErr, &failoverErr) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: failoverErr.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(webSearchErr.Error()),
|
||||
})
|
||||
return nil, failoverErr
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(webSearchErr.Error())
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||
}
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(body)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: failoverErr.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||
})
|
||||
return nil, failoverErr
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
|
||||
}
|
||||
|
||||
parseResult, err := kiropkg.ParseNonStreamingEventStreamWithContext(resp.Body, mappedModel, requestCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Failed to parse Kiro upstream response",
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "application/json")
|
||||
if requestID := resp.Header.Get("x-request-id"); requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", parseResult.ResponseBody)
|
||||
|
||||
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenType != "oauth" {
|
||||
return nil, 0, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(anthropicBody)
|
||||
if isOnlyWebSearchToolInBody(anthropicBody) {
|
||||
pr, pw := io.Pipe()
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "text/event-stream")
|
||||
go func() {
|
||||
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: headers,
|
||||
Body: pr,
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
return nil, inputTokens, err
|
||||
}
|
||||
return nil, inputTokens, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return resp, inputTokens, nil
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
wrappedHeaders := resp.Header.Clone()
|
||||
wrappedHeaders.Set("Content-Type", "text/event-stream")
|
||||
if requestID := buildKiroRequestID(resp); requestID != "" {
|
||||
wrappedHeaders.Set("x-request-id", requestID)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
_, streamErr := kiropkg.StreamEventStreamAsAnthropicWithContext(ctx, resp.Body, pw, mappedModel, inputTokens, requestCtx)
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: wrappedHeaders,
|
||||
Body: pr,
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||
var requestCtx kiropkg.KiroRequestContext
|
||||
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
|
||||
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||
return nil, requestCtx, failoverErr
|
||||
}
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
|
||||
modelID := kiropkg.MapModel(mappedModel)
|
||||
currentToken := token
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
payload := buildResult.Payload
|
||||
requestCtx = buildResult.Context
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
proxyURL := kiroProxyURL(account)
|
||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
maxRetries := 2
|
||||
|
||||
for idx, endpoint := range endpoints {
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
if attempt < maxRetries {
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
cooldown, err := s.markKiro429(ctx, accountKey)
|
||||
if err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
if idx+1 < len(endpoints) {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
break
|
||||
}
|
||||
resp.Header.Set("x-kiro-cooldown", cooldown.String())
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusRequestTimeout || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
|
||||
if attempt < maxRetries {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if idx+1 < len(endpoints) {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
break
|
||||
}
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusPaymentRequired {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, requestCtx, readErr
|
||||
}
|
||||
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||
if classification.Category == kiroErrorMonthlyRequest {
|
||||
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
|
||||
}
|
||||
return nil, requestCtx, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, requestCtx, readErr
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
|
||||
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
|
||||
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
payload = buildResult.Payload
|
||||
requestCtx = buildResult.Context
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if refreshErr != nil && isNonRetryableRefreshError(refreshErr) {
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
}
|
||||
|
||||
if classifyKiroHTTPError(resp.StatusCode, string(respBody)).Category == kiroErrorAuthError {
|
||||
s.markKiroAuthTemporarilyUnavailable(ctx, account, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, requestCtx, readErr
|
||||
}
|
||||
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||
logKiroBadRequestClassification(classification, account, mappedModel, resp.Header, respBody)
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
}
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
}
|
||||
return nil, requestCtx, fmt.Errorf("kiro upstream endpoints exhausted")
|
||||
}
|
||||
|
||||
func buildKiroEndpoints(account *Account) []kiroEndpointConfig {
|
||||
region := kiroAPIRegion(account)
|
||||
return []kiroEndpointConfig{
|
||||
{
|
||||
URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
|
||||
Name: "AmazonQ",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) ([]byte, error) {
|
||||
result, err := buildKiroPayloadForAccountWithRepo(ctx, nil, account, anthropicBody, modelID, token, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Payload, nil
|
||||
}
|
||||
|
||||
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
|
||||
profileArn := resolveKiroPayloadProfileArn(account)
|
||||
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
|
||||
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
|
||||
}
|
||||
|
||||
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
|
||||
requestModel = strings.TrimSpace(requestModel)
|
||||
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
|
||||
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
|
||||
return next
|
||||
}
|
||||
return anthropicBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
reason := fmt.Sprintf("kiro auth failure (%d): %s", statusCode, strings.TrimSpace(body))
|
||||
_ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason)
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroMonthlyRequestCountRateLimited(ctx context.Context, account *Account, body string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
resetAt := nextKiroMonthlyResetUTC(time.Now())
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
logger.L().Warn("kiro monthly request count rate-limit failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Time("reset_at", resetAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
reason := "kiro monthly request count exhausted (402): MONTHLY_REQUEST_COUNT"
|
||||
if trimmed := strings.TrimSpace(body); trimmed != "" {
|
||||
reason = fmt.Sprintf("%s body=%s", reason, truncateForLog([]byte(trimmed), 512))
|
||||
}
|
||||
logger.L().Warn("kiro monthly request count rate-limited",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Time("reset_at", resetAt),
|
||||
zap.String("reason", reason),
|
||||
)
|
||||
}
|
||||
|
||||
func nextKiroMonthlyResetUTC(now time.Time) time.Time {
|
||||
utc := now.UTC()
|
||||
year, month, _ := utc.Date()
|
||||
return time.Date(year, month+1, 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func resetHTTPResponseBody(resp *http.Response, body []byte) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
resp.ContentLength = int64(len(body))
|
||||
}
|
||||
|
||||
func estimateKiroInputTokens(body []byte) int {
|
||||
if len(body) == 0 {
|
||||
return 0
|
||||
}
|
||||
if tokens := gjson.GetBytes(body, "metadata.input_tokens").Int(); tokens > 0 {
|
||||
return int(tokens)
|
||||
}
|
||||
tokens := len(body) / 4
|
||||
if tokens == 0 {
|
||||
return 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func kiroUsageToClaude(usage kiropkg.Usage, fallbackInput int) ClaudeUsage {
|
||||
inputTokens := usage.InputTokens
|
||||
if inputTokens == 0 {
|
||||
inputTokens = fallbackInput
|
||||
}
|
||||
return ClaudeUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroInvalidModelRateLimited(ctx context.Context, account *Account, mappedModel string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil || account.Type != AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
resetAt := time.Now().Add(kiroInvalidModelTempUnschedDuration)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
logger.L().Warn("kiro invalid model rate-limit failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
|
||||
zap.Time("reset_at", resetAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.L().Warn("kiro invalid model rate-limited",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
|
||||
zap.Time("reset_at", resetAt),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleKiroHTTPError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, mappedModel string, requestBody []byte) error {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg == "" {
|
||||
upstreamMsg = strings.TrimSpace(string(respBody))
|
||||
}
|
||||
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
logKiroBadRequestClassification(classification, account, "", resp.Header, respBody)
|
||||
}
|
||||
if classification.Category == kiroErrorMonthlyRequest {
|
||||
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
|
||||
}
|
||||
if classification.Category == kiroErrorBadRequestInvalidModel && account != nil && account.Type == AccountTypeOAuth {
|
||||
s.markKiroInvalidModelRateLimited(ctx, account, mappedModel)
|
||||
event := s.buildKiroInvalidModelUpstreamEvent(account, resp, upstreamMsg, mappedModel, requestBody, c)
|
||||
appendOpsUpstreamError(c, event)
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusPaymentRequired || s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
c.JSON(mapUpstreamStatusCode(resp.StatusCode), gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": coalesceKiroErrorMessage(resp.StatusCode, upstreamMsg),
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("kiro upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildKiroInvalidModelUpstreamEvent(account *Account, resp *http.Response, upstreamMsg, mappedModel string, requestBody []byte, c *gin.Context) OpsUpstreamErrorEvent {
|
||||
_ = s
|
||||
requestedModel := strings.TrimSpace(gjson.GetBytes(requestBody, "model").String())
|
||||
hasTools := gjson.GetBytes(requestBody, "tools").Exists()
|
||||
hasAdaptiveThinking := strings.EqualFold(strings.TrimSpace(gjson.GetBytes(requestBody, "thinking.type").String()), "adaptive")
|
||||
hasContext1MBeta := false
|
||||
if c != nil {
|
||||
hasContext1MBeta = strings.Contains(c.GetHeader("Anthropic-Beta"), "context-1m")
|
||||
}
|
||||
return OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
RequestedModel: requestedModel,
|
||||
MappedModel: strings.TrimSpace(mappedModel),
|
||||
KiroModelID: kiropkg.MapModel(mappedModel),
|
||||
HasTools: hasTools,
|
||||
HasAdaptiveThinking: hasAdaptiveThinking,
|
||||
HasContext1MBeta: hasContext1MBeta,
|
||||
}
|
||||
}
|
||||
|
||||
func logKiroBadRequestClassification(classification kiroErrorClassification, account *Account, model string, headers http.Header, body []byte) {
|
||||
if classification.StatusCode != http.StatusBadRequest {
|
||||
return
|
||||
}
|
||||
var accountID int64
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
logger.L().Warn("kiro upstream bad request classified",
|
||||
zap.String("category", classification.Category),
|
||||
zap.Int("status", classification.StatusCode),
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.String("model", strings.TrimSpace(model)),
|
||||
zap.String("request_id", headers.Get("x-request-id")),
|
||||
zap.String("body_excerpt", truncateForLog(body, 512)),
|
||||
)
|
||||
}
|
||||
|
||||
func coalesceKiroErrorMessage(statusCode int, upstreamMsg string) string {
|
||||
if upstreamMsg != "" {
|
||||
return upstreamMsg
|
||||
}
|
||||
switch statusCode {
|
||||
case http.StatusTooManyRequests:
|
||||
return "Rate limit exceeded"
|
||||
case http.StatusForbidden:
|
||||
return "Access denied"
|
||||
case http.StatusUnauthorized:
|
||||
return "Authentication failed"
|
||||
default:
|
||||
return "Upstream request failed"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
)
|
||||
|
||||
var errKiroCooldownStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||
|
||||
type KiroCooldownStore interface {
|
||||
ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||
MarkSuccess(ctx context.Context, tokenKey string) error
|
||||
Mark429(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||
MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||
GetState(ctx context.Context, tokenKey string) (*kirocooldown.State, error)
|
||||
ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error)
|
||||
}
|
||||
|
||||
func asKiroCooldownFailoverError(err error) *UpstreamFailoverError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var cooldownErr *kirocooldown.Error
|
||||
if !errors.As(err, &cooldownErr) {
|
||||
return nil
|
||||
}
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
ResponseBody: []byte(cooldownErr.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) checkAndWaitKiroCooldown(ctx context.Context, tokenKey string) error {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return errKiroCooldownStoreUnavailable
|
||||
}
|
||||
waitFor, err := s.kiroCooldownStore.ReserveRequest(ctx, tokenKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if waitFor <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(waitFor)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroSuccess(ctx context.Context, tokenKey string) error {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.MarkSuccess(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiro429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return 0, errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.Mark429(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return 0, errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.MarkSuspended(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func (s *GatewayService) getKiroCooldownState(ctx context.Context, tokenKey string) (*kirocooldown.State, error) {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return nil, errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.GetState(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func kiroRuntimeStateSnapshot(state *kirocooldown.State) (string, string, *time.Time) {
|
||||
if state == nil || !state.Active {
|
||||
return "", "", nil
|
||||
}
|
||||
resetAt := state.CooldownUntil
|
||||
switch state.Reason {
|
||||
case kirocooldown.CooldownReasonSuspended:
|
||||
return "suspended", state.Reason, &resetAt
|
||||
default:
|
||||
return "cooldown", state.Reason, &resetAt
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
//go:build integration
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const kiroCooldownRedisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestRedisKiroCooldownStoreSharesCooldownAcrossInstances(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
storeA := kirocooldown.NewStore(rdb)
|
||||
storeB := kirocooldown.NewStore(rdb)
|
||||
|
||||
cooldown, err := storeA.Mark429(ctx, "token-shared")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Minute, cooldown)
|
||||
|
||||
wait, err := storeB.ReserveRequest(ctx, "token-shared")
|
||||
require.Zero(t, wait)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), kirocooldown.CooldownReason429)
|
||||
|
||||
require.NoError(t, storeB.MarkSuccess(ctx, "token-shared"))
|
||||
|
||||
wait, err = storeA.ReserveRequest(ctx, "token-shared")
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, wait, 0*time.Second)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreSharesReservationAcrossInstances(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
storeA := kirocooldown.NewStore(rdb)
|
||||
storeB := kirocooldown.NewStore(rdb)
|
||||
|
||||
wait, err := storeA.ReserveRequest(ctx, "token-rate")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, wait)
|
||||
|
||||
wait, err = storeB.ReserveRequest(ctx, "token-rate")
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, wait, 0*time.Millisecond)
|
||||
require.LessOrEqual(t, wait, kirocooldown.MaxRequestInterval)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreSharesSuspendedStateAcrossInstances(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
storeA := kirocooldown.NewStore(rdb)
|
||||
storeB := kirocooldown.NewStore(rdb)
|
||||
|
||||
cooldown, err := storeA.MarkSuspended(ctx, "token-suspended")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, kirocooldown.LongCooldown, cooldown)
|
||||
|
||||
wait, err := storeB.ReserveRequest(ctx, "token-suspended")
|
||||
require.Zero(t, wait)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), kirocooldown.CooldownReasonSuspended)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreSuspendedResetsFailCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
store := kirocooldown.NewStore(rdb)
|
||||
|
||||
_, err := store.Mark429(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
_, err = store.Mark429(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
|
||||
cooldown, err := store.MarkSuspended(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, kirocooldown.LongCooldown, cooldown)
|
||||
|
||||
cooldown, err = store.Mark429(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Minute, cooldown)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreReserveDifferentTokenIgnoresOldCooldown(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
store := kirocooldown.NewStore(rdb)
|
||||
|
||||
_, err := store.Mark429(ctx, "token-old")
|
||||
require.NoError(t, err)
|
||||
|
||||
wait, err := store.ReserveRequest(ctx, "token-new")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, wait)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreUsesExpectedTTLs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
store := kirocooldown.NewStore(rdb)
|
||||
|
||||
_, err := store.ReserveRequest(ctx, "token-ttl-active")
|
||||
require.NoError(t, err)
|
||||
activeTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-active")).Result()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, activeTTL, 0*time.Second)
|
||||
require.LessOrEqual(t, activeTTL, kirocooldown.ActiveTTL())
|
||||
|
||||
_, err = store.MarkSuspended(ctx, "token-ttl-state")
|
||||
require.NoError(t, err)
|
||||
stateTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-state")).Result()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, stateTTL, 24*time.Hour)
|
||||
require.LessOrEqual(t, stateTTL, kirocooldown.StateTTL())
|
||||
}
|
||||
|
||||
func startKiroCooldownRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureKiroCooldownDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, kiroCooldownRedisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
host, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", host, port.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureKiroCooldownDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if kiroCooldownDockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过依赖 testcontainers 的 Kiro cooldown 集成测试")
|
||||
}
|
||||
|
||||
func kiroCooldownDockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func kiroCooldownUserHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
@@ -0,0 +1,583 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubKiroCooldownStore struct {
|
||||
reserveWait time.Duration
|
||||
reserveErr error
|
||||
successErr error
|
||||
mark429TTL time.Duration
|
||||
mark429Err error
|
||||
suspendedTTL time.Duration
|
||||
suspendedErr error
|
||||
state *kirocooldown.State
|
||||
stateErr error
|
||||
clearCalled bool
|
||||
clearKeys []string
|
||||
clearResult bool
|
||||
clearErr error
|
||||
}
|
||||
|
||||
type recordingKiroTempUnschedRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
called bool
|
||||
id int64
|
||||
until time.Time
|
||||
reason string
|
||||
rateCalled bool
|
||||
rateID int64
|
||||
rateLimitReset time.Time
|
||||
rateLimitedCall int
|
||||
}
|
||||
|
||||
func (r *recordingKiroTempUnschedRepo) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||||
r.called = true
|
||||
r.id = id
|
||||
r.until = until
|
||||
r.reason = reason
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingKiroTempUnschedRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateCalled = true
|
||||
r.rateID = id
|
||||
r.rateLimitReset = resetAt
|
||||
r.rateLimitedCall++
|
||||
return nil
|
||||
}
|
||||
|
||||
type recordingKiroErrorRepo struct {
|
||||
recordingKiroTempUnschedRepo
|
||||
setErrorCalls int
|
||||
errorID int64
|
||||
errorMsg string
|
||||
}
|
||||
|
||||
func (r *recordingKiroErrorRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
r.errorID = id
|
||||
r.errorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||
return s.reserveWait, s.reserveErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) MarkSuccess(context.Context, string) error {
|
||||
return s.successErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||
return s.mark429TTL, s.mark429Err
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||
return s.suspendedTTL, s.suspendedErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||
if s.clearCalled && s.clearResult {
|
||||
return nil, nil
|
||||
}
|
||||
return s.state, s.stateErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) ClearEarliestTransientCooldown(_ context.Context, tokenKeys []string) (bool, error) {
|
||||
s.clearCalled = true
|
||||
s.clearKeys = append([]string(nil), tokenKeys...)
|
||||
return s.clearResult, s.clearErr
|
||||
}
|
||||
|
||||
func TestCalculateKiro429Cooldown(t *testing.T) {
|
||||
require.Equal(t, time.Minute, kirocooldown.Calculate429Cooldown(0))
|
||||
require.Equal(t, 2*time.Minute, kirocooldown.Calculate429Cooldown(1))
|
||||
require.Equal(t, 4*time.Minute, kirocooldown.Calculate429Cooldown(2))
|
||||
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(3))
|
||||
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(10))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownReturnsNilWithoutWait(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
}
|
||||
|
||||
require.NoError(t, svc.checkAndWaitKiroCooldown(context.Background(), "token1"))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownPropagatesReserveError(t *testing.T) {
|
||||
expected := errors.New("redis unavailable")
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{reserveErr: expected},
|
||||
}
|
||||
|
||||
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
|
||||
require.ErrorIs(t, err, expected)
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownRequiresStore(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
|
||||
require.ErrorIs(t, err, errKiroCooldownStoreUnavailable)
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownWaitsAndHonorsContext(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{reserveWait: 200 * time.Millisecond},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := svc.checkAndWaitKiroCooldown(ctx, "token1")
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestAsKiroCooldownFailoverError(t *testing.T) {
|
||||
err := kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429)
|
||||
|
||||
var cooldownErr *kirocooldown.Error
|
||||
require.ErrorAs(t, err, &cooldownErr)
|
||||
|
||||
failoverErr := asKiroCooldownFailoverError(err)
|
||||
require.NotNil(t, failoverErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
|
||||
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||
}
|
||||
|
||||
func TestAsKiroCooldownFailoverErrorIgnoresNonCooldownErrors(t *testing.T) {
|
||||
require.Nil(t, asKiroCooldownFailoverError(errors.New("redis unavailable")))
|
||||
}
|
||||
|
||||
func TestGatewayServiceTryRecoverKiroCooldownPoolClearsOnlyTransientCooldown(t *testing.T) {
|
||||
store := &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(time.Minute),
|
||||
Remaining: time.Minute,
|
||||
},
|
||||
clearResult: true,
|
||||
}
|
||||
svc := &GatewayService{kiroCooldownStore: store}
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
|
||||
require.True(t, recovered)
|
||||
require.True(t, store.clearCalled)
|
||||
require.Len(t, store.clearKeys, 1)
|
||||
require.Equal(t, buildKiroAccountKey(&accounts[0]), store.clearKeys[0])
|
||||
}
|
||||
|
||||
func TestGatewayServiceTryRecoverKiroCooldownPoolSkipsSuspended(t *testing.T) {
|
||||
store := &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReasonSuspended,
|
||||
CooldownUntil: time.Now().Add(time.Hour),
|
||||
Remaining: time.Hour,
|
||||
},
|
||||
clearResult: true,
|
||||
}
|
||||
svc := &GatewayService{kiroCooldownStore: store}
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
|
||||
require.False(t, recovered)
|
||||
require.False(t, store.clearCalled)
|
||||
}
|
||||
|
||||
func TestSelectAccountWithLoadAwarenessRecoversKiroCooldownPool(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
account := Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
store := &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(time.Minute),
|
||||
Remaining: time.Minute,
|
||||
},
|
||||
clearResult: true,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: &mockAccountRepoForGemini{accounts: []Account{account}},
|
||||
concurrencyService: NewConcurrencyService(&mockConcurrencyCache{}),
|
||||
cfg: cfg,
|
||||
kiroCooldownStore: store,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformKiro)
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "", nil, "", 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, account.ID, result.Account.ID)
|
||||
require.True(t, store.clearCalled)
|
||||
require.Equal(t, []string{buildKiroAccountKey(&account)}, store.clearKeys)
|
||||
}
|
||||
|
||||
func TestClassifyKiroHTTPErrorMonthlyRequestCount(t *testing.T) {
|
||||
tests := []string{
|
||||
`{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
|
||||
`{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}`,
|
||||
`API returned 402: {"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
|
||||
}
|
||||
|
||||
for _, body := range tests {
|
||||
classification := classifyKiroHTTPError(http.StatusPaymentRequired, body)
|
||||
require.Equal(t, kiroErrorMonthlyRequest, classification.Category)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyKiroHTTPErrorPlain402IsTransient(t *testing.T) {
|
||||
classification := classifyKiroHTTPError(http.StatusPaymentRequired, `{"message":"payment required"}`)
|
||||
require.Equal(t, kiroErrorUpstreamTransient, classification.Category)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{
|
||||
reserveErr: kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429),
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
|
||||
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-opus-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||
require.NoError(t, readErr)
|
||||
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
|
||||
}
|
||||
|
||||
func TestHandleKiroHTTPErrorOAuthInvalidModelRateLimitsAndFailovers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Anthropic-Beta", "context-1m-2025-08-07")
|
||||
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Name: "kiro-oauth",
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
svc := &GatewayService{accountRepo: repo}
|
||||
requestBody := []byte(`{"model":"claude-opus-4-7","tools":[{"name":"search"}],"thinking":{"type":"adaptive"}}`)
|
||||
resp := newJSONResponse(http.StatusBadRequest, `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`)
|
||||
resp.Header.Set("x-request-id", "req-invalid-model")
|
||||
|
||||
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", requestBody)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "Invalid model")
|
||||
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||
|
||||
require.False(t, repo.called)
|
||||
require.True(t, repo.rateCalled)
|
||||
require.Equal(t, account.ID, repo.rateID)
|
||||
require.WithinDuration(t, time.Now().Add(kiroInvalidModelTempUnschedDuration), repo.rateLimitReset, 5*time.Second)
|
||||
|
||||
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, PlatformKiro, events[0].Platform)
|
||||
require.Equal(t, account.ID, events[0].AccountID)
|
||||
require.Equal(t, account.Name, events[0].AccountName)
|
||||
require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
|
||||
require.Equal(t, "req-invalid-model", events[0].UpstreamRequestID)
|
||||
require.Equal(t, "failover", events[0].Kind)
|
||||
require.Equal(t, "claude-opus-4-7", events[0].RequestedModel)
|
||||
require.Equal(t, "claude-opus-4.6", events[0].MappedModel)
|
||||
require.Equal(t, "claude-opus-4.6", events[0].KiroModelID)
|
||||
require.True(t, events[0].HasTools)
|
||||
require.True(t, events[0].HasAdaptiveThinking)
|
||||
require.True(t, events[0].HasContext1MBeta)
|
||||
}
|
||||
|
||||
func TestHandleKiroHTTPErrorAPIKeyInvalidModelDoesNotFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
account := &Account{
|
||||
ID: 43,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
svc := &GatewayService{accountRepo: repo}
|
||||
resp := newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model. Please select a different model to continue."}`)
|
||||
|
||||
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", []byte(`{"model":"claude-opus-4-7"}`))
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.NotErrorAs(t, err, &failoverErr)
|
||||
require.False(t, repo.called)
|
||||
require.False(t, repo.rateCalled)
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestNextKiroMonthlyResetUTC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
want time.Time
|
||||
}{
|
||||
{
|
||||
name: "middle of month",
|
||||
now: time.Date(2026, time.April, 27, 10, 30, 45, 123, time.FixedZone("CST", 8*3600)),
|
||||
want: time.Date(2026, time.May, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "december rolls year",
|
||||
now: time.Date(2026, time.December, 31, 23, 59, 59, 0, time.UTC),
|
||||
want: time.Date(2027, time.January, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, nextKiroMonthlyResetUTC(tt.now))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailovers(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusPaymentRequired, `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "MONTHLY_REQUEST_COUNT")
|
||||
require.False(t, repo.called)
|
||||
require.True(t, repo.rateCalled)
|
||||
require.Equal(t, account.ID, repo.rateID)
|
||||
require.Equal(t, nextKiroMonthlyResetUTC(time.Now()), repo.rateLimitReset)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusPaymentRequired, `{"message":"payment required"}`),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
|
||||
require.False(t, repo.called)
|
||||
require.False(t, repo.rateCalled)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedule(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "old-refresh",
|
||||
},
|
||||
}
|
||||
repo := &recordingKiroErrorRepo{
|
||||
recordingKiroTempUnschedRepo: recordingKiroTempUnschedRepo{
|
||||
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
||||
accountsByID: map[int64]*Account{account.ID: account},
|
||||
},
|
||||
},
|
||||
}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"message":"token expired"}`),
|
||||
},
|
||||
}
|
||||
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
kiroTokenProvider: provider,
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, account.ID, repo.errorID)
|
||||
require.Contains(t, repo.errorMsg, "invalid_grant")
|
||||
require.False(t, repo.called, "non-retryable refresh errors should not mark temporary unschedulable")
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionSkipsActiveKiroCooldown(t *testing.T) {
|
||||
now := time.Now().Add(2 * time.Minute)
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: now,
|
||||
Remaining: 2 * time.Minute,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
require.False(t, svc.isAccountSchedulableForSelection(account))
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
kiroTokenRefreshSkew = 3 * time.Minute
|
||||
kiroTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
type KiroTokenCache = GeminiTokenCache
|
||||
|
||||
type kiroAccountTokenRefresher interface {
|
||||
RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error)
|
||||
BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any
|
||||
}
|
||||
|
||||
type KiroTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache KiroTokenCache
|
||||
kiroOAuthService kiroAccountTokenRefresher
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewKiroTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache KiroTokenCache,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
) *KiroTokenProvider {
|
||||
return &KiroTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
kiroOAuthService: kiroOAuthService,
|
||||
refreshPolicy: GeminiProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a kiro oauth account")
|
||||
}
|
||||
|
||||
cacheKey := KiroTokenCacheKey(account)
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= kiroTokenRefreshSkew
|
||||
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, kiroTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > kiroTokenCacheSkew:
|
||||
ttl = until - kiroTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func KiroTokenCacheKey(account *Account) string {
|
||||
if account == nil {
|
||||
return "kiro:account:0"
|
||||
}
|
||||
if clientIDHash := strings.TrimSpace(account.GetCredential("client_id_hash")); clientIDHash != "" {
|
||||
return "kiro:" + clientIDHash
|
||||
}
|
||||
if clientID := strings.TrimSpace(account.GetCredential("client_id")); clientID != "" {
|
||||
return "kiro:client:" + clientID
|
||||
}
|
||||
return "kiro:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) ForceRefreshAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a kiro oauth account")
|
||||
}
|
||||
if p.kiroOAuthService == nil {
|
||||
return "", errors.New("kiro oauth service is nil")
|
||||
}
|
||||
|
||||
cacheKey := KiroTokenCacheKey(account)
|
||||
lockHeld := false
|
||||
if p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
lockHeld = true
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
if p.accountRepo != nil {
|
||||
if latestAccount, err := p.accountRepo.GetByID(ctx, account.ID); err == nil && latestAccount != nil {
|
||||
account = latestAccount
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := p.kiroOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
if !lockHeld {
|
||||
if latestAccount, stale := CheckTokenVersion(ctx, account, p.accountRepo); stale && latestAccount != nil {
|
||||
account = latestAccount
|
||||
if accessToken := strings.TrimSpace(account.GetCredential("access_token")); accessToken != "" {
|
||||
_ = p.cacheAccessToken(ctx, account, accessToken)
|
||||
return accessToken, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if isNonRetryableRefreshError(err) && p.accountRepo != nil {
|
||||
errorMsg := "Token refresh failed (non-retryable): " + err.Error()
|
||||
_ = p.accountRepo.SetError(ctx, account.ID, errorMsg)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
newCredentials := MergeCredentials(account.Credentials, p.kiroOAuthService.BuildAccountCredentials(tokenInfo))
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
if err := persistAccountCredentials(ctx, p.accountRepo, account, newCredentials); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" {
|
||||
accessToken = strings.TrimSpace(tokenInfo.AccessToken)
|
||||
}
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found after kiro refresh")
|
||||
}
|
||||
if err := p.cacheAccessToken(ctx, account, accessToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) cacheAccessToken(ctx context.Context, account *Account, accessToken string) error {
|
||||
if p.tokenCache == nil || account == nil || strings.TrimSpace(accessToken) == "" {
|
||||
return nil
|
||||
}
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > kiroTokenCacheSkew:
|
||||
ttl = until - kiroTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
return p.tokenCache.SetAccessToken(ctx, KiroTokenCacheKey(account), accessToken, ttl)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type kiroTokenProviderRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
setErrorID int64
|
||||
setErrorMsg string
|
||||
}
|
||||
|
||||
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
r.setErrorID = id
|
||||
r.setErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
type kiroTokenProviderSequenceRepo struct {
|
||||
kiroTokenProviderRepo
|
||||
accounts []*Account
|
||||
reads int
|
||||
}
|
||||
|
||||
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||
if len(r.accounts) == 0 {
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
idx := r.reads
|
||||
if idx >= len(r.accounts) {
|
||||
idx = len(r.accounts) - 1
|
||||
}
|
||||
r.reads++
|
||||
return r.accounts[idx], nil
|
||||
}
|
||||
|
||||
type stubKiroAccountTokenRefresher struct {
|
||||
tokenInfo *KiroTokenInfo
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
||||
if tokenInfo == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": tokenInfo.ExpiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
||||
}
|
||||
repo := &kiroTokenProviderRepo{
|
||||
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
||||
accountsByID: map[int64]*Account{account.ID: account},
|
||||
},
|
||||
}
|
||||
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||
|
||||
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, token)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, account.ID, repo.setErrorID)
|
||||
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
|
||||
require.Contains(t, repo.setErrorMsg, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
|
||||
usedAccount := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
||||
}
|
||||
latestAccount := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
|
||||
}
|
||||
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
|
||||
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||
|
||||
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fresh-access", token)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const kiroRefreshWindow = 15 * time.Minute
|
||||
|
||||
type KiroTokenRefresher struct {
|
||||
kiroOAuthService *KiroOAuthService
|
||||
}
|
||||
|
||||
func NewKiroTokenRefresher(kiroOAuthService *KiroOAuthService) *KiroTokenRefresher {
|
||||
return &KiroTokenRefresher{
|
||||
kiroOAuthService: kiroOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) CacheKey(account *Account) string {
|
||||
return KiroTokenCacheKey(account)
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Until(*expiresAt) <= kiroRefreshWindow
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.kiroOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.kiroOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
return MergeCredentials(account.Credentials, newCredentials), nil
|
||||
}
|
||||
@@ -0,0 +1,608 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
kiroUsageOrigin = "AI_EDITOR"
|
||||
kiroUsageResourceType = "AGENTIC_REQUEST"
|
||||
|
||||
kiroDefaultRegion = "us-east-1"
|
||||
)
|
||||
|
||||
var resolveKiroRuntimeEndpoint = kiroRuntimeEndpoint
|
||||
|
||||
type kiroUsageLimitsResponse struct {
|
||||
NextDateReset any `json:"nextDateReset"`
|
||||
OverageConfiguration kiroOverageConfiguration `json:"overageConfiguration"`
|
||||
SubscriptionInfo kiroSubscriptionInfo `json:"subscriptionInfo"`
|
||||
UsageBreakdownList []kiroUsageBreakdown `json:"usageBreakdownList"`
|
||||
}
|
||||
|
||||
type kiroOverageConfiguration struct {
|
||||
OverageStatus string `json:"overageStatus"`
|
||||
}
|
||||
|
||||
type kiroSubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type kiroUsageBreakdown struct {
|
||||
Currency string `json:"currency"`
|
||||
CurrentOverages *float64 `json:"currentOverages"`
|
||||
CurrentOveragesWithPrecision *float64 `json:"currentOveragesWithPrecision"`
|
||||
CurrentUsage *float64 `json:"currentUsage"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
|
||||
DisplayName string `json:"displayName"`
|
||||
DisplayNamePlural string `json:"displayNamePlural"`
|
||||
FreeTrialInfo *kiroFreeTrialInfo `json:"freeTrialInfo"`
|
||||
NextDateReset any `json:"nextDateReset"`
|
||||
OverageCharges *float64 `json:"overageCharges"`
|
||||
ResourceType string `json:"resourceType"`
|
||||
UsageLimit *float64 `json:"usageLimit"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
|
||||
}
|
||||
|
||||
type kiroFreeTrialInfo struct {
|
||||
CurrentUsage *float64 `json:"currentUsage"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
|
||||
FreeTrialExpiry any `json:"freeTrialExpiry"`
|
||||
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||
UsageLimit *float64 `json:"usageLimit"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
|
||||
}
|
||||
|
||||
type kiroUsageHTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *kiroUsageHTTPError) Error() string {
|
||||
if e == nil {
|
||||
return "kiro usage request failed"
|
||||
}
|
||||
if strings.TrimSpace(e.Body) == "" {
|
||||
return fmt.Sprintf("kiro usage request failed (status %d)", e.StatusCode)
|
||||
}
|
||||
return fmt.Sprintf("kiro usage request failed (status %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getKiroUsage(ctx context.Context, account *Account, source string, forceRefresh bool) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
if account == nil {
|
||||
return &UsageInfo{
|
||||
Source: source,
|
||||
UpdatedAt: &now,
|
||||
Error: "account is nil",
|
||||
ErrorCode: errorCodeNetworkError,
|
||||
}, nil
|
||||
}
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return &UsageInfo{
|
||||
Source: source,
|
||||
UpdatedAt: &now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cached, hasCached := s.getCachedKiroUsage(account.ID)
|
||||
if hasCached && (cached.ErrorCode != "" || cached.Error != "") {
|
||||
cached.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, cached)
|
||||
return cached, nil
|
||||
}
|
||||
if !forceRefresh && hasCached {
|
||||
cached.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, cached)
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
flightKey := fmt.Sprintf("kiro-usage:%d", account.ID)
|
||||
result, fetchErr, _ := s.cache.kiroUsageFlight.Do(flightKey, func() (any, error) {
|
||||
if !forceRefresh {
|
||||
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
usage, err := s.fetchAndCacheKiroUsage(ctx, account, source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return usage, nil
|
||||
})
|
||||
if fetchErr == nil {
|
||||
if usage, ok := result.(*UsageInfo); ok && usage != nil {
|
||||
usage.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, usage)
|
||||
if source == "active" {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
degraded := buildKiroDegradedUsage(fetchErr)
|
||||
degraded.Source = source
|
||||
if hasCached {
|
||||
cached.Error = degraded.Error
|
||||
cached.ErrorCode = degraded.ErrorCode
|
||||
cached.NeedsReauth = degraded.NeedsReauth
|
||||
cached.KiroQuotaState = degraded.KiroQuotaState
|
||||
cached.KiroQuotaReason = degraded.KiroQuotaReason
|
||||
cached.KiroQuotaResetAt = degraded.KiroQuotaResetAt
|
||||
cached.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, cached)
|
||||
return cached, nil
|
||||
}
|
||||
s.storeKiroUsageSnapshot(account.ID, degraded)
|
||||
s.attachKiroRuntimeState(ctx, account, degraded)
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) fetchAndCacheKiroUsage(ctx context.Context, account *Account, source string) (*UsageInfo, error) {
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
|
||||
region := kiroAPIRegion(account)
|
||||
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||
|
||||
resp, err := s.requestKiroUsageLimits(ctx, account, region, profileArn, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage := mapKiroUsageToInfo(resp)
|
||||
usage.Source = source
|
||||
s.storeKiroUsageSnapshot(account.ID, usage)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) storeKiroUsageSnapshot(accountID int64, usage *UsageInfo) {
|
||||
if s == nil || s.cache == nil || accountID <= 0 || usage == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
s.cache.kiroUsageCache.Store(accountID, &kiroUsageCache{
|
||||
usageInfo: cloneUsageInfo(usage),
|
||||
timestamp: now,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getCachedKiroUsage(accountID int64) (*UsageInfo, bool) {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
cached, ok := s.cache.kiroUsageCache.Load(accountID)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cache, ok := cached.(*kiroUsageCache)
|
||||
if !ok || cache == nil || cache.usageInfo == nil {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(cache.timestamp) >= kiroCacheTTL(cache.usageInfo) {
|
||||
return nil, false
|
||||
}
|
||||
return cloneUsageInfo(cache.usageInfo), true
|
||||
}
|
||||
|
||||
func kiroCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return kiroUsageErrorTTL
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return kiroUsageErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
func cloneUsageInfo(info *UsageInfo) *UsageInfo {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *info
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) requestKiroUsageLimits(ctx context.Context, account *Account, region, profileArn, token string) (*kiroUsageLimitsResponse, error) {
|
||||
endpoint := resolveKiroRuntimeEndpoint(region)
|
||||
reqURL, err := url.Parse(endpoint + "/getUsageLimits")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build kiro usage url failed: %w", err)
|
||||
}
|
||||
q := reqURL.Query()
|
||||
q.Set("origin", kiroUsageOrigin)
|
||||
if profileArn = strings.TrimSpace(profileArn); profileArn != "" {
|
||||
q.Set("profileArn", profileArn)
|
||||
}
|
||||
q.Set("resourceType", kiroUsageResourceType)
|
||||
reqURL.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kiro usage request failed: %w", err)
|
||||
}
|
||||
s.applyKiroRuntimeHeaders(req, account, token)
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: accountProxyURL(account),
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kiro usage client failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro usage request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read kiro usage response failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
|
||||
}
|
||||
|
||||
var parsed kiroUsageLimitsResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("decode kiro usage response failed: %w", err)
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) applyKiroRuntimeHeaders(req *http.Request, account *Account, token string) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
machineID := buildKiroMachineID(account)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
applyKiroConditionalHeaders(req, account)
|
||||
}
|
||||
|
||||
func accountProxyURL(account *Account) string {
|
||||
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
|
||||
func kiroRuntimeEndpoint(region string) string {
|
||||
region = strings.TrimSpace(region)
|
||||
if region == "" {
|
||||
region = kiroDefaultRegion
|
||||
}
|
||||
switch region {
|
||||
case "us-east-1":
|
||||
return "https://q.us-east-1.amazonaws.com"
|
||||
case "eu-central-1":
|
||||
return "https://q.eu-central-1.amazonaws.com"
|
||||
case "us-gov-east-1":
|
||||
return "https://q-fips.us-gov-east-1.amazonaws.com"
|
||||
case "us-gov-west-1":
|
||||
return "https://q-fips.us-gov-west-1.amazonaws.com"
|
||||
case "us-iso-east-1":
|
||||
return "https://q.us-iso-east-1.c2s.ic.gov"
|
||||
case "us-isob-east-1":
|
||||
return "https://q.us-isob-east-1.sc2s.sgov.gov"
|
||||
case "us-isof-south-1":
|
||||
return "https://q.us-isof-south-1.csp.hci.ic.gov"
|
||||
case "us-isof-east-1":
|
||||
return "https://q.us-isof-east-1.csp.hci.ic.gov"
|
||||
default:
|
||||
if strings.HasPrefix(region, "us-gov-") {
|
||||
return "https://q-fips." + region + ".amazonaws.com"
|
||||
}
|
||||
return "https://q." + region + ".amazonaws.com"
|
||||
}
|
||||
}
|
||||
|
||||
func isLoopbackEndpoint(raw string) bool {
|
||||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func mapKiroUsageToInfo(resp *kiroUsageLimitsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
if resp == nil {
|
||||
return &UsageInfo{UpdatedAt: &now}
|
||||
}
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
KiroSubscriptionName: strings.TrimSpace(resp.SubscriptionInfo.SubscriptionTitle),
|
||||
KiroSubscriptionType: strings.TrimSpace(resp.SubscriptionInfo.Type),
|
||||
KiroOveragesEnabled: strings.EqualFold(strings.TrimSpace(resp.OverageConfiguration.OverageStatus), "ENABLED"),
|
||||
}
|
||||
|
||||
resetAt := parseKiroTimestamp(resp.NextDateReset)
|
||||
if credit := selectKiroCreditBreakdown(resp.UsageBreakdownList); credit != nil {
|
||||
if breakdownReset := parseKiroTimestamp(credit.NextDateReset); breakdownReset != nil {
|
||||
resetAt = breakdownReset
|
||||
}
|
||||
info.KiroCredit = &KiroCreditProgress{
|
||||
CurrentUsage: selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage),
|
||||
UsageLimit: selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit),
|
||||
PercentageUsed: percentageOrZero(selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage), selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit)),
|
||||
}
|
||||
info.KiroOverage = &KiroOverageInfo{
|
||||
CurrentOverages: selectKiroFloat(credit.CurrentOveragesWithPrecision, credit.CurrentOverages),
|
||||
OverageCharges: selectKiroFloat(credit.OverageCharges, nil),
|
||||
CurrencyCode: strings.TrimSpace(credit.Currency),
|
||||
CurrencySymbol: kiroCurrencySymbol(strings.TrimSpace(credit.Currency)),
|
||||
}
|
||||
if ft := credit.FreeTrialInfo; ft != nil && strings.EqualFold(strings.TrimSpace(ft.FreeTrialStatus), "ACTIVE") {
|
||||
expiry := parseKiroTimestamp(ft.FreeTrialExpiry)
|
||||
daysRemaining := 0
|
||||
if expiry != nil {
|
||||
daysRemaining = int(time.Until(*expiry).Hours() / 24)
|
||||
if time.Until(*expiry)%(24*time.Hour) != 0 {
|
||||
daysRemaining++
|
||||
}
|
||||
if daysRemaining < 0 {
|
||||
daysRemaining = 0
|
||||
}
|
||||
}
|
||||
current := selectKiroFloat(ft.CurrentUsageWithPrecision, ft.CurrentUsage)
|
||||
limit := selectKiroFloat(ft.UsageLimitWithPrecision, ft.UsageLimit)
|
||||
info.KiroBonus = &KiroCreditProgress{
|
||||
CurrentUsage: current,
|
||||
UsageLimit: limit,
|
||||
PercentageUsed: percentageOrZero(current, limit),
|
||||
DaysRemaining: daysRemaining,
|
||||
ExpiryDate: expiry,
|
||||
}
|
||||
}
|
||||
}
|
||||
info.KiroResetAt = resetAt
|
||||
setKiroQuotaStateFromUsage(info)
|
||||
return info
|
||||
}
|
||||
|
||||
func selectKiroCreditBreakdown(items []kiroUsageBreakdown) *kiroUsageBreakdown {
|
||||
for i := range items {
|
||||
if strings.EqualFold(strings.TrimSpace(items[i].ResourceType), "CREDIT") {
|
||||
return &items[i]
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &items[0]
|
||||
}
|
||||
|
||||
func selectKiroFloat(preferred *float64, fallback *float64) float64 {
|
||||
switch {
|
||||
case preferred != nil:
|
||||
return *preferred
|
||||
case fallback != nil:
|
||||
return *fallback
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func percentageOrZero(current, limit float64) float64 {
|
||||
if limit <= 0 {
|
||||
return 0
|
||||
}
|
||||
return current / limit * 100
|
||||
}
|
||||
|
||||
func parseKiroTimestamp(raw any) *time.Time {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, trimmed); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
||||
return unixishToTime(i)
|
||||
}
|
||||
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
|
||||
return unixishFloatToTime(f)
|
||||
}
|
||||
case float64:
|
||||
return unixishFloatToTime(v)
|
||||
case int64:
|
||||
return unixishToTime(v)
|
||||
case int:
|
||||
return unixishToTime(int64(v))
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return unixishToTime(i)
|
||||
}
|
||||
if f, err := v.Float64(); err == nil {
|
||||
return unixishFloatToTime(f)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unixishFloatToTime(v float64) *time.Time {
|
||||
if v <= 0 {
|
||||
return nil
|
||||
}
|
||||
if v >= 1e12 {
|
||||
t := time.UnixMilli(int64(v))
|
||||
return &t
|
||||
}
|
||||
t := time.Unix(int64(v), 0)
|
||||
return &t
|
||||
}
|
||||
|
||||
func unixishToTime(v int64) *time.Time {
|
||||
if v <= 0 {
|
||||
return nil
|
||||
}
|
||||
if v >= 1e12 {
|
||||
t := time.UnixMilli(v)
|
||||
return &t
|
||||
}
|
||||
t := time.Unix(v, 0)
|
||||
return &t
|
||||
}
|
||||
|
||||
func kiroCurrencySymbol(code string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(code)) {
|
||||
case "USD":
|
||||
return "$"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildKiroDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: "usage API error",
|
||||
ErrorCode: errorCodeNetworkError,
|
||||
}
|
||||
if err == nil {
|
||||
return info
|
||||
}
|
||||
|
||||
info.Error = fmt.Sprintf("usage API error: %v", err)
|
||||
|
||||
classification := classifyKiroError(err)
|
||||
switch classification.Category {
|
||||
case kiroErrorAuthError:
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case kiroErrorRateLimited:
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
case kiroErrorQuotaExhausted:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
|
||||
info.KiroQuotaReason = classification.Message
|
||||
case kiroErrorOverageExhausted:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
info.KiroQuotaState = kiroQuotaStateOverageExhausted
|
||||
info.KiroQuotaReason = classification.Message
|
||||
case kiroErrorSuspended, kiroErrorUsageForbidden, kiroErrorProfileError:
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) attachKiroRuntimeState(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
if s == nil || usage == nil || account == nil || account.Platform != PlatformKiro || s.kiroCooldownStore == nil {
|
||||
return
|
||||
}
|
||||
usage.KiroRuntimeState = ""
|
||||
usage.KiroRuntimeReason = ""
|
||||
usage.KiroRuntimeResetAt = nil
|
||||
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
|
||||
if err != nil || state == nil {
|
||||
return
|
||||
}
|
||||
usage.KiroRuntimeState, usage.KiroRuntimeReason, usage.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) EnrichAccountWithKiroRuntimeState(ctx context.Context, account *Account) {
|
||||
if s == nil || account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
account.KiroQuotaState = ""
|
||||
account.KiroQuotaReason = ""
|
||||
account.KiroQuotaResetAt = nil
|
||||
account.KiroRuntimeState = ""
|
||||
account.KiroRuntimeReason = ""
|
||||
account.KiroRuntimeResetAt = nil
|
||||
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
|
||||
account.KiroQuotaState = usage.KiroQuotaState
|
||||
account.KiroQuotaReason = usage.KiroQuotaReason
|
||||
account.KiroQuotaResetAt = usage.KiroQuotaResetAt
|
||||
}
|
||||
if s.kiroCooldownStore == nil {
|
||||
return
|
||||
}
|
||||
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
|
||||
if err != nil || state == nil {
|
||||
return
|
||||
}
|
||||
account.KiroRuntimeState, account.KiroRuntimeReason, account.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
|
||||
}
|
||||
|
||||
func setKiroQuotaStateFromUsage(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
info.KiroQuotaState = ""
|
||||
info.KiroQuotaReason = ""
|
||||
info.KiroQuotaResetAt = nil
|
||||
|
||||
creditExhausted := false
|
||||
if info.KiroCredit != nil && info.KiroCredit.UsageLimit > 0 {
|
||||
creditExhausted = info.KiroCredit.CurrentUsage >= info.KiroCredit.UsageLimit
|
||||
}
|
||||
overageActive := info.KiroOverage != nil &&
|
||||
(info.KiroOverage.CurrentOverages > 0 || info.KiroOverage.OverageCharges > 0)
|
||||
|
||||
switch {
|
||||
case info.KiroOveragesEnabled && (overageActive || creditExhausted):
|
||||
info.KiroQuotaState = kiroQuotaStateOverageActive
|
||||
info.KiroQuotaReason = "overages_enabled"
|
||||
info.KiroQuotaResetAt = info.KiroResetAt
|
||||
case creditExhausted:
|
||||
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
|
||||
info.KiroQuotaReason = "credits_exhausted"
|
||||
info.KiroQuotaResetAt = info.KiroResetAt
|
||||
default:
|
||||
info.KiroQuotaState = kiroQuotaStateNormal
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,458 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
)
|
||||
|
||||
const kiroMaxWebSearchIterations = 5
|
||||
|
||||
var (
|
||||
errKiroWebSearchFallback = errors.New("kiro web search fallback")
|
||||
kiroWebSearchDescCache sync.Map
|
||||
)
|
||||
|
||||
type kiroWebSearchExecution struct {
|
||||
ResponseBody []byte
|
||||
Usage ClaudeUsage
|
||||
RequestID string
|
||||
}
|
||||
|
||||
type kiroWebSearchHTTPError struct {
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
type kiroStreamChunkCollector struct {
|
||||
chunks [][]byte
|
||||
}
|
||||
|
||||
func (e *kiroWebSearchHTTPError) Error() string {
|
||||
if e == nil || e.Response == nil {
|
||||
return "kiro web search http error"
|
||||
}
|
||||
return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
|
||||
}
|
||||
|
||||
func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
|
||||
if len(p) > 0 {
|
||||
w.chunks = append(w.chunks, append([]byte(nil), p...))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
|
||||
collector := &kiroStreamChunkCollector{}
|
||||
result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return collector.chunks, result, nil
|
||||
}
|
||||
|
||||
func writeSSEChunks(w io.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
msgID = "msg_" + kiropkg.GenerateToolUseID()
|
||||
}
|
||||
if strings.TrimSpace(model) == "" {
|
||||
model = "kiro"
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||
) error {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
|
||||
if err != nil {
|
||||
currentBody = anthropicBody
|
||||
}
|
||||
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
nextContentBlockIndex := 0
|
||||
|
||||
if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
|
||||
s.prefetchKiroWebSearchDescription(ctx, account, token)
|
||||
|
||||
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
|
||||
if strings.TrimSpace(nextToken) != "" {
|
||||
token = nextToken
|
||||
}
|
||||
if mcpErr != nil {
|
||||
results = nil
|
||||
}
|
||||
|
||||
if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
|
||||
return err
|
||||
}
|
||||
nextContentBlockIndex += 2
|
||||
|
||||
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
|
||||
if err != nil {
|
||||
return errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return &kiroWebSearchHTTPError{Response: resp}
|
||||
}
|
||||
|
||||
chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
|
||||
}()
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
|
||||
analysis := kiropkg.AnalyzeBufferedStream(chunks)
|
||||
if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
|
||||
filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
|
||||
if err := writeSSEChunks(w, filtered); err != nil {
|
||||
return err
|
||||
}
|
||||
if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
|
||||
nextContentBlockIndex = maxIndex + 1
|
||||
}
|
||||
query = analysis.WebSearchQuery
|
||||
if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
|
||||
currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
} else {
|
||||
currentToolUseID = analysis.WebSearchToolUseID
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(adjusted); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("kiro web search exceeded max iterations")
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil, errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
|
||||
if err != nil {
|
||||
currentBody = anthropicBody
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(anthropicBody)
|
||||
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
searches := make([]kiropkg.SearchIndicator, 0, 2)
|
||||
requestID := ""
|
||||
|
||||
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
|
||||
s.prefetchKiroWebSearchDescription(ctx, account, token)
|
||||
|
||||
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
|
||||
if strings.TrimSpace(nextToken) != "" {
|
||||
token = nextToken
|
||||
}
|
||||
if mcpErr != nil {
|
||||
results = nil
|
||||
}
|
||||
searches = append(searches, kiropkg.SearchIndicator{
|
||||
ToolUseID: currentToolUseID,
|
||||
Query: query,
|
||||
Results: results,
|
||||
})
|
||||
|
||||
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
|
||||
if err != nil {
|
||||
return nil, errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &kiroWebSearchHTTPError{Response: resp}
|
||||
}
|
||||
|
||||
parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
|
||||
}()
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if requestID == "" {
|
||||
requestID = buildKiroRequestID(resp)
|
||||
}
|
||||
|
||||
nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
|
||||
if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
|
||||
finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
|
||||
if injectErr == nil {
|
||||
parseResult.ResponseBody = finalBody
|
||||
}
|
||||
return &kiroWebSearchExecution{
|
||||
ResponseBody: parseResult.ResponseBody,
|
||||
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
|
||||
RequestID: requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
query = nextQuery
|
||||
if strings.TrimSpace(nextToolUseID) == "" {
|
||||
nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
}
|
||||
currentToolUseID = nextToolUseID
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("kiro web search exceeded max iterations")
|
||||
}
|
||||
|
||||
func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
|
||||
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
|
||||
if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
|
||||
if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
|
||||
kiropkg.SetCachedWebSearchDescription(desc)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(kiropkg.MCPRequest{
|
||||
ID: "tools_list",
|
||||
JSONRPC: "2.0",
|
||||
Method: "tools/list",
|
||||
})
|
||||
resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
|
||||
if err != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var result kiropkg.MCPResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
return
|
||||
}
|
||||
for _, tool := range result.Result.Tools {
|
||||
if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
|
||||
kiroWebSearchDescCache.Store(endpoint, tool.Description)
|
||||
kiropkg.SetCachedWebSearchDescription(tool.Description)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
|
||||
reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
|
||||
if err != nil {
|
||||
return nil, token, err
|
||||
}
|
||||
|
||||
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
|
||||
resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
|
||||
if err != nil {
|
||||
return nil, nextToken, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nextToken, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed kiropkg.MCPResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, nextToken, err
|
||||
}
|
||||
if parsed.Error != nil {
|
||||
msg := "unknown error"
|
||||
if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
|
||||
msg = strings.TrimSpace(*parsed.Error.Message)
|
||||
}
|
||||
code := 0
|
||||
if parsed.Error.Code != nil {
|
||||
code = *parsed.Error.Code
|
||||
}
|
||||
return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return kiropkg.ParseSearchResults(&parsed), nextToken, nil
|
||||
}
|
||||
|
||||
func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
|
||||
return kiropkg.MCPRequest{
|
||||
ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
|
||||
JSONRPC: "2.0",
|
||||
Method: "tools/call",
|
||||
Params: map[string]interface{}{
|
||||
"name": "web_search",
|
||||
"arguments": map[string]interface{}{
|
||||
"query": query,
|
||||
"_meta": map[string]interface{}{
|
||||
"_isValid": true,
|
||||
"_activePath": []string{"query"},
|
||||
"_completedPaths": [][]string{{"query"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
|
||||
currentToken := token
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
proxyURL := kiroProxyURL(account)
|
||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
|
||||
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||
return nil, currentToken, failoverErr
|
||||
}
|
||||
return nil, currentToken, err
|
||||
}
|
||||
|
||||
req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
|
||||
if err != nil {
|
||||
return nil, currentToken, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
return nil, currentToken, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, currentToken, readErr
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
|
||||
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
|
||||
return nil, currentToken, err
|
||||
}
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
if s.kiroTokenProvider == nil {
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, currentToken, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
if _, err := s.markKiro429(ctx, accountKey); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, currentToken, err
|
||||
}
|
||||
}
|
||||
if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
|
||||
if attempt < 2 {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, currentToken, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, currentToken, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
|
||||
return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildKiroWebSearchMCPRequest_UsesUnderscoredMetaKeys(t *testing.T) {
|
||||
req := buildKiroWebSearchMCPRequest("golang concurrency")
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "tools/call", gjson.GetBytes(body, "method").String())
|
||||
require.Equal(t, "web_search", gjson.GetBytes(body, "params.name").String())
|
||||
require.Equal(t, "golang concurrency", gjson.GetBytes(body, "params.arguments.query").String())
|
||||
require.True(t, gjson.GetBytes(body, "params.arguments._meta._isValid").Bool())
|
||||
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._activePath.0").String())
|
||||
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._completedPaths.0.0").String())
|
||||
require.False(t, gjson.GetBytes(body, "params.arguments._meta.isValid").Exists())
|
||||
require.False(t, gjson.GetBytes(body, "params.arguments._meta.activePath").Exists())
|
||||
require.False(t, gjson.GetBytes(body, "params.arguments._meta.completedPaths").Exists())
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
privacyCalls := 0
|
||||
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
|
||||
privacyCalls++
|
||||
|
||||
@@ -289,6 +289,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
|
||||
out := *ev
|
||||
|
||||
out.Platform = strings.TrimSpace(out.Platform)
|
||||
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
|
||||
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
|
||||
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
|
||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||
|
||||
|
||||
@@ -89,6 +89,14 @@ type OpsUpstreamErrorEvent struct {
|
||||
AccountID int64 `json:"account_id,omitempty"`
|
||||
AccountName string `json:"account_name,omitempty"`
|
||||
|
||||
// Model diagnostics.
|
||||
RequestedModel string `json:"requested_model,omitempty"`
|
||||
MappedModel string `json:"mapped_model,omitempty"`
|
||||
KiroModelID string `json:"kiro_model_id,omitempty"`
|
||||
HasTools bool `json:"has_tools,omitempty"`
|
||||
HasAdaptiveThinking bool `json:"has_adaptive_thinking,omitempty"`
|
||||
HasContext1MBeta bool `json:"has_context_1m_beta,omitempty"`
|
||||
|
||||
// Outcome
|
||||
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
||||
|
||||
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-2.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = canonicalModelNameForPricing(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
model = strings.TrimPrefix(model, "publishers/google/models/")
|
||||
@@ -628,7 +628,31 @@ func normalizeModelNameForPricing(model string) string {
|
||||
if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" {
|
||||
return canonical
|
||||
}
|
||||
return model
|
||||
return canonicalModelNameForPricing(model)
|
||||
}
|
||||
|
||||
func canonicalModelNameForPricing(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch model {
|
||||
case "claude-opus-4.5":
|
||||
return "claude-opus-4-5"
|
||||
case "claude-opus-4.6":
|
||||
return "claude-opus-4-6"
|
||||
case "claude-opus-4.7":
|
||||
return "claude-opus-4-7"
|
||||
case "claude-sonnet-4.5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "claude-sonnet-4.6":
|
||||
return "claude-sonnet-4-6"
|
||||
case "claude-haiku-4.5":
|
||||
return "claude-haiku-4-5"
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func lastSegment(model string) string {
|
||||
@@ -674,8 +698,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
|
||||
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
|
||||
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
|
||||
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
|
||||
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
|
||||
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
|
||||
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
|
||||
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
|
||||
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
|
||||
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
|
||||
@@ -713,6 +739,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "sonnet"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
|
||||
fallbackName = "sonnet-4.6"
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "sonnet-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
@@ -722,6 +750,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "haiku"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "haiku-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
fallbackName = "haiku-3.5"
|
||||
default:
|
||||
|
||||
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
|
||||
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
|
||||
@@ -42,6 +42,9 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
|
||||
// Antigravity 同样可能有两种缓存键
|
||||
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
||||
case PlatformKiro:
|
||||
keysToDelete = append(keysToDelete, KiroTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "kiro:"+accountIDKey)
|
||||
case PlatformOpenAI:
|
||||
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
||||
case PlatformAnthropic:
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
)
|
||||
|
||||
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
|
||||
@@ -44,6 +45,7 @@ func NewTokenRefreshService(
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
cacheInvalidator TokenCacheInvalidator,
|
||||
schedulerCache SchedulerCache,
|
||||
cfg *config.Config,
|
||||
@@ -64,6 +66,7 @@ func NewTokenRefreshService(
|
||||
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||
kiroRefresher := NewKiroTokenRefresher(kiroOAuthService)
|
||||
|
||||
// 注册平台特定的刷新器(TokenRefresher 接口)
|
||||
s.refreshers = []TokenRefresher{
|
||||
@@ -71,6 +74,7 @@ func NewTokenRefreshService(
|
||||
openAIRefresher,
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
kiroRefresher,
|
||||
}
|
||||
|
||||
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
||||
@@ -79,6 +83,7 @@ func NewTokenRefreshService(
|
||||
openAIRefresher,
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
kiroRefresher,
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -415,6 +420,10 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var kiroInvalidGrant *kiropkg.RefreshTokenInvalidError
|
||||
if errors.As(err, &kiroInvalidGrant) {
|
||||
return true
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
nonRetryable := []string{
|
||||
"invalid_grant", // refresh_token 已失效
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Platform: PlatformGemini,
|
||||
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformGemini,
|
||||
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformGemini,
|
||||
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformOpenAI, // OpenAI OAuth 账户
|
||||
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
resetAt := time.Now().Add(30 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformGemini,
|
||||
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformGemini,
|
||||
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 15,
|
||||
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Platform: tt.platform,
|
||||
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 18,
|
||||
Platform: PlatformOpenAI,
|
||||
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ func optionalNonEqualStringPtr(value, compare string) *string {
|
||||
|
||||
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
|
||||
return trimmed
|
||||
return normalizeModelNameForPricing(trimmed)
|
||||
}
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
return normalizeModelNameForPricing(upstreamModel)
|
||||
}
|
||||
|
||||
func optionalInt64Ptr(v int64) *int64 {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -52,6 +53,7 @@ func ProvideTokenRefreshService(
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
cacheInvalidator TokenCacheInvalidator,
|
||||
schedulerCache SchedulerCache,
|
||||
cfg *config.Config,
|
||||
@@ -60,7 +62,7 @@ func ProvideTokenRefreshService(
|
||||
proxyRepo ProxyRepository,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
||||
@@ -129,6 +131,23 @@ func ProvideAntigravityTokenProvider(
|
||||
return p
|
||||
}
|
||||
|
||||
func ProvideKiroTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *KiroTokenProvider {
|
||||
p := NewKiroTokenProvider(accountRepo, tokenCache, kiroOAuthService)
|
||||
executor := NewKiroTokenRefresher(kiroOAuthService)
|
||||
p.SetRefreshAPI(refreshAPI, executor)
|
||||
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
|
||||
return p
|
||||
}
|
||||
|
||||
func ProvideKiroCooldownStore(redisClient *redis.Client) KiroCooldownStore {
|
||||
return kirocooldown.NewStore(redisClient)
|
||||
}
|
||||
|
||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||
@@ -457,8 +476,11 @@ var ProviderSet = wire.NewSet(
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewKiroOAuthService,
|
||||
ProvideOAuthRefreshAPI,
|
||||
ProvideGeminiTokenProvider,
|
||||
ProvideKiroTokenProvider,
|
||||
ProvideKiroCooldownStore,
|
||||
NewGeminiMessagesCompatService,
|
||||
ProvideAntigravityTokenProvider,
|
||||
ProvideOpenAITokenProvider,
|
||||
|
||||
Reference in New Issue
Block a user