Merge remote-tracking branch 'pr/2131' into release/v0.1.133
# Conflicts: # backend/cmd/server/wire_gen.go # backend/internal/config/config.go # backend/internal/service/gateway_service.go # backend/internal/service/pricing_service.go # backend/internal/service/wire.go # deploy/config.example.yaml # frontend/src/views/admin/AccountsView.vue
This commit is contained in:
@@ -93,6 +93,7 @@ func provideCleanup(
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
kiroOAuth *service.KiroOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
@@ -216,6 +217,10 @@ func provideCleanup(
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"KiroOAuthService", func() error {
|
||||
kiroOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIWSPool", func() error {
|
||||
if openAIGateway != nil {
|
||||
openAIGateway.CloseOpenAIWSPool()
|
||||
|
||||
@@ -146,13 +146,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
|
||||
kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
@@ -166,6 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
|
||||
kiroOAuthHandler := admin.NewKiroOAuthHandler(kiroOAuthService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
@@ -179,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
kiroCooldownStore := service.ProvideKiroCooldownStore(redisClient)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@@ -236,7 +240,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
contentModerationHandler := admin.NewContentModerationHandler(contentModerationService)
|
||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, kiroOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -260,13 +264,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -316,6 +320,7 @@ func provideCleanup(
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
kiroOAuth *service.KiroOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
@@ -438,6 +443,10 @@ func provideCleanup(
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"KiroOAuthService", func() error {
|
||||
kiroOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIWSPool", func() error {
|
||||
if openAIGateway != nil {
|
||||
openAIGateway.CloseOpenAIWSPool()
|
||||
|
||||
@@ -36,6 +36,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
)
|
||||
@@ -72,6 +73,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
openAIOAuthSvc,
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // kiroOAuth
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
|
||||
@@ -682,6 +682,8 @@ type GatewayConfig struct {
|
||||
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
|
||||
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用
|
||||
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
|
||||
// KiroStreamKeepaliveInterval: Kiro 流式 keepalive 间隔(秒),0使用默认 25 秒
|
||||
KiroStreamKeepaliveInterval int `mapstructure:"kiro_stream_keepalive_interval"`
|
||||
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
|
||||
MaxLineSize int `mapstructure:"max_line_size"`
|
||||
|
||||
@@ -1752,6 +1754,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
|
||||
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.kiro_stream_keepalive_interval", 25)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
@@ -2369,6 +2372,13 @@ func (c *Config) Validate() error {
|
||||
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
|
||||
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
|
||||
}
|
||||
if c.Gateway.KiroStreamKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be non-negative")
|
||||
}
|
||||
if c.Gateway.KiroStreamKeepaliveInterval != 0 &&
|
||||
(c.Gateway.KiroStreamKeepaliveInterval < 5 || c.Gateway.KiroStreamKeepaliveInterval > 30) {
|
||||
return fmt.Errorf("gateway.kiro_stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||
}
|
||||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformKiro = "kiro"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
@@ -117,6 +118,21 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultKiroModelMapping 是 Kiro 平台的默认模型映射。
|
||||
// 键为对外暴露/允许请求的模型名,值为实际发送到 Kiro 上游的模型名。
|
||||
var DefaultKiroModelMapping = map[string]string{
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4-6-thinking": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4.5",
|
||||
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
|
||||
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -24,3 +27,56 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expected := map[string]string{
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4-6-thinking": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4-6-thinking": "claude-sonnet-4.6",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4.5",
|
||||
"claude-opus-4-5-20251101-thinking": "claude-opus-4.5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-5-20250929-thinking": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4.5",
|
||||
"claude-haiku-4-5-20251001-thinking": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
if len(DefaultKiroModelMapping) != len(expected) {
|
||||
t.Fatalf("expected %d Kiro mappings, got %d", len(expected), len(DefaultKiroModelMapping))
|
||||
}
|
||||
for model, want := range expected {
|
||||
if got := DefaultKiroModelMapping[model]; got != want {
|
||||
t.Fatalf("unexpected Kiro mapping for %q: got %q want %q", model, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range []string{
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
"deepseek-3-2",
|
||||
"minimax-m2-1",
|
||||
"qwen3-coder-next",
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6-chat",
|
||||
} {
|
||||
if _, ok := DefaultKiroModelMapping[model]; ok {
|
||||
t.Fatalf("did not expect %q to remain in DefaultKiroModelMapping", model)
|
||||
}
|
||||
}
|
||||
for model := range DefaultKiroModelMapping {
|
||||
if strings.HasSuffix(model, "-agentic") {
|
||||
t.Fatalf("did not expect agentic Kiro mapping %q", model)
|
||||
}
|
||||
if strings.HasSuffix(model, "-chat") {
|
||||
t.Fatalf("did not expect chat-only Kiro mapping %q", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -179,6 +180,9 @@ type AccountWithConcurrency struct {
|
||||
const accountListGroupUngroupedQueryValue = "ungrouped"
|
||||
|
||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||
if h.accountUsageService != nil {
|
||||
h.accountUsageService.EnrichAccountWithKiroRuntimeState(ctx, account)
|
||||
}
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(account),
|
||||
CurrentConcurrency: 0,
|
||||
@@ -351,6 +355,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
result := make([]AccountWithConcurrency, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if h.accountUsageService != nil {
|
||||
h.accountUsageService.EnrichAccountWithKiroRuntimeState(c.Request.Context(), acc)
|
||||
}
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(acc),
|
||||
CurrentConcurrency: concurrencyCounts[acc.ID],
|
||||
@@ -1953,6 +1960,18 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Kiro accounts
|
||||
if account.Platform == service.PlatformKiro {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, kiropkg.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, buildMappedKiroModels(mapping))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
@@ -1994,6 +2013,28 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
func buildMappedKiroModels(mapping map[string]string) []kiropkg.Model {
|
||||
models := make([]kiropkg.Model, 0, len(mapping))
|
||||
for requestedModel := range mapping {
|
||||
var found bool
|
||||
for _, dm := range kiropkg.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
models = append(models, kiropkg.Model{
|
||||
ID: requestedModel,
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||
@@ -2206,6 +2247,12 @@ func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
|
||||
response.Success(c, domain.DefaultAntigravityModelMapping)
|
||||
}
|
||||
|
||||
// GetKiroDefaultModelMapping 获取 Kiro 平台的默认模型映射
|
||||
// GET /api/v1/admin/accounts/kiro/default-model-mapping
|
||||
func (h *AccountHandler) GetKiroDefaultModelMapping(c *gin.Context) {
|
||||
response.Success(c, domain.DefaultKiroModelMapping)
|
||||
}
|
||||
|
||||
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
|
||||
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
|
||||
func sanitizeExtraBaseRPM(extra map[string]any) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "kiro-oauth",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 47,
|
||||
Name: "kiro-oauth-mapped",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "kiro-apikey",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 46,
|
||||
Name: "kiro-apikey-defaults",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -123,7 +123,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
|
||||
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(createReq))
|
||||
|
||||
updateReq := UpdateGroupRequest{Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type KiroOAuthHandler struct {
|
||||
kiroOAuthService *service.KiroOAuthService
|
||||
}
|
||||
|
||||
func NewKiroOAuthHandler(kiroOAuthService *service.KiroOAuthService) *KiroOAuthHandler {
|
||||
return &KiroOAuthHandler{kiroOAuthService: kiroOAuthService}
|
||||
}
|
||||
|
||||
type KiroGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req KiroGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.kiroOAuthService.GenerateAuthURL(c.Request.Context(), &service.KiroGenerateAuthURLInput{
|
||||
ProxyID: req.ProxyID,
|
||||
Provider: req.Provider,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "生成授权链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type KiroGenerateIDCAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
StartURL string `json:"start_url"`
|
||||
Region string `json:"region"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) GenerateIDCAuthURL(c *gin.Context) {
|
||||
var req KiroGenerateIDCAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.kiroOAuthService.GenerateIDCAuthURL(c.Request.Context(), &service.KiroGenerateIDCAuthURLInput{
|
||||
ProxyID: req.ProxyID,
|
||||
StartURL: req.StartURL,
|
||||
Region: req.Region,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "生成 IDC 授权链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type KiroExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
CallbackPath string `json:"callback_path"`
|
||||
LoginOption string `json:"login_option"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req KiroExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
tokenInfo, err := h.kiroOAuthService.ExchangeCode(c.Request.Context(), &service.KiroExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
CallbackPath: req.CallbackPath,
|
||||
LoginOption: req.LoginOption,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Token 交换失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
type KiroRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
Provider string `json:"provider"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
StartURL string `json:"start_url"`
|
||||
Region string `json:"region"`
|
||||
ProfileArn string `json:"profile_arn"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req KiroRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
tokenInfo, err := h.kiroOAuthService.RefreshToken(c.Request.Context(), &service.KiroRefreshTokenInput{
|
||||
RefreshToken: req.RefreshToken,
|
||||
AuthMethod: req.AuthMethod,
|
||||
Provider: req.Provider,
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
StartURL: req.StartURL,
|
||||
Region: req.Region,
|
||||
ProfileArn: req.ProfileArn,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "刷新 Kiro Token 失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
type KiroImportTokenRequest struct {
|
||||
TokenJSON string `json:"token_json" binding:"required"`
|
||||
DeviceRegistrationJSON string `json:"device_registration_json"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
|
||||
var req KiroImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
|
||||
TokenJSON: req.TokenJSON,
|
||||
DeviceRegistrationJSON: req.DeviceRegistrationJSON,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
@@ -230,6 +230,12 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
TempUnschedulableUntil: a.TempUnschedulableUntil,
|
||||
TempUnschedulableReason: a.TempUnschedulableReason,
|
||||
KiroQuotaState: a.KiroQuotaState,
|
||||
KiroQuotaReason: a.KiroQuotaReason,
|
||||
KiroQuotaResetAt: a.KiroQuotaResetAt,
|
||||
KiroRuntimeState: a.KiroRuntimeState,
|
||||
KiroRuntimeReason: a.KiroRuntimeReason,
|
||||
KiroRuntimeResetAt: a.KiroRuntimeResetAt,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
|
||||
@@ -183,6 +183,12 @@ type Account struct {
|
||||
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
|
||||
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
|
||||
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
|
||||
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
|
||||
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
|
||||
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
|
||||
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
|
||||
@@ -17,6 +17,7 @@ type AdminHandlers struct {
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
KiroOAuth *admin.KiroOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
|
||||
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
kiroOAuthHandler *admin.KiroOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
@@ -52,6 +53,7 @@ func ProvideAdminHandlers(
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
KiroOAuth: kiroOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
@@ -156,6 +158,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewKiroOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
|
||||
@@ -36,11 +36,12 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformKiro = "kiro"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RuntimeFingerprint struct {
|
||||
OIDCSDKVersion string
|
||||
RuntimeSDKVersion string
|
||||
StreamingSDKVersion string
|
||||
OSType string
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string
|
||||
}
|
||||
|
||||
type runtimeFingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*RuntimeFingerprint
|
||||
}
|
||||
|
||||
var (
|
||||
globalRuntimeFingerprintManager *runtimeFingerprintManager
|
||||
globalRuntimeFingerprintManagerOnce sync.Once
|
||||
|
||||
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
|
||||
runtimeSDKVersions = []string{"1.0.0"}
|
||||
streamingSDKVersions = []string{"1.0.34"}
|
||||
osTypes = []string{"darwin", "win32"}
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"24.6.0"},
|
||||
"win32": {"10.0.22631"},
|
||||
}
|
||||
nodeVersions = []string{"22.22.0"}
|
||||
kiroVersions = []string{
|
||||
"0.11.132", "0.11.131", "0.11.130",
|
||||
}
|
||||
)
|
||||
|
||||
func globalRuntimeFingerprints() *runtimeFingerprintManager {
|
||||
globalRuntimeFingerprintManagerOnce.Do(func() {
|
||||
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
|
||||
fingerprints: make(map[string]*RuntimeFingerprint),
|
||||
}
|
||||
})
|
||||
return globalRuntimeFingerprintManager
|
||||
}
|
||||
|
||||
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
|
||||
lookupKey := fingerprintLookupKey(accountKey, "runtime")
|
||||
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
|
||||
|
||||
m.mu.RLock()
|
||||
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||
m.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||
return fp
|
||||
}
|
||||
fp := generateRuntimeFingerprint(lookupKey, machineID)
|
||||
m.fingerprints[lookupKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
|
||||
hash := sha256.Sum256([]byte(accountKey))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
osType := goOSToNodePlatform(runtime.GOOS)
|
||||
if !containsString(osTypes, osType) {
|
||||
osType = osTypes[rng.Intn(len(osTypes))]
|
||||
}
|
||||
osVersionPool := osVersions[osType]
|
||||
if len(osVersionPool) == 0 {
|
||||
osVersionPool = osVersions["darwin"]
|
||||
}
|
||||
|
||||
return &RuntimeFingerprint{
|
||||
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||
OSType: osType,
|
||||
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
|
||||
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||
KiroHash: machineID,
|
||||
}
|
||||
}
|
||||
|
||||
func goOSToNodePlatform(goos string) string {
|
||||
switch strings.TrimSpace(goos) {
|
||||
case "windows":
|
||||
return "win32"
|
||||
default:
|
||||
return strings.TrimSpace(goos)
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(items []string, target string) bool {
|
||||
for _, item := range items {
|
||||
if item == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
|
||||
switch {
|
||||
case strings.TrimSpace(clientIDHash) != "":
|
||||
return clientIDHash
|
||||
case strings.TrimSpace(clientID) != "":
|
||||
return shortSHA(clientID)
|
||||
case strings.TrimSpace(refreshToken) != "":
|
||||
return shortSHA(refreshToken)
|
||||
case strings.TrimSpace(profileArn) != "":
|
||||
return shortSHA(profileArn)
|
||||
case accountID > 0:
|
||||
return shortSHA(fmt.Sprintf("account:%d", accountID))
|
||||
default:
|
||||
return shortSHA(uuid.NewString())
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeMachineID(machineID string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(machineID)
|
||||
if len(trimmed) == 64 && isHexString(trimmed) {
|
||||
return strings.ToLower(trimmed), true
|
||||
}
|
||||
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
|
||||
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
|
||||
normalized := strings.ToLower(withoutDashes)
|
||||
return normalized + normalized, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
|
||||
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
|
||||
return sha256Hex("KotlinNativeAPI/" + refreshToken)
|
||||
}
|
||||
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
|
||||
return sha256Hex("KiroAPIKey/" + apiKey)
|
||||
}
|
||||
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
|
||||
return sha256Hex("KiroFallback/" + fallbackKey)
|
||||
}
|
||||
return sha256Hex("KiroFallback/default")
|
||||
}
|
||||
|
||||
func shortSHA(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func sha256Hex(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func isHexString(value string) bool {
|
||||
for _, c := range value {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
|
||||
if normalized, ok := NormalizeMachineID(machineID); ok {
|
||||
return normalized
|
||||
}
|
||||
return BuildMachineID("", "", fallbackKey)
|
||||
}
|
||||
|
||||
func fingerprintLookupKey(accountKey, fallback string) string {
|
||||
key := strings.TrimSpace(accountKey)
|
||||
if key != "" {
|
||||
return key
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func BuildRuntimeUserAgent(accountKey, machineID string) string {
|
||||
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
|
||||
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
|
||||
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
|
||||
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
|
||||
"amz-sdk-invocation-id": uuid.NewString(),
|
||||
"amz-sdk-request": "attempt=1; max=4",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
|
||||
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
}
|
||||
}
|
||||
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := baseDelay << attempt
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
const jitterFactor = 0.3
|
||||
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
|
||||
return time.Duration(float64(delay) * jitter)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildLoginHeadersStable(t *testing.T) {
|
||||
headers1 := BuildLoginHeaders("", "")
|
||||
headers2 := BuildLoginHeaders("", "")
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
|
||||
require.Equal(t, "application/json", headers1["Content-Type"])
|
||||
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
|
||||
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
|
||||
}
|
||||
|
||||
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
|
||||
machineIDA := BuildMachineID("refresh-a", "", "")
|
||||
machineIDB := BuildMachineID("refresh-b", "", "")
|
||||
headers1 := BuildLoginHeaders("account-a", machineIDA)
|
||||
headers2 := BuildLoginHeaders("account-a", machineIDA)
|
||||
headers3 := BuildLoginHeaders("account-a", machineIDB)
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
|
||||
require.Contains(t, headers1["User-Agent"], machineIDA)
|
||||
}
|
||||
|
||||
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
|
||||
machineID := BuildMachineID("", "", "oidc-machine")
|
||||
headers1 := BuildOIDCHeaders("account-a", machineID)
|
||||
headers2 := BuildOIDCHeaders("account-a", machineID)
|
||||
headers3 := BuildOIDCHeaders("account-b", machineID)
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
|
||||
}
|
||||
|
||||
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
|
||||
key1 := BuildAccountKey("", "", "", "", 42)
|
||||
key2 := BuildAccountKey("", "", "", "", 42)
|
||||
key3 := BuildAccountKey("", "", "", "", 43)
|
||||
|
||||
require.Equal(t, key1, key2)
|
||||
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
|
||||
require.NotEqual(t, key1, key3)
|
||||
}
|
||||
|
||||
func TestBuildMachineID(t *testing.T) {
|
||||
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
|
||||
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
|
||||
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
|
||||
|
||||
fallback1 := BuildMachineID("", "", "account:1")
|
||||
fallback2 := BuildMachineID("", "", "account:1")
|
||||
fallback3 := BuildMachineID("", "", "account:2")
|
||||
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
|
||||
require.Equal(t, fallback1, fallback2)
|
||||
require.NotEqual(t, fallback1, fallback3)
|
||||
require.Len(t, fallback1, 64)
|
||||
}
|
||||
|
||||
func TestNormalizeMachineID(t *testing.T) {
|
||||
hex64 := strings.Repeat("A", 64)
|
||||
normalized, ok := NormalizeMachineID(hex64)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, strings.ToLower(hex64), normalized)
|
||||
|
||||
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
|
||||
|
||||
_, ok = NormalizeMachineID("not-a-machine-id")
|
||||
require.False(t, ok)
|
||||
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func expectedKiroMachineID(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package kiro
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
var DefaultModels = []Model{
|
||||
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
|
||||
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
|
||||
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
|
||||
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
|
||||
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
|
||||
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
|
||||
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
|
||||
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
|
||||
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
|
||||
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
|
||||
ids := make([]string, 0, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
|
||||
require.Equal(t, []string{
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-6-thinking",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6-thinking",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5-20251101-thinking",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5-20250929-thinking",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5-20251001-thinking",
|
||||
}, ids)
|
||||
|
||||
require.Contains(t, ids, "claude-sonnet-4-6")
|
||||
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
|
||||
require.NotContains(t, ids, "auto")
|
||||
require.NotContains(t, ids, "claude-sonnet-4")
|
||||
require.NotContains(t, ids, "gpt-4o")
|
||||
require.NotContains(t, ids, "deepseek-3-2")
|
||||
require.NotContains(t, ids, "minimax-m2-1")
|
||||
require.NotContains(t, ids, "qwen3-coder-next")
|
||||
require.NotContains(t, ids, "claude-opus-4-7")
|
||||
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
|
||||
for _, id := range ids {
|
||||
require.NotContains(t, id, "kiro-")
|
||||
require.NotContains(t, id, "-agentic")
|
||||
require.NotContains(t, id, "-chat")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
socialAuthPortalURL = "https://app.kiro.dev"
|
||||
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
defaultIDCRegion = "us-east-1"
|
||||
BuilderIDStartURL = "https://view.awsapps.com/start"
|
||||
sessionTTL = 10 * time.Minute
|
||||
sessionCleanupEvery = 32
|
||||
sessionCleanupMin = 32
|
||||
)
|
||||
|
||||
var (
|
||||
socialAuthEndpointURL = socialAuthEndpoint
|
||||
oidcEndpointOverride = ""
|
||||
)
|
||||
|
||||
type SocialProvider string
|
||||
|
||||
const (
|
||||
SocialProviderGoogle SocialProvider = "Google"
|
||||
SocialProviderGitHub SocialProvider = "Github"
|
||||
)
|
||||
|
||||
type AuthSession struct {
|
||||
State string
|
||||
CodeVerifier string
|
||||
ProxyURL string
|
||||
CreatedAt time.Time
|
||||
AuthType string
|
||||
Provider string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Region string
|
||||
StartURL string
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
data map[string]*AuthSession
|
||||
setCount uint64
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
return &SessionStore{data: make(map[string]*AuthSession)}
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
session, ok := s.data[id]
|
||||
if ok && sessionExpired(session, now) {
|
||||
delete(s.data, id)
|
||||
return nil, false
|
||||
}
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(id string, session *AuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.setCount++
|
||||
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
|
||||
s.pruneExpiredLocked(time.Now())
|
||||
}
|
||||
s.data[id] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.data, id)
|
||||
}
|
||||
|
||||
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
|
||||
for id, session := range s.data {
|
||||
if sessionExpired(session, now) {
|
||||
delete(s.data, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sessionExpired(session *AuthSession, now time.Time) bool {
|
||||
if session == nil {
|
||||
return true
|
||||
}
|
||||
if session.CreatedAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return now.After(session.CreatedAt.Add(sessionTTL))
|
||||
}
|
||||
|
||||
type TokenData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
ExpiresAt string `json:"expiresAt,omitempty"`
|
||||
AuthMethod string `json:"authMethod,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
type socialTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
type registerClientResponse struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
type createTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
type userInfoResponse struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type deviceRegistration struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
type RefreshTokenInvalidError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *RefreshTokenInvalidError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
body := strings.TrimSpace(e.Body)
|
||||
if body == "" {
|
||||
return "kiro refresh token invalid (invalid_grant)"
|
||||
}
|
||||
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
|
||||
}
|
||||
|
||||
func GenerateSessionID() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
return randomURLSafe(16)
|
||||
}
|
||||
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
return randomURLSafe(32)
|
||||
}
|
||||
|
||||
func randomURLSafe(n int) (string, error) {
|
||||
buf := make([]byte, n)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
|
||||
params := url.Values{}
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("redirect_from", "KiroIDE")
|
||||
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
|
||||
}
|
||||
|
||||
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
|
||||
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
|
||||
if redirectURI == "" {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSpace(callbackPath)
|
||||
if path == "" {
|
||||
path = "/oauth/callback"
|
||||
} else if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
fullRedirectURI := redirectURI + path
|
||||
if option := strings.TrimSpace(loginOption); option != "" {
|
||||
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
|
||||
}
|
||||
return fullRedirectURI
|
||||
}
|
||||
|
||||
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
|
||||
payload := map[string]string{
|
||||
"code": code,
|
||||
"code_verifier": codeVerifier,
|
||||
"redirect_uri": redirectURI,
|
||||
}
|
||||
var resp socialTokenResponse
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
return &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
|
||||
payload := map[string]string{
|
||||
"refreshToken": refreshToken,
|
||||
}
|
||||
var resp socialTokenResponse
|
||||
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
return &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: provider,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]any{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||
"redirectUris": []string{redirectURI},
|
||||
"issuerUrl": issuerURL,
|
||||
}
|
||||
var resp registerClientResponse
|
||||
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scopes", strings.Join([]string{
|
||||
"codewhisperer:completions",
|
||||
"codewhisperer:analysis",
|
||||
"codewhisperer:conversations",
|
||||
"codewhisperer:transformations",
|
||||
"codewhisperer:taskassist",
|
||||
}, " "))
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
|
||||
}
|
||||
|
||||
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"code": code,
|
||||
"codeVerifier": codeVerifier,
|
||||
"redirectUri": redirectURI,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
var resp createTokenResponse
|
||||
accountKey := BuildAccountKey(clientID, "", "", "", 0)
|
||||
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
token := &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}
|
||||
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"refreshToken": refreshToken,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
var resp createTokenResponse
|
||||
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
|
||||
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
token := &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}
|
||||
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return ""
|
||||
}
|
||||
var resp userInfoResponse
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
}
|
||||
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(resp.Email)
|
||||
}
|
||||
|
||||
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
|
||||
var token TokenData
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
|
||||
}
|
||||
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
||||
if strings.TrimSpace(token.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
|
||||
var reg deviceRegistration
|
||||
if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse device registration: %w", err)
|
||||
}
|
||||
if reg.ClientID != "" {
|
||||
token.ClientID = reg.ClientID
|
||||
}
|
||||
if reg.ClientSecret != "" {
|
||||
token.ClientSecret = reg.ClientSecret
|
||||
}
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func getOIDCEndpoint(region string) string {
|
||||
if strings.TrimSpace(oidcEndpointOverride) != "" {
|
||||
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||
}
|
||||
|
||||
func oidcHeaders(accountKey, machineID string) map[string]string {
|
||||
headers := BuildOIDCHeaders(accountKey, machineID)
|
||||
if headers["amz-sdk-invocation-id"] == "" {
|
||||
headers["amz-sdk-invocation-id"] = uuid.NewString()
|
||||
}
|
||||
if headers["amz-sdk-request"] == "" {
|
||||
headers["amz-sdk-request"] = "attempt=1; max=4"
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
|
||||
client, err := newHTTPClient(proxyURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(encoded)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if payload != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
for key, value := range extraHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyText := strings.TrimSpace(string(respBody))
|
||||
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
|
||||
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
|
||||
}
|
||||
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
|
||||
}
|
||||
if out == nil || len(respBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(respBody, out)
|
||||
}
|
||||
|
||||
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
|
||||
_, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport := &http.Transport{}
|
||||
if parsed != nil {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/refreshToken", r.URL.Path)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := socialAuthEndpointURL
|
||||
socialAuthEndpointURL = server.URL
|
||||
t.Cleanup(func() { socialAuthEndpointURL = previous })
|
||||
|
||||
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
|
||||
require.Error(t, err)
|
||||
|
||||
var invalid *RefreshTokenInvalidError
|
||||
require.True(t, errors.As(err, &invalid))
|
||||
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||
require.Contains(t, invalid.Body, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/token", r.URL.Path)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
|
||||
require.Error(t, err)
|
||||
|
||||
var invalid *RefreshTokenInvalidError
|
||||
require.True(t, errors.As(err, &invalid))
|
||||
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||
require.Contains(t, invalid.Body, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
|
||||
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, profileArn, token.ProfileArn)
|
||||
require.Equal(t, "kiro@example.com", token.Email)
|
||||
}
|
||||
|
||||
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
|
||||
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, profileArn, token.ProfileArn)
|
||||
require.Equal(t, "kiro@example.com", token.Email)
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build unit
|
||||
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
|
||||
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
|
||||
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
|
||||
if got != want {
|
||||
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSocialTokenRedirectURI(t *testing.T) {
|
||||
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
|
||||
want := "http://localhost:49153/oauth/callback?login_option=github"
|
||||
if got != want {
|
||||
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
|
||||
|
||||
session, ok := store.Get("expired")
|
||||
if ok || session != nil {
|
||||
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
|
||||
}
|
||||
if _, exists := store.data["expired"]; exists {
|
||||
t.Fatalf("expired session should be deleted from the store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
now := time.Now()
|
||||
for i := 0; i < sessionCleanupMin; i++ {
|
||||
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
|
||||
}
|
||||
store.setCount = sessionCleanupEvery - 1
|
||||
|
||||
store.Set("fresh", &AuthSession{CreatedAt: now})
|
||||
|
||||
if len(store.data) != 1 {
|
||||
t.Fatalf("store size = %d, want 1", len(store.data))
|
||||
}
|
||||
if _, ok := store.data["fresh"]; !ok {
|
||||
t.Fatalf("fresh session should remain after pruning")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,368 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
|
||||
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
|
||||
|
||||
var cachedWebSearchDescription atomic.Value // stores string
|
||||
|
||||
type MCPRequest struct {
|
||||
ID string `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type MCPResponse struct {
|
||||
Result *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result,omitempty"`
|
||||
Error *struct {
|
||||
Code *int `json:"code,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type WebSearchResults struct {
|
||||
Results []WebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
type WebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Snippet *string `json:"snippet,omitempty"`
|
||||
PublishedDate *int64 `json:"publishedDate,omitempty"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Domain *string `json:"domain,omitempty"`
|
||||
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
|
||||
PublicDomain *bool `json:"publicDomain,omitempty"`
|
||||
}
|
||||
|
||||
type SearchIndicator struct {
|
||||
ToolUseID string
|
||||
Query string
|
||||
Results *WebSearchResults
|
||||
}
|
||||
|
||||
func GetCachedWebSearchDescription() string {
|
||||
if v := cachedWebSearchDescription.Load(); v != nil {
|
||||
return strings.TrimSpace(v.(string))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SetCachedWebSearchDescription(desc string) {
|
||||
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
|
||||
}
|
||||
|
||||
func BuildMcpEndpoint(region string) string {
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
}
|
||||
|
||||
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
|
||||
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, item := range resp.Result.Content {
|
||||
if item.Type != "" && item.Type != "text" {
|
||||
continue
|
||||
}
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
|
||||
return &results
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExtractSearchQuery(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
arr := messages.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
msg := arr[i]
|
||||
if msg.Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
text := extractSearchText(msg.Get("content"))
|
||||
const prefix = "Perform a web search for the query: "
|
||||
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||
if text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractSearchText(content gjson.Result) string {
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
if !content.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func GenerateToolUseID() string {
|
||||
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
|
||||
}
|
||||
|
||||
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body, err
|
||||
}
|
||||
rawTools, ok := payload["tools"].([]interface{})
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
replaced := make([]interface{}, 0, len(rawTools))
|
||||
for _, rawTool := range rawTools {
|
||||
tool, ok := rawTool.(map[string]interface{})
|
||||
if !ok {
|
||||
replaced = append(replaced, rawTool)
|
||||
continue
|
||||
}
|
||||
name := getInterfaceString(tool["name"])
|
||||
toolType := getInterfaceString(tool["type"])
|
||||
if !isWebSearchToolName(name, toolType) {
|
||||
replaced = append(replaced, rawTool)
|
||||
continue
|
||||
}
|
||||
replaced = append(replaced, map[string]interface{}{
|
||||
"name": "web_search",
|
||||
"description": minimalWebSearchDescription,
|
||||
"input_schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The search query to execute",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
payload["tools"] = replaced
|
||||
updated, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
|
||||
}
|
||||
|
||||
rawMessages, ok := payload["messages"].([]interface{})
|
||||
if !ok {
|
||||
return claudePayload, fmt.Errorf("claude payload missing messages array")
|
||||
}
|
||||
|
||||
assistantMsg := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": query},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userContent := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": formatToolResultText(results),
|
||||
},
|
||||
}
|
||||
if guidance := searchGuidanceText(); guidance != "" {
|
||||
userContent = append(userContent, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": guidance,
|
||||
})
|
||||
}
|
||||
userMsg := map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": userContent,
|
||||
}
|
||||
|
||||
rawMessages = append(rawMessages, assistantMsg, userMsg)
|
||||
payload["messages"] = rawMessages
|
||||
updated, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
|
||||
if len(searches) == 0 {
|
||||
return responsePayload, nil
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(responsePayload, &response); err != nil {
|
||||
return responsePayload, err
|
||||
}
|
||||
content, _ := response["content"].([]interface{})
|
||||
updated := make([]interface{}, 0, len(searches)*2+len(content))
|
||||
for _, search := range searches {
|
||||
updated = append(updated, map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": search.ToolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": search.Query},
|
||||
})
|
||||
updated = append(updated, map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"content": buildSearchResultContent(search.Results),
|
||||
})
|
||||
}
|
||||
updated = append(updated, content...)
|
||||
response["content"] = updated
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return responsePayload, err
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
|
||||
content := make([]map[string]interface{}, 0)
|
||||
if results == nil {
|
||||
return content
|
||||
}
|
||||
for _, result := range results.Results {
|
||||
snippet := ""
|
||||
if result.Snippet != nil {
|
||||
snippet = strings.TrimSpace(*result.Snippet)
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": result.Title,
|
||||
"url": result.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
|
||||
content := gjson.GetBytes(responsePayload, "content")
|
||||
if !content.IsArray() {
|
||||
return "", "", false
|
||||
}
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() != "tool_use" {
|
||||
continue
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if !isWebSearchToolName(name, "") {
|
||||
continue
|
||||
}
|
||||
query = strings.TrimSpace(block.Get("input.query").String())
|
||||
if query == "" {
|
||||
continue
|
||||
}
|
||||
return block.Get("id").String(), query, true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func isWebSearchToolName(name, toolType string) bool {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
toolType = strings.ToLower(strings.TrimSpace(toolType))
|
||||
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
||||
return true
|
||||
}
|
||||
switch name {
|
||||
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getInterfaceString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(val)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(val))
|
||||
}
|
||||
}
|
||||
|
||||
func formatToolResultText(results *WebSearchResults) string {
|
||||
if results == nil || len(results.Results) == 0 {
|
||||
return "No search results found."
|
||||
}
|
||||
payload, err := json.MarshalIndent(results.Results, "", " ")
|
||||
if err != nil {
|
||||
return "Found search results, but failed to format them."
|
||||
}
|
||||
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
|
||||
}
|
||||
|
||||
func searchGuidanceText() string {
|
||||
now := time.Now()
|
||||
return fmt.Sprintf(`<search_guidance>
|
||||
Current date: %s (%s)
|
||||
|
||||
IMPORTANT: Evaluate the search results above carefully. If the results are:
|
||||
- Mostly spam, SEO junk, or unrelated websites
|
||||
- Missing actual information about the query topic
|
||||
- Outdated or not matching the requested time frame
|
||||
|
||||
Then you MUST use the web_search tool again with a refined query. Try:
|
||||
- Rephrasing in English for better coverage
|
||||
- Using more specific keywords
|
||||
- Adding date context
|
||||
|
||||
Do NOT apologize for bad results without first attempting a re-search.
|
||||
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BufferedStreamResult struct {
|
||||
StopReason string
|
||||
WebSearchQuery string
|
||||
WebSearchToolUseID string
|
||||
HasWebSearchToolUse bool
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if results != nil {
|
||||
for _, result := range results.Results {
|
||||
snippet := ""
|
||||
if result.Snippet != nil {
|
||||
snippet = strings.TrimSpace(*result.Snippet)
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": result.Title,
|
||||
"url": result.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
|
||||
events := []map[string]interface{}{
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": toolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
}
|
||||
|
||||
result := make([][]byte, 0, len(events))
|
||||
for _, event := range events {
|
||||
eventType, _ := event["type"].(string)
|
||||
payload, _ := json.Marshal(event)
|
||||
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
var currentToolName string
|
||||
currentToolIndex := -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "message_delta":
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
|
||||
result.StopReason = stopReason
|
||||
}
|
||||
}
|
||||
case "content_block_start":
|
||||
contentBlock, ok := event["content_block"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := contentBlock["type"].(string)
|
||||
if blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
currentToolName, _ = contentBlock["name"].(string)
|
||||
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
|
||||
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
case "content_block_delta":
|
||||
if currentToolName == "" {
|
||||
continue
|
||||
}
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
deltaType, _ := delta["type"].(string)
|
||||
if deltaType != "input_json_delta" {
|
||||
continue
|
||||
}
|
||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partialJSON)
|
||||
}
|
||||
case "content_block_stop":
|
||||
if !isWebSearchToolName(currentToolName, "") {
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
continue
|
||||
}
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
|
||||
result.WebSearchQuery = strings.TrimSpace(input["query"])
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
|
||||
filtered := make([][]byte, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
|
||||
if shouldForward {
|
||||
filtered = append(filtered, adjusted)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
return filterSSEChunk(chunk, -1, offset)
|
||||
}
|
||||
|
||||
func MaxContentBlockIndex(chunks [][]byte) int {
|
||||
maxIndex := -1
|
||||
for _, chunk := range chunks {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
|
||||
maxIndex = int(idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return maxIndex
|
||||
}
|
||||
|
||||
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
var builder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
|
||||
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
}
|
||||
builder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||
continue
|
||||
}
|
||||
adjusted := adjustEventPayload(payload, indexOffset)
|
||||
if adjusted == "" {
|
||||
continue
|
||||
}
|
||||
builder.WriteString("data: " + adjusted + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
return []byte(builder.String()), true
|
||||
}
|
||||
|
||||
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
|
||||
if payload == "" {
|
||||
return false
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return false
|
||||
}
|
||||
eventType, _ := event["type"].(string)
|
||||
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
|
||||
return true
|
||||
}
|
||||
if webSearchToolUseIndex < 0 {
|
||||
return false
|
||||
}
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func adjustEventPayload(payload string, indexOffset int) string {
|
||||
if payload == "" || indexOffset == 0 {
|
||||
return payload
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return payload
|
||||
}
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
if adjusted, err := json.Marshal(event); err == nil {
|
||||
return string(adjusted)
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
|
||||
snippet := "result snippet"
|
||||
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
|
||||
Results: []WebSearchResult{
|
||||
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
|
||||
},
|
||||
}, 0)
|
||||
|
||||
require.Len(t, events, 5)
|
||||
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
|
||||
require.Contains(t, string(events[0]), `"input":{}`)
|
||||
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
|
||||
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
|
||||
require.NotContains(t, string(events[3]), `"tool_use_id"`)
|
||||
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
|
||||
}
|
||||
|
||||
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||
}
|
||||
|
||||
result := AnalyzeBufferedStream(chunks)
|
||||
require.True(t, result.HasWebSearchToolUse)
|
||||
require.Equal(t, "golang concurrency", result.WebSearchQuery)
|
||||
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
|
||||
require.Equal(t, 1, result.WebSearchToolUseIndex)
|
||||
require.Equal(t, "tool_use", result.StopReason)
|
||||
}
|
||||
|
||||
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||
}
|
||||
|
||||
filtered := FilterChunksForClient(chunks, 1, 2)
|
||||
require.NotEmpty(t, filtered)
|
||||
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
|
||||
require.NotContains(t, joined, `"type":"message_start"`)
|
||||
require.NotContains(t, joined, `"type":"message_delta"`)
|
||||
require.NotContains(t, joined, `"name":"web_search"`)
|
||||
require.Contains(t, joined, `"index":2`)
|
||||
require.Equal(t, 2, MaxContentBlockIndex(filtered))
|
||||
}
|
||||
|
||||
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
|
||||
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
|
||||
require.False(t, shouldForward)
|
||||
|
||||
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
|
||||
require.True(t, shouldForward)
|
||||
require.Contains(t, string(adjusted), `"index":3`)
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools":[{"type":"web_search_20250305","description":"old"}],
|
||||
"messages":[{"role":"user","content":"golang"}]
|
||||
}`)
|
||||
|
||||
updated, err := ReplaceWebSearchToolDescription(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
|
||||
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
|
||||
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
|
||||
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
|
||||
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
|
||||
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
|
||||
}
|
||||
|
||||
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[{"role":"user","content":"what is golang"}]
|
||||
}`)
|
||||
results := &WebSearchResults{
|
||||
Results: []WebSearchResult{
|
||||
{Title: "Go", URL: "https://go.dev"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
|
||||
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
|
||||
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
|
||||
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
|
||||
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
|
||||
}
|
||||
|
||||
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
|
||||
response := []byte(`{
|
||||
"content":[
|
||||
{"type":"text","text":"let me search"},
|
||||
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
|
||||
]
|
||||
}`)
|
||||
|
||||
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "srvtoolu_next", toolUseID)
|
||||
require.Equal(t, "golang concurrency", query)
|
||||
}
|
||||
|
||||
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
|
||||
response := []byte(`{
|
||||
"id":"msg_1",
|
||||
"type":"message",
|
||||
"role":"assistant",
|
||||
"model":"kiro",
|
||||
"content":[{"type":"text","text":"final"}],
|
||||
"stop_reason":"end_turn",
|
||||
"usage":{"input_tokens":1,"output_tokens":1}
|
||||
}`)
|
||||
|
||||
snippet := "result snippet"
|
||||
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
|
||||
{
|
||||
ToolUseID: "srvtoolu_test",
|
||||
Query: "golang",
|
||||
Results: &WebSearchResults{
|
||||
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &decoded))
|
||||
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
|
||||
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
|
||||
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
|
||||
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
|
||||
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
|
||||
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
|
||||
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
|
||||
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
|
||||
}
|
||||
|
||||
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
|
||||
resp := &MCPResponse{
|
||||
Result: &struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
}{
|
||||
Content: []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}{
|
||||
{
|
||||
Type: "text",
|
||||
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ParseSearchResults(resp)
|
||||
require.NotNil(t, results)
|
||||
require.Len(t, results.Results, 1)
|
||||
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
|
||||
require.Equal(t, "doc-1", *results.Results[0].ID)
|
||||
require.Equal(t, "go.dev", *results.Results[0].Domain)
|
||||
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
|
||||
require.True(t, *results.Results[0].PublicDomain)
|
||||
}
|
||||
|
||||
func TestSearchGuidanceText_IsStructured(t *testing.T) {
|
||||
guidance := searchGuidanceText()
|
||||
require.Contains(t, guidance, "<search_guidance>")
|
||||
require.Contains(t, guidance, "Current date:")
|
||||
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
|
||||
require.Contains(t, guidance, "Rephrasing in English for better coverage")
|
||||
}
|
||||
@@ -0,0 +1,479 @@
|
||||
package kirocooldown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
MinRequestInterval = time.Second
|
||||
MaxRequestInterval = 2 * time.Second
|
||||
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
|
||||
ShortCooldown = time.Minute
|
||||
MaxCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
|
||||
redisTimeout = 3 * time.Second
|
||||
activeTTL = 10 * time.Second
|
||||
stateTTL = 25 * time.Hour
|
||||
keyPrefix = "kiro:cooldown:"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||
|
||||
reserveRequestScript = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
|
||||
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
|
||||
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
|
||||
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
|
||||
local interval_ms = tonumber(ARGV[1])
|
||||
local active_ttl_ms = tonumber(ARGV[2])
|
||||
local state_ttl_ms = tonumber(ARGV[3])
|
||||
|
||||
if cooldown_until_ms > now_ms then
|
||||
return {1, cooldown_until_ms - now_ms, cooldown_reason}
|
||||
end
|
||||
|
||||
if cooldown_until_ms > 0 then
|
||||
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
|
||||
end
|
||||
|
||||
local next_slot_ms = now_ms
|
||||
if last_request_ms > 0 then
|
||||
local candidate_ms = last_request_ms + interval_ms
|
||||
if candidate_ms > now_ms then
|
||||
next_slot_ms = candidate_ms
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
|
||||
if fail_count > 0 or cooldown_until_ms > now_ms then
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
else
|
||||
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
|
||||
end
|
||||
return {0, next_slot_ms - now_ms, ''}
|
||||
`)
|
||||
|
||||
mark429Script = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
|
||||
local short_cooldown_ms = tonumber(ARGV[1])
|
||||
local max_cooldown_ms = tonumber(ARGV[2])
|
||||
local state_ttl_ms = tonumber(ARGV[3])
|
||||
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
|
||||
if cooldown_ms > max_cooldown_ms then
|
||||
cooldown_ms = max_cooldown_ms
|
||||
end
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', fail_count,
|
||||
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||
'cooldown_reason', ARGV[4]
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
return cooldown_ms
|
||||
`)
|
||||
|
||||
markSuccessScript = redis.NewScript(`
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', 0,
|
||||
'cooldown_until_ms', 0,
|
||||
'cooldown_reason', ''
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
|
||||
return 1
|
||||
`)
|
||||
|
||||
markSuspendedScript = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local cooldown_ms = tonumber(ARGV[1])
|
||||
local state_ttl_ms = tonumber(ARGV[2])
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', 0,
|
||||
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||
'cooldown_reason', ARGV[3]
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
return cooldown_ms
|
||||
`)
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
remaining time.Duration
|
||||
reason string
|
||||
}
|
||||
|
||||
type State struct {
|
||||
Active bool
|
||||
Reason string
|
||||
CooldownUntil time.Time
|
||||
Remaining time.Duration
|
||||
FailCount int
|
||||
}
|
||||
|
||||
func NewError(remaining time.Duration, reason string) error {
|
||||
return &Error{remaining: remaining, reason: reason}
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.reason == "" {
|
||||
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
|
||||
}
|
||||
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
|
||||
}
|
||||
|
||||
func Calculate429Cooldown(retryCount int) time.Duration {
|
||||
if retryCount < 0 {
|
||||
retryCount = 0
|
||||
}
|
||||
cooldown := ShortCooldown * time.Duration(1<<retryCount)
|
||||
if cooldown > MaxCooldown {
|
||||
return MaxCooldown
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
rngMu sync.Mutex
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func NewStore(client *redis.Client) *Store {
|
||||
return &Store{
|
||||
client: client,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
values, err := reserveRequestScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
s.nextInterval().Milliseconds(),
|
||||
activeTTL.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
|
||||
}
|
||||
parts, ok := values.([]interface{})
|
||||
if !ok || len(parts) != 3 {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
|
||||
}
|
||||
state, err := luaInt64(parts[0])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
|
||||
}
|
||||
waitMS, err := luaInt64(parts[1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
|
||||
}
|
||||
reason, err := luaString(parts[2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
|
||||
}
|
||||
if state == 1 {
|
||||
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
|
||||
}
|
||||
if waitMS <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return time.Duration(waitMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
|
||||
if err := s.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
if err := markSuccessScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
activeTTL.Milliseconds(),
|
||||
).Err(); err != nil {
|
||||
return fmt.Errorf("kiro cooldown mark success: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
result, err := mark429Script.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
ShortCooldown.Milliseconds(),
|
||||
MaxCooldown.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
CooldownReason429,
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||
}
|
||||
cooldownMS, err := luaInt64(result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||
}
|
||||
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
result, err := markSuspendedScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
LongCooldown.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
CooldownReasonSuspended,
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||
}
|
||||
cooldownMS, err := luaInt64(result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||
}
|
||||
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
values, err := s.client.HMGet(
|
||||
cacheCtx,
|
||||
RedisKey(tokenKey),
|
||||
"cooldown_until_ms",
|
||||
"cooldown_reason",
|
||||
"fail_count",
|
||||
).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
|
||||
}
|
||||
if len(values) != 3 {
|
||||
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
|
||||
}
|
||||
|
||||
cooldownUntilMS, err := luaInt64(values[0])
|
||||
if err != nil && values[0] != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
|
||||
}
|
||||
reason, err := luaString(values[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
|
||||
}
|
||||
failCount, err := luaInt64(values[2])
|
||||
if err != nil && values[2] != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
|
||||
}
|
||||
if cooldownUntilMS <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cooldownUntil := time.UnixMilli(cooldownUntilMS)
|
||||
remaining := time.Until(cooldownUntil)
|
||||
if remaining <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &State{
|
||||
Active: true,
|
||||
Reason: reason,
|
||||
CooldownUntil: cooldownUntil,
|
||||
Remaining: remaining,
|
||||
FailCount: int(failCount),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
uniqueKeys := make([]string, 0, len(tokenKeys))
|
||||
seen := make(map[string]struct{}, len(tokenKeys))
|
||||
for _, tokenKey := range tokenKeys {
|
||||
tokenKey = strings.TrimSpace(tokenKey)
|
||||
if tokenKey == "" {
|
||||
continue
|
||||
}
|
||||
redisKey := RedisKey(tokenKey)
|
||||
if _, ok := seen[redisKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[redisKey] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, redisKey)
|
||||
}
|
||||
if len(uniqueKeys) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
type candidate struct {
|
||||
redisKey string
|
||||
cooldownUntilMS int64
|
||||
failCount int64
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
var best *candidate
|
||||
|
||||
pipe := s.client.Pipeline()
|
||||
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
|
||||
for _, redisKey := range uniqueKeys {
|
||||
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
|
||||
}
|
||||
if _, err := pipe.Exec(cacheCtx); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
|
||||
}
|
||||
|
||||
for i, cmd := range cmds {
|
||||
values, err := cmd.Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
|
||||
}
|
||||
if len(values) != 3 {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
|
||||
}
|
||||
cooldownUntilMS, err := luaInt64(values[0])
|
||||
if err != nil && values[0] != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
|
||||
}
|
||||
reason, err := luaString(values[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
|
||||
}
|
||||
failCount, err := luaInt64(values[2])
|
||||
if err != nil && values[2] != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
|
||||
}
|
||||
if cooldownUntilMS <= now || reason != CooldownReason429 {
|
||||
continue
|
||||
}
|
||||
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
|
||||
if best == nil ||
|
||||
current.cooldownUntilMS < best.cooldownUntilMS ||
|
||||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
|
||||
best = current
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
|
||||
}
|
||||
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func RedisKey(tokenKey string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
|
||||
digest := hex.EncodeToString(sum[:])
|
||||
return keyPrefix + "{" + digest + "}"
|
||||
}
|
||||
|
||||
func ActiveTTL() time.Duration {
|
||||
return activeTTL
|
||||
}
|
||||
|
||||
func StateTTL() time.Duration {
|
||||
return stateTTL
|
||||
}
|
||||
|
||||
func (s *Store) validate() error {
|
||||
if s == nil || s.client == nil {
|
||||
return ErrStoreUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) nextInterval() time.Duration {
|
||||
s.rngMu.Lock()
|
||||
defer s.rngMu.Unlock()
|
||||
if MaxRequestInterval <= MinRequestInterval {
|
||||
return MinRequestInterval
|
||||
}
|
||||
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
|
||||
}
|
||||
|
||||
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, redisTimeout)
|
||||
}
|
||||
|
||||
func luaInt64(v any) (int64, error) {
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return n, nil
|
||||
case int:
|
||||
return int64(n), nil
|
||||
case string:
|
||||
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
|
||||
case []byte:
|
||||
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func luaString(v any) (string, error) {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s, nil
|
||||
case []byte:
|
||||
return string(s), nil
|
||||
case nil:
|
||||
return "", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported lua string type %T", v)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package kirocooldown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
|
||||
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
|
||||
|
||||
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
|
||||
}
|
||||
if cleared {
|
||||
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
|
||||
store := NewStore(nil)
|
||||
|
||||
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
|
||||
if err == nil {
|
||||
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
|
||||
}
|
||||
if cleared {
|
||||
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
service.PlatformKiro: 1,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
|
||||
@@ -41,6 +41,9 @@ func RegisterAdminRoutes(
|
||||
// Antigravity OAuth
|
||||
registerAntigravityOAuthRoutes(admin, h)
|
||||
|
||||
// Kiro OAuth / IDC
|
||||
registerKiroOAuthRoutes(admin, h)
|
||||
|
||||
// 代理管理
|
||||
registerProxyRoutes(admin, h)
|
||||
|
||||
@@ -315,6 +318,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
accounts.GET("/kiro/default-model-mapping", h.Admin.Account.GetKiroDefaultModelMapping)
|
||||
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
@@ -367,6 +371,17 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
|
||||
}
|
||||
}
|
||||
|
||||
func registerKiroOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
kiro := admin.Group("/kiro")
|
||||
{
|
||||
kiro.POST("/oauth/auth-url", h.Admin.KiroOAuth.GenerateAuthURL)
|
||||
kiro.POST("/oauth/idc-auth-url", h.Admin.KiroOAuth.GenerateIDCAuthURL)
|
||||
kiro.POST("/oauth/exchange-code", h.Admin.KiroOAuth.ExchangeCode)
|
||||
kiro.POST("/oauth/refresh-token", h.Admin.KiroOAuth.RefreshToken)
|
||||
kiro.POST("/oauth/import-token", h.Admin.KiroOAuth.ImportToken)
|
||||
}
|
||||
}
|
||||
|
||||
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
|
||||
@@ -48,6 +48,13 @@ type Account struct {
|
||||
TempUnschedulableUntil *time.Time
|
||||
TempUnschedulableReason string
|
||||
|
||||
KiroQuotaState string
|
||||
KiroQuotaReason string
|
||||
KiroQuotaResetAt *time.Time
|
||||
KiroRuntimeState string
|
||||
KiroRuntimeReason string
|
||||
KiroRuntimeResetAt *time.Time
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
@@ -164,6 +171,10 @@ func (a *Account) IsGemini() bool {
|
||||
return a.Platform == PlatformGemini
|
||||
}
|
||||
|
||||
func (a *Account) IsKiro() bool {
|
||||
return a.Platform == PlatformKiro
|
||||
}
|
||||
|
||||
func (a *Account) GeminiOAuthType() string {
|
||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||
return ""
|
||||
@@ -478,17 +489,17 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
|
||||
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
||||
if a.Credentials == nil {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
// 部分平台在未显式配置 model_mapping 时仍应使用默认映射,
|
||||
// 以限制可调度/可转发的模型集合。
|
||||
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||
return defaults
|
||||
}
|
||||
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||
return nil
|
||||
}
|
||||
if len(rawMapping) == 0 {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||
return defaults
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -510,13 +521,23 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
||||
return result
|
||||
}
|
||||
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
if defaults := defaultModelMappingForPlatform(a.Platform); defaults != nil {
|
||||
return defaults
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultModelMappingForPlatform(platform string) map[string]string {
|
||||
switch platform {
|
||||
case domain.PlatformAntigravity:
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
case domain.PlatformKiro:
|
||||
return domain.DefaultKiroModelMapping
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mapPtr(m map[string]any) uintptr {
|
||||
if m == nil {
|
||||
return 0
|
||||
@@ -608,8 +629,8 @@ func resolveRequestedModelInMapping(mapping map[string]string, requestedModel st
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)。
|
||||
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时也会先回退到默认映射。
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
@@ -622,8 +643,8 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
|
||||
}
|
||||
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)。
|
||||
// 对带默认映射的平台(如 Antigravity/Kiro),未显式配置时返回默认映射结果。
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
@@ -725,6 +746,9 @@ func (a *Account) GetBaseURL() string {
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
if a.Platform == PlatformKiro {
|
||||
return ""
|
||||
}
|
||||
return "https://api.anthropic.com"
|
||||
}
|
||||
if a.Platform == PlatformAntigravity {
|
||||
|
||||
@@ -180,7 +180,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
|
||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||
}
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini || g.Platform == PlatformKiro) {
|
||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
@@ -66,6 +67,7 @@ type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
kiroTokenProvider *KiroTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
@@ -77,6 +79,7 @@ func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
kiroTokenProvider *KiroTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
@@ -86,6 +89,7 @@ func NewAccountTestService(
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
kiroTokenProvider: kiroTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
@@ -192,6 +196,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.IsKiro() && account.Type == AccountTypeOAuth {
|
||||
return s.testKiroAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
@@ -240,6 +248,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" && account.Platform == PlatformKiro {
|
||||
return s.sendErrorAndEnd(c, "Kiro API Key accounts require a Base URL")
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
@@ -388,6 +399,149 @@ func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Con
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) testKiroAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
testModelID := strings.TrimSpace(modelID)
|
||||
if testModelID == "" {
|
||||
testModelID = "claude-sonnet-4-6"
|
||||
}
|
||||
if mappedModel := account.GetMappedModel(testModelID); strings.TrimSpace(mappedModel) != "" {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Kiro account type: %s", account.Type))
|
||||
}
|
||||
|
||||
if s.kiroTokenProvider == nil {
|
||||
return s.sendErrorAndEnd(c, "Kiro token provider not configured")
|
||||
}
|
||||
|
||||
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get Kiro access token: %s", err.Error()))
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
resp, err := s.executeKiroTestUpstream(ctx, account, payloadBytes, testModelID, accessToken)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, formatKiroTestError(resp.StatusCode, body, testModelID, account))
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
_, streamErr := kiropkg.StreamEventStreamAsAnthropic(ctx, resp.Body, pw, testModelID, estimateKiroInputTokens(payloadBytes))
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
return s.processClaudeStream(c, pr)
|
||||
}
|
||||
|
||||
func formatKiroTestError(statusCode int, body []byte, requestedModel string, account *Account) string {
|
||||
return fmt.Sprintf("API returned %d: %s", statusCode, string(body))
|
||||
}
|
||||
|
||||
func (s *AccountTestService) executeKiroTestUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string) (*http.Response, error) {
|
||||
modelID := kiropkg.MapModel(mappedModel)
|
||||
currentToken := token
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := buildResult.Payload
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
proxyURL := kiroProxyURL(account)
|
||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
maxRetries := 2
|
||||
for idx, endpoint := range endpoints {
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
|
||||
if idx+1 < len(endpoints) {
|
||||
_ = resp.Body.Close()
|
||||
break
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
|
||||
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = buildResult.Payload
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("kiro upstream endpoints exhausted")
|
||||
}
|
||||
|
||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
region := bedrockRuntimeRegion(account)
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
)
|
||||
|
||||
func TestAccountTestService_KiroAPIKeyUsesGenericAnthropicCompatiblePath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 19,
|
||||
Name: "kiro-apikey-test",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"base_url": "https://kiro-upstream.example.com",
|
||||
"api_key": "kiro-api-key",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"invalid api key"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
req := upstream.requests[0]
|
||||
require.Equal(t, "kiro-upstream.example.com", req.URL.Host)
|
||||
require.Equal(t, "/v1/messages", req.URL.Path)
|
||||
require.Equal(t, "kiro-api-key", req.Header.Get("x-api-key"))
|
||||
require.Empty(t, req.Header.Get("Authorization"))
|
||||
require.Equal(t, claude.APIKeyBetaHeader, req.Header.Get("anthropic-beta"))
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroAPIKeyWithoutBaseURLErrors(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Name: "kiro-apikey-missing-base-url",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "kiro-api-key",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: &queuedHTTPUpstream{},
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Base URL")
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountTestService_KiroUsesKiroUpstreamInsteadOfAnthropic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "kiro-test",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/TESTSOCIAL",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{1: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"type":"authentication_error","message":"Invalid bearer token"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "gpt-4o", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
req := upstream.requests[0]
|
||||
require.Equal(t, "q.us-east-1.amazonaws.com", req.URL.Host)
|
||||
require.Equal(t, "/generateAssistantResponse", req.URL.Path)
|
||||
require.Equal(t, "Bearer kiro-access-token", req.Header.Get("Authorization"))
|
||||
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Empty(t, req.Header.Get("anthropic-version"))
|
||||
require.NotContains(t, req.URL.Host, "api.anthropic.com")
|
||||
}
|
||||
|
||||
func TestAccountTestService_Kiro429DoesNotFallbackToCodeWhispererEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Name: "kiro-fallback",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"api_region": "us-west-2",
|
||||
"region": "us-west-2",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTFALLBACK",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{2: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusTooManyRequests, `{"message":"slow down"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
|
||||
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
|
||||
require.Contains(t, err.Error(), "API returned 429")
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroIDCWithoutProfileArnOmitsProfileArnAndUsesDefaultRuntimeRegion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "kiro-idc-default-region",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"auth_method": "idc",
|
||||
"provider": "AWS",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{5: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
require.Equal(t, "q.us-east-1.amazonaws.com", upstream.requests[0].URL.Host)
|
||||
body, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||
require.NoError(t, readErr)
|
||||
require.NotContains(t, string(body), `"profileArn":`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroInvalidModelErrorPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "kiro-invalid-model",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/TESTINVALIDMODEL",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, `API returned 400: {"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`, err.Error())
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "kiro-invalid-model-refresh",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{7: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-opus-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "API returned 400")
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||
require.NoError(t, readErr)
|
||||
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
|
||||
}
|
||||
|
||||
func TestAccountTestService_KiroPreferredEndpointIsIgnored(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "kiro-preferred-endpoint",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"api_region": "us-west-2",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/PREFERRED",
|
||||
"preferred_endpoint": "codewhisperer",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{6: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"type":"error","error":{"message":"Invalid bearer token"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
kiroTokenProvider: NewKiroTokenProvider(nil, nil, nil),
|
||||
httpUpstream: upstream,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
err := svc.TestAccountConnection(ctx, account.ID, "claude-sonnet-4-6", "", AccountTestModeDefault)
|
||||
require.Error(t, err)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com", upstream.requests[0].URL.Host)
|
||||
require.Empty(t, upstream.requests[0].Header.Get("X-Amz-Target"))
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccount_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "kiro-builder-id",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "idc",
|
||||
"provider": "BuilderId",
|
||||
"region": "us-east-1",
|
||||
"client_id": "builder-client-id",
|
||||
},
|
||||
}
|
||||
|
||||
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(testPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(kiroPayload), `"profileArn":`)
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccount_KiroBuilderIDUsesCredentialProfileArn(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 33,
|
||||
Name: "kiro-builder-id-cached",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "builder-id",
|
||||
"provider": "BuilderId",
|
||||
"region": "us-east-1",
|
||||
"client_id": "builder-client-id",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED",
|
||||
},
|
||||
}
|
||||
|
||||
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(testPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(kiroPayload), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/CACHED"`)
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccount_KiroEnterpriseIDCOmitsMissingProfileArn(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 4,
|
||||
Name: "kiro-enterprise-idc",
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "idc",
|
||||
"provider": "AWS",
|
||||
"region": "us-east-1",
|
||||
"client_id": "enterprise-client-id",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
|
||||
testPayload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(testPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
kiroPayload, err := buildKiroPayloadForAccount(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "kiro-access-token", "claude-sonnet-4-6", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(kiroPayload), `"profileArn":`)
|
||||
}
|
||||
@@ -103,10 +103,17 @@ type antigravityUsageCache struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// kiroUsageCache 缓存 Kiro 额度快照
|
||||
type kiroUsageCache struct {
|
||||
usageInfo *UsageInfo
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
kiroUsageErrorTTL = 1 * time.Minute // Kiro 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -118,8 +125,10 @@ type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
kiroUsageCache sync.Map // accountID -> *kiroUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
kiroUsageFlight singleflight.Group // 防止同一 Kiro 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
@@ -176,6 +185,23 @@ type AICredit struct {
|
||||
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
||||
}
|
||||
|
||||
// KiroCreditProgress 表示 Kiro 主额度或 Bonus 的用量进度。
|
||||
type KiroCreditProgress struct {
|
||||
CurrentUsage float64 `json:"current_usage"`
|
||||
UsageLimit float64 `json:"usage_limit"`
|
||||
PercentageUsed float64 `json:"percentage_used"`
|
||||
DaysRemaining int `json:"days_remaining,omitempty"`
|
||||
ExpiryDate *time.Time `json:"expiry_date,omitempty"`
|
||||
}
|
||||
|
||||
// KiroOverageInfo 表示 Kiro 账号的 overage 状态。
|
||||
type KiroOverageInfo struct {
|
||||
CurrentOverages float64 `json:"current_overages"`
|
||||
OverageCharges float64 `json:"overage_charges"`
|
||||
CurrencyCode string `json:"currency_code,omitempty"`
|
||||
CurrencySymbol string `json:"currency_symbol,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
Source string `json:"source,omitempty"` // "passive" or "active"
|
||||
@@ -203,6 +229,21 @@ type UsageInfo struct {
|
||||
// Antigravity AI Credits 余额
|
||||
AICredits []AICredit `json:"ai_credits,omitempty"`
|
||||
|
||||
// Kiro Credits 额度与 overage 信息
|
||||
KiroSubscriptionName string `json:"kiro_subscription_name,omitempty"`
|
||||
KiroSubscriptionType string `json:"kiro_subscription_type,omitempty"`
|
||||
KiroResetAt *time.Time `json:"kiro_reset_at,omitempty"`
|
||||
KiroOveragesEnabled bool `json:"kiro_overages_enabled,omitempty"`
|
||||
KiroCredit *KiroCreditProgress `json:"kiro_credit,omitempty"`
|
||||
KiroBonus *KiroCreditProgress `json:"kiro_bonus,omitempty"`
|
||||
KiroOverage *KiroOverageInfo `json:"kiro_overage,omitempty"`
|
||||
KiroQuotaState string `json:"kiro_quota_state,omitempty"`
|
||||
KiroQuotaReason string `json:"kiro_quota_reason,omitempty"`
|
||||
KiroQuotaResetAt *time.Time `json:"kiro_quota_reset_at,omitempty"`
|
||||
KiroRuntimeState string `json:"kiro_runtime_state,omitempty"`
|
||||
KiroRuntimeReason string `json:"kiro_runtime_reason,omitempty"`
|
||||
KiroRuntimeResetAt *time.Time `json:"kiro_runtime_reset_at,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
@@ -266,6 +307,7 @@ type AccountUsageService struct {
|
||||
cache *UsageCache
|
||||
identityCache IdentityCache
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
kiroCooldownStore KiroCooldownStore
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
@@ -291,6 +333,13 @@ func NewAccountUsageService(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) SetKiroCooldownStore(store KiroCooldownStore) *AccountUsageService {
|
||||
if s != nil {
|
||||
s.kiroCooldownStore = store
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// GetUsage 获取账号使用量
|
||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
@@ -317,6 +366,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return usage, err
|
||||
}
|
||||
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
return s.getKiroUsage(ctx, account, "active", false)
|
||||
}
|
||||
|
||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||
if account.Platform == PlatformAntigravity {
|
||||
usage, err := s.getAntigravityUsage(ctx, account)
|
||||
@@ -425,6 +478,13 @@ func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformKiro {
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("passive usage only supported for Kiro OAuth accounts")
|
||||
}
|
||||
return s.getKiroUsage(ctx, account, "passive", false)
|
||||
}
|
||||
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroAPIKeyUnsupported(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9101,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.Nil(t, usage)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "does not support usage query")
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetPassiveUsage_KiroAPIKeyUnsupported(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9102,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
usage, err := svc.GetPassiveUsage(context.Background(), account.ID)
|
||||
require.Nil(t, usage)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Kiro OAuth")
|
||||
}
|
||||
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
requestedModel: "any-model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping falls back to default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping rejects model outside default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "auto",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping uses default upstream mapping",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: "claude-sonnet-4.6",
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
|
||||
@@ -1716,7 +1716,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
}
|
||||
|
||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||
@@ -2008,7 +2008,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
|
||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini || group.Platform == PlatformKiro) && len(accountIDsToCopy) > 0 {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||
|
||||
@@ -198,6 +198,13 @@ func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
|
||||
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
|
||||
|
||||
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
|
||||
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
|
||||
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
|
||||
|
||||
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
|
||||
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
|
||||
|
||||
// Gemini 3.1 Pro
|
||||
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-6, // $2 per MTok
|
||||
@@ -278,13 +285,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
|
||||
return s.fallbackPrices["claude-sonnet-4.6"]
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-sonnet-4.5"]
|
||||
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-haiku-4.5"]
|
||||
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
|
||||
@@ -39,6 +39,7 @@ const (
|
||||
PlatformOpenAI = domain.PlatformOpenAI
|
||||
PlatformGemini = domain.PlatformGemini
|
||||
PlatformAntigravity = domain.PlatformAntigravity
|
||||
PlatformKiro = domain.PlatformKiro
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
|
||||
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -105,6 +109,24 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -126,7 +148,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -144,6 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 12. Handle error response with failover
|
||||
|
||||
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -102,6 +106,24 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -123,7 +145,7 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -141,6 +163,7 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 12. Handle error response with failover
|
||||
|
||||
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
@@ -56,6 +57,7 @@ const (
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
postUsageBillingTimeout = 15 * time.Second
|
||||
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
|
||||
defaultKiroStreamKeepalive = 25 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -70,6 +72,7 @@ const (
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
type kiroCooldownRecoveryAttemptedKeyType struct{}
|
||||
|
||||
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
|
||||
type accountWithLoad struct {
|
||||
@@ -78,6 +81,7 @@ type accountWithLoad struct {
|
||||
}
|
||||
|
||||
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
||||
var kiroCooldownRecoveryAttemptedKey = kiroCooldownRecoveryAttemptedKeyType{}
|
||||
|
||||
var (
|
||||
windowCostPrefetchCacheHitTotal atomic.Int64
|
||||
@@ -554,6 +558,8 @@ type GatewayService struct {
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
kiroTokenProvider *KiroTokenProvider
|
||||
kiroCooldownStore KiroCooldownStore
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
@@ -592,6 +598,8 @@ func NewGatewayService(
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
kiroTokenProvider *KiroTokenProvider,
|
||||
kiroCooldownStore KiroCooldownStore,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
rpmCache RPMCache,
|
||||
digestStore *DigestSessionStore,
|
||||
@@ -624,6 +632,8 @@ func NewGatewayService(
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
kiroTokenProvider: kiroTokenProvider,
|
||||
kiroCooldownStore: kiroCooldownStore,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
@@ -902,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
|
||||
injectMetadata bool
|
||||
metadataUserID string
|
||||
stripSystemCacheControl bool
|
||||
preserveToolChoice bool
|
||||
}
|
||||
|
||||
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
|
||||
@@ -1116,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
|
||||
if !gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
@@ -1967,6 +1984,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, useMixed) {
|
||||
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
|
||||
return s.SelectAccountWithLoadAwareness(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID)
|
||||
}
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
@@ -2346,14 +2367,91 @@ func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
return account.IsSchedulable()
|
||||
if !account.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
return s.isKiroRuntimeSchedulable(context.Background(), account)
|
||||
}
|
||||
|
||||
func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
|
||||
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
return false
|
||||
}
|
||||
return s.isKiroRuntimeSchedulable(ctx, account)
|
||||
}
|
||||
|
||||
func (s *GatewayService) isKiroRuntimeSchedulable(ctx context.Context, account *Account) bool {
|
||||
if account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth || s == nil || s.kiroCooldownStore == nil {
|
||||
return true
|
||||
}
|
||||
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(account))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return state == nil || !state.Active
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryRecoverKiroCooldownPool(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) bool {
|
||||
if s == nil || s.kiroCooldownStore == nil || ctx.Value(kiroCooldownRecoveryAttemptedKey) == true {
|
||||
return false
|
||||
}
|
||||
tokenKeys := s.kiroTransientCooldownRecoveryKeys(ctx, accounts, requestedModel, excludedIDs, allowMixedScheduling)
|
||||
if len(tokenKeys) == 0 {
|
||||
return false
|
||||
}
|
||||
cleared, err := s.kiroCooldownStore.ClearEarliestTransientCooldown(ctx, tokenKeys)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery failed: %v", err)
|
||||
return false
|
||||
}
|
||||
if cleared {
|
||||
logger.LegacyPrintf("service.gateway", "Kiro cooldown pool recovery cleared one transient cooldown")
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
func (s *GatewayService) kiroTransientCooldownRecoveryKeys(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, allowMixedScheduling bool) []string {
|
||||
tokenKeys := make([]string, 0, len(accounts))
|
||||
eligible := 0
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc == nil || acc.Platform != PlatformKiro || acc.Type != AccountTypeOAuth {
|
||||
if allowMixedScheduling {
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) ||
|
||||
!s.isAccountSchedulableForWindowCost(ctx, acc, false) ||
|
||||
!s.isAccountSchedulableForRPM(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
eligible++
|
||||
state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc))
|
||||
if err != nil || state == nil || !state.Active {
|
||||
return nil
|
||||
}
|
||||
if state.Reason != kirocooldown.CooldownReason429 {
|
||||
return nil
|
||||
}
|
||||
tokenKeys = append(tokenKeys, buildKiroAccountKey(acc))
|
||||
}
|
||||
if eligible == 0 || len(tokenKeys) != eligible {
|
||||
return nil
|
||||
}
|
||||
return tokenKeys
|
||||
}
|
||||
|
||||
// isAccountInGroup checks if the account belongs to the specified group.
|
||||
@@ -3232,6 +3330,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
|
||||
if selected == nil {
|
||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
||||
if s.tryRecoverKiroCooldownPool(ctx, accounts, requestedModel, excludedIDs, false) {
|
||||
retryCtx := context.WithValue(ctx, kiroCooldownRecoveryAttemptedKey, true)
|
||||
return s.selectAccountForModelWithPlatform(retryCtx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||
}
|
||||
@@ -3611,6 +3713,17 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
return selectionFailureDiagnosis{Category: "excluded"}
|
||||
}
|
||||
if !acc.IsSchedulable() {
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||
}
|
||||
if acc.Platform == PlatformKiro && acc.Type == AccountTypeOAuth {
|
||||
if state, err := s.getKiroCooldownState(ctx, buildKiroAccountKey(acc)); err == nil && state != nil && state.Active {
|
||||
return selectionFailureDiagnosis{
|
||||
Category: "unschedulable",
|
||||
Detail: fmt.Sprintf("kiro_runtime_%s remaining=%s", state.Reason, state.Remaining.Truncate(time.Second)),
|
||||
}
|
||||
}
|
||||
}
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||
}
|
||||
@@ -3774,6 +3887,13 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth && s.kiroTokenProvider != nil {
|
||||
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
||||
accessToken := account.GetCredential("access_token")
|
||||
@@ -4343,11 +4463,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
@@ -4371,6 +4486,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return s.forwardBedrock(ctx, c, account, parsed, startTime)
|
||||
}
|
||||
|
||||
if account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
return s.forwardKiroMessages(ctx, c, account, parsed, startTime)
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||
if account.Platform == PlatformAnthropic && c != nil {
|
||||
@@ -4425,7 +4549,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
|
||||
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
|
||||
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{
|
||||
stripSystemCacheControl: !systemRewritten,
|
||||
preserveToolChoice: account.Platform == PlatformKiro,
|
||||
}
|
||||
if s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil && fp != nil {
|
||||
@@ -4462,7 +4589,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(reqModel); next != "" && next != reqModel {
|
||||
mappedModel = next
|
||||
mappingSource = "account"
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
mappingSource = "account"
|
||||
@@ -5967,6 +6099,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" && account.Platform == PlatformKiro {
|
||||
return nil, fmt.Errorf("kiro api key account requires base_url")
|
||||
}
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
@@ -7228,10 +7363,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
keepaliveInterval := s.streamKeepaliveIntervalForAccount(account)
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
@@ -8277,6 +8409,9 @@ type recordUsageOpts struct {
|
||||
// 长上下文计费(仅 Gemini 路径需要)
|
||||
LongContextThreshold int
|
||||
LongContextMultiplier float64
|
||||
|
||||
// Kiro 账号在上游返回 auto 等无法定价模型时使用保守计费兜底。
|
||||
IsKiroAccount bool
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -8414,6 +8549,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
opts.IsKiroAccount = account != nil && account.Platform == PlatformKiro
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
@@ -8492,6 +8628,28 @@ func (s *GatewayService) calculateRecordUsageCost(
|
||||
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
}
|
||||
|
||||
const kiroConservativeFallbackBillingModel = "claude-opus-4-6"
|
||||
|
||||
func shouldUseKiroConservativeBillingFallback(result *ForwardResult, billingModel string, opts *recordUsageOpts) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return opts != nil && opts.IsKiroAccount
|
||||
}
|
||||
|
||||
func (s *GatewayService) calculateKiroConservativeTokenCost(tokens UsageTokens, multiplier float64) *CostBreakdown {
|
||||
if s == nil || s.billingService == nil {
|
||||
return nil
|
||||
}
|
||||
cost, err := s.billingService.CalculateCost(kiroConservativeFallbackBillingModel, tokens, multiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate conservative Kiro fallback cost failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
return cost
|
||||
}
|
||||
|
||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||
@@ -8596,6 +8754,12 @@ func (s *GatewayService) calculateTokenCost(
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
if shouldUseKiroConservativeBillingFallback(result, billingModel, opts) {
|
||||
if fallback := s.calculateKiroConservativeTokenCost(tokens, multiplier); fallback != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Using conservative Kiro fallback pricing for model=%s", billingModel)
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
return &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
return cost
|
||||
@@ -8856,6 +9020,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射:
|
||||
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
|
||||
@@ -9486,6 +9654,19 @@ func reconcileCachedTokens(usage map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *GatewayService) streamKeepaliveIntervalForAccount(account *Account) time.Duration {
|
||||
if account != nil && account.Platform == PlatformKiro {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.KiroStreamKeepaliveInterval > 0 {
|
||||
return time.Duration(s.cfg.Gateway.KiroStreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
return defaultKiroStreamKeepalive
|
||||
}
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
return time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
|
||||
|
||||
// initDebugGatewayBodyFile 初始化网关调试日志文件。
|
||||
|
||||
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
|
||||
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
|
||||
// Anthropic API Key keeps the existing account-level override:
|
||||
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Kiro OAuth uses channel-level switch only.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
@@ -62,13 +64,29 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
return false
|
||||
}
|
||||
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
return false
|
||||
default: // "default" → follow channel config
|
||||
default:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
}
|
||||
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
@@ -76,8 +94,7 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(account.Platform)
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(platform)
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||
"message": map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
return flushSSEJSON(w, "message_start", evt)
|
||||
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
"name": toolNameWebSearch, "input": map[string]any{},
|
||||
},
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
|
||||
return err
|
||||
}
|
||||
inputJSON, err := json.Marshal(map[string]string{"query": query})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal query: %w", err)
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||
blocks := make([]map[string]string, 0, len(results))
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
|
||||
blocks := make([]map[string]any, 0, len(results))
|
||||
for _, r := range results {
|
||||
block := map[string]string{
|
||||
block := map[string]any{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
}
|
||||
if r.Snippet != "" {
|
||||
block["page_content"] = r.Snippet
|
||||
"encrypted_content": r.Snippet,
|
||||
"page_age": nil,
|
||||
}
|
||||
if r.PageAge != "" {
|
||||
block["page_age"] = r.PageAge
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +14,31 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"cache_creation_input_tokens":0`)
|
||||
require.Contains(t, body, `"cache_read_input_tokens":0`)
|
||||
}
|
||||
|
||||
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `event: content_block_start`)
|
||||
require.Contains(t, body, `"type":"server_tool_use"`)
|
||||
require.Contains(t, body, `"input":{}`)
|
||||
require.Contains(t, body, `event: content_block_delta`)
|
||||
require.Contains(t, body, `"type":"input_json_delta"`)
|
||||
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, body, `event: content_block_stop`)
|
||||
}
|
||||
|
||||
// --- isOnlyWebSearchToolInBody ---
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
|
||||
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||
// Second result has no PageAge
|
||||
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||
_, hasPageAge := blocks[1]["page_age"]
|
||||
require.False(t, hasPageAge)
|
||||
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
|
||||
require.Nil(t, blocks[1]["page_age"])
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
|
||||
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||
_, hasContent := blocks[0]["page_content"]
|
||||
require.False(t, hasContent)
|
||||
require.Equal(t, "", blocks[0]["encrypted_content"])
|
||||
require.Nil(t, blocks[0]["page_age"])
|
||||
}
|
||||
|
||||
// --- buildTextSummary ---
|
||||
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
func newKiroOAuthAccount() *Account {
|
||||
return &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
|
||||
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
|
||||
// nil channelService + default mode → returns false
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 11,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(77, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(77)
|
||||
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 12,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(78, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(78)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,623 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type kiroUsageCooldownStore struct {
|
||||
state *kirocooldown.State
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||
return s.state, s.err
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func kiroFloatPtr(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
|
||||
},
|
||||
}
|
||||
|
||||
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_kiro_billing_normalized",
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
|
||||
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
|
||||
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
|
||||
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
|
||||
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "*/*", r.Header.Get("Accept"))
|
||||
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
|
||||
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
|
||||
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
|
||||
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageConfiguration": {"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentOveragesWithPrecision":2,
|
||||
"currentUsageWithPrecision":125,
|
||||
"freeTrialInfo":{
|
||||
"currentUsageWithPrecision":25,
|
||||
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
|
||||
"freeTrialStatus":"ACTIVE",
|
||||
"usageLimitWithPrecision":500
|
||||
},
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageCharges":0.08,
|
||||
"resourceType":"CREDIT",
|
||||
"usageLimitWithPrecision":2000
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, "active", usage.Source)
|
||||
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
|
||||
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
|
||||
require.True(t, usage.KiroOveragesEnabled)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
|
||||
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
|
||||
require.NotNil(t, usage.KiroBonus)
|
||||
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
|
||||
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
|
||||
require.NotNil(t, usage.KiroOverage)
|
||||
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
|
||||
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
|
||||
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
|
||||
require.NotNil(t, usage.KiroResetAt)
|
||||
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
|
||||
require.NotNil(t, usage.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 702,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":300,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer successServer.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.NotNil(t, firstUsage.KiroCredit)
|
||||
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
|
||||
|
||||
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
|
||||
}))
|
||||
defer failingServer.Close()
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Empty(t, usage.Error)
|
||||
require.Empty(t, usage.ErrorCode)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 703,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "BuilderId",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 707,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":64,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 709,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"api_region": "eu-west-1",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":11,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{"eu-west-1"}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 710,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":7,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 704,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(90 * time.Second),
|
||||
Remaining: 90 * time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cooldown", usage.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
|
||||
require.NotNil(t, usage.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: `{"message":"profileArn is required for this request."}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeForbidden, info.ErrorCode)
|
||||
require.False(t, info.NeedsReauth)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: `{"message":"overage exhausted for this billing window"}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
|
||||
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
|
||||
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 708,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
requestCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
|
||||
|
||||
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, secondUsage)
|
||||
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
|
||||
require.Equal(t, 1, requestCount)
|
||||
}
|
||||
|
||||
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
|
||||
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
|
||||
NextDateReset: "2099-03-13T12:00:00Z",
|
||||
OverageConfiguration: kiroOverageConfiguration{
|
||||
OverageStatus: "DISABLED",
|
||||
},
|
||||
UsageBreakdownList: []kiroUsageBreakdown{
|
||||
{
|
||||
ResourceType: "CREDIT",
|
||||
CurrentUsageWithPrecision: kiroFloatPtr(2000),
|
||||
UsageLimitWithPrecision: kiroFloatPtr(2000),
|
||||
CurrentOveragesWithPrecision: kiroFloatPtr(0),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
|
||||
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
|
||||
require.NotNil(t, info.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
|
||||
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(2 * time.Minute),
|
||||
Remaining: 2 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
account := &Account{
|
||||
ID: 705,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
|
||||
require.Equal(t, "cooldown", account.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
|
||||
require.NotNil(t, account.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 706,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset":"2099-03-13T12:00:00Z",
|
||||
"overageConfiguration":{"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":2000,
|
||||
"currentOveragesWithPrecision":4,
|
||||
"overageCharges":0.2,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
_, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
target := &Account{
|
||||
ID: account.ID,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
|
||||
|
||||
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
|
||||
require.NotNil(t, target.KiroQuotaResetAt)
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
|
||||
account := Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformKiro,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
require.Empty(t, account.GetBaseURL())
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
|
||||
|
||||
require.Equal(t, 25*time.Second, got)
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamKeepaliveInterval: 10,
|
||||
KiroStreamKeepaliveInterval: 25,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
|
||||
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, nil)
|
||||
|
||||
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
upstreamModel string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "kiro claude sonnet 4.6 uses pricing key format",
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
upstreamModel: "claude-sonnet-4.6",
|
||||
want: "claude-sonnet-4-6",
|
||||
},
|
||||
{
|
||||
name: "falls back to upstream when requested model empty",
|
||||
requestedModel: "",
|
||||
upstreamModel: "claude-haiku-4-5",
|
||||
want: "claude-haiku-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_billing_normalized",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_auto_fallback_cost",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "auto",
|
||||
UpstreamModel: "auto",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 601, Quota: 100},
|
||||
User: &User{ID: 701},
|
||||
Account: &Account{ID: 801, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
kiroErrorAuthError = "auth_error"
|
||||
kiroErrorMonthlyRequest = "monthly_request_count"
|
||||
kiroErrorProfileError = "profile_error"
|
||||
kiroErrorQuotaExhausted = "quota_exhausted"
|
||||
kiroErrorOverageExhausted = "overage_exhausted"
|
||||
kiroErrorRateLimited = "rate_limited"
|
||||
kiroErrorSuspended = "suspended"
|
||||
kiroErrorUsageForbidden = "usage_forbidden"
|
||||
kiroErrorUpstreamTransient = "upstream_transient"
|
||||
kiroErrorBadRequestSchema = "bad_request_schema"
|
||||
kiroErrorBadRequestToolPairing = "bad_request_tool_pairing"
|
||||
kiroErrorBadRequestInvalidModel = "bad_request_invalid_model"
|
||||
kiroErrorBadRequestAuth = "bad_request_auth"
|
||||
kiroErrorBadRequestQuota = "bad_request_quota"
|
||||
kiroErrorBadRequestUnknown = "bad_request_unknown"
|
||||
kiroErrorRefreshTokenInvalid = "refresh_token_invalid"
|
||||
|
||||
kiroQuotaStateNormal = "normal"
|
||||
kiroQuotaStateOverageActive = "overage_active"
|
||||
kiroQuotaStateCreditsExhausted = "credits_exhausted"
|
||||
kiroQuotaStateOverageExhausted = "overage_exhausted"
|
||||
)
|
||||
|
||||
type kiroErrorClassification struct {
|
||||
Category string
|
||||
StatusCode int
|
||||
Message string
|
||||
}
|
||||
|
||||
func classifyKiroHTTPError(statusCode int, body string) kiroErrorClassification {
|
||||
trimmed := strings.TrimSpace(body)
|
||||
lower := strings.ToLower(trimmed)
|
||||
|
||||
switch {
|
||||
case statusCode == http.StatusUnauthorized:
|
||||
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusPaymentRequired && looksLikeKiroMonthlyRequestCountError(trimmed):
|
||||
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusForbidden && isKiroSuspendedBody([]byte(trimmed)):
|
||||
return kiroErrorClassification{Category: kiroErrorSuspended, StatusCode: statusCode, Message: trimmed}
|
||||
case looksLikeKiroProfileError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorProfileError, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusBadRequest:
|
||||
return classifyKiroBadRequest(trimmed, lower)
|
||||
case statusCode == http.StatusForbidden && isKiroTokenErrorBody([]byte(trimmed)):
|
||||
return kiroErrorClassification{Category: kiroErrorAuthError, StatusCode: statusCode, Message: trimmed}
|
||||
case looksLikeKiroOverageExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorOverageExhausted, StatusCode: statusCode, Message: trimmed}
|
||||
case looksLikeKiroQuotaExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusTooManyRequests:
|
||||
return kiroErrorClassification{Category: kiroErrorRateLimited, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode == http.StatusForbidden:
|
||||
return kiroErrorClassification{Category: kiroErrorUsageForbidden, StatusCode: statusCode, Message: trimmed}
|
||||
case statusCode >= 500:
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
|
||||
default:
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, StatusCode: statusCode, Message: trimmed}
|
||||
}
|
||||
}
|
||||
|
||||
func classifyKiroError(err error) kiroErrorClassification {
|
||||
if err == nil {
|
||||
return kiroErrorClassification{}
|
||||
}
|
||||
|
||||
var httpErr *kiroUsageHTTPError
|
||||
if errors.As(err, &httpErr) && httpErr != nil {
|
||||
return classifyKiroHTTPError(httpErr.StatusCode, httpErr.Body)
|
||||
}
|
||||
|
||||
errStr := strings.TrimSpace(err.Error())
|
||||
lower := strings.ToLower(errStr)
|
||||
switch {
|
||||
case looksLikeKiroInvalidGrantError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorRefreshTokenInvalid, Message: errStr}
|
||||
case looksLikeKiroMonthlyRequestCountError(errStr):
|
||||
return kiroErrorClassification{Category: kiroErrorMonthlyRequest, Message: errStr}
|
||||
case looksLikeKiroProfileError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorProfileError, Message: errStr}
|
||||
case looksLikeKiroOverageExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorOverageExhausted, Message: errStr}
|
||||
case looksLikeKiroQuotaExhaustedError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorQuotaExhausted, Message: errStr}
|
||||
case strings.Contains(lower, "context deadline exceeded"),
|
||||
strings.Contains(lower, "timeout"),
|
||||
isNetErr(err):
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
|
||||
default:
|
||||
return kiroErrorClassification{Category: kiroErrorUpstreamTransient, Message: errStr}
|
||||
}
|
||||
}
|
||||
|
||||
func classifyKiroBadRequest(trimmed, lower string) kiroErrorClassification {
|
||||
switch {
|
||||
case looksLikeKiroBadRequestSchemaError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestSchema, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroBadRequestToolPairingError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestToolPairing, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroBadRequestInvalidModelError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestInvalidModel, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroInvalidGrantError(lower) || looksLikeKiroBadRequestAuthError(lower):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestAuth, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
case looksLikeKiroQuotaExhaustedError(lower) || looksLikeKiroMonthlyRequestCountError(trimmed):
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestQuota, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
default:
|
||||
return kiroErrorClassification{Category: kiroErrorBadRequestUnknown, StatusCode: http.StatusBadRequest, Message: trimmed}
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestSchemaError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "schema") ||
|
||||
strings.Contains(lower, "inputschema") ||
|
||||
strings.Contains(lower, "improperly formed request") ||
|
||||
strings.Contains(lower, "additionalproperties") ||
|
||||
(strings.Contains(lower, "properties") && strings.Contains(lower, "required"))
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestToolPairingError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "tool_use") ||
|
||||
strings.Contains(lower, "tool_result") ||
|
||||
strings.Contains(lower, "tooluseid") ||
|
||||
strings.Contains(lower, "toolresults") ||
|
||||
strings.Contains(lower, "must be paired") ||
|
||||
strings.Contains(lower, "missing tool result")
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestInvalidModelError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "invalid model") ||
|
||||
strings.Contains(lower, "invalid_model_id") ||
|
||||
strings.Contains(lower, "model not supported") ||
|
||||
strings.Contains(lower, "unsupportedmodel") ||
|
||||
strings.Contains(lower, "modelid")
|
||||
}
|
||||
|
||||
func looksLikeKiroBadRequestAuthError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "invalid token") ||
|
||||
strings.Contains(lower, "expired token") ||
|
||||
strings.Contains(lower, "access token") ||
|
||||
strings.Contains(lower, "refresh token")
|
||||
}
|
||||
|
||||
func looksLikeKiroInvalidGrantError(lower string) bool {
|
||||
return strings.Contains(lower, "invalid_grant")
|
||||
}
|
||||
|
||||
func looksLikeKiroMonthlyRequestCountError(body string) bool {
|
||||
trimmed := strings.TrimSpace(body)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, "MONTHLY_REQUEST_COUNT") {
|
||||
return true
|
||||
}
|
||||
if !gjson.Valid(trimmed) {
|
||||
return false
|
||||
}
|
||||
return gjson.Get(trimmed, "reason").String() == "MONTHLY_REQUEST_COUNT" ||
|
||||
gjson.Get(trimmed, "error.reason").String() == "MONTHLY_REQUEST_COUNT"
|
||||
}
|
||||
|
||||
func looksLikeKiroProfileError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return (strings.Contains(lower, "profilearn") && strings.Contains(lower, "required")) ||
|
||||
(strings.Contains(lower, "profile arn") && strings.Contains(lower, "required")) ||
|
||||
(strings.Contains(lower, "profile") && strings.Contains(lower, "not found")) ||
|
||||
(strings.Contains(lower, "invalid profile")) ||
|
||||
(strings.Contains(lower, "listavailableprofiles"))
|
||||
}
|
||||
|
||||
func looksLikeKiroQuotaExhaustedError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return (strings.Contains(lower, "credit") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "depleted"))) ||
|
||||
(strings.Contains(lower, "quota") && (strings.Contains(lower, "exhaust") || strings.Contains(lower, "exceeded") || strings.Contains(lower, "depleted"))) ||
|
||||
(strings.Contains(lower, "usage limit") && (strings.Contains(lower, "reached") || strings.Contains(lower, "exceeded"))) ||
|
||||
(strings.Contains(lower, "resource has been exhausted"))
|
||||
}
|
||||
|
||||
func looksLikeKiroOverageExhaustedError(lower string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "overage") &&
|
||||
(strings.Contains(lower, "exhaust") ||
|
||||
strings.Contains(lower, "disabled") ||
|
||||
strings.Contains(lower, "not enabled") ||
|
||||
strings.Contains(lower, "not allowed") ||
|
||||
strings.Contains(lower, "limit"))
|
||||
}
|
||||
|
||||
func isNetErr(err error) bool {
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyKiroHTTPErrorBadRequestCategories(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "schema",
|
||||
body: `{"message":"Improperly formed request: inputSchema.properties must be an object"}`,
|
||||
want: kiroErrorBadRequestSchema,
|
||||
},
|
||||
{
|
||||
name: "tool pairing",
|
||||
body: `{"message":"tool_use must be paired with a matching tool_result"}`,
|
||||
want: kiroErrorBadRequestToolPairing,
|
||||
},
|
||||
{
|
||||
name: "invalid model id",
|
||||
body: `{"message":"invalid modelId: model not supported"}`,
|
||||
want: kiroErrorBadRequestInvalidModel,
|
||||
},
|
||||
{
|
||||
name: "invalid model upstream",
|
||||
body: `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
|
||||
want: kiroErrorBadRequestInvalidModel,
|
||||
},
|
||||
{
|
||||
name: "invalid model reason",
|
||||
body: `{"message":"model route unavailable","reason":"INVALID_MODEL_ID"}`,
|
||||
want: kiroErrorBadRequestInvalidModel,
|
||||
},
|
||||
{
|
||||
name: "auth",
|
||||
body: `{"error":"invalid_grant","message":"Invalid refresh token provided"}`,
|
||||
want: kiroErrorBadRequestAuth,
|
||||
},
|
||||
{
|
||||
name: "quota",
|
||||
body: `{"message":"resource has been exhausted"}`,
|
||||
want: kiroErrorBadRequestQuota,
|
||||
},
|
||||
{
|
||||
name: "unknown",
|
||||
body: `{"message":"bad request"}`,
|
||||
want: kiroErrorBadRequestUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
classification := classifyKiroHTTPError(http.StatusBadRequest, tt.body)
|
||||
require.Equal(t, tt.want, classification.Category)
|
||||
require.Equal(t, http.StatusBadRequest, classification.StatusCode)
|
||||
require.Equal(t, tt.body, classification.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func buildKiroAccountKey(account *Account) string {
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
return kiropkg.BuildAccountKey(
|
||||
account.GetCredential("client_id"),
|
||||
account.GetCredential("client_id_hash"),
|
||||
account.GetCredential("refresh_token"),
|
||||
account.GetCredential("profile_arn"),
|
||||
account.ID,
|
||||
)
|
||||
}
|
||||
|
||||
func buildKiroMachineID(account *Account) string {
|
||||
if account == nil {
|
||||
return kiropkg.BuildMachineID("", "", "account:nil")
|
||||
}
|
||||
for _, key := range []string{"machine_id", "machineId"} {
|
||||
if machineID, ok := kiropkg.NormalizeMachineID(account.GetCredential(key)); ok {
|
||||
return machineID
|
||||
}
|
||||
}
|
||||
fallbackKey := buildKiroMachineIDFallbackKey(account)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
return kiropkg.BuildMachineID("", firstKiroCredential(account, "kiro_api_key", "kiroApiKey", "api_key"), fallbackKey)
|
||||
}
|
||||
return kiropkg.BuildMachineID(account.GetCredential("refresh_token"), "", fallbackKey)
|
||||
}
|
||||
|
||||
func firstKiroCredential(account *Account, keys ...string) string {
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
for _, key := range keys {
|
||||
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildKiroMachineIDFallbackKey(account *Account) string {
|
||||
if account == nil {
|
||||
return "account:nil"
|
||||
}
|
||||
if account.ID > 0 {
|
||||
return fmt.Sprintf("account:%d", account.ID)
|
||||
}
|
||||
for _, key := range []string{"client_id", "profile_arn"} {
|
||||
if value := strings.TrimSpace(account.GetCredential(key)); value != "" {
|
||||
return key + ":" + value
|
||||
}
|
||||
}
|
||||
if name := strings.TrimSpace(account.Name); name != "" {
|
||||
return "name:" + name
|
||||
}
|
||||
return "account:unknown"
|
||||
}
|
||||
|
||||
func buildKiroRequestID(resp *http.Response) string {
|
||||
if resp == nil {
|
||||
return ""
|
||||
}
|
||||
if requestID := strings.TrimSpace(resp.Header.Get("x-request-id")); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if requestID := strings.TrimSpace(resp.Header.Get("x-amzn-requestid")); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
return strings.TrimSpace(resp.Header.Get("x-amz-request-id"))
|
||||
}
|
||||
|
||||
func isKiroInvalidModelIDBody(respBody []byte) bool {
|
||||
var payload struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
Error struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(respBody, &payload) != nil {
|
||||
return looksLikeKiroBadRequestInvalidModelError(strings.ToLower(string(respBody)))
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(payload.Reason), "INVALID_MODEL_ID") ||
|
||||
strings.EqualFold(strings.TrimSpace(payload.Error.Reason), "INVALID_MODEL_ID") ||
|
||||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Message)) ||
|
||||
looksLikeKiroBadRequestInvalidModelError(strings.ToLower(payload.Error.Message))
|
||||
}
|
||||
|
||||
func isKiroSuspendedBody(respBody []byte) bool {
|
||||
body := string(respBody)
|
||||
return strings.Contains(body, "SUSPENDED") || strings.Contains(body, "TEMPORARILY_SUSPENDED")
|
||||
}
|
||||
|
||||
func isKiroTokenErrorBody(respBody []byte) bool {
|
||||
lower := strings.ToLower(string(respBody))
|
||||
return strings.Contains(lower, "token") ||
|
||||
strings.Contains(lower, "expired") ||
|
||||
strings.Contains(lower, "invalid") ||
|
||||
strings.Contains(lower, "unauthorized")
|
||||
}
|
||||
|
||||
func kiroProxyURL(account *Account) string {
|
||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func kiroAPIRegion(account *Account) string {
|
||||
if account == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
region := strings.TrimSpace(account.GetCredential("api_region"))
|
||||
if region == "" {
|
||||
region = kiroDefaultRegion
|
||||
}
|
||||
return region
|
||||
}
|
||||
|
||||
func applyKiroConditionalHeaders(req *http.Request, account *Account) {
|
||||
if req == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(account.GetCredential("auth_method")), "external_idp") {
|
||||
req.Header.Set("TokenType", "EXTERNAL_IDP")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(account.GetCredential("provider")), "Internal") {
|
||||
req.Header.Set("redirect-for-internal", "true")
|
||||
}
|
||||
}
|
||||
|
||||
func resolveKiroPayloadProfileArn(account *Account) string {
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||
}
|
||||
|
||||
func newKiroJSONRequest(ctx context.Context, endpointURL string, payload []byte, token, accountKey, machineID, amzTarget string, account *Account) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||
if amzTarget != "" {
|
||||
req.Header.Set("X-Amz-Target", amzTarget)
|
||||
}
|
||||
if account != nil {
|
||||
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||
if profileArn != "" {
|
||||
req.Header.Set("x-amzn-kiro-profile-arn", profileArn)
|
||||
}
|
||||
}
|
||||
applyKiroConditionalHeaders(req, account)
|
||||
return req, nil
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
|
||||
accountA := &Account{
|
||||
ID: 99,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-a",
|
||||
},
|
||||
}
|
||||
accountB := &Account{
|
||||
ID: 99,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-b",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, buildKiroAccountKey(accountA), buildKiroAccountKey(accountB))
|
||||
}
|
||||
|
||||
func TestBuildKiroMachineIDPrefersExplicitCredential(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"machineId": "2582956e-cc88-4669-b546-07adbffcb894",
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", buildKiroMachineID(account))
|
||||
}
|
||||
|
||||
func TestBuildKiroMachineIDDerivesFromRefreshToken(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiropkg.BuildMachineID("refresh-token", "", "account:102"), buildKiroMachineID(account))
|
||||
}
|
||||
|
||||
func TestBuildKiroMachineIDDerivesFromAPIKeyAccount(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"kiroApiKey": "kiro-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiropkg.BuildMachineID("", "kiro-api-key", "account:103"), buildKiroMachineID(account))
|
||||
}
|
||||
|
||||
func TestNewKiroJSONRequestAddsConditionalHeaders(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"auth_method": "external_idp",
|
||||
"provider": "Internal",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER",
|
||||
},
|
||||
}
|
||||
|
||||
req, err := newKiroJSONRequest(
|
||||
context.Background(),
|
||||
"https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
[]byte(`{"ok":true}`),
|
||||
"access-token",
|
||||
"account-key",
|
||||
buildKiroMachineID(account),
|
||||
"",
|
||||
account,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "EXTERNAL_IDP", req.Header.Get("TokenType"))
|
||||
require.Equal(t, "true", req.Header.Get("redirect-for-internal"))
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/HEADER", req.Header.Get("x-amzn-kiro-profile-arn"))
|
||||
require.Equal(t, "vibe", req.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Equal(t, "true", req.Header.Get("x-amzn-codewhisperer-optout"))
|
||||
require.Contains(t, req.Header.Get("User-Agent"), "aws-sdk-js/1.0.34")
|
||||
require.Contains(t, req.Header.Get("User-Agent"), "md/nodejs#22.22.0")
|
||||
require.Contains(t, req.Header.Get("User-Agent"), buildKiroMachineID(account))
|
||||
require.Contains(t, req.Header.Get("X-Amz-User-Agent"), buildKiroMachineID(account))
|
||||
require.True(t, strings.Contains(req.Header.Get("User-Agent"), "api/codewhispererstreaming#1.0.34"))
|
||||
require.Empty(t, req.Header.Get("Anthropic-Beta"))
|
||||
}
|
||||
|
||||
func TestIsKiroInvalidModelIDBodyRecognizesKnownForms(t *testing.T) {
|
||||
tests := []string{
|
||||
`{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`,
|
||||
`{"message":"Invalid model. Please select a different model to continue."}`,
|
||||
`API Error: 400 {"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`,
|
||||
}
|
||||
|
||||
for _, body := range tests {
|
||||
require.True(t, isKiroInvalidModelIDBody([]byte(body)), body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/test",
|
||||
},
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-sonnet-4-6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
headers := http.Header{}
|
||||
headers.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-sonnet-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-sonnet-4-6",
|
||||
headers,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(payload), "CHUNKED WRITE PROTOCOL")
|
||||
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6-thinking",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
|
||||
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.NotContains(t, systemContent, "<thinking_mode>")
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"api_region": "eu-west-1",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
|
||||
"region": "ap-northeast-1",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, "eu-west-1", kiroAPIRegion(account))
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionIgnoresProfileARNRegionFallback(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:us-west-2:123456789012:profile/test",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionIgnoresOIDCRegionFallback(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"region": "ap-northeast-2",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, kiroDefaultRegion, kiroAPIRegion(account))
|
||||
}
|
||||
|
||||
func TestBuildKiroEndpointsUsesOnlyAmazonQEndpoint(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"api_region": "us-west-2",
|
||||
"preferred_endpoint": "cw",
|
||||
},
|
||||
}
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
require.Len(t, endpoints, 1)
|
||||
require.Equal(t, "AmazonQ", endpoints[0].Name)
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
|
||||
require.Empty(t, endpoints[0].AmzTarget)
|
||||
}
|
||||
|
||||
func TestBuildKiroEndpointsIgnoresPreferredEndpoint(t *testing.T) {
|
||||
for _, preferred := range []string{"codewhisperer", "cw", "unknown"} {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"api_region": "us-west-2",
|
||||
"preferred_endpoint": preferred,
|
||||
},
|
||||
}
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
require.Len(t, endpoints, 1)
|
||||
require.Equal(t, "AmazonQ", endpoints[0].Name)
|
||||
require.Equal(t, "q.us-west-2.amazonaws.com/generateAssistantResponse", endpoints[0].URL[8:])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountKiroDefaultMappingRestrictsUnsupportedModels(t *testing.T) {
|
||||
account := &Account{Platform: PlatformKiro}
|
||||
|
||||
require.False(t, account.IsModelSupported("gpt-4o"))
|
||||
require.False(t, account.IsModelSupported("kiro-gpt-4o"))
|
||||
require.False(t, account.IsModelSupported("auto"))
|
||||
require.Equal(t, "claude-sonnet-4.6", account.GetMappedModel("claude-sonnet-4-6"))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCalculateTokenCost_KiroAutoUsesConservativeFallback(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
|
||||
svc := NewGatewayService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
result := &ForwardResult{
|
||||
Model: "auto",
|
||||
UpstreamModel: "auto",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
}
|
||||
|
||||
expected, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
cost := svc.calculateTokenCost(context.Background(), result, &APIKey{}, "auto", 1.1, &recordUsageOpts{IsKiroAccount: true})
|
||||
require.NotNil(t, cost)
|
||||
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-12)
|
||||
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-12)
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
)
|
||||
|
||||
const (
|
||||
// Kiro desktop social auth uses localhost loopback callbacks from a fixed
|
||||
// allowlist. Use one of the bundled ports from the official client.
|
||||
kiroSocialRedirectURI = "http://localhost:49153"
|
||||
// AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
|
||||
kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
|
||||
)
|
||||
|
||||
type KiroOAuthService struct {
|
||||
sessionStore *kiropkg.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
|
||||
return &KiroOAuthService{
|
||||
sessionStore: kiropkg.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) Stop() {}
|
||||
|
||||
type KiroAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
type KiroIDCAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
ClientID string `json:"client_id"`
|
||||
Region string `json:"region"`
|
||||
StartURL string `json:"start_url"`
|
||||
}
|
||||
|
||||
type KiroTokenInfo struct {
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ProfileArn string `json:"profile_arn,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDHash string `json:"client_id_hash,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
StartURL string `json:"start_url,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
type KiroGenerateAuthURLInput struct {
|
||||
ProxyID *int64
|
||||
Provider string
|
||||
}
|
||||
|
||||
type KiroExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
CallbackPath string
|
||||
LoginOption string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
type KiroGenerateIDCAuthURLInput struct {
|
||||
ProxyID *int64
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type KiroRefreshTokenInput struct {
|
||||
RefreshToken string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
StartURL string
|
||||
Region string
|
||||
ProfileArn string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
type KiroImportTokenInput struct {
|
||||
TokenJSON string
|
||||
DeviceRegistrationJSON string
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
|
||||
provider := strings.TrimSpace(input.Provider)
|
||||
if provider == "" {
|
||||
provider = string(kiropkg.SocialProviderGoogle)
|
||||
}
|
||||
if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
|
||||
return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
|
||||
}
|
||||
state, err := kiropkg.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate state failed: %w", err)
|
||||
}
|
||||
codeVerifier, err := kiropkg.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate code verifier failed: %w", err)
|
||||
}
|
||||
sessionID := kiropkg.GenerateSessionID()
|
||||
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
AuthType: "social",
|
||||
Provider: provider,
|
||||
RedirectURI: kiroSocialRedirectURI,
|
||||
})
|
||||
return &KiroAuthURLResult{
|
||||
AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
return nil, fmt.Errorf("state invalid")
|
||||
}
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
|
||||
}
|
||||
|
||||
switch session.AuthType {
|
||||
case "social":
|
||||
token, err := kiropkg.CreateSocialToken(
|
||||
ctx,
|
||||
proxyURL,
|
||||
input.Code,
|
||||
session.CodeVerifier,
|
||||
buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.Provider = session.Provider
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
return toKiroTokenInfo(token), nil
|
||||
case "idc":
|
||||
token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
return toKiroTokenInfo(token), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
|
||||
}
|
||||
}
|
||||
|
||||
func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
|
||||
option := strings.ToLower(strings.TrimSpace(loginOption))
|
||||
if option == "" {
|
||||
switch provider {
|
||||
case string(kiropkg.SocialProviderGitHub):
|
||||
option = "github"
|
||||
case string(kiropkg.SocialProviderGoogle):
|
||||
option = "google"
|
||||
}
|
||||
}
|
||||
return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
|
||||
startURL := strings.TrimSpace(input.StartURL)
|
||||
if startURL == "" {
|
||||
startURL = kiropkg.BuilderIDStartURL
|
||||
}
|
||||
region := strings.TrimSpace(input.Region)
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
state, err := kiropkg.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate state failed: %w", err)
|
||||
}
|
||||
codeVerifier, err := kiropkg.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate code verifier failed: %w", err)
|
||||
}
|
||||
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||
reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID := kiropkg.GenerateSessionID()
|
||||
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
AuthType: "idc",
|
||||
RedirectURI: kiroIDCRedirectURI,
|
||||
ClientID: reg.ClientID,
|
||||
ClientSecret: reg.ClientSecret,
|
||||
Region: region,
|
||||
StartURL: startURL,
|
||||
})
|
||||
return &KiroIDCAuthURLResult{
|
||||
AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
ClientID: reg.ClientID,
|
||||
Region: region,
|
||||
StartURL: startURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
|
||||
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
|
||||
authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
|
||||
if authMethod == "" {
|
||||
authMethod = "social"
|
||||
}
|
||||
|
||||
var token *kiropkg.TokenData
|
||||
var err error
|
||||
switch authMethod {
|
||||
case "idc":
|
||||
token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
|
||||
default:
|
||||
token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.ProfileArn == "" {
|
||||
token.ProfileArn = input.ProfileArn
|
||||
}
|
||||
if token.ClientID == "" {
|
||||
token.ClientID = input.ClientID
|
||||
}
|
||||
if token.ClientSecret == "" {
|
||||
token.ClientSecret = input.ClientSecret
|
||||
}
|
||||
if token.StartURL == "" {
|
||||
token.StartURL = input.StartURL
|
||||
}
|
||||
if token.Region == "" {
|
||||
token.Region = input.Region
|
||||
}
|
||||
return toKiroTokenInfo(token), nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("not a kiro oauth account")
|
||||
}
|
||||
return s.RefreshToken(ctx, &KiroRefreshTokenInput{
|
||||
RefreshToken: account.GetCredential("refresh_token"),
|
||||
AuthMethod: account.GetCredential("auth_method"),
|
||||
Provider: account.GetCredential("provider"),
|
||||
ClientID: account.GetCredential("client_id"),
|
||||
ClientSecret: account.GetCredential("client_secret"),
|
||||
StartURL: account.GetCredential("start_url"),
|
||||
Region: account.GetCredential("region"),
|
||||
ProfileArn: account.GetCredential("profile_arn"),
|
||||
ProxyID: account.ProxyID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
|
||||
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toKiroTokenInfo(token), nil
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
||||
if tokenInfo == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
creds := map[string]any{}
|
||||
if tokenInfo.AccessToken != "" {
|
||||
creds["access_token"] = tokenInfo.AccessToken
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.ProfileArn != "" {
|
||||
creds["profile_arn"] = tokenInfo.ProfileArn
|
||||
}
|
||||
if tokenInfo.ExpiresAt != "" {
|
||||
creds["expires_at"] = tokenInfo.ExpiresAt
|
||||
}
|
||||
if tokenInfo.AuthMethod != "" {
|
||||
creds["auth_method"] = tokenInfo.AuthMethod
|
||||
}
|
||||
if tokenInfo.Provider != "" {
|
||||
creds["provider"] = tokenInfo.Provider
|
||||
}
|
||||
if tokenInfo.ClientID != "" {
|
||||
creds["client_id"] = tokenInfo.ClientID
|
||||
}
|
||||
if tokenInfo.ClientSecret != "" {
|
||||
creds["client_secret"] = tokenInfo.ClientSecret
|
||||
}
|
||||
if tokenInfo.ClientIDHash != "" {
|
||||
creds["client_id_hash"] = tokenInfo.ClientIDHash
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.StartURL != "" {
|
||||
creds["start_url"] = tokenInfo.StartURL
|
||||
}
|
||||
if tokenInfo.Region != "" {
|
||||
creds["region"] = tokenInfo.Region
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
|
||||
func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
|
||||
if token == nil {
|
||||
return nil
|
||||
}
|
||||
return &KiroTokenInfo{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ProfileArn: token.ProfileArn,
|
||||
ExpiresAt: token.ExpiresAt,
|
||||
AuthMethod: token.AuthMethod,
|
||||
Provider: token.Provider,
|
||||
ClientID: token.ClientID,
|
||||
ClientSecret: token.ClientSecret,
|
||||
ClientIDHash: token.ClientIDHash,
|
||||
Email: token.Email,
|
||||
StartURL: token.StartURL,
|
||||
Region: token.Region,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||
if proxyID == nil || s.proxyRepo == nil {
|
||||
return "", nil
|
||||
}
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err != nil || proxy == nil {
|
||||
return "", err
|
||||
}
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKiroIDCAuthRedirectURIUsesLoopbackIP(t *testing.T) {
|
||||
require.Equal(t, "http://127.0.0.1:9876/oauth/callback", kiroIDCRedirectURI)
|
||||
}
|
||||
|
||||
func TestKiroSocialAuthRedirectURIUsesLoopbackIP(t *testing.T) {
|
||||
require.Equal(t, "http://localhost:49153", kiroSocialRedirectURI)
|
||||
}
|
||||
|
||||
func TestBuildKiroSocialExchangeRedirectURIUsesProviderDefault(t *testing.T) {
|
||||
require.Equal(
|
||||
t,
|
||||
"http://localhost:49153/oauth/callback?login_option=github",
|
||||
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "", ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestBuildKiroSocialExchangeRedirectURIPreservesParsedCallbackData(t *testing.T) {
|
||||
require.Equal(
|
||||
t,
|
||||
"http://localhost:49153/signin/callback?login_option=google",
|
||||
buildKiroSocialExchangeRedirectURI("http://localhost:49153", "Github", "/signin/callback", "google"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestKiroOAuthService_ExchangeCodeRejectsExpiredSession(t *testing.T) {
|
||||
svc := NewKiroOAuthService(nil)
|
||||
svc.sessionStore.Set("expired-session", &kiropkg.AuthSession{
|
||||
State: "expected-state",
|
||||
CreatedAt: time.Now().Add(-11 * time.Minute),
|
||||
})
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &KiroExchangeCodeInput{
|
||||
SessionID: "expired-session",
|
||||
State: "expected-state",
|
||||
Code: "auth-code",
|
||||
})
|
||||
require.EqualError(t, err, "session not found or expired")
|
||||
}
|
||||
@@ -0,0 +1,740 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type kiroEndpointConfig struct {
|
||||
URL string
|
||||
AmzTarget string
|
||||
Name string
|
||||
}
|
||||
|
||||
const kiroInvalidModelTempUnschedDuration = time.Minute
|
||||
|
||||
const (
|
||||
kiroRetryBaseDelay = 200 * time.Millisecond
|
||||
kiroRetryMaxDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
var kiroRetrySleep = sleepWithContext
|
||||
|
||||
func kiroRetryBackoffDelay(attempt int) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := kiroRetryBaseDelay * time.Duration(1<<attempt)
|
||||
if delay > kiroRetryMaxDelay {
|
||||
delay = kiroRetryMaxDelay
|
||||
}
|
||||
jitterMax := delay / 4
|
||||
if jitterMax <= 0 {
|
||||
return delay
|
||||
}
|
||||
return delay + time.Duration(mathrand.Int63n(int64(jitterMax)+1))
|
||||
}
|
||||
|
||||
func sleepKiroRetry(ctx context.Context, attempt int) error {
|
||||
return kiroRetrySleep(ctx, kiroRetryBackoffDelay(attempt))
|
||||
}
|
||||
|
||||
func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) {
|
||||
if account == nil || parsed == nil {
|
||||
return nil, fmt.Errorf("kiro forward: missing account or request")
|
||||
}
|
||||
|
||||
originalModel := parsed.Model
|
||||
mappedModel := originalModel
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
body := parsed.Body
|
||||
if mappedModel != originalModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
}
|
||||
logger.L().Debug("gateway forward_kiro_messages: request prepared",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("auth_method", strings.TrimSpace(account.GetCredential("auth_method"))),
|
||||
zap.String("requested_model", originalModel),
|
||||
zap.String("mapped_model", mappedModel),
|
||||
zap.Bool("has_profile_arn", strings.TrimSpace(account.GetCredential("profile_arn")) != ""),
|
||||
)
|
||||
|
||||
if s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, body) {
|
||||
parsedForEmulation := *parsed
|
||||
parsedForEmulation.Body = body
|
||||
return s.handleWebSearchEmulation(ctx, c, account, &parsedForEmulation)
|
||||
}
|
||||
|
||||
if parsed.Stream {
|
||||
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: failoverErr.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||
})
|
||||
return nil, failoverErr
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
|
||||
}
|
||||
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamResult.usage == nil {
|
||||
streamResult.usage = &ClaudeUsage{}
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *streamResult.usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: streamResult.firstTokenMs,
|
||||
ClientDisconnect: streamResult.clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenType != "oauth" {
|
||||
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||
}
|
||||
if isOnlyWebSearchToolInBody(body) {
|
||||
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
switch {
|
||||
case errors.Is(webSearchErr, errKiroWebSearchFallback):
|
||||
case webSearchErr == nil:
|
||||
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||
c.Header("Content-Type", "application/json")
|
||||
if webSearchResult.RequestID != "" {
|
||||
c.Header("x-request-id", webSearchResult.RequestID)
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", webSearchResult.ResponseBody)
|
||||
return &ForwardResult{
|
||||
RequestID: webSearchResult.RequestID,
|
||||
Usage: webSearchResult.Usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
default:
|
||||
var httpErr *kiroWebSearchHTTPError
|
||||
if errors.As(webSearchErr, &httpErr) && httpErr.Response != nil {
|
||||
return nil, s.handleKiroHTTPError(ctx, httpErr.Response, c, account, mappedModel, body)
|
||||
}
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(webSearchErr, &failoverErr) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: failoverErr.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(webSearchErr.Error()),
|
||||
})
|
||||
return nil, failoverErr
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(webSearchErr.Error())
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||
}
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(body)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: failoverErr.StatusCode,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||
})
|
||||
return nil, failoverErr
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("kiro upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, s.handleKiroHTTPError(ctx, resp, c, account, mappedModel, body)
|
||||
}
|
||||
|
||||
parseResult, err := kiropkg.ParseNonStreamingEventStreamWithContext(resp.Body, mappedModel, requestCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Failed to parse Kiro upstream response",
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "application/json")
|
||||
if requestID := resp.Header.Get("x-request-id"); requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", parseResult.ResponseBody)
|
||||
|
||||
upstreamModel := normalizeModelNameForPricing(kiropkg.MapModel(mappedModel))
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenType != "oauth" {
|
||||
return nil, 0, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(anthropicBody)
|
||||
if isOnlyWebSearchToolInBody(anthropicBody) {
|
||||
pr, pw := io.Pipe()
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "text/event-stream")
|
||||
go func() {
|
||||
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: headers,
|
||||
Body: pr,
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
return nil, inputTokens, err
|
||||
}
|
||||
return nil, inputTokens, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return resp, inputTokens, nil
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
wrappedHeaders := resp.Header.Clone()
|
||||
wrappedHeaders.Set("Content-Type", "text/event-stream")
|
||||
if requestID := buildKiroRequestID(resp); requestID != "" {
|
||||
wrappedHeaders.Set("x-request-id", requestID)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
_, streamErr := kiropkg.StreamEventStreamAsAnthropicWithContext(ctx, resp.Body, pw, mappedModel, inputTokens, requestCtx)
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
}
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: wrappedHeaders,
|
||||
Body: pr,
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||
var requestCtx kiropkg.KiroRequestContext
|
||||
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
|
||||
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||
return nil, requestCtx, failoverErr
|
||||
}
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
|
||||
modelID := kiropkg.MapModel(mappedModel)
|
||||
currentToken := token
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
payload := buildResult.Payload
|
||||
requestCtx = buildResult.Context
|
||||
|
||||
endpoints := buildKiroEndpoints(account)
|
||||
proxyURL := kiroProxyURL(account)
|
||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
maxRetries := 2
|
||||
|
||||
for idx, endpoint := range endpoints {
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
req, err := newKiroJSONRequest(ctx, endpoint.URL, payload, currentToken, accountKey, buildKiroMachineID(account), endpoint.AmzTarget, account)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
if attempt < maxRetries {
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
cooldown, err := s.markKiro429(ctx, accountKey)
|
||||
if err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
if idx+1 < len(endpoints) {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
break
|
||||
}
|
||||
resp.Header.Set("x-kiro-cooldown", cooldown.String())
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusRequestTimeout || (resp.StatusCode >= 500 && resp.StatusCode < 600) {
|
||||
if attempt < maxRetries {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if idx+1 < len(endpoints) {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
break
|
||||
}
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusPaymentRequired {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, requestCtx, readErr
|
||||
}
|
||||
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||
if classification.Category == kiroErrorMonthlyRequest {
|
||||
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
|
||||
}
|
||||
return nil, requestCtx, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, requestCtx, readErr
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
|
||||
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if s.kiroTokenProvider != nil && (resp.StatusCode == http.StatusUnauthorized || isKiroTokenErrorBody(respBody)) && attempt < maxRetries {
|
||||
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
payload = buildResult.Payload
|
||||
requestCtx = buildResult.Context
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, requestCtx, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if refreshErr != nil && isNonRetryableRefreshError(refreshErr) {
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
}
|
||||
|
||||
if classifyKiroHTTPError(resp.StatusCode, string(respBody)).Category == kiroErrorAuthError {
|
||||
s.markKiroAuthTemporarilyUnavailable(ctx, account, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, requestCtx, readErr
|
||||
}
|
||||
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||
logKiroBadRequestClassification(classification, account, mappedModel, resp.Header, respBody)
|
||||
resetHTTPResponseBody(resp, respBody)
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
}
|
||||
return resp, requestCtx, nil
|
||||
}
|
||||
}
|
||||
return nil, requestCtx, fmt.Errorf("kiro upstream endpoints exhausted")
|
||||
}
|
||||
|
||||
func buildKiroEndpoints(account *Account) []kiroEndpointConfig {
|
||||
region := kiroAPIRegion(account)
|
||||
return []kiroEndpointConfig{
|
||||
{
|
||||
URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
|
||||
Name: "AmazonQ",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) ([]byte, error) {
|
||||
result, err := buildKiroPayloadForAccountWithRepo(ctx, nil, account, anthropicBody, modelID, token, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Payload, nil
|
||||
}
|
||||
|
||||
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
|
||||
profileArn := resolveKiroPayloadProfileArn(account)
|
||||
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
|
||||
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
|
||||
}
|
||||
|
||||
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
|
||||
requestModel = strings.TrimSpace(requestModel)
|
||||
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
|
||||
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
|
||||
return next
|
||||
}
|
||||
return anthropicBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
reason := fmt.Sprintf("kiro auth failure (%d): %s", statusCode, strings.TrimSpace(body))
|
||||
_ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason)
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroMonthlyRequestCountRateLimited(ctx context.Context, account *Account, body string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
resetAt := nextKiroMonthlyResetUTC(time.Now())
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
logger.L().Warn("kiro monthly request count rate-limit failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Time("reset_at", resetAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
reason := "kiro monthly request count exhausted (402): MONTHLY_REQUEST_COUNT"
|
||||
if trimmed := strings.TrimSpace(body); trimmed != "" {
|
||||
reason = fmt.Sprintf("%s body=%s", reason, truncateForLog([]byte(trimmed), 512))
|
||||
}
|
||||
logger.L().Warn("kiro monthly request count rate-limited",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Time("reset_at", resetAt),
|
||||
zap.String("reason", reason),
|
||||
)
|
||||
}
|
||||
|
||||
func nextKiroMonthlyResetUTC(now time.Time) time.Time {
|
||||
utc := now.UTC()
|
||||
year, month, _ := utc.Date()
|
||||
return time.Date(year, month+1, 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func resetHTTPResponseBody(resp *http.Response, body []byte) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
resp.ContentLength = int64(len(body))
|
||||
}
|
||||
|
||||
func estimateKiroInputTokens(body []byte) int {
|
||||
if len(body) == 0 {
|
||||
return 0
|
||||
}
|
||||
if tokens := gjson.GetBytes(body, "metadata.input_tokens").Int(); tokens > 0 {
|
||||
return int(tokens)
|
||||
}
|
||||
tokens := len(body) / 4
|
||||
if tokens == 0 {
|
||||
return 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func kiroUsageToClaude(usage kiropkg.Usage, fallbackInput int) ClaudeUsage {
|
||||
inputTokens := usage.InputTokens
|
||||
if inputTokens == 0 {
|
||||
inputTokens = fallbackInput
|
||||
}
|
||||
return ClaudeUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroInvalidModelRateLimited(ctx context.Context, account *Account, mappedModel string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil || account.Type != AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
resetAt := time.Now().Add(kiroInvalidModelTempUnschedDuration)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
logger.L().Warn("kiro invalid model rate-limit failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
|
||||
zap.Time("reset_at", resetAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.L().Warn("kiro invalid model rate-limited",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("mapped_model", strings.TrimSpace(mappedModel)),
|
||||
zap.Time("reset_at", resetAt),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleKiroHTTPError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, mappedModel string, requestBody []byte) error {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg == "" {
|
||||
upstreamMsg = strings.TrimSpace(string(respBody))
|
||||
}
|
||||
classification := classifyKiroHTTPError(resp.StatusCode, string(respBody))
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
logKiroBadRequestClassification(classification, account, "", resp.Header, respBody)
|
||||
}
|
||||
if classification.Category == kiroErrorMonthlyRequest {
|
||||
s.markKiroMonthlyRequestCountRateLimited(ctx, account, string(respBody))
|
||||
}
|
||||
if classification.Category == kiroErrorBadRequestInvalidModel && account != nil && account.Type == AccountTypeOAuth {
|
||||
s.markKiroInvalidModelRateLimited(ctx, account, mappedModel)
|
||||
event := s.buildKiroInvalidModelUpstreamEvent(account, resp, upstreamMsg, mappedModel, requestBody, c)
|
||||
appendOpsUpstreamError(c, event)
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusPaymentRequired || s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
c.JSON(mapUpstreamStatusCode(resp.StatusCode), gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": coalesceKiroErrorMessage(resp.StatusCode, upstreamMsg),
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("kiro upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildKiroInvalidModelUpstreamEvent(account *Account, resp *http.Response, upstreamMsg, mappedModel string, requestBody []byte, c *gin.Context) OpsUpstreamErrorEvent {
|
||||
_ = s
|
||||
requestedModel := strings.TrimSpace(gjson.GetBytes(requestBody, "model").String())
|
||||
hasTools := gjson.GetBytes(requestBody, "tools").Exists()
|
||||
hasAdaptiveThinking := strings.EqualFold(strings.TrimSpace(gjson.GetBytes(requestBody, "thinking.type").String()), "adaptive")
|
||||
hasContext1MBeta := false
|
||||
if c != nil {
|
||||
hasContext1MBeta = strings.Contains(c.GetHeader("Anthropic-Beta"), "context-1m")
|
||||
}
|
||||
return OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
RequestedModel: requestedModel,
|
||||
MappedModel: strings.TrimSpace(mappedModel),
|
||||
KiroModelID: kiropkg.MapModel(mappedModel),
|
||||
HasTools: hasTools,
|
||||
HasAdaptiveThinking: hasAdaptiveThinking,
|
||||
HasContext1MBeta: hasContext1MBeta,
|
||||
}
|
||||
}
|
||||
|
||||
func logKiroBadRequestClassification(classification kiroErrorClassification, account *Account, model string, headers http.Header, body []byte) {
|
||||
if classification.StatusCode != http.StatusBadRequest {
|
||||
return
|
||||
}
|
||||
var accountID int64
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
logger.L().Warn("kiro upstream bad request classified",
|
||||
zap.String("category", classification.Category),
|
||||
zap.Int("status", classification.StatusCode),
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.String("model", strings.TrimSpace(model)),
|
||||
zap.String("request_id", headers.Get("x-request-id")),
|
||||
zap.String("body_excerpt", truncateForLog(body, 512)),
|
||||
)
|
||||
}
|
||||
|
||||
func coalesceKiroErrorMessage(statusCode int, upstreamMsg string) string {
|
||||
if upstreamMsg != "" {
|
||||
return upstreamMsg
|
||||
}
|
||||
switch statusCode {
|
||||
case http.StatusTooManyRequests:
|
||||
return "Rate limit exceeded"
|
||||
case http.StatusForbidden:
|
||||
return "Access denied"
|
||||
case http.StatusUnauthorized:
|
||||
return "Authentication failed"
|
||||
default:
|
||||
return "Upstream request failed"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
)
|
||||
|
||||
var errKiroCooldownStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||
|
||||
type KiroCooldownStore interface {
|
||||
ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||
MarkSuccess(ctx context.Context, tokenKey string) error
|
||||
Mark429(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||
MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error)
|
||||
GetState(ctx context.Context, tokenKey string) (*kirocooldown.State, error)
|
||||
ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error)
|
||||
}
|
||||
|
||||
func asKiroCooldownFailoverError(err error) *UpstreamFailoverError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var cooldownErr *kirocooldown.Error
|
||||
if !errors.As(err, &cooldownErr) {
|
||||
return nil
|
||||
}
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
ResponseBody: []byte(cooldownErr.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) checkAndWaitKiroCooldown(ctx context.Context, tokenKey string) error {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return errKiroCooldownStoreUnavailable
|
||||
}
|
||||
waitFor, err := s.kiroCooldownStore.ReserveRequest(ctx, tokenKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if waitFor <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(waitFor)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroSuccess(ctx context.Context, tokenKey string) error {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.MarkSuccess(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiro429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return 0, errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.Mark429(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return 0, errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.MarkSuspended(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func (s *GatewayService) getKiroCooldownState(ctx context.Context, tokenKey string) (*kirocooldown.State, error) {
|
||||
if s == nil || s.kiroCooldownStore == nil {
|
||||
return nil, errKiroCooldownStoreUnavailable
|
||||
}
|
||||
return s.kiroCooldownStore.GetState(ctx, tokenKey)
|
||||
}
|
||||
|
||||
func kiroRuntimeStateSnapshot(state *kirocooldown.State) (string, string, *time.Time) {
|
||||
if state == nil || !state.Active {
|
||||
return "", "", nil
|
||||
}
|
||||
resetAt := state.CooldownUntil
|
||||
switch state.Reason {
|
||||
case kirocooldown.CooldownReasonSuspended:
|
||||
return "suspended", state.Reason, &resetAt
|
||||
default:
|
||||
return "cooldown", state.Reason, &resetAt
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
//go:build integration
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const kiroCooldownRedisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestRedisKiroCooldownStoreSharesCooldownAcrossInstances(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
storeA := kirocooldown.NewStore(rdb)
|
||||
storeB := kirocooldown.NewStore(rdb)
|
||||
|
||||
cooldown, err := storeA.Mark429(ctx, "token-shared")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Minute, cooldown)
|
||||
|
||||
wait, err := storeB.ReserveRequest(ctx, "token-shared")
|
||||
require.Zero(t, wait)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), kirocooldown.CooldownReason429)
|
||||
|
||||
require.NoError(t, storeB.MarkSuccess(ctx, "token-shared"))
|
||||
|
||||
wait, err = storeA.ReserveRequest(ctx, "token-shared")
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, wait, 0*time.Second)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreSharesReservationAcrossInstances(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
storeA := kirocooldown.NewStore(rdb)
|
||||
storeB := kirocooldown.NewStore(rdb)
|
||||
|
||||
wait, err := storeA.ReserveRequest(ctx, "token-rate")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, wait)
|
||||
|
||||
wait, err = storeB.ReserveRequest(ctx, "token-rate")
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, wait, 0*time.Millisecond)
|
||||
require.LessOrEqual(t, wait, kirocooldown.MaxRequestInterval)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreSharesSuspendedStateAcrossInstances(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
storeA := kirocooldown.NewStore(rdb)
|
||||
storeB := kirocooldown.NewStore(rdb)
|
||||
|
||||
cooldown, err := storeA.MarkSuspended(ctx, "token-suspended")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, kirocooldown.LongCooldown, cooldown)
|
||||
|
||||
wait, err := storeB.ReserveRequest(ctx, "token-suspended")
|
||||
require.Zero(t, wait)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), kirocooldown.CooldownReasonSuspended)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreSuspendedResetsFailCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
store := kirocooldown.NewStore(rdb)
|
||||
|
||||
_, err := store.Mark429(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
_, err = store.Mark429(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
|
||||
cooldown, err := store.MarkSuspended(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, kirocooldown.LongCooldown, cooldown)
|
||||
|
||||
cooldown, err = store.Mark429(ctx, "token-reset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Minute, cooldown)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreReserveDifferentTokenIgnoresOldCooldown(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
store := kirocooldown.NewStore(rdb)
|
||||
|
||||
_, err := store.Mark429(ctx, "token-old")
|
||||
require.NoError(t, err)
|
||||
|
||||
wait, err := store.ReserveRequest(ctx, "token-new")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, wait)
|
||||
}
|
||||
|
||||
func TestRedisKiroCooldownStoreUsesExpectedTTLs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startKiroCooldownRedis(t, ctx)
|
||||
store := kirocooldown.NewStore(rdb)
|
||||
|
||||
_, err := store.ReserveRequest(ctx, "token-ttl-active")
|
||||
require.NoError(t, err)
|
||||
activeTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-active")).Result()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, activeTTL, 0*time.Second)
|
||||
require.LessOrEqual(t, activeTTL, kirocooldown.ActiveTTL())
|
||||
|
||||
_, err = store.MarkSuspended(ctx, "token-ttl-state")
|
||||
require.NoError(t, err)
|
||||
stateTTL, err := rdb.PTTL(ctx, kirocooldown.RedisKey("token-ttl-state")).Result()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, stateTTL, 24*time.Hour)
|
||||
require.LessOrEqual(t, stateTTL, kirocooldown.StateTTL())
|
||||
}
|
||||
|
||||
func startKiroCooldownRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureKiroCooldownDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, kiroCooldownRedisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
host, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", host, port.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureKiroCooldownDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if kiroCooldownDockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过依赖 testcontainers 的 Kiro cooldown 集成测试")
|
||||
}
|
||||
|
||||
func kiroCooldownDockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(kiroCooldownUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func kiroCooldownUserHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
@@ -0,0 +1,583 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubKiroCooldownStore struct {
|
||||
reserveWait time.Duration
|
||||
reserveErr error
|
||||
successErr error
|
||||
mark429TTL time.Duration
|
||||
mark429Err error
|
||||
suspendedTTL time.Duration
|
||||
suspendedErr error
|
||||
state *kirocooldown.State
|
||||
stateErr error
|
||||
clearCalled bool
|
||||
clearKeys []string
|
||||
clearResult bool
|
||||
clearErr error
|
||||
}
|
||||
|
||||
type recordingKiroTempUnschedRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
called bool
|
||||
id int64
|
||||
until time.Time
|
||||
reason string
|
||||
rateCalled bool
|
||||
rateID int64
|
||||
rateLimitReset time.Time
|
||||
rateLimitedCall int
|
||||
}
|
||||
|
||||
func (r *recordingKiroTempUnschedRepo) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||||
r.called = true
|
||||
r.id = id
|
||||
r.until = until
|
||||
r.reason = reason
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingKiroTempUnschedRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateCalled = true
|
||||
r.rateID = id
|
||||
r.rateLimitReset = resetAt
|
||||
r.rateLimitedCall++
|
||||
return nil
|
||||
}
|
||||
|
||||
type recordingKiroErrorRepo struct {
|
||||
recordingKiroTempUnschedRepo
|
||||
setErrorCalls int
|
||||
errorID int64
|
||||
errorMsg string
|
||||
}
|
||||
|
||||
func (r *recordingKiroErrorRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
r.errorID = id
|
||||
r.errorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||
return s.reserveWait, s.reserveErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) MarkSuccess(context.Context, string) error {
|
||||
return s.successErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||
return s.mark429TTL, s.mark429Err
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||
return s.suspendedTTL, s.suspendedErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||
if s.clearCalled && s.clearResult {
|
||||
return nil, nil
|
||||
}
|
||||
return s.state, s.stateErr
|
||||
}
|
||||
|
||||
func (s *stubKiroCooldownStore) ClearEarliestTransientCooldown(_ context.Context, tokenKeys []string) (bool, error) {
|
||||
s.clearCalled = true
|
||||
s.clearKeys = append([]string(nil), tokenKeys...)
|
||||
return s.clearResult, s.clearErr
|
||||
}
|
||||
|
||||
func TestCalculateKiro429Cooldown(t *testing.T) {
|
||||
require.Equal(t, time.Minute, kirocooldown.Calculate429Cooldown(0))
|
||||
require.Equal(t, 2*time.Minute, kirocooldown.Calculate429Cooldown(1))
|
||||
require.Equal(t, 4*time.Minute, kirocooldown.Calculate429Cooldown(2))
|
||||
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(3))
|
||||
require.Equal(t, 5*time.Minute, kirocooldown.Calculate429Cooldown(10))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownReturnsNilWithoutWait(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
}
|
||||
|
||||
require.NoError(t, svc.checkAndWaitKiroCooldown(context.Background(), "token1"))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownPropagatesReserveError(t *testing.T) {
|
||||
expected := errors.New("redis unavailable")
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{reserveErr: expected},
|
||||
}
|
||||
|
||||
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
|
||||
require.ErrorIs(t, err, expected)
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownRequiresStore(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
err := svc.checkAndWaitKiroCooldown(context.Background(), "token1")
|
||||
require.ErrorIs(t, err, errKiroCooldownStoreUnavailable)
|
||||
}
|
||||
|
||||
func TestGatewayServiceCheckAndWaitKiroCooldownWaitsAndHonorsContext(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{reserveWait: 200 * time.Millisecond},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := svc.checkAndWaitKiroCooldown(ctx, "token1")
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestAsKiroCooldownFailoverError(t *testing.T) {
|
||||
err := kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429)
|
||||
|
||||
var cooldownErr *kirocooldown.Error
|
||||
require.ErrorAs(t, err, &cooldownErr)
|
||||
|
||||
failoverErr := asKiroCooldownFailoverError(err)
|
||||
require.NotNil(t, failoverErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
|
||||
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||
}
|
||||
|
||||
func TestAsKiroCooldownFailoverErrorIgnoresNonCooldownErrors(t *testing.T) {
|
||||
require.Nil(t, asKiroCooldownFailoverError(errors.New("redis unavailable")))
|
||||
}
|
||||
|
||||
func TestGatewayServiceTryRecoverKiroCooldownPoolClearsOnlyTransientCooldown(t *testing.T) {
|
||||
store := &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(time.Minute),
|
||||
Remaining: time.Minute,
|
||||
},
|
||||
clearResult: true,
|
||||
}
|
||||
svc := &GatewayService{kiroCooldownStore: store}
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
|
||||
require.True(t, recovered)
|
||||
require.True(t, store.clearCalled)
|
||||
require.Len(t, store.clearKeys, 1)
|
||||
require.Equal(t, buildKiroAccountKey(&accounts[0]), store.clearKeys[0])
|
||||
}
|
||||
|
||||
func TestGatewayServiceTryRecoverKiroCooldownPoolSkipsSuspended(t *testing.T) {
|
||||
store := &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReasonSuspended,
|
||||
CooldownUntil: time.Now().Add(time.Hour),
|
||||
Remaining: time.Hour,
|
||||
},
|
||||
clearResult: true,
|
||||
}
|
||||
svc := &GatewayService{kiroCooldownStore: store}
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
recovered := svc.tryRecoverKiroCooldownPool(context.Background(), accounts, "", nil, false)
|
||||
require.False(t, recovered)
|
||||
require.False(t, store.clearCalled)
|
||||
}
|
||||
|
||||
func TestSelectAccountWithLoadAwarenessRecoversKiroCooldownPool(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
account := Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
store := &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(time.Minute),
|
||||
Remaining: time.Minute,
|
||||
},
|
||||
clearResult: true,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: &mockAccountRepoForGemini{accounts: []Account{account}},
|
||||
concurrencyService: NewConcurrencyService(&mockConcurrencyCache{}),
|
||||
cfg: cfg,
|
||||
kiroCooldownStore: store,
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformKiro)
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "", nil, "", 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, account.ID, result.Account.ID)
|
||||
require.True(t, store.clearCalled)
|
||||
require.Equal(t, []string{buildKiroAccountKey(&account)}, store.clearKeys)
|
||||
}
|
||||
|
||||
func TestClassifyKiroHTTPErrorMonthlyRequestCount(t *testing.T) {
|
||||
tests := []string{
|
||||
`{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
|
||||
`{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}`,
|
||||
`API returned 402: {"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`,
|
||||
}
|
||||
|
||||
for _, body := range tests {
|
||||
classification := classifyKiroHTTPError(http.StatusPaymentRequired, body)
|
||||
require.Equal(t, kiroErrorMonthlyRequest, classification.Category)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyKiroHTTPErrorPlain402IsTransient(t *testing.T) {
|
||||
classification := classifyKiroHTTPError(http.StatusPaymentRequired, `{"message":"payment required"}`)
|
||||
require.Equal(t, kiroErrorUpstreamTransient, classification.Category)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{
|
||||
reserveErr: kirocooldown.NewError(32500*time.Millisecond, kirocooldown.CooldownReason429),
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||
require.Equal(t, "kiro token is in cooldown for 33s (reason: rate_limit_exceeded)", string(failoverErr.ResponseBody))
|
||||
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE",
|
||||
},
|
||||
}
|
||||
repo := &mockAccountRepoForGemini{accountsByID: map[int64]*Account{account.ID: account}}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model ID. Please select a different model to continue.","reason":"INVALID_MODEL_ID"}`),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-opus-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
|
||||
firstBody, readErr := io.ReadAll(upstream.requests[0].Body)
|
||||
require.NoError(t, readErr)
|
||||
require.Contains(t, string(firstBody), `"profileArn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE"`)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/STALE", account.GetCredential("profile_arn"))
|
||||
}
|
||||
|
||||
func TestHandleKiroHTTPErrorOAuthInvalidModelRateLimitsAndFailovers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Anthropic-Beta", "context-1m-2025-08-07")
|
||||
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Name: "kiro-oauth",
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
svc := &GatewayService{accountRepo: repo}
|
||||
requestBody := []byte(`{"model":"claude-opus-4-7","tools":[{"name":"search"}],"thinking":{"type":"adaptive"}}`)
|
||||
resp := newJSONResponse(http.StatusBadRequest, `{"error":{"message":"Invalid model. Please select a different model to continue.","type":"upstream_error"}}`)
|
||||
resp.Header.Set("x-request-id", "req-invalid-model")
|
||||
|
||||
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", requestBody)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "Invalid model")
|
||||
require.False(t, failoverErr.RetryableOnSameAccount)
|
||||
|
||||
require.False(t, repo.called)
|
||||
require.True(t, repo.rateCalled)
|
||||
require.Equal(t, account.ID, repo.rateID)
|
||||
require.WithinDuration(t, time.Now().Add(kiroInvalidModelTempUnschedDuration), repo.rateLimitReset, 5*time.Second)
|
||||
|
||||
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, PlatformKiro, events[0].Platform)
|
||||
require.Equal(t, account.ID, events[0].AccountID)
|
||||
require.Equal(t, account.Name, events[0].AccountName)
|
||||
require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
|
||||
require.Equal(t, "req-invalid-model", events[0].UpstreamRequestID)
|
||||
require.Equal(t, "failover", events[0].Kind)
|
||||
require.Equal(t, "claude-opus-4-7", events[0].RequestedModel)
|
||||
require.Equal(t, "claude-opus-4.6", events[0].MappedModel)
|
||||
require.Equal(t, "claude-opus-4.6", events[0].KiroModelID)
|
||||
require.True(t, events[0].HasTools)
|
||||
require.True(t, events[0].HasAdaptiveThinking)
|
||||
require.True(t, events[0].HasContext1MBeta)
|
||||
}
|
||||
|
||||
func TestHandleKiroHTTPErrorAPIKeyInvalidModelDoesNotFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
account := &Account{
|
||||
ID: 43,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
svc := &GatewayService{accountRepo: repo}
|
||||
resp := newJSONResponse(http.StatusBadRequest, `{"message":"Invalid model. Please select a different model to continue."}`)
|
||||
|
||||
err := svc.handleKiroHTTPError(context.Background(), resp, c, account, "claude-opus-4.6", []byte(`{"model":"claude-opus-4-7"}`))
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.NotErrorAs(t, err, &failoverErr)
|
||||
require.False(t, repo.called)
|
||||
require.False(t, repo.rateCalled)
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestNextKiroMonthlyResetUTC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
want time.Time
|
||||
}{
|
||||
{
|
||||
name: "middle of month",
|
||||
now: time.Date(2026, time.April, 27, 10, 30, 45, 123, time.FixedZone("CST", 8*3600)),
|
||||
want: time.Date(2026, time.May, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "december rolls year",
|
||||
now: time.Date(2026, time.December, 31, 23, 59, 59, 0, time.UTC),
|
||||
want: time.Date(2027, time.January, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, nextKiroMonthlyResetUTC(tt.now))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailovers(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusPaymentRequired, `{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}`),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "MONTHLY_REQUEST_COUNT")
|
||||
require.False(t, repo.called)
|
||||
require.True(t, repo.rateCalled)
|
||||
require.Equal(t, account.ID, repo.rateID)
|
||||
require.Equal(t, nextKiroMonthlyResetUTC(time.Now()), repo.rateLimitReset)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
repo := &recordingKiroTempUnschedRepo{}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusPaymentRequired, `{"message":"payment required"}`),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusPaymentRequired, failoverErr.StatusCode)
|
||||
require.False(t, repo.called)
|
||||
require.False(t, repo.rateCalled)
|
||||
}
|
||||
|
||||
func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedule(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "old-refresh",
|
||||
},
|
||||
}
|
||||
repo := &recordingKiroErrorRepo{
|
||||
recordingKiroTempUnschedRepo: recordingKiroTempUnschedRepo{
|
||||
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
||||
accountsByID: map[int64]*Account{account.ID: account},
|
||||
},
|
||||
},
|
||||
}
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"message":"token expired"}`),
|
||||
},
|
||||
}
|
||||
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
kiroCooldownStore: &stubKiroCooldownStore{},
|
||||
tlsFPProfileService: &TLSFingerprintProfileService{},
|
||||
kiroTokenProvider: provider,
|
||||
}
|
||||
|
||||
payload, err := createTestPayload("claude-sonnet-4-6")
|
||||
require.NoError(t, err)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, account.ID, repo.errorID)
|
||||
require.Contains(t, repo.errorMsg, "invalid_grant")
|
||||
require.False(t, repo.called, "non-retryable refresh errors should not mark temporary unschedulable")
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionSkipsActiveKiroCooldown(t *testing.T) {
|
||||
now := time.Now().Add(2 * time.Minute)
|
||||
svc := &GatewayService{
|
||||
kiroCooldownStore: &stubKiroCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: now,
|
||||
Remaining: 2 * time.Minute,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
require.False(t, svc.isAccountSchedulableForSelection(account))
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
kiroTokenRefreshSkew = 3 * time.Minute
|
||||
kiroTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
type KiroTokenCache = GeminiTokenCache
|
||||
|
||||
type kiroAccountTokenRefresher interface {
|
||||
RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error)
|
||||
BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any
|
||||
}
|
||||
|
||||
type KiroTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache KiroTokenCache
|
||||
kiroOAuthService kiroAccountTokenRefresher
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewKiroTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache KiroTokenCache,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
) *KiroTokenProvider {
|
||||
return &KiroTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
kiroOAuthService: kiroOAuthService,
|
||||
refreshPolicy: GeminiProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a kiro oauth account")
|
||||
}
|
||||
|
||||
cacheKey := KiroTokenCacheKey(account)
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= kiroTokenRefreshSkew
|
||||
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, kiroTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > kiroTokenCacheSkew:
|
||||
ttl = until - kiroTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func KiroTokenCacheKey(account *Account) string {
|
||||
if account == nil {
|
||||
return "kiro:account:0"
|
||||
}
|
||||
if clientIDHash := strings.TrimSpace(account.GetCredential("client_id_hash")); clientIDHash != "" {
|
||||
return "kiro:" + clientIDHash
|
||||
}
|
||||
if clientID := strings.TrimSpace(account.GetCredential("client_id")); clientID != "" {
|
||||
return "kiro:client:" + clientID
|
||||
}
|
||||
return "kiro:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) ForceRefreshAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a kiro oauth account")
|
||||
}
|
||||
if p.kiroOAuthService == nil {
|
||||
return "", errors.New("kiro oauth service is nil")
|
||||
}
|
||||
|
||||
cacheKey := KiroTokenCacheKey(account)
|
||||
lockHeld := false
|
||||
if p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
lockHeld = true
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
if p.accountRepo != nil {
|
||||
if latestAccount, err := p.accountRepo.GetByID(ctx, account.ID); err == nil && latestAccount != nil {
|
||||
account = latestAccount
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := p.kiroOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
if !lockHeld {
|
||||
if latestAccount, stale := CheckTokenVersion(ctx, account, p.accountRepo); stale && latestAccount != nil {
|
||||
account = latestAccount
|
||||
if accessToken := strings.TrimSpace(account.GetCredential("access_token")); accessToken != "" {
|
||||
_ = p.cacheAccessToken(ctx, account, accessToken)
|
||||
return accessToken, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if isNonRetryableRefreshError(err) && p.accountRepo != nil {
|
||||
errorMsg := "Token refresh failed (non-retryable): " + err.Error()
|
||||
_ = p.accountRepo.SetError(ctx, account.ID, errorMsg)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
newCredentials := MergeCredentials(account.Credentials, p.kiroOAuthService.BuildAccountCredentials(tokenInfo))
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
if err := persistAccountCredentials(ctx, p.accountRepo, account, newCredentials); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" {
|
||||
accessToken = strings.TrimSpace(tokenInfo.AccessToken)
|
||||
}
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found after kiro refresh")
|
||||
}
|
||||
if err := p.cacheAccessToken(ctx, account, accessToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *KiroTokenProvider) cacheAccessToken(ctx context.Context, account *Account, accessToken string) error {
|
||||
if p.tokenCache == nil || account == nil || strings.TrimSpace(accessToken) == "" {
|
||||
return nil
|
||||
}
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > kiroTokenCacheSkew:
|
||||
ttl = until - kiroTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
return p.tokenCache.SetAccessToken(ctx, KiroTokenCacheKey(account), accessToken, ttl)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type kiroTokenProviderRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
setErrorID int64
|
||||
setErrorMsg string
|
||||
}
|
||||
|
||||
func (r *kiroTokenProviderRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
r.setErrorID = id
|
||||
r.setErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
type kiroTokenProviderSequenceRepo struct {
|
||||
kiroTokenProviderRepo
|
||||
accounts []*Account
|
||||
reads int
|
||||
}
|
||||
|
||||
func (r *kiroTokenProviderSequenceRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||
if len(r.accounts) == 0 {
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
idx := r.reads
|
||||
if idx >= len(r.accounts) {
|
||||
idx = len(r.accounts) - 1
|
||||
}
|
||||
r.reads++
|
||||
return r.accounts[idx], nil
|
||||
}
|
||||
|
||||
type stubKiroAccountTokenRefresher struct {
|
||||
tokenInfo *KiroTokenInfo
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubKiroAccountTokenRefresher) RefreshAccountToken(context.Context, *Account) (*KiroTokenInfo, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *stubKiroAccountTokenRefresher) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
|
||||
if tokenInfo == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": tokenInfo.ExpiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroTokenProviderForceRefreshInvalidGrantSetsError(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
||||
}
|
||||
repo := &kiroTokenProviderRepo{
|
||||
mockAccountRepoForGemini: mockAccountRepoForGemini{
|
||||
accountsByID: map[int64]*Account{account.ID: account},
|
||||
},
|
||||
}
|
||||
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||
|
||||
token, err := provider.ForceRefreshAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, token)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, account.ID, repo.setErrorID)
|
||||
require.Contains(t, repo.setErrorMsg, "Token refresh failed (non-retryable)")
|
||||
require.Contains(t, repo.setErrorMsg, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestKiroTokenProviderForceRefreshRaceRecoveryDoesNotSetError(t *testing.T) {
|
||||
usedAccount := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-refresh"},
|
||||
}
|
||||
latestAccount := &Account{
|
||||
ID: 42,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "new-refresh", "access_token": "fresh-access", "_token_version": int64(2)},
|
||||
}
|
||||
repo := &kiroTokenProviderSequenceRepo{accounts: []*Account{usedAccount, latestAccount}}
|
||||
provider := NewKiroTokenProvider(repo, nil, nil)
|
||||
provider.kiroOAuthService = &stubKiroAccountTokenRefresher{err: errors.New("invalid_grant: token revoked")}
|
||||
|
||||
token, err := provider.ForceRefreshAccessToken(context.Background(), usedAccount)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fresh-access", token)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const kiroRefreshWindow = 15 * time.Minute
|
||||
|
||||
type KiroTokenRefresher struct {
|
||||
kiroOAuthService *KiroOAuthService
|
||||
}
|
||||
|
||||
func NewKiroTokenRefresher(kiroOAuthService *KiroOAuthService) *KiroTokenRefresher {
|
||||
return &KiroTokenRefresher{
|
||||
kiroOAuthService: kiroOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) CacheKey(account *Account) string {
|
||||
return KiroTokenCacheKey(account)
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformKiro && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Until(*expiresAt) <= kiroRefreshWindow
|
||||
}
|
||||
|
||||
func (r *KiroTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.kiroOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.kiroOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
return MergeCredentials(account.Credentials, newCredentials), nil
|
||||
}
|
||||
@@ -0,0 +1,608 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
kiroUsageOrigin = "AI_EDITOR"
|
||||
kiroUsageResourceType = "AGENTIC_REQUEST"
|
||||
|
||||
kiroDefaultRegion = "us-east-1"
|
||||
)
|
||||
|
||||
var resolveKiroRuntimeEndpoint = kiroRuntimeEndpoint
|
||||
|
||||
type kiroUsageLimitsResponse struct {
|
||||
NextDateReset any `json:"nextDateReset"`
|
||||
OverageConfiguration kiroOverageConfiguration `json:"overageConfiguration"`
|
||||
SubscriptionInfo kiroSubscriptionInfo `json:"subscriptionInfo"`
|
||||
UsageBreakdownList []kiroUsageBreakdown `json:"usageBreakdownList"`
|
||||
}
|
||||
|
||||
type kiroOverageConfiguration struct {
|
||||
OverageStatus string `json:"overageStatus"`
|
||||
}
|
||||
|
||||
type kiroSubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type kiroUsageBreakdown struct {
|
||||
Currency string `json:"currency"`
|
||||
CurrentOverages *float64 `json:"currentOverages"`
|
||||
CurrentOveragesWithPrecision *float64 `json:"currentOveragesWithPrecision"`
|
||||
CurrentUsage *float64 `json:"currentUsage"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
|
||||
DisplayName string `json:"displayName"`
|
||||
DisplayNamePlural string `json:"displayNamePlural"`
|
||||
FreeTrialInfo *kiroFreeTrialInfo `json:"freeTrialInfo"`
|
||||
NextDateReset any `json:"nextDateReset"`
|
||||
OverageCharges *float64 `json:"overageCharges"`
|
||||
ResourceType string `json:"resourceType"`
|
||||
UsageLimit *float64 `json:"usageLimit"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
|
||||
}
|
||||
|
||||
type kiroFreeTrialInfo struct {
|
||||
CurrentUsage *float64 `json:"currentUsage"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision"`
|
||||
FreeTrialExpiry any `json:"freeTrialExpiry"`
|
||||
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||
UsageLimit *float64 `json:"usageLimit"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision"`
|
||||
}
|
||||
|
||||
type kiroUsageHTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *kiroUsageHTTPError) Error() string {
|
||||
if e == nil {
|
||||
return "kiro usage request failed"
|
||||
}
|
||||
if strings.TrimSpace(e.Body) == "" {
|
||||
return fmt.Sprintf("kiro usage request failed (status %d)", e.StatusCode)
|
||||
}
|
||||
return fmt.Sprintf("kiro usage request failed (status %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getKiroUsage(ctx context.Context, account *Account, source string, forceRefresh bool) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
if account == nil {
|
||||
return &UsageInfo{
|
||||
Source: source,
|
||||
UpdatedAt: &now,
|
||||
Error: "account is nil",
|
||||
ErrorCode: errorCodeNetworkError,
|
||||
}, nil
|
||||
}
|
||||
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return &UsageInfo{
|
||||
Source: source,
|
||||
UpdatedAt: &now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cached, hasCached := s.getCachedKiroUsage(account.ID)
|
||||
if hasCached && (cached.ErrorCode != "" || cached.Error != "") {
|
||||
cached.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, cached)
|
||||
return cached, nil
|
||||
}
|
||||
if !forceRefresh && hasCached {
|
||||
cached.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, cached)
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
flightKey := fmt.Sprintf("kiro-usage:%d", account.ID)
|
||||
result, fetchErr, _ := s.cache.kiroUsageFlight.Do(flightKey, func() (any, error) {
|
||||
if !forceRefresh {
|
||||
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
usage, err := s.fetchAndCacheKiroUsage(ctx, account, source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return usage, nil
|
||||
})
|
||||
if fetchErr == nil {
|
||||
if usage, ok := result.(*UsageInfo); ok && usage != nil {
|
||||
usage.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, usage)
|
||||
if source == "active" {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
degraded := buildKiroDegradedUsage(fetchErr)
|
||||
degraded.Source = source
|
||||
if hasCached {
|
||||
cached.Error = degraded.Error
|
||||
cached.ErrorCode = degraded.ErrorCode
|
||||
cached.NeedsReauth = degraded.NeedsReauth
|
||||
cached.KiroQuotaState = degraded.KiroQuotaState
|
||||
cached.KiroQuotaReason = degraded.KiroQuotaReason
|
||||
cached.KiroQuotaResetAt = degraded.KiroQuotaResetAt
|
||||
cached.Source = source
|
||||
s.attachKiroRuntimeState(ctx, account, cached)
|
||||
return cached, nil
|
||||
}
|
||||
s.storeKiroUsageSnapshot(account.ID, degraded)
|
||||
s.attachKiroRuntimeState(ctx, account, degraded)
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) fetchAndCacheKiroUsage(ctx context.Context, account *Account, source string) (*UsageInfo, error) {
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
|
||||
region := kiroAPIRegion(account)
|
||||
profileArn := strings.TrimSpace(account.GetCredential("profile_arn"))
|
||||
|
||||
resp, err := s.requestKiroUsageLimits(ctx, account, region, profileArn, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage := mapKiroUsageToInfo(resp)
|
||||
usage.Source = source
|
||||
s.storeKiroUsageSnapshot(account.ID, usage)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) storeKiroUsageSnapshot(accountID int64, usage *UsageInfo) {
|
||||
if s == nil || s.cache == nil || accountID <= 0 || usage == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
s.cache.kiroUsageCache.Store(accountID, &kiroUsageCache{
|
||||
usageInfo: cloneUsageInfo(usage),
|
||||
timestamp: now,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getCachedKiroUsage(accountID int64) (*UsageInfo, bool) {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
cached, ok := s.cache.kiroUsageCache.Load(accountID)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cache, ok := cached.(*kiroUsageCache)
|
||||
if !ok || cache == nil || cache.usageInfo == nil {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(cache.timestamp) >= kiroCacheTTL(cache.usageInfo) {
|
||||
return nil, false
|
||||
}
|
||||
return cloneUsageInfo(cache.usageInfo), true
|
||||
}
|
||||
|
||||
func kiroCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return kiroUsageErrorTTL
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return kiroUsageErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
func cloneUsageInfo(info *UsageInfo) *UsageInfo {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *info
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) requestKiroUsageLimits(ctx context.Context, account *Account, region, profileArn, token string) (*kiroUsageLimitsResponse, error) {
|
||||
endpoint := resolveKiroRuntimeEndpoint(region)
|
||||
reqURL, err := url.Parse(endpoint + "/getUsageLimits")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build kiro usage url failed: %w", err)
|
||||
}
|
||||
q := reqURL.Query()
|
||||
q.Set("origin", kiroUsageOrigin)
|
||||
if profileArn = strings.TrimSpace(profileArn); profileArn != "" {
|
||||
q.Set("profileArn", profileArn)
|
||||
}
|
||||
q.Set("resourceType", kiroUsageResourceType)
|
||||
reqURL.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kiro usage request failed: %w", err)
|
||||
}
|
||||
s.applyKiroRuntimeHeaders(req, account, token)
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: accountProxyURL(account),
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kiro usage client failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro usage request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read kiro usage response failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
|
||||
}
|
||||
|
||||
var parsed kiroUsageLimitsResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("decode kiro usage response failed: %w", err)
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) applyKiroRuntimeHeaders(req *http.Request, account *Account, token string) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
machineID := buildKiroMachineID(account)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
applyKiroConditionalHeaders(req, account)
|
||||
}
|
||||
|
||||
func accountProxyURL(account *Account) string {
|
||||
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
|
||||
func kiroRuntimeEndpoint(region string) string {
|
||||
region = strings.TrimSpace(region)
|
||||
if region == "" {
|
||||
region = kiroDefaultRegion
|
||||
}
|
||||
switch region {
|
||||
case "us-east-1":
|
||||
return "https://q.us-east-1.amazonaws.com"
|
||||
case "eu-central-1":
|
||||
return "https://q.eu-central-1.amazonaws.com"
|
||||
case "us-gov-east-1":
|
||||
return "https://q-fips.us-gov-east-1.amazonaws.com"
|
||||
case "us-gov-west-1":
|
||||
return "https://q-fips.us-gov-west-1.amazonaws.com"
|
||||
case "us-iso-east-1":
|
||||
return "https://q.us-iso-east-1.c2s.ic.gov"
|
||||
case "us-isob-east-1":
|
||||
return "https://q.us-isob-east-1.sc2s.sgov.gov"
|
||||
case "us-isof-south-1":
|
||||
return "https://q.us-isof-south-1.csp.hci.ic.gov"
|
||||
case "us-isof-east-1":
|
||||
return "https://q.us-isof-east-1.csp.hci.ic.gov"
|
||||
default:
|
||||
if strings.HasPrefix(region, "us-gov-") {
|
||||
return "https://q-fips." + region + ".amazonaws.com"
|
||||
}
|
||||
return "https://q." + region + ".amazonaws.com"
|
||||
}
|
||||
}
|
||||
|
||||
func isLoopbackEndpoint(raw string) bool {
|
||||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func mapKiroUsageToInfo(resp *kiroUsageLimitsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
if resp == nil {
|
||||
return &UsageInfo{UpdatedAt: &now}
|
||||
}
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
KiroSubscriptionName: strings.TrimSpace(resp.SubscriptionInfo.SubscriptionTitle),
|
||||
KiroSubscriptionType: strings.TrimSpace(resp.SubscriptionInfo.Type),
|
||||
KiroOveragesEnabled: strings.EqualFold(strings.TrimSpace(resp.OverageConfiguration.OverageStatus), "ENABLED"),
|
||||
}
|
||||
|
||||
resetAt := parseKiroTimestamp(resp.NextDateReset)
|
||||
if credit := selectKiroCreditBreakdown(resp.UsageBreakdownList); credit != nil {
|
||||
if breakdownReset := parseKiroTimestamp(credit.NextDateReset); breakdownReset != nil {
|
||||
resetAt = breakdownReset
|
||||
}
|
||||
info.KiroCredit = &KiroCreditProgress{
|
||||
CurrentUsage: selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage),
|
||||
UsageLimit: selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit),
|
||||
PercentageUsed: percentageOrZero(selectKiroFloat(credit.CurrentUsageWithPrecision, credit.CurrentUsage), selectKiroFloat(credit.UsageLimitWithPrecision, credit.UsageLimit)),
|
||||
}
|
||||
info.KiroOverage = &KiroOverageInfo{
|
||||
CurrentOverages: selectKiroFloat(credit.CurrentOveragesWithPrecision, credit.CurrentOverages),
|
||||
OverageCharges: selectKiroFloat(credit.OverageCharges, nil),
|
||||
CurrencyCode: strings.TrimSpace(credit.Currency),
|
||||
CurrencySymbol: kiroCurrencySymbol(strings.TrimSpace(credit.Currency)),
|
||||
}
|
||||
if ft := credit.FreeTrialInfo; ft != nil && strings.EqualFold(strings.TrimSpace(ft.FreeTrialStatus), "ACTIVE") {
|
||||
expiry := parseKiroTimestamp(ft.FreeTrialExpiry)
|
||||
daysRemaining := 0
|
||||
if expiry != nil {
|
||||
daysRemaining = int(time.Until(*expiry).Hours() / 24)
|
||||
if time.Until(*expiry)%(24*time.Hour) != 0 {
|
||||
daysRemaining++
|
||||
}
|
||||
if daysRemaining < 0 {
|
||||
daysRemaining = 0
|
||||
}
|
||||
}
|
||||
current := selectKiroFloat(ft.CurrentUsageWithPrecision, ft.CurrentUsage)
|
||||
limit := selectKiroFloat(ft.UsageLimitWithPrecision, ft.UsageLimit)
|
||||
info.KiroBonus = &KiroCreditProgress{
|
||||
CurrentUsage: current,
|
||||
UsageLimit: limit,
|
||||
PercentageUsed: percentageOrZero(current, limit),
|
||||
DaysRemaining: daysRemaining,
|
||||
ExpiryDate: expiry,
|
||||
}
|
||||
}
|
||||
}
|
||||
info.KiroResetAt = resetAt
|
||||
setKiroQuotaStateFromUsage(info)
|
||||
return info
|
||||
}
|
||||
|
||||
func selectKiroCreditBreakdown(items []kiroUsageBreakdown) *kiroUsageBreakdown {
|
||||
for i := range items {
|
||||
if strings.EqualFold(strings.TrimSpace(items[i].ResourceType), "CREDIT") {
|
||||
return &items[i]
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &items[0]
|
||||
}
|
||||
|
||||
func selectKiroFloat(preferred *float64, fallback *float64) float64 {
|
||||
switch {
|
||||
case preferred != nil:
|
||||
return *preferred
|
||||
case fallback != nil:
|
||||
return *fallback
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func percentageOrZero(current, limit float64) float64 {
|
||||
if limit <= 0 {
|
||||
return 0
|
||||
}
|
||||
return current / limit * 100
|
||||
}
|
||||
|
||||
func parseKiroTimestamp(raw any) *time.Time {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, trimmed); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
||||
return unixishToTime(i)
|
||||
}
|
||||
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
|
||||
return unixishFloatToTime(f)
|
||||
}
|
||||
case float64:
|
||||
return unixishFloatToTime(v)
|
||||
case int64:
|
||||
return unixishToTime(v)
|
||||
case int:
|
||||
return unixishToTime(int64(v))
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return unixishToTime(i)
|
||||
}
|
||||
if f, err := v.Float64(); err == nil {
|
||||
return unixishFloatToTime(f)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unixishFloatToTime(v float64) *time.Time {
|
||||
if v <= 0 {
|
||||
return nil
|
||||
}
|
||||
if v >= 1e12 {
|
||||
t := time.UnixMilli(int64(v))
|
||||
return &t
|
||||
}
|
||||
t := time.Unix(int64(v), 0)
|
||||
return &t
|
||||
}
|
||||
|
||||
func unixishToTime(v int64) *time.Time {
|
||||
if v <= 0 {
|
||||
return nil
|
||||
}
|
||||
if v >= 1e12 {
|
||||
t := time.UnixMilli(v)
|
||||
return &t
|
||||
}
|
||||
t := time.Unix(v, 0)
|
||||
return &t
|
||||
}
|
||||
|
||||
func kiroCurrencySymbol(code string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(code)) {
|
||||
case "USD":
|
||||
return "$"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildKiroDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: "usage API error",
|
||||
ErrorCode: errorCodeNetworkError,
|
||||
}
|
||||
if err == nil {
|
||||
return info
|
||||
}
|
||||
|
||||
info.Error = fmt.Sprintf("usage API error: %v", err)
|
||||
|
||||
classification := classifyKiroError(err)
|
||||
switch classification.Category {
|
||||
case kiroErrorAuthError:
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case kiroErrorRateLimited:
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
case kiroErrorQuotaExhausted:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
|
||||
info.KiroQuotaReason = classification.Message
|
||||
case kiroErrorOverageExhausted:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
info.KiroQuotaState = kiroQuotaStateOverageExhausted
|
||||
info.KiroQuotaReason = classification.Message
|
||||
case kiroErrorSuspended, kiroErrorUsageForbidden, kiroErrorProfileError:
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) attachKiroRuntimeState(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
if s == nil || usage == nil || account == nil || account.Platform != PlatformKiro || s.kiroCooldownStore == nil {
|
||||
return
|
||||
}
|
||||
usage.KiroRuntimeState = ""
|
||||
usage.KiroRuntimeReason = ""
|
||||
usage.KiroRuntimeResetAt = nil
|
||||
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
|
||||
if err != nil || state == nil {
|
||||
return
|
||||
}
|
||||
usage.KiroRuntimeState, usage.KiroRuntimeReason, usage.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) EnrichAccountWithKiroRuntimeState(ctx context.Context, account *Account) {
|
||||
if s == nil || account == nil || account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
account.KiroQuotaState = ""
|
||||
account.KiroQuotaReason = ""
|
||||
account.KiroQuotaResetAt = nil
|
||||
account.KiroRuntimeState = ""
|
||||
account.KiroRuntimeReason = ""
|
||||
account.KiroRuntimeResetAt = nil
|
||||
if usage, ok := s.getCachedKiroUsage(account.ID); ok {
|
||||
account.KiroQuotaState = usage.KiroQuotaState
|
||||
account.KiroQuotaReason = usage.KiroQuotaReason
|
||||
account.KiroQuotaResetAt = usage.KiroQuotaResetAt
|
||||
}
|
||||
if s.kiroCooldownStore == nil {
|
||||
return
|
||||
}
|
||||
state, err := s.kiroCooldownStore.GetState(ctx, buildKiroAccountKey(account))
|
||||
if err != nil || state == nil {
|
||||
return
|
||||
}
|
||||
account.KiroRuntimeState, account.KiroRuntimeReason, account.KiroRuntimeResetAt = kiroRuntimeStateSnapshot(state)
|
||||
}
|
||||
|
||||
func setKiroQuotaStateFromUsage(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
info.KiroQuotaState = ""
|
||||
info.KiroQuotaReason = ""
|
||||
info.KiroQuotaResetAt = nil
|
||||
|
||||
creditExhausted := false
|
||||
if info.KiroCredit != nil && info.KiroCredit.UsageLimit > 0 {
|
||||
creditExhausted = info.KiroCredit.CurrentUsage >= info.KiroCredit.UsageLimit
|
||||
}
|
||||
overageActive := info.KiroOverage != nil &&
|
||||
(info.KiroOverage.CurrentOverages > 0 || info.KiroOverage.OverageCharges > 0)
|
||||
|
||||
switch {
|
||||
case info.KiroOveragesEnabled && (overageActive || creditExhausted):
|
||||
info.KiroQuotaState = kiroQuotaStateOverageActive
|
||||
info.KiroQuotaReason = "overages_enabled"
|
||||
info.KiroQuotaResetAt = info.KiroResetAt
|
||||
case creditExhausted:
|
||||
info.KiroQuotaState = kiroQuotaStateCreditsExhausted
|
||||
info.KiroQuotaReason = "credits_exhausted"
|
||||
info.KiroQuotaResetAt = info.KiroResetAt
|
||||
default:
|
||||
info.KiroQuotaState = kiroQuotaStateNormal
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,458 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
)
|
||||
|
||||
const kiroMaxWebSearchIterations = 5
|
||||
|
||||
var (
|
||||
errKiroWebSearchFallback = errors.New("kiro web search fallback")
|
||||
kiroWebSearchDescCache sync.Map
|
||||
)
|
||||
|
||||
type kiroWebSearchExecution struct {
|
||||
ResponseBody []byte
|
||||
Usage ClaudeUsage
|
||||
RequestID string
|
||||
}
|
||||
|
||||
type kiroWebSearchHTTPError struct {
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
type kiroStreamChunkCollector struct {
|
||||
chunks [][]byte
|
||||
}
|
||||
|
||||
func (e *kiroWebSearchHTTPError) Error() string {
|
||||
if e == nil || e.Response == nil {
|
||||
return "kiro web search http error"
|
||||
}
|
||||
return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
|
||||
}
|
||||
|
||||
func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
|
||||
if len(p) > 0 {
|
||||
w.chunks = append(w.chunks, append([]byte(nil), p...))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
|
||||
collector := &kiroStreamChunkCollector{}
|
||||
result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return collector.chunks, result, nil
|
||||
}
|
||||
|
||||
func writeSSEChunks(w io.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
msgID = "msg_" + kiropkg.GenerateToolUseID()
|
||||
}
|
||||
if strings.TrimSpace(model) == "" {
|
||||
model = "kiro"
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||
) error {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
|
||||
if err != nil {
|
||||
currentBody = anthropicBody
|
||||
}
|
||||
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
nextContentBlockIndex := 0
|
||||
|
||||
if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
|
||||
s.prefetchKiroWebSearchDescription(ctx, account, token)
|
||||
|
||||
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
|
||||
if strings.TrimSpace(nextToken) != "" {
|
||||
token = nextToken
|
||||
}
|
||||
if mcpErr != nil {
|
||||
results = nil
|
||||
}
|
||||
|
||||
if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
|
||||
return err
|
||||
}
|
||||
nextContentBlockIndex += 2
|
||||
|
||||
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
|
||||
if err != nil {
|
||||
return errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return &kiroWebSearchHTTPError{Response: resp}
|
||||
}
|
||||
|
||||
chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
|
||||
}()
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
|
||||
analysis := kiropkg.AnalyzeBufferedStream(chunks)
|
||||
if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
|
||||
filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
|
||||
if err := writeSSEChunks(w, filtered); err != nil {
|
||||
return err
|
||||
}
|
||||
if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
|
||||
nextContentBlockIndex = maxIndex + 1
|
||||
}
|
||||
query = analysis.WebSearchQuery
|
||||
if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
|
||||
currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
} else {
|
||||
currentToolUseID = analysis.WebSearchToolUseID
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(adjusted); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("kiro web search exceeded max iterations")
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil, errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
|
||||
if err != nil {
|
||||
currentBody = anthropicBody
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(anthropicBody)
|
||||
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
searches := make([]kiropkg.SearchIndicator, 0, 2)
|
||||
requestID := ""
|
||||
|
||||
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
|
||||
s.prefetchKiroWebSearchDescription(ctx, account, token)
|
||||
|
||||
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
|
||||
if strings.TrimSpace(nextToken) != "" {
|
||||
token = nextToken
|
||||
}
|
||||
if mcpErr != nil {
|
||||
results = nil
|
||||
}
|
||||
searches = append(searches, kiropkg.SearchIndicator{
|
||||
ToolUseID: currentToolUseID,
|
||||
Query: query,
|
||||
Results: results,
|
||||
})
|
||||
|
||||
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
|
||||
if err != nil {
|
||||
return nil, errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &kiroWebSearchHTTPError{Response: resp}
|
||||
}
|
||||
|
||||
parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
|
||||
}()
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if requestID == "" {
|
||||
requestID = buildKiroRequestID(resp)
|
||||
}
|
||||
|
||||
nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
|
||||
if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
|
||||
finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
|
||||
if injectErr == nil {
|
||||
parseResult.ResponseBody = finalBody
|
||||
}
|
||||
return &kiroWebSearchExecution{
|
||||
ResponseBody: parseResult.ResponseBody,
|
||||
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
|
||||
RequestID: requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
query = nextQuery
|
||||
if strings.TrimSpace(nextToolUseID) == "" {
|
||||
nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
|
||||
}
|
||||
currentToolUseID = nextToolUseID
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("kiro web search exceeded max iterations")
|
||||
}
|
||||
|
||||
func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
|
||||
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
|
||||
if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
|
||||
if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
|
||||
kiropkg.SetCachedWebSearchDescription(desc)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(kiropkg.MCPRequest{
|
||||
ID: "tools_list",
|
||||
JSONRPC: "2.0",
|
||||
Method: "tools/list",
|
||||
})
|
||||
resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
|
||||
if err != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var result kiropkg.MCPResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
return
|
||||
}
|
||||
for _, tool := range result.Result.Tools {
|
||||
if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
|
||||
kiroWebSearchDescCache.Store(endpoint, tool.Description)
|
||||
kiropkg.SetCachedWebSearchDescription(tool.Description)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
|
||||
reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
|
||||
if err != nil {
|
||||
return nil, token, err
|
||||
}
|
||||
|
||||
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
|
||||
resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
|
||||
if err != nil {
|
||||
return nil, nextToken, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nextToken, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed kiropkg.MCPResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, nextToken, err
|
||||
}
|
||||
if parsed.Error != nil {
|
||||
msg := "unknown error"
|
||||
if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
|
||||
msg = strings.TrimSpace(*parsed.Error.Message)
|
||||
}
|
||||
code := 0
|
||||
if parsed.Error.Code != nil {
|
||||
code = *parsed.Error.Code
|
||||
}
|
||||
return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return kiropkg.ParseSearchResults(&parsed), nextToken, nil
|
||||
}
|
||||
|
||||
func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
|
||||
return kiropkg.MCPRequest{
|
||||
ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
|
||||
JSONRPC: "2.0",
|
||||
Method: "tools/call",
|
||||
Params: map[string]interface{}{
|
||||
"name": "web_search",
|
||||
"arguments": map[string]interface{}{
|
||||
"query": query,
|
||||
"_meta": map[string]interface{}{
|
||||
"_isValid": true,
|
||||
"_activePath": []string{"query"},
|
||||
"_completedPaths": [][]string{{"query"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
|
||||
currentToken := token
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
proxyURL := kiroProxyURL(account)
|
||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
|
||||
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||
return nil, currentToken, failoverErr
|
||||
}
|
||||
return nil, currentToken, err
|
||||
}
|
||||
|
||||
req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
|
||||
if err != nil {
|
||||
return nil, currentToken, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
return nil, currentToken, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, currentToken, readErr
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
|
||||
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
|
||||
return nil, currentToken, err
|
||||
}
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
if s.kiroTokenProvider == nil {
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, currentToken, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
if _, err := s.markKiro429(ctx, accountKey); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, currentToken, err
|
||||
}
|
||||
}
|
||||
if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
|
||||
if attempt < 2 {
|
||||
_ = resp.Body.Close()
|
||||
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
|
||||
return nil, currentToken, sleepErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, currentToken, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, currentToken, nil
|
||||
}
|
||||
|
||||
return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildKiroWebSearchMCPRequest_UsesUnderscoredMetaKeys(t *testing.T) {
|
||||
req := buildKiroWebSearchMCPRequest("golang concurrency")
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "tools/call", gjson.GetBytes(body, "method").String())
|
||||
require.Equal(t, "web_search", gjson.GetBytes(body, "params.name").String())
|
||||
require.Equal(t, "golang concurrency", gjson.GetBytes(body, "params.arguments.query").String())
|
||||
require.True(t, gjson.GetBytes(body, "params.arguments._meta._isValid").Bool())
|
||||
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._activePath.0").String())
|
||||
require.Equal(t, "query", gjson.GetBytes(body, "params.arguments._meta._completedPaths.0.0").String())
|
||||
require.False(t, gjson.GetBytes(body, "params.arguments._meta.isValid").Exists())
|
||||
require.False(t, gjson.GetBytes(body, "params.arguments._meta.activePath").Exists())
|
||||
require.False(t, gjson.GetBytes(body, "params.arguments._meta.completedPaths").Exists())
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
privacyCalls := 0
|
||||
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
|
||||
privacyCalls++
|
||||
|
||||
@@ -289,6 +289,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
|
||||
out := *ev
|
||||
|
||||
out.Platform = strings.TrimSpace(out.Platform)
|
||||
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
|
||||
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
|
||||
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
|
||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||
|
||||
|
||||
@@ -89,6 +89,14 @@ type OpsUpstreamErrorEvent struct {
|
||||
AccountID int64 `json:"account_id,omitempty"`
|
||||
AccountName string `json:"account_name,omitempty"`
|
||||
|
||||
// Model diagnostics.
|
||||
RequestedModel string `json:"requested_model,omitempty"`
|
||||
MappedModel string `json:"mapped_model,omitempty"`
|
||||
KiroModelID string `json:"kiro_model_id,omitempty"`
|
||||
HasTools bool `json:"has_tools,omitempty"`
|
||||
HasAdaptiveThinking bool `json:"has_adaptive_thinking,omitempty"`
|
||||
HasContext1MBeta bool `json:"has_context_1m_beta,omitempty"`
|
||||
|
||||
// Outcome
|
||||
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
||||
|
||||
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-2.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = canonicalModelNameForPricing(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
model = strings.TrimPrefix(model, "publishers/google/models/")
|
||||
@@ -628,7 +628,31 @@ func normalizeModelNameForPricing(model string) string {
|
||||
if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" {
|
||||
return canonical
|
||||
}
|
||||
return canonicalModelNameForPricing(model)
|
||||
}
|
||||
|
||||
func canonicalModelNameForPricing(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch model {
|
||||
case "claude-opus-4.5":
|
||||
return "claude-opus-4-5"
|
||||
case "claude-opus-4.6":
|
||||
return "claude-opus-4-6"
|
||||
case "claude-opus-4.7":
|
||||
return "claude-opus-4-7"
|
||||
case "claude-sonnet-4.5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "claude-sonnet-4.6":
|
||||
return "claude-sonnet-4-6"
|
||||
case "claude-haiku-4.5":
|
||||
return "claude-haiku-4-5"
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func lastSegment(model string) string {
|
||||
@@ -674,8 +698,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
|
||||
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
|
||||
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
|
||||
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
|
||||
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
|
||||
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
|
||||
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
|
||||
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
|
||||
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
|
||||
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
|
||||
@@ -713,6 +739,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "sonnet"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
|
||||
fallbackName = "sonnet-4.6"
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "sonnet-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
@@ -722,6 +750,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "haiku"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "haiku-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
fallbackName = "haiku-3.5"
|
||||
default:
|
||||
|
||||
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
|
||||
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
|
||||
@@ -42,6 +42,9 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
|
||||
// Antigravity 同样可能有两种缓存键
|
||||
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
||||
case PlatformKiro:
|
||||
keysToDelete = append(keysToDelete, KiroTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "kiro:"+accountIDKey)
|
||||
case PlatformOpenAI:
|
||||
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
||||
case PlatformAnthropic:
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
)
|
||||
|
||||
// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间
|
||||
@@ -44,6 +45,7 @@ func NewTokenRefreshService(
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
cacheInvalidator TokenCacheInvalidator,
|
||||
schedulerCache SchedulerCache,
|
||||
cfg *config.Config,
|
||||
@@ -64,6 +66,7 @@ func NewTokenRefreshService(
|
||||
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||
kiroRefresher := NewKiroTokenRefresher(kiroOAuthService)
|
||||
|
||||
// 注册平台特定的刷新器(TokenRefresher 接口)
|
||||
s.refreshers = []TokenRefresher{
|
||||
@@ -71,6 +74,7 @@ func NewTokenRefreshService(
|
||||
openAIRefresher,
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
kiroRefresher,
|
||||
}
|
||||
|
||||
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
||||
@@ -79,6 +83,7 @@ func NewTokenRefreshService(
|
||||
openAIRefresher,
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
kiroRefresher,
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -415,6 +420,10 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var kiroInvalidGrant *kiropkg.RefreshTokenInvalidError
|
||||
if errors.As(err, &kiroInvalidGrant) {
|
||||
return true
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
nonRetryable := []string{
|
||||
"invalid_grant", // refresh_token 已失效
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Platform: PlatformGemini,
|
||||
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformGemini,
|
||||
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformGemini,
|
||||
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformOpenAI, // OpenAI OAuth 账户
|
||||
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
resetAt := time.Now().Add(30 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformGemini,
|
||||
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformGemini,
|
||||
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 15,
|
||||
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Platform: tt.platform,
|
||||
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 18,
|
||||
Platform: PlatformOpenAI,
|
||||
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ func optionalNonEqualStringPtr(value, compare string) *string {
|
||||
|
||||
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
|
||||
return trimmed
|
||||
return normalizeModelNameForPricing(trimmed)
|
||||
}
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
return normalizeModelNameForPricing(upstreamModel)
|
||||
}
|
||||
|
||||
func optionalInt64Ptr(v int64) *int64 {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -52,6 +53,7 @@ func ProvideTokenRefreshService(
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
cacheInvalidator TokenCacheInvalidator,
|
||||
schedulerCache SchedulerCache,
|
||||
cfg *config.Config,
|
||||
@@ -60,7 +62,7 @@ func ProvideTokenRefreshService(
|
||||
proxyRepo ProxyRepository,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, kiroOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
||||
@@ -129,6 +131,23 @@ func ProvideAntigravityTokenProvider(
|
||||
return p
|
||||
}
|
||||
|
||||
func ProvideKiroTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
kiroOAuthService *KiroOAuthService,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *KiroTokenProvider {
|
||||
p := NewKiroTokenProvider(accountRepo, tokenCache, kiroOAuthService)
|
||||
executor := NewKiroTokenRefresher(kiroOAuthService)
|
||||
p.SetRefreshAPI(refreshAPI, executor)
|
||||
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
|
||||
return p
|
||||
}
|
||||
|
||||
func ProvideKiroCooldownStore(redisClient *redis.Client) KiroCooldownStore {
|
||||
return kirocooldown.NewStore(redisClient)
|
||||
}
|
||||
|
||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||
@@ -457,8 +476,11 @@ var ProviderSet = wire.NewSet(
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewKiroOAuthService,
|
||||
ProvideOAuthRefreshAPI,
|
||||
ProvideGeminiTokenProvider,
|
||||
ProvideKiroTokenProvider,
|
||||
ProvideKiroCooldownStore,
|
||||
NewGeminiMessagesCompatService,
|
||||
ProvideAntigravityTokenProvider,
|
||||
ProvideOpenAITokenProvider,
|
||||
|
||||
@@ -89,6 +89,10 @@ security:
|
||||
enabled: false
|
||||
# Allowed upstream hosts for API proxying
|
||||
# 允许代理的上游 API 主机列表
|
||||
# If you enable Kiro OAuth / IDC, also allow Kiro auth and AWS SSO OIDC hosts.
|
||||
# 如果启用 Kiro OAuth / IDC,请同时放行 Kiro 鉴权域名和 AWS SSO OIDC 域名。
|
||||
# If you enable Kiro runtime forwarding, also allow the corresponding AWS Q API endpoint.
|
||||
# 如果启用 Kiro 运行时转发,还需要放行对应的 AWS Q API 域名。
|
||||
upstream_hosts:
|
||||
- "api.openai.com"
|
||||
- "api.anthropic.com"
|
||||
@@ -97,6 +101,11 @@ security:
|
||||
- "api.minimaxi.com"
|
||||
- "generativelanguage.googleapis.com"
|
||||
- "cloudcode-pa.googleapis.com"
|
||||
- "prod.us-east-1.auth.desktop.kiro.dev"
|
||||
- "oidc.us-east-1.amazonaws.com"
|
||||
- "oidc.*.amazonaws.com"
|
||||
- "device.sso.*.amazonaws.com"
|
||||
- "q.*.amazonaws.com"
|
||||
- "*.openai.azure.com"
|
||||
# Allowed hosts for pricing data download
|
||||
# 允许下载定价数据的主机列表
|
||||
@@ -372,6 +381,9 @@ gateway:
|
||||
# Max image requests waiting in this process when overflow_mode=wait, 0=unlimited
|
||||
# wait 模式当前进程允许排队等待的图片请求数,0=不限制
|
||||
max_waiting_requests: 100
|
||||
# Kiro stream keepalive interval (seconds), 0=use default 25
|
||||
# Kiro 流式 keepalive 间隔(秒),0=使用默认 25
|
||||
kiro_stream_keepalive_interval: 25
|
||||
# SSE max line size in bytes (default: 40MB)
|
||||
# SSE 单行最大字节数(默认 40MB)
|
||||
max_line_size: 41943040
|
||||
|
||||
@@ -565,6 +565,13 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getKiroDefaultModelMapping(): Promise<Record<string, string>> {
|
||||
const { data } = await apiClient.get<Record<string, string>>(
|
||||
'/admin/accounts/kiro/default-model-mapping'
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OpenAI token using refresh token
|
||||
* @param refreshToken - The refresh token
|
||||
|
||||
@@ -17,6 +17,7 @@ import subscriptionsAPI from './subscriptions'
|
||||
import usageAPI from './usage'
|
||||
import geminiAPI from './gemini'
|
||||
import antigravityAPI from './antigravity'
|
||||
import kiroAPI from './kiro'
|
||||
import userAttributesAPI from './userAttributes'
|
||||
import opsAPI from './ops'
|
||||
import errorPassthroughAPI from './errorPassthrough'
|
||||
@@ -50,6 +51,7 @@ export const adminAPI = {
|
||||
usage: usageAPI,
|
||||
gemini: geminiAPI,
|
||||
antigravity: antigravityAPI,
|
||||
kiro: kiroAPI,
|
||||
userAttributes: userAttributesAPI,
|
||||
ops: opsAPI,
|
||||
errorPassthrough: errorPassthroughAPI,
|
||||
@@ -81,6 +83,7 @@ export {
|
||||
usageAPI,
|
||||
geminiAPI,
|
||||
antigravityAPI,
|
||||
kiroAPI,
|
||||
userAttributesAPI,
|
||||
opsAPI,
|
||||
errorPassthroughAPI,
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import { apiClient } from '../client'
|
||||
|
||||
export interface KiroAuthUrlResponse {
|
||||
auth_url: string
|
||||
session_id: string
|
||||
state: string
|
||||
}
|
||||
|
||||
export interface KiroIDCAuthUrlResponse extends KiroAuthUrlResponse {
|
||||
client_id?: string
|
||||
region?: string
|
||||
start_url?: string
|
||||
}
|
||||
|
||||
export interface KiroTokenInfo {
|
||||
access_token?: string
|
||||
refresh_token?: string
|
||||
profile_arn?: string
|
||||
expires_at?: string
|
||||
auth_method?: string
|
||||
provider?: string
|
||||
client_id?: string
|
||||
client_secret?: string
|
||||
client_id_hash?: string
|
||||
email?: string
|
||||
start_url?: string
|
||||
region?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export async function generateAuthUrl(payload: {
|
||||
proxy_id?: number
|
||||
provider?: string
|
||||
}): Promise<KiroAuthUrlResponse> {
|
||||
const { data } = await apiClient.post<KiroAuthUrlResponse>('/admin/kiro/oauth/auth-url', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function generateIDCAuthUrl(payload: {
|
||||
proxy_id?: number
|
||||
start_url?: string
|
||||
region?: string
|
||||
}): Promise<KiroIDCAuthUrlResponse> {
|
||||
const { data } = await apiClient.post<KiroIDCAuthUrlResponse>('/admin/kiro/oauth/idc-auth-url', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function exchangeCode(payload: {
|
||||
session_id: string
|
||||
state: string
|
||||
code: string
|
||||
callback_path?: string
|
||||
login_option?: string
|
||||
proxy_id?: number
|
||||
}): Promise<KiroTokenInfo> {
|
||||
const { data } = await apiClient.post<KiroTokenInfo>('/admin/kiro/oauth/exchange-code', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function refreshToken(payload: {
|
||||
refresh_token: string
|
||||
auth_method?: string
|
||||
provider?: string
|
||||
client_id?: string
|
||||
client_secret?: string
|
||||
start_url?: string
|
||||
region?: string
|
||||
profile_arn?: string
|
||||
proxy_id?: number
|
||||
}): Promise<KiroTokenInfo> {
|
||||
const { data } = await apiClient.post<KiroTokenInfo>('/admin/kiro/oauth/refresh-token', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function importToken(payload: {
|
||||
token_json: string
|
||||
device_registration_json?: string
|
||||
}): Promise<KiroTokenInfo> {
|
||||
const { data } = await apiClient.post<KiroTokenInfo>('/admin/kiro/oauth/import-token', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
export default {
|
||||
generateAuthUrl,
|
||||
generateIDCAuthUrl,
|
||||
exchangeCode,
|
||||
refreshToken,
|
||||
importToken
|
||||
}
|
||||
@@ -12,6 +12,11 @@
|
||||
<span class="text-[11px] text-gray-400 dark:text-gray-500">{{ overloadCountdown }}</span>
|
||||
</div>
|
||||
|
||||
<div v-else-if="kiroQuotaBadgeLabel" class="flex flex-col items-center gap-1">
|
||||
<span :class="['badge text-xs', kiroQuotaBadgeClass]">{{ kiroQuotaBadgeLabel }}</span>
|
||||
<span v-if="kiroQuotaHint" class="text-[11px] text-gray-400 dark:text-gray-500">{{ kiroQuotaHint }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Main Status Badge (shown when not rate limited/overloaded) -->
|
||||
<template v-else>
|
||||
<button
|
||||
@@ -69,7 +74,7 @@
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 w-56 -translate-x-1/2 whitespace-normal rounded bg-gray-900 px-3 py-2 text-center text-xs leading-relaxed text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatDateTime(account.rate_limit_reset_at) }) }}
|
||||
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatDateTime(activeKiroRuntimeResetAt || account.rate_limit_reset_at) }) }}
|
||||
<div
|
||||
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700"
|
||||
></div>
|
||||
@@ -172,11 +177,37 @@ const emit = defineEmits<{
|
||||
}>()
|
||||
|
||||
// Computed: is rate limited (429)
|
||||
const activeKiroRuntimeResetAt = computed(() => {
|
||||
if (props.account.platform !== 'kiro') return null
|
||||
if (props.account.kiro_runtime_state !== 'cooldown') return null
|
||||
if (!props.account.kiro_runtime_reset_at) return null
|
||||
const resetAt = new Date(props.account.kiro_runtime_reset_at)
|
||||
if (Number.isNaN(resetAt.getTime()) || resetAt <= new Date()) return null
|
||||
return props.account.kiro_runtime_reset_at
|
||||
})
|
||||
|
||||
const isRateLimited = computed(() => {
|
||||
if (activeKiroRuntimeResetAt.value) return true
|
||||
if (!props.account.rate_limit_reset_at) return false
|
||||
return new Date(props.account.rate_limit_reset_at) > new Date()
|
||||
})
|
||||
|
||||
const isKiroRuntimeSuspended = computed(() => {
|
||||
if (props.account.platform !== 'kiro') return false
|
||||
if (props.account.kiro_runtime_state !== 'suspended') return false
|
||||
if (!props.account.kiro_runtime_reset_at) return true
|
||||
const resetAt = new Date(props.account.kiro_runtime_reset_at)
|
||||
return Number.isNaN(resetAt.getTime()) || resetAt > new Date()
|
||||
})
|
||||
|
||||
const activeKiroQuotaResetAt = computed(() => {
|
||||
if (props.account.platform !== 'kiro') return null
|
||||
if (!props.account.kiro_quota_reset_at) return null
|
||||
const resetAt = new Date(props.account.kiro_quota_reset_at)
|
||||
if (Number.isNaN(resetAt.getTime()) || resetAt <= new Date()) return null
|
||||
return props.account.kiro_quota_reset_at
|
||||
})
|
||||
|
||||
type AccountModelStatusItem = {
|
||||
kind: 'rate_limit' | 'credits_exhausted' | 'credits_active'
|
||||
model: string
|
||||
@@ -281,7 +312,7 @@ const isTempUnschedulable = computed(() => {
|
||||
|
||||
// Computed: has error status
|
||||
const hasError = computed(() => {
|
||||
return props.account.status === 'error'
|
||||
return props.account.status === 'error' || isKiroRuntimeSuspended.value
|
||||
})
|
||||
|
||||
const isQuotaExceeded = computed(() => {
|
||||
@@ -296,7 +327,7 @@ const isQuotaExceeded = computed(() => {
|
||||
|
||||
// Computed: countdown text for rate limit (429)
|
||||
const rateLimitCountdown = computed(() => {
|
||||
return formatCountdown(props.account.rate_limit_reset_at)
|
||||
return formatCountdown(activeKiroRuntimeResetAt.value || props.account.rate_limit_reset_at)
|
||||
})
|
||||
|
||||
const rateLimitResumeText = computed(() => {
|
||||
@@ -309,8 +340,45 @@ const overloadCountdown = computed(() => {
|
||||
return formatCountdownWithSuffix(props.account.overload_until)
|
||||
})
|
||||
|
||||
const kiroQuotaBadgeLabel = computed(() => {
|
||||
if (props.account.platform !== 'kiro') return ''
|
||||
switch (props.account.kiro_quota_state) {
|
||||
case 'credits_exhausted':
|
||||
return t('admin.accounts.status.creditsExhausted')
|
||||
case 'overage_exhausted':
|
||||
return t('admin.accounts.status.overageExhausted')
|
||||
default:
|
||||
return ''
|
||||
}
|
||||
})
|
||||
|
||||
const kiroQuotaBadgeClass = computed(() => {
|
||||
switch (props.account.kiro_quota_state) {
|
||||
case 'credits_exhausted':
|
||||
case 'overage_exhausted':
|
||||
return 'badge-danger'
|
||||
default:
|
||||
return 'badge-gray'
|
||||
}
|
||||
})
|
||||
|
||||
const kiroQuotaHint = computed(() => {
|
||||
if (!activeKiroQuotaResetAt.value) return ''
|
||||
switch (props.account.kiro_quota_state) {
|
||||
case 'credits_exhausted':
|
||||
return t('admin.accounts.status.creditsExhaustedUntil', { time: formatDateTime(activeKiroQuotaResetAt.value) })
|
||||
case 'overage_exhausted':
|
||||
return t('admin.accounts.status.overageExhaustedUntil', { time: formatDateTime(activeKiroQuotaResetAt.value) })
|
||||
default:
|
||||
return ''
|
||||
}
|
||||
})
|
||||
|
||||
// Computed: status badge class
|
||||
const statusClass = computed(() => {
|
||||
if (isKiroRuntimeSuspended.value) {
|
||||
return 'badge-danger'
|
||||
}
|
||||
if (hasError.value) {
|
||||
return 'badge-danger'
|
||||
}
|
||||
@@ -331,6 +399,9 @@ const statusClass = computed(() => {
|
||||
|
||||
// Computed: status text
|
||||
const statusText = computed(() => {
|
||||
if (isKiroRuntimeSuspended.value) {
|
||||
return t('admin.accounts.forbidden')
|
||||
}
|
||||
if (hasError.value) {
|
||||
return t('admin.accounts.status.error')
|
||||
}
|
||||
|
||||
@@ -395,6 +395,72 @@
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Kiro platform: show credits + bonus + overage summary -->
|
||||
<template v-else-if="account.platform === 'kiro' && account.type === 'oauth'">
|
||||
<div v-if="loading" class="space-y-1.5">
|
||||
<div class="h-4 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="space-y-1">
|
||||
<div class="h-1.5 w-32 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-1.5 w-28 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else-if="error" class="text-xs text-red-500">
|
||||
{{ error }}
|
||||
</div>
|
||||
<div v-else-if="kiroUsageAvailable || kiroStatusBadgeLabel" class="space-y-2">
|
||||
<div v-if="kiroStatusBadgeLabel" class="flex flex-wrap items-center gap-x-2 gap-y-1">
|
||||
<span
|
||||
:class="[
|
||||
'inline-flex items-center gap-1 text-[10px] font-medium',
|
||||
kiroStatusToneClass
|
||||
]"
|
||||
:title="usageInfo?.error || undefined"
|
||||
>
|
||||
<span class="h-1.5 w-1.5 rounded-full bg-current opacity-80"></span>
|
||||
{{ kiroStatusBadgeLabel }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div v-if="kiroStatusHint" class="text-[9px] leading-tight text-gray-500 dark:text-gray-400">
|
||||
{{ kiroStatusHint }}
|
||||
</div>
|
||||
|
||||
<div v-if="usageInfo?.kiro_credit" class="space-y-1">
|
||||
<div class="flex items-baseline justify-between gap-2 text-[11px] text-gray-600 dark:text-gray-300">
|
||||
<span class="font-medium tracking-[0.01em]">{{ t('admin.accounts.usageWindow.kiroCredits') }}</span>
|
||||
<span class="font-semibold tabular-nums text-gray-700 dark:text-gray-200">{{ formatKiroAmount(usageInfo.kiro_credit.current_usage) }} / {{ formatKiroAmount(usageInfo.kiro_credit.usage_limit) }}</span>
|
||||
</div>
|
||||
<div class="h-1.5 overflow-hidden rounded-full bg-gray-200 dark:bg-gray-700">
|
||||
<div class="h-full rounded-full bg-amber-500 transition-all" :style="{ width: `${kiroCreditPercent}%` }"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="usageInfo?.kiro_bonus" class="space-y-1">
|
||||
<div class="flex items-baseline justify-between gap-2 text-[11px] text-gray-600 dark:text-gray-300">
|
||||
<span class="font-medium tracking-[0.01em]">{{ t('admin.accounts.usageWindow.kiroBonus') }}</span>
|
||||
<span class="font-semibold tabular-nums text-gray-700 dark:text-gray-200">{{ formatKiroAmount(usageInfo.kiro_bonus.current_usage) }} / {{ formatKiroAmount(usageInfo.kiro_bonus.usage_limit) }}</span>
|
||||
</div>
|
||||
<div class="h-1.5 overflow-hidden rounded-full bg-gray-200 dark:bg-gray-700">
|
||||
<div class="h-full rounded-full bg-emerald-500 transition-all" :style="{ width: `${kiroBonusPercent}%` }"></div>
|
||||
</div>
|
||||
<div v-if="kiroBonusMeta" class="text-[9px] leading-tight text-gray-500 dark:text-gray-400">
|
||||
{{ kiroBonusMeta }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap items-center gap-x-3 gap-y-1 text-[10px] text-gray-500 dark:text-gray-400">
|
||||
<span v-if="kiroResetDisplay" class="inline-flex items-center gap-1">
|
||||
<span class="text-gray-400 dark:text-gray-500">{{ t('admin.accounts.usageWindow.kiroReset') }}</span>
|
||||
<span class="font-medium tabular-nums text-gray-600 dark:text-gray-300">{{ kiroResetDisplay }}</span>
|
||||
</span>
|
||||
<span v-if="kiroOverageSummary" class="inline-flex items-center gap-1 font-medium">
|
||||
{{ kiroOverageSummary }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="text-xs text-gray-400">-</div>
|
||||
</template>
|
||||
|
||||
<!-- Other accounts: no usage window -->
|
||||
<template v-else>
|
||||
<div class="text-xs text-gray-400">-</div>
|
||||
@@ -498,6 +564,10 @@ const props = withDefaults(
|
||||
}
|
||||
)
|
||||
|
||||
const emit = defineEmits<{
|
||||
kiroUsageMeta: [meta: { plan_type?: string; kiro_overages_enabled: boolean }]
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const desktopViewportQuery = '(min-width: 768px)'
|
||||
|
||||
@@ -510,7 +580,9 @@ const error = ref<string | null>(null)
|
||||
const usageInfo = ref<AccountUsageInfo | null>(null)
|
||||
const rootRef = ref<HTMLElement | null>(null)
|
||||
const isDesktopViewport = ref(
|
||||
typeof window === 'undefined' ? true : window.matchMedia(desktopViewportQuery).matches
|
||||
typeof window === 'undefined' || typeof window.matchMedia !== 'function'
|
||||
? true
|
||||
: window.matchMedia(desktopViewportQuery).matches
|
||||
)
|
||||
const hasEnteredViewport = ref(false)
|
||||
const pendingAutoLoad = ref(false)
|
||||
@@ -531,6 +603,9 @@ const shouldFetchUsage = computed(() => {
|
||||
if (props.account.platform === 'anthropic') {
|
||||
return props.account.type === 'oauth' || props.account.type === 'setup-token'
|
||||
}
|
||||
if (props.account.platform === 'kiro') {
|
||||
return props.account.type === 'oauth'
|
||||
}
|
||||
if (props.account.platform === 'gemini') {
|
||||
return true
|
||||
}
|
||||
@@ -984,6 +1059,179 @@ const isAnthropicOAuthOrSetupToken = computed(() => {
|
||||
return props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')
|
||||
})
|
||||
|
||||
const isKiroOAuth = computed(() => {
|
||||
return props.account.platform === 'kiro' && props.account.type === 'oauth'
|
||||
})
|
||||
|
||||
const defaultUsageSource = computed<'passive' | 'active' | undefined>(() => {
|
||||
if (isAnthropicOAuthOrSetupToken.value || isKiroOAuth.value) {
|
||||
return 'passive'
|
||||
}
|
||||
return undefined
|
||||
})
|
||||
|
||||
const manualRefreshUsageSource = computed<'passive' | 'active' | undefined>(() => {
|
||||
if (isKiroOAuth.value) {
|
||||
return 'active'
|
||||
}
|
||||
return defaultUsageSource.value
|
||||
})
|
||||
|
||||
const kiroUsageAvailable = computed(() => {
|
||||
return !!(
|
||||
usageInfo.value?.kiro_credit ||
|
||||
usageInfo.value?.kiro_bonus ||
|
||||
usageInfo.value?.kiro_overage ||
|
||||
usageInfo.value?.kiro_reset_at ||
|
||||
usageInfo.value?.kiro_overages_enabled
|
||||
)
|
||||
})
|
||||
|
||||
const syncKiroUsageMeta = (info?: AccountUsageInfo | null) => {
|
||||
if (!isKiroOAuth.value) return
|
||||
|
||||
const planType = (
|
||||
info?.kiro_subscription_name ||
|
||||
info?.kiro_subscription_type ||
|
||||
''
|
||||
).trim()
|
||||
|
||||
emit('kiroUsageMeta', {
|
||||
...(planType ? { plan_type: planType } : {}),
|
||||
kiro_overages_enabled: info?.kiro_overages_enabled === true
|
||||
})
|
||||
}
|
||||
|
||||
const clampPercent = (value?: number | null) => {
|
||||
if (value == null || !Number.isFinite(value)) return 0
|
||||
return Math.max(0, Math.min(100, value))
|
||||
}
|
||||
|
||||
const kiroCreditPercent = computed(() => clampPercent(usageInfo.value?.kiro_credit?.percentage_used))
|
||||
const kiroBonusPercent = computed(() => clampPercent(usageInfo.value?.kiro_bonus?.percentage_used))
|
||||
|
||||
const formatKiroAmount = (value?: number | null) => {
|
||||
if (value == null || !Number.isFinite(value)) return '0'
|
||||
if (Math.abs(value) >= 1000 || Number.isInteger(value)) {
|
||||
return formatCompactNumber(value, { allowBillions: false })
|
||||
}
|
||||
return value.toFixed(2).replace(/\.?0+$/, '')
|
||||
}
|
||||
|
||||
const kiroResetDisplay = computed(() => {
|
||||
const raw = usageInfo.value?.kiro_reset_at
|
||||
if (!raw) return ''
|
||||
const parsed = new Date(raw)
|
||||
if (Number.isNaN(parsed.getTime())) return ''
|
||||
return parsed.toLocaleDateString()
|
||||
})
|
||||
|
||||
const kiroBonusMeta = computed(() => {
|
||||
const bonus = usageInfo.value?.kiro_bonus
|
||||
if (!bonus) return ''
|
||||
if ((bonus.days_remaining ?? 0) > 0) {
|
||||
return t('admin.accounts.usageWindow.kiroDaysLeft', { days: bonus.days_remaining })
|
||||
}
|
||||
if (bonus.expiry_date) {
|
||||
const parsed = new Date(bonus.expiry_date)
|
||||
if (!Number.isNaN(parsed.getTime())) {
|
||||
return `${t('admin.accounts.usageWindow.kiroExpires')} ${parsed.toLocaleDateString()}`
|
||||
}
|
||||
}
|
||||
return ''
|
||||
})
|
||||
|
||||
const kiroRuntimeResetDisplay = computed(() => {
|
||||
const raw = usageInfo.value?.kiro_runtime_reset_at
|
||||
if (!raw) return ''
|
||||
const parsed = new Date(raw)
|
||||
if (Number.isNaN(parsed.getTime())) return ''
|
||||
return parsed.toLocaleString()
|
||||
})
|
||||
|
||||
const kiroQuotaResetDisplay = computed(() => {
|
||||
const raw = usageInfo.value?.kiro_quota_reset_at
|
||||
if (!raw) return ''
|
||||
const parsed = new Date(raw)
|
||||
if (Number.isNaN(parsed.getTime())) return ''
|
||||
return parsed.toLocaleString()
|
||||
})
|
||||
|
||||
const isKiroProfileError = computed(() => {
|
||||
if (!isKiroOAuth.value) return false
|
||||
const err = (usageInfo.value?.error || '').toLowerCase()
|
||||
return err.includes('profilearn is required') ||
|
||||
(err.includes('profile arn') && err.includes('required')) ||
|
||||
err.includes('profilearn') ||
|
||||
err.includes('listavailableprofiles')
|
||||
})
|
||||
|
||||
const isKiroUsageForbidden = computed(() => {
|
||||
if (!isKiroOAuth.value) return false
|
||||
return usageInfo.value?.error_code === 'forbidden' && !usageInfo.value?.needs_reauth && !isKiroProfileError.value
|
||||
})
|
||||
|
||||
const kiroQuotaState = computed(() => usageInfo.value?.kiro_quota_state || '')
|
||||
|
||||
const kiroStatusBadgeLabel = computed(() => {
|
||||
const runtimeState = usageInfo.value?.kiro_runtime_state
|
||||
if (runtimeState === 'suspended') return t('admin.accounts.forbidden')
|
||||
if (runtimeState === 'cooldown') return t('admin.accounts.status.rateLimited')
|
||||
if (usageInfo.value?.needs_reauth) return t('admin.accounts.needsReauth')
|
||||
if (isKiroProfileError.value) return t('admin.accounts.usageError')
|
||||
if (isKiroUsageForbidden.value) return t('admin.accounts.forbidden')
|
||||
if (kiroQuotaState.value === 'overage_active') return t('admin.accounts.status.overageActive')
|
||||
if (kiroQuotaState.value === 'credits_exhausted') return t('admin.accounts.status.creditsExhausted')
|
||||
if (kiroQuotaState.value === 'overage_exhausted') return t('admin.accounts.status.overageExhausted')
|
||||
return ''
|
||||
})
|
||||
|
||||
const kiroStatusToneClass = computed(() => {
|
||||
const runtimeState = usageInfo.value?.kiro_runtime_state
|
||||
if (runtimeState === 'suspended') return 'text-red-700 dark:text-red-300'
|
||||
if (runtimeState === 'cooldown') return 'text-amber-700 dark:text-amber-300'
|
||||
if (usageInfo.value?.needs_reauth) return 'text-orange-700 dark:text-orange-300'
|
||||
if (isKiroProfileError.value) return 'text-yellow-700 dark:text-yellow-300'
|
||||
if (isKiroUsageForbidden.value) return 'text-rose-700 dark:text-rose-300'
|
||||
if (kiroQuotaState.value === 'overage_active') return 'text-amber-700 dark:text-amber-300'
|
||||
if (kiroQuotaState.value === 'credits_exhausted' || kiroQuotaState.value === 'overage_exhausted') {
|
||||
return 'text-red-700 dark:text-red-300'
|
||||
}
|
||||
return 'text-gray-600 dark:text-gray-300'
|
||||
})
|
||||
|
||||
const kiroStatusHint = computed(() => {
|
||||
const runtimeState = usageInfo.value?.kiro_runtime_state
|
||||
if (runtimeState === 'cooldown' && kiroRuntimeResetDisplay.value) {
|
||||
return t('admin.accounts.status.rateLimitedUntil', { time: kiroRuntimeResetDisplay.value })
|
||||
}
|
||||
if (kiroQuotaState.value === 'credits_exhausted' && kiroQuotaResetDisplay.value) {
|
||||
return t('admin.accounts.status.creditsExhaustedUntil', { time: kiroQuotaResetDisplay.value })
|
||||
}
|
||||
if (kiroQuotaState.value === 'overage_exhausted' && kiroQuotaResetDisplay.value) {
|
||||
return t('admin.accounts.status.overageExhaustedUntil', { time: kiroQuotaResetDisplay.value })
|
||||
}
|
||||
return ''
|
||||
})
|
||||
|
||||
const kiroOverageSummary = computed(() => {
|
||||
const overage = usageInfo.value?.kiro_overage
|
||||
if (!overage) return ''
|
||||
const hasOverageCount = (overage.current_overages ?? 0) > 0
|
||||
const hasCharges = (overage.overage_charges ?? 0) > 0
|
||||
if (!hasOverageCount && !hasCharges) return ''
|
||||
|
||||
const parts: string[] = [t('admin.accounts.usageWindow.kiroOverage')]
|
||||
if (hasOverageCount) {
|
||||
parts.push(formatKiroAmount(overage.current_overages))
|
||||
}
|
||||
if (hasCharges) {
|
||||
const symbol = overage.currency_symbol || overage.currency_code || ''
|
||||
parts.push(`(${symbol}${(overage.overage_charges ?? 0).toFixed(2)})`)
|
||||
}
|
||||
return parts.join(' ')
|
||||
})
|
||||
|
||||
const loadUsage = async (options?: { source?: 'passive' | 'active'; bypassCache?: boolean }) => {
|
||||
if (!shouldFetchUsage.value) return
|
||||
|
||||
@@ -992,6 +1240,7 @@ const loadUsage = async (options?: { source?: 'passive' | 'active'; bypassCache?
|
||||
const cached = _usageCache.get(props.account.id)
|
||||
if (cached && Date.now() - cached.ts < USAGE_CACHE_TTL) {
|
||||
usageInfo.value = cached.data
|
||||
syncKiroUsageMeta(cached.data)
|
||||
loading.value = false
|
||||
return
|
||||
}
|
||||
@@ -1001,10 +1250,13 @@ const loadUsage = async (options?: { source?: 'passive' | 'active'; bypassCache?
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const fetchFn = () => adminAPI.accounts.getUsage(props.account.id, options?.source)
|
||||
const fetchFn = () => options?.source
|
||||
? adminAPI.accounts.getUsage(props.account.id, options.source)
|
||||
: adminAPI.accounts.getUsage(props.account.id)
|
||||
const result = await enqueueUsageRequest(props.account, fetchFn)
|
||||
if (!unmounted.value) {
|
||||
usageInfo.value = result
|
||||
syncKiroUsageMeta(result)
|
||||
_usageCache.set(props.account.id, { data: result, ts: Date.now() })
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -1070,7 +1322,10 @@ const attachVisibilityObserver = () => {
|
||||
const loadActiveUsage = async () => {
|
||||
activeQueryLoading.value = true
|
||||
try {
|
||||
usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active')
|
||||
const result = await adminAPI.accounts.getUsage(props.account.id, 'active')
|
||||
usageInfo.value = result
|
||||
syncKiroUsageMeta(result)
|
||||
_usageCache.set(props.account.id, { data: result, ts: Date.now() })
|
||||
} catch (e: any) {
|
||||
console.error('Failed to load active usage:', e)
|
||||
} finally {
|
||||
@@ -1166,7 +1421,7 @@ const formatKeyUserCost = computed(() => {
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
if (typeof window !== 'undefined') {
|
||||
if (typeof window !== 'undefined' && typeof window.matchMedia === 'function') {
|
||||
desktopViewportMediaQuery = window.matchMedia(desktopViewportQuery)
|
||||
isDesktopViewport.value = desktopViewportMediaQuery.matches
|
||||
desktopViewportListener = (event: MediaQueryListEvent) => {
|
||||
@@ -1180,15 +1435,17 @@ onMounted(() => {
|
||||
}
|
||||
|
||||
if (!shouldAutoLoadUsageOnMount.value) return
|
||||
const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined
|
||||
requestAutoLoad(source)
|
||||
requestAutoLoad(defaultUsageSource.value)
|
||||
})
|
||||
|
||||
watch(openAIUsageRefreshKey, (nextKey, prevKey) => {
|
||||
if (!prevKey || nextKey === prevKey) return
|
||||
if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return
|
||||
|
||||
requestAutoLoad()
|
||||
_usageCache.delete(props.account.id)
|
||||
loadUsage({ bypassCache: true }).catch((e) => {
|
||||
console.error('Failed to reload OpenAI usage after row refresh:', e)
|
||||
})
|
||||
})
|
||||
|
||||
watch(
|
||||
@@ -1197,9 +1454,8 @@ watch(
|
||||
if (nextToken === prevToken) return
|
||||
if (!shouldFetchUsage.value) return
|
||||
|
||||
const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined
|
||||
_usageCache.delete(props.account.id)
|
||||
loadUsage({ source, bypassCache: true }).catch((e) => {
|
||||
loadUsage({ source: manualRefreshUsageSource.value, bypassCache: true }).catch((e) => {
|
||||
console.error('Failed to refresh usage after manual refresh:', e)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -147,6 +147,19 @@
|
||||
<Icon name="cloud" size="sm" />
|
||||
Antigravity
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="form.platform = 'kiro'"
|
||||
:class="[
|
||||
'flex flex-1 items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
|
||||
form.platform === 'kiro'
|
||||
? 'bg-white text-amber-700 shadow-sm dark:bg-dark-600 dark:text-amber-300'
|
||||
: 'text-gray-600 hover:text-gray-900 dark:text-gray-400 dark:hover:text-gray-200'
|
||||
]"
|
||||
>
|
||||
<Icon name="sparkles" size="sm" />
|
||||
Kiro
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -774,6 +787,457 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Kiro account type selection -->
|
||||
<div v-if="form.platform === 'kiro'">
|
||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||
<div class="mt-2 grid grid-cols-2 gap-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'oauth-based'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'oauth-based'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', accountCategory === 'oauth-based' ? 'bg-amber-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
|
||||
<Icon name="key" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.types.oauth') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.types.kiroOauth') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'apikey'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'apikey'
|
||||
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
|
||||
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
|
||||
]"
|
||||
>
|
||||
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', accountCategory === 'apikey' ? 'bg-purple-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
|
||||
<Icon name="cloud" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
API Key
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.types.kiroApikey') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Kiro OAuth auth mode selection -->
|
||||
<div v-if="form.platform === 'kiro' && accountCategory === 'oauth-based'">
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.authModeTitle') }}</label>
|
||||
<div class="mt-2 grid grid-cols-1 gap-3 md:grid-cols-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroAccountType = 'oauth'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroAccountType === 'oauth'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', kiroAccountType === 'oauth' ? 'bg-amber-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
|
||||
<Icon name="key" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.oauthTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.oauthSubtitle') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroAccountType = 'idc'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroAccountType === 'idc'
|
||||
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
|
||||
: 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
|
||||
]"
|
||||
>
|
||||
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', kiroAccountType === 'idc' ? 'bg-blue-500 text-white' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
|
||||
<Icon name="cloud" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.idcTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.idcSubtitle') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroAccountType = 'import'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroAccountType === 'import'
|
||||
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
|
||||
: 'border-gray-200 hover:border-slate-300 dark:border-dark-600 dark:hover:border-slate-700'
|
||||
]"
|
||||
>
|
||||
<div :class="['flex h-8 w-8 shrink-0 items-center justify-center rounded-lg', kiroAccountType === 'import' ? 'bg-slate-700 text-white dark:bg-slate-500' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400']">
|
||||
<Icon name="download" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.importTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.importSubtitle') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="form.platform === 'kiro' && accountCategory === 'oauth-based' && kiroAccountType === 'oauth'" class="mt-4 space-y-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.oauthProviderTitle') }}</label>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.oauth.kiro.socialSubtitle') }}</span>
|
||||
</div>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroOAuthProvider = 'google'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroOAuthProvider === 'google'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroOAuthProvider === 'google'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="user" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.googleTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.googleDesc') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroOAuthProvider = 'github'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroOAuthProvider === 'github'
|
||||
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
|
||||
: 'border-gray-200 hover:border-slate-300 dark:border-dark-600 dark:hover:border-slate-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroOAuthProvider === 'github'
|
||||
? 'bg-slate-700 text-white dark:bg-slate-500'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="terminal" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.githubTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.githubDesc') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="form.platform === 'kiro' && accountCategory === 'oauth-based' && kiroAccountType === 'idc'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.startUrlLabel') }}</label>
|
||||
<input
|
||||
v-model="kiroIDCStartUrl"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.oauth.kiro.startUrlPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.regionLabel') }}</label>
|
||||
<input
|
||||
v-model="kiroIDCRegion"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.oauth.kiro.regionPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="form.platform === 'kiro' && accountCategory === 'apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
||||
<input
|
||||
v-model="apiKeyBaseUrl"
|
||||
type="text"
|
||||
required
|
||||
class="input"
|
||||
placeholder="https://your-kiro-upstream.example.com"
|
||||
/>
|
||||
<p class="input-hint">{{ baseUrlHint }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.apiKeyRequired') }}</label>
|
||||
<input
|
||||
v-model="apiKeyValue"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
placeholder="sk-..."
|
||||
/>
|
||||
<p class="input-hint">{{ apiKeyHint }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="form.platform === 'kiro' && accountCategory === 'apikey'" class="space-y-4">
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.customErrorCodesHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="customErrorCodesEnabled = !customErrorCodesEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
customErrorCodesEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
customErrorCodesEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="customErrorCodesEnabled" class="space-y-3">
|
||||
<div class="rounded-lg bg-amber-50 p-3 dark:bg-amber-900/20">
|
||||
<p class="text-xs text-amber-700 dark:text-amber-400">
|
||||
<Icon name="exclamationTriangle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.customErrorCodesWarning') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="code in commonErrorCodes"
|
||||
:key="code.value"
|
||||
type="button"
|
||||
@click="toggleErrorCode(code.value)"
|
||||
:class="[
|
||||
'rounded-lg px-3 py-1.5 text-sm font-medium transition-colors',
|
||||
selectedErrorCodes.includes(code.value)
|
||||
? 'bg-red-100 text-red-700 ring-1 ring-red-500 dark:bg-red-900/30 dark:text-red-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ code.value }} {{ code.label }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
v-model.number="customErrorCodeInput"
|
||||
type="number"
|
||||
min="100"
|
||||
max="599"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.enterErrorCode')"
|
||||
@keyup.enter="addCustomErrorCode"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="addCustomErrorCode"
|
||||
class="btn btn-secondary shrink-0"
|
||||
>
|
||||
{{ t('admin.accounts.add') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap gap-1.5">
|
||||
<span
|
||||
v-for="code in selectedErrorCodes.sort((a, b) => a - b)"
|
||||
:key="code"
|
||||
class="inline-flex items-center gap-1 rounded-full bg-red-100 px-2.5 py-0.5 text-sm font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400"
|
||||
>
|
||||
{{ code }}
|
||||
<button
|
||||
type="button"
|
||||
@click="removeErrorCode(code)"
|
||||
class="hover:text-red-900 dark:hover:text-red-300"
|
||||
>
|
||||
<Icon name="x" size="sm" :stroke-width="2" />
|
||||
</button>
|
||||
</span>
|
||||
<span v-if="selectedErrorCodes.length === 0" class="text-xs text-gray-400">
|
||||
{{ t('admin.accounts.noneSelectedUsesDefault') }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Kiro 只支持模型映射模式,不支持白名单模式 -->
|
||||
<div v-if="form.platform === 'kiro'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
<div>
|
||||
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
|
||||
<p class="text-xs text-purple-700 dark:text-purple-400">
|
||||
{{ t('admin.accounts.mapRequestModels') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-if="kiroModelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in kiroModelMappings"
|
||||
:key="getKiroModelMappingKey(mapping)"
|
||||
class="space-y-1"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
v-model="mapping.from"
|
||||
type="text"
|
||||
:class="[
|
||||
'input flex-1',
|
||||
!isValidWildcardPattern(mapping.from) ? 'border-red-500 dark:border-red-500' : ''
|
||||
]"
|
||||
:placeholder="t('admin.accounts.requestModel')"
|
||||
/>
|
||||
<svg class="h-4 w-4 flex-shrink-0 text-gray-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3" />
|
||||
</svg>
|
||||
<input
|
||||
v-model="mapping.to"
|
||||
type="text"
|
||||
:class="[
|
||||
'input flex-1',
|
||||
mapping.to.includes('*') ? 'border-red-500 dark:border-red-500' : ''
|
||||
]"
|
||||
:placeholder="t('admin.accounts.actualModel')"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeKiroModelMapping(index)"
|
||||
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
>
|
||||
<Icon name="x" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button type="button" @click="addKiroModelMapping" class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300">
|
||||
<svg class="mr-1 inline h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4" />
|
||||
</svg>
|
||||
{{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in kiroPresetMappings"
|
||||
:key="preset.label"
|
||||
type="button"
|
||||
@click="addKiroPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Upstream config (only for Antigravity upstream type) -->
|
||||
<div v-if="form.platform === 'antigravity' && antigravityAccountType === 'upstream'" class="space-y-4">
|
||||
<div>
|
||||
@@ -1009,7 +1473,7 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && form.platform !== 'kiro'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
||||
<input
|
||||
@@ -2750,7 +3214,23 @@
|
||||
|
||||
<!-- Step 2: OAuth Authorization -->
|
||||
<div v-else class="space-y-5">
|
||||
<div v-if="isKiroImportMode" class="space-y-4 rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-700 dark:bg-amber-900/20">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.tokenJsonLabel') }}</label>
|
||||
<textarea v-model="kiroTokenJson" rows="8" class="input font-mono text-xs" placeholder='{"accessToken":"...","refreshToken":"..."}'></textarea>
|
||||
<p class="input-hint">{{ t('admin.accounts.oauth.kiro.tokenJsonHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.deviceRegistrationLabel') }}</label>
|
||||
<textarea v-model="kiroDeviceRegistrationJson" rows="6" class="input font-mono text-xs" placeholder='{"clientId":"...","clientSecret":"..."}'></textarea>
|
||||
<p class="input-hint">{{ t('admin.accounts.oauth.kiro.deviceRegistrationHint') }}</p>
|
||||
</div>
|
||||
<div v-if="currentOAuthError" class="rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30">
|
||||
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">{{ currentOAuthError }}</p>
|
||||
</div>
|
||||
</div>
|
||||
<OAuthAuthorizationFlow
|
||||
v-else
|
||||
ref="oauthFlowRef"
|
||||
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
|
||||
:auth-url="currentAuthUrl"
|
||||
@@ -2824,7 +3304,16 @@
|
||||
{{ t('common.back') }}
|
||||
</button>
|
||||
<button
|
||||
v-if="isManualInputMethod"
|
||||
v-if="isKiroImportMode"
|
||||
type="button"
|
||||
:disabled="currentOAuthLoading || !kiroTokenJson.trim()"
|
||||
class="btn btn-primary"
|
||||
@click="handleKiroImport"
|
||||
>
|
||||
{{ currentOAuthLoading ? t('admin.accounts.creating') : t('common.create') }}
|
||||
</button>
|
||||
<button
|
||||
v-else-if="isManualInputMethod"
|
||||
type="button"
|
||||
:disabled="!canExchangeCode"
|
||||
class="btn btn-primary"
|
||||
@@ -3101,6 +3590,7 @@ import {
|
||||
commonErrorCodes,
|
||||
buildModelMappingObject,
|
||||
fetchAntigravityDefaultMappings,
|
||||
fetchKiroDefaultMappings,
|
||||
isValidWildcardPattern
|
||||
} from '@/composables/useModelWhitelist'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
@@ -3114,6 +3604,7 @@ import {
|
||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
|
||||
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
|
||||
import { useKiroOAuth } from '@/composables/useKiroOAuth'
|
||||
import type {
|
||||
Proxy,
|
||||
AdminGroup,
|
||||
@@ -3151,6 +3642,8 @@ import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||
interface OAuthFlowExposed {
|
||||
authCode: string
|
||||
oauthState: string
|
||||
oauthCallbackPath: string
|
||||
oauthLoginOption: string
|
||||
projectId: string
|
||||
sessionKey: string
|
||||
refreshToken: string
|
||||
@@ -3167,6 +3660,11 @@ const oauthStepTitle = computed(() => {
|
||||
if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title')
|
||||
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
|
||||
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
|
||||
if (form.platform === 'kiro') {
|
||||
return kiroAccountType.value === 'import'
|
||||
? t('admin.accounts.oauth.kiro.importDialogTitle')
|
||||
: t('admin.accounts.oauth.kiro.title')
|
||||
}
|
||||
return t('admin.accounts.oauth.title')
|
||||
})
|
||||
|
||||
@@ -3174,12 +3672,14 @@ const oauthStepTitle = computed(() => {
|
||||
const baseUrlHint = computed(() => {
|
||||
if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint')
|
||||
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
|
||||
if (form.platform === 'kiro') return t('admin.accounts.kiro.baseUrlHint')
|
||||
return t('admin.accounts.baseUrlHint')
|
||||
})
|
||||
|
||||
const apiKeyHint = computed(() => {
|
||||
if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint')
|
||||
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
|
||||
if (form.platform === 'kiro') return t('admin.accounts.kiro.apiKeyHint')
|
||||
return t('admin.accounts.apiKeyHint')
|
||||
})
|
||||
|
||||
@@ -3202,12 +3702,14 @@ const oauth = useAccountOAuth() // For Anthropic OAuth
|
||||
const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth
|
||||
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
|
||||
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
|
||||
const kiroOAuth = useKiroOAuth() // For Kiro OAuth / IDC
|
||||
|
||||
// Computed: current OAuth state for template binding
|
||||
const currentAuthUrl = computed(() => {
|
||||
if (form.platform === 'openai') return openaiOAuth.authUrl.value
|
||||
if (form.platform === 'gemini') return geminiOAuth.authUrl.value
|
||||
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
|
||||
if (form.platform === 'kiro') return kiroOAuth.authUrl.value
|
||||
return oauth.authUrl.value
|
||||
})
|
||||
|
||||
@@ -3215,6 +3717,7 @@ const currentSessionId = computed(() => {
|
||||
if (form.platform === 'openai') return openaiOAuth.sessionId.value
|
||||
if (form.platform === 'gemini') return geminiOAuth.sessionId.value
|
||||
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
|
||||
if (form.platform === 'kiro') return kiroOAuth.sessionId.value
|
||||
return oauth.sessionId.value
|
||||
})
|
||||
|
||||
@@ -3222,6 +3725,7 @@ const currentOAuthLoading = computed(() => {
|
||||
if (form.platform === 'openai') return openaiOAuth.loading.value
|
||||
if (form.platform === 'gemini') return geminiOAuth.loading.value
|
||||
if (form.platform === 'antigravity') return antigravityOAuth.loading.value
|
||||
if (form.platform === 'kiro') return kiroOAuth.loading.value
|
||||
return oauth.loading.value
|
||||
})
|
||||
|
||||
@@ -3229,6 +3733,7 @@ const currentOAuthError = computed(() => {
|
||||
if (form.platform === 'openai') return openaiOAuth.error.value
|
||||
if (form.platform === 'gemini') return geminiOAuth.error.value
|
||||
if (form.platform === 'antigravity') return antigravityOAuth.error.value
|
||||
if (form.platform === 'kiro') return kiroOAuth.error.value
|
||||
return oauth.error.value
|
||||
})
|
||||
|
||||
@@ -3307,6 +3812,14 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist'
|
||||
const antigravityWhitelistModels = ref<string[]>([])
|
||||
const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
||||
const kiroAccountType = ref<'oauth' | 'idc' | 'import'>('oauth')
|
||||
const kiroOAuthProvider = ref<'google' | 'github'>('google')
|
||||
const kiroIDCStartUrl = ref('https://view.awsapps.com/start')
|
||||
const kiroIDCRegion = ref('us-east-1')
|
||||
const kiroTokenJson = ref('')
|
||||
const kiroDeviceRegistrationJson = ref('')
|
||||
const kiroModelMappings = ref<ModelMapping[]>([])
|
||||
const kiroPresetMappings = computed(() => getPresetMappingsByPlatform('kiro'))
|
||||
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
|
||||
|
||||
// Bedrock credentials
|
||||
@@ -3328,6 +3841,7 @@ const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||
const getOpenAICompactModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-openai-compact-model-mapping')
|
||||
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-antigravity-model-mapping')
|
||||
const getKiroModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-kiro-model-mapping')
|
||||
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('create-temp-unsched-rule')
|
||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
||||
const geminiAIStudioOAuthEnabled = ref(false)
|
||||
@@ -3513,6 +4027,8 @@ const isOAuthFlow = computed(() => {
|
||||
return accountCategory.value === 'oauth-based'
|
||||
})
|
||||
|
||||
const isKiroImportMode = computed(() => form.platform === 'kiro' && kiroAccountType.value === 'import')
|
||||
|
||||
const isManualInputMethod = computed(() => {
|
||||
return oauthFlowRef.value?.inputMethod === 'manual'
|
||||
})
|
||||
@@ -3535,6 +4051,9 @@ const canExchangeCode = computed(() => {
|
||||
if (form.platform === 'antigravity') {
|
||||
return authCode.trim() && antigravityOAuth.sessionId.value && !antigravityOAuth.loading.value
|
||||
}
|
||||
if (form.platform === 'kiro') {
|
||||
return authCode.trim() && kiroOAuth.sessionId.value && !kiroOAuth.loading.value
|
||||
}
|
||||
return authCode.trim() && oauth.sessionId.value && !oauth.loading.value
|
||||
})
|
||||
|
||||
@@ -3556,10 +4075,15 @@ watch(
|
||||
antigravityModelMappings.value = [...mappings]
|
||||
})
|
||||
antigravityWhitelistModels.value = []
|
||||
} else if (form.platform === 'kiro') {
|
||||
fetchKiroDefaultMappings().then(mappings => {
|
||||
kiroModelMappings.value = [...mappings]
|
||||
})
|
||||
} else {
|
||||
antigravityWhitelistModels.value = []
|
||||
antigravityModelMappings.value = []
|
||||
antigravityModelRestrictionMode.value = 'mapping'
|
||||
kiroModelMappings.value = []
|
||||
}
|
||||
} else {
|
||||
resetForm()
|
||||
@@ -3576,6 +4100,10 @@ watch(
|
||||
form.type = 'apikey'
|
||||
return
|
||||
}
|
||||
if (form.platform === 'kiro') {
|
||||
form.type = category === 'oauth-based' ? 'oauth' : 'apikey'
|
||||
return
|
||||
}
|
||||
// Bedrock 类型
|
||||
if (form.platform === 'anthropic' && category === 'bedrock') {
|
||||
form.type = 'bedrock' as AccountType
|
||||
@@ -3602,6 +4130,8 @@ watch(
|
||||
? 'https://api.openai.com'
|
||||
: newPlatform === 'gemini'
|
||||
? 'https://generativelanguage.googleapis.com'
|
||||
: newPlatform === 'kiro'
|
||||
? ''
|
||||
: 'https://api.anthropic.com'
|
||||
// Clear model-related settings
|
||||
allowedModels.value = []
|
||||
@@ -3615,11 +4145,21 @@ watch(
|
||||
antigravityWhitelistModels.value = []
|
||||
accountCategory.value = 'oauth-based'
|
||||
antigravityAccountType.value = 'oauth'
|
||||
} else if (newPlatform === 'kiro') {
|
||||
fetchKiroDefaultMappings().then(mappings => {
|
||||
kiroModelMappings.value = [...mappings]
|
||||
})
|
||||
accountCategory.value = 'oauth-based'
|
||||
kiroAccountType.value = 'oauth'
|
||||
kiroOAuthProvider.value = 'google'
|
||||
apiKeyBaseUrl.value = ''
|
||||
apiKeyValue.value = ''
|
||||
} else {
|
||||
allowOverages.value = false
|
||||
antigravityWhitelistModels.value = []
|
||||
antigravityModelMappings.value = []
|
||||
antigravityModelRestrictionMode.value = 'mapping'
|
||||
kiroModelMappings.value = []
|
||||
}
|
||||
if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
|
||||
accountCategory.value = 'oauth-based'
|
||||
@@ -3659,6 +4199,7 @@ watch(
|
||||
|
||||
geminiOAuth.resetState()
|
||||
antigravityOAuth.resetState()
|
||||
kiroOAuth.resetState()
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3760,6 +4301,22 @@ const addAntigravityPresetMapping = (from: string, to: string) => {
|
||||
antigravityModelMappings.value.push({ from, to })
|
||||
}
|
||||
|
||||
const addKiroModelMapping = () => {
|
||||
kiroModelMappings.value.push({ from: '', to: '' })
|
||||
}
|
||||
|
||||
const removeKiroModelMapping = (index: number) => {
|
||||
kiroModelMappings.value.splice(index, 1)
|
||||
}
|
||||
|
||||
const addKiroPresetMapping = (from: string, to: string) => {
|
||||
if (kiroModelMappings.value.some((m) => m.from === from)) {
|
||||
appStore.showInfo(t('admin.accounts.mappingExists', { model: from }))
|
||||
return
|
||||
}
|
||||
kiroModelMappings.value.push({ from, to })
|
||||
}
|
||||
|
||||
// Error code toggle helper
|
||||
const toggleErrorCode = (code: number) => {
|
||||
const index = selectedErrorCodes.value.indexOf(code)
|
||||
@@ -4033,6 +4590,15 @@ const resetForm = () => {
|
||||
fetchAntigravityDefaultMappings().then(mappings => {
|
||||
antigravityModelMappings.value = [...mappings]
|
||||
})
|
||||
kiroAccountType.value = 'oauth'
|
||||
kiroOAuthProvider.value = 'google'
|
||||
kiroIDCStartUrl.value = 'https://view.awsapps.com/start'
|
||||
kiroIDCRegion.value = 'us-east-1'
|
||||
kiroTokenJson.value = ''
|
||||
kiroDeviceRegistrationJson.value = ''
|
||||
fetchKiroDefaultMappings().then(mappings => {
|
||||
kiroModelMappings.value = [...mappings]
|
||||
})
|
||||
poolModeEnabled.value = false
|
||||
poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
customErrorCodesEnabled.value = false
|
||||
@@ -4084,6 +4650,7 @@ const resetForm = () => {
|
||||
openaiOAuth.resetState()
|
||||
geminiOAuth.resetState()
|
||||
antigravityOAuth.resetState()
|
||||
kiroOAuth.resetState()
|
||||
oauthFlowRef.value?.reset()
|
||||
antigravityMixedChannelConfirmed.value = false
|
||||
clearMixedChannelDialog()
|
||||
@@ -4379,6 +4946,45 @@ const handleSubmit = async () => {
|
||||
return
|
||||
}
|
||||
|
||||
// For Kiro API key type, create directly
|
||||
if (form.platform === 'kiro' && accountCategory.value === 'apikey') {
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!apiKeyBaseUrl.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl'))
|
||||
return
|
||||
}
|
||||
if (!apiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
base_url: apiKeyBaseUrl.value.trim(),
|
||||
api_key: apiKeyValue.value.trim()
|
||||
}
|
||||
|
||||
const modelMapping = buildModelMappingObject('mapping', [], kiroModelMappings.value)
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
if (poolModeEnabled.value) {
|
||||
credentials.pool_mode = true
|
||||
credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
}
|
||||
|
||||
if (customErrorCodesEnabled.value) {
|
||||
credentials.custom_error_codes_enabled = true
|
||||
credentials.custom_error_codes = [...selectedErrorCodes.value]
|
||||
}
|
||||
|
||||
await createAccountAndFinish('kiro', 'apikey', credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For apikey type, create directly
|
||||
if (!apiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
|
||||
@@ -4450,6 +5056,7 @@ const goBackToBasicInfo = () => {
|
||||
openaiOAuth.resetState()
|
||||
geminiOAuth.resetState()
|
||||
antigravityOAuth.resetState()
|
||||
kiroOAuth.resetState()
|
||||
oauthFlowRef.value?.reset()
|
||||
}
|
||||
|
||||
@@ -4465,6 +5072,19 @@ const handleGenerateUrl = async () => {
|
||||
)
|
||||
} else if (form.platform === 'antigravity') {
|
||||
await antigravityOAuth.generateAuthUrl(form.proxy_id)
|
||||
} else if (form.platform === 'kiro') {
|
||||
if (kiroAccountType.value === 'idc') {
|
||||
await kiroOAuth.generateIDCAuthUrl({
|
||||
proxyId: form.proxy_id,
|
||||
startUrl: kiroIDCStartUrl.value.trim() || undefined,
|
||||
region: kiroIDCRegion.value.trim() || undefined
|
||||
})
|
||||
} else {
|
||||
await kiroOAuth.generateAuthUrl(
|
||||
form.proxy_id,
|
||||
kiroOAuthProvider.value === 'github' ? 'Github' : 'Google'
|
||||
)
|
||||
}
|
||||
} else {
|
||||
await oauth.generateAuthUrl(addMethod.value, form.proxy_id)
|
||||
}
|
||||
@@ -5035,6 +5655,50 @@ const handleAntigravityExchange = async (authCode: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
const buildKiroCredentials = (tokenInfo: Parameters<typeof kiroOAuth.buildCredentials>[0]) => {
|
||||
const credentials = kiroOAuth.buildCredentials(tokenInfo)
|
||||
const modelMapping = buildModelMappingObject('mapping', [], kiroModelMappings.value)
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
return credentials
|
||||
}
|
||||
|
||||
const handleKiroExchange = async (authCode: string) => {
|
||||
if (!authCode.trim() || !kiroOAuth.sessionId.value) return
|
||||
|
||||
kiroOAuth.loading.value = true
|
||||
kiroOAuth.error.value = ''
|
||||
|
||||
try {
|
||||
const stateFromInput = oauthFlowRef.value?.oauthState || ''
|
||||
const stateToUse = stateFromInput || kiroOAuth.state.value
|
||||
if (!stateToUse) {
|
||||
kiroOAuth.error.value = t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(kiroOAuth.error.value)
|
||||
return
|
||||
}
|
||||
|
||||
const tokenInfo = await kiroOAuth.exchangeAuthCode({
|
||||
code: authCode.trim(),
|
||||
sessionId: kiroOAuth.sessionId.value,
|
||||
state: stateToUse,
|
||||
callbackPath: oauthFlowRef.value?.oauthCallbackPath || '',
|
||||
loginOption: oauthFlowRef.value?.oauthLoginOption || '',
|
||||
proxyId: form.proxy_id
|
||||
})
|
||||
if (!tokenInfo) return
|
||||
|
||||
const credentials = buildKiroCredentials(tokenInfo)
|
||||
await createAccountAndFinish('kiro', 'oauth', credentials)
|
||||
} catch (error: any) {
|
||||
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(kiroOAuth.error.value)
|
||||
} finally {
|
||||
kiroOAuth.loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic OAuth 授权码兑换
|
||||
const handleAnthropicExchange = async (authCode: string) => {
|
||||
if (!authCode.trim() || !oauth.sessionId.value) return
|
||||
@@ -5133,6 +5797,8 @@ const handleExchangeCode = async () => {
|
||||
return handleOpenAIExchange(authCode)
|
||||
case 'gemini':
|
||||
return handleGeminiExchange(authCode)
|
||||
case 'kiro':
|
||||
return handleKiroExchange(authCode)
|
||||
case 'antigravity':
|
||||
return handleAntigravityExchange(authCode)
|
||||
default:
|
||||
@@ -5140,6 +5806,24 @@ const handleExchangeCode = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
const handleKiroImport = async () => {
|
||||
if (!isKiroImportMode.value || !kiroTokenJson.value.trim()) return
|
||||
|
||||
const tokenInfo = await kiroOAuth.importToken(
|
||||
kiroTokenJson.value,
|
||||
kiroDeviceRegistrationJson.value || undefined
|
||||
)
|
||||
if (!tokenInfo) return
|
||||
|
||||
try {
|
||||
const credentials = buildKiroCredentials(tokenInfo)
|
||||
await createAccountAndFinish('kiro', 'oauth', credentials)
|
||||
} catch (error: any) {
|
||||
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(kiroOAuth.error.value)
|
||||
}
|
||||
}
|
||||
|
||||
const handleCookieAuth = async (sessionKey: string) => {
|
||||
oauth.loading.value = true
|
||||
oauth.error.value = ''
|
||||
|
||||
@@ -39,6 +39,8 @@
|
||||
? 'https://api.openai.com'
|
||||
: account.platform === 'gemini'
|
||||
? 'https://generativelanguage.googleapis.com'
|
||||
: account.platform === 'kiro'
|
||||
? 'https://your-kiro-upstream.example.com'
|
||||
: account.platform === 'antigravity'
|
||||
? 'https://cloudcode-pa.googleapis.com'
|
||||
: 'https://api.anthropic.com'
|
||||
@@ -61,6 +63,8 @@
|
||||
? 'sk-proj-...'
|
||||
: account.platform === 'gemini'
|
||||
? 'AIza...'
|
||||
: account.platform === 'kiro'
|
||||
? 'sk-...'
|
||||
: account.platform === 'antigravity'
|
||||
? 'sk-...'
|
||||
: 'sk-ant-...'
|
||||
@@ -69,8 +73,93 @@
|
||||
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section (不适用于 Antigravity) -->
|
||||
<div v-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div v-if="account.platform === 'kiro'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
|
||||
<p class="text-xs text-purple-700 dark:text-purple-400">
|
||||
{{ t('admin.accounts.mapRequestModels') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="getModelMappingKey(mapping)"
|
||||
class="space-y-1"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
v-model="mapping.from"
|
||||
type="text"
|
||||
:class="[
|
||||
'input flex-1',
|
||||
!isValidWildcardPattern(mapping.from) ? 'border-red-500 dark:border-red-500' : ''
|
||||
]"
|
||||
:placeholder="t('admin.accounts.requestModel')"
|
||||
/>
|
||||
<svg class="h-4 w-4 flex-shrink-0 text-gray-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3" />
|
||||
</svg>
|
||||
<input
|
||||
v-model="mapping.to"
|
||||
type="text"
|
||||
:class="[
|
||||
'input flex-1',
|
||||
mapping.to.includes('*') ? 'border-red-500 dark:border-red-500' : ''
|
||||
]"
|
||||
:placeholder="t('admin.accounts.actualModel')"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeModelMapping(index)"
|
||||
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
>
|
||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<p v-if="!isValidWildcardPattern(mapping.from)" class="text-xs text-red-500">
|
||||
{{ t('admin.accounts.wildcardOnlyAtEnd') }}
|
||||
</p>
|
||||
<p v-if="mapping.to.includes('*')" class="text-xs text-red-500">
|
||||
{{ t('admin.accounts.targetNoWildcard') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="addModelMapping"
|
||||
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
|
||||
>
|
||||
<svg class="mr-1 inline h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4" />
|
||||
</svg>
|
||||
{{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in presetMappings"
|
||||
:key="preset.label"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section (不适用于 Antigravity / Kiro) -->
|
||||
<div v-else-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<div
|
||||
@@ -407,9 +496,9 @@
|
||||
|
||||
</div>
|
||||
|
||||
<!-- OpenAI OAuth Model Mapping (OAuth 类型没有 apikey 容器,需要独立的模型映射区域) -->
|
||||
<!-- OpenAI / Kiro OAuth Model Restriction (OAuth 类型没有 apikey 容器,需要独立区域) -->
|
||||
<div
|
||||
v-if="account.platform === 'openai' && account.type === 'oauth'"
|
||||
v-if="(account.platform === 'openai' || account.platform === 'kiro') && account.type === 'oauth'"
|
||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||
>
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
@@ -423,6 +512,82 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<template v-else-if="account.platform === 'kiro'">
|
||||
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
|
||||
<p class="text-xs text-purple-700 dark:text-purple-400">
|
||||
{{ t('admin.accounts.mapRequestModels') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="'oauth-' + getModelMappingKey(mapping)"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
v-model="mapping.from"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.requestModel')"
|
||||
/>
|
||||
<svg
|
||||
class="h-4 w-4 flex-shrink-0 text-gray-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M14 5l7 7m0 0l-7 7m7-7H3"
|
||||
/>
|
||||
</svg>
|
||||
<input
|
||||
v-model="mapping.to"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.actualModel')"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeModelMapping(index)"
|
||||
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
>
|
||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="addModelMapping"
|
||||
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
|
||||
>
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in presetMappings"
|
||||
:key="'oauth-' + preset.label"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template v-else>
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
@@ -2205,6 +2370,7 @@ import {
|
||||
resolveOpenAIWSModeFromExtra
|
||||
} from '@/utils/openaiWsMode'
|
||||
import {
|
||||
fetchKiroDefaultMappings,
|
||||
getPresetMappingsByPlatform,
|
||||
commonErrorCodes,
|
||||
buildModelMappingObject,
|
||||
@@ -2233,11 +2399,13 @@ const baseUrlHint = computed(() => {
|
||||
if (!props.account) return t('admin.accounts.baseUrlHint')
|
||||
if (props.account.platform === 'openai') return t('admin.accounts.openai.baseUrlHint')
|
||||
if (props.account.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
|
||||
if (props.account.platform === 'kiro') return t('admin.accounts.kiro.baseUrlHint')
|
||||
return t('admin.accounts.baseUrlHint')
|
||||
})
|
||||
|
||||
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
||||
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
|
||||
const isKiroOAuthAccount = computed(() => props.account?.platform === 'kiro' && props.account?.type === 'oauth')
|
||||
|
||||
// Model mapping type
|
||||
interface ModelMapping {
|
||||
@@ -2295,6 +2463,21 @@ const getOpenAICompactModelMappingKey = createStableObjectKeyResolver<ModelMappi
|
||||
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
|
||||
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
|
||||
|
||||
const applyKiroModelMappings = (entries: Array<[string, string]>) => {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
|
||||
const loadDefaultKiroModelMappings = () => {
|
||||
fetchKiroDefaultMappings().then(mappings => {
|
||||
if (!isKiroOAuthAccount.value) return
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = mappings.map(({ from, to }) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
})
|
||||
}
|
||||
|
||||
const showMixedChannelWarning = ref(false)
|
||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
|
||||
null
|
||||
@@ -2486,6 +2669,7 @@ const tempUnschedPresets = computed(() => [
|
||||
const defaultBaseUrl = computed(() => {
|
||||
if (props.account?.platform === 'openai') return 'https://api.openai.com'
|
||||
if (props.account?.platform === 'gemini') return 'https://generativelanguage.googleapis.com'
|
||||
if (props.account?.platform === 'kiro') return ''
|
||||
return 'https://api.anthropic.com'
|
||||
})
|
||||
|
||||
@@ -2709,6 +2893,8 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
? 'https://api.openai.com'
|
||||
: newAccount.platform === 'gemini'
|
||||
? 'https://generativelanguage.googleapis.com'
|
||||
: newAccount.platform === 'kiro'
|
||||
? ''
|
||||
: 'https://api.anthropic.com'
|
||||
editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl
|
||||
|
||||
@@ -2717,6 +2903,11 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
const entries = Object.entries(existingMappings)
|
||||
|
||||
if (newAccount.platform === 'kiro') {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
} else {
|
||||
// Detect if this is whitelist mode (all from === to) or mapping mode
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
|
||||
@@ -2731,6 +2922,16 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
}
|
||||
} else if (newAccount.platform === 'kiro') {
|
||||
fetchKiroDefaultMappings().then(mappings => {
|
||||
if (props.account?.id !== newAccount.id || props.account?.type !== 'apikey' || props.account?.platform !== 'kiro') {
|
||||
return
|
||||
}
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = mappings.map(({ from, to }) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
})
|
||||
} else {
|
||||
// No mappings: default to whitelist mode with empty selection (allow all)
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
@@ -2785,15 +2986,18 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
const entries = Object.entries(existingMappings)
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
if (isWhitelistMode) {
|
||||
// Whitelist mode: populate allowedModels
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = entries.map(([from]) => from)
|
||||
modelMappings.value = []
|
||||
} else {
|
||||
// Mapping mode: populate modelMappings
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
// No mappings: default to whitelist mode with empty selection (allow all)
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
@@ -2835,8 +3039,16 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
: 'https://api.anthropic.com'
|
||||
editBaseUrl.value = platformDefaultUrl
|
||||
|
||||
// Load model mappings for OpenAI OAuth accounts
|
||||
if (newAccount.platform === 'openai' && newAccount.credentials) {
|
||||
// Load model mappings for OpenAI/Kiro OAuth accounts
|
||||
if (newAccount.platform === 'kiro' && newAccount.credentials) {
|
||||
const oauthCredentials = newAccount.credentials as Record<string, unknown>
|
||||
const existingMappings = oauthCredentials.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object' && Object.keys(existingMappings).length > 0) {
|
||||
applyKiroModelMappings(Object.entries(existingMappings))
|
||||
} else {
|
||||
loadDefaultKiroModelMappings()
|
||||
}
|
||||
} else if (newAccount.platform === 'openai' && newAccount.credentials) {
|
||||
const oauthCredentials = newAccount.credentials as Record<string, unknown>
|
||||
const existingMappings = oauthCredentials.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
@@ -2892,6 +3104,7 @@ watch(
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
|
||||
// Model mapping helpers
|
||||
const addModelMapping = () => {
|
||||
modelMappings.value.push({ from: '', to: '' })
|
||||
@@ -3333,9 +3546,16 @@ const handleSubmit = async () => {
|
||||
// For apikey type, handle credentials update
|
||||
if (props.account.type === 'apikey') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newBaseUrl = editBaseUrl.value.trim() || defaultBaseUrl.value
|
||||
const newBaseUrl = props.account.platform === 'kiro'
|
||||
? editBaseUrl.value.trim()
|
||||
: (editBaseUrl.value.trim() || defaultBaseUrl.value)
|
||||
const shouldApplyModelMapping = !(props.account.platform === 'openai' && openaiPassthroughEnabled.value)
|
||||
|
||||
if (!newBaseUrl) {
|
||||
appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl'))
|
||||
return
|
||||
}
|
||||
|
||||
// Always update credentials for apikey type to handle model mapping changes
|
||||
const newCredentials: Record<string, unknown> = {
|
||||
...currentCredentials,
|
||||
@@ -3356,7 +3576,11 @@ const handleSubmit = async () => {
|
||||
|
||||
// Add model mapping if configured(OpenAI 开启自动透传时保留现有映射,不再编辑)
|
||||
if (shouldApplyModelMapping) {
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
const modelMapping = buildModelMappingObject(
|
||||
props.account.platform === 'kiro' ? 'mapping' : modelRestrictionMode.value,
|
||||
props.account.platform === 'kiro' ? [] : allowedModels.value,
|
||||
modelMappings.value
|
||||
)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
} else {
|
||||
@@ -3494,7 +3718,7 @@ const handleSubmit = async () => {
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
const modelMapping = buildModelMappingObject('mapping', [], modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
} else {
|
||||
@@ -3548,6 +3772,22 @@ const handleSubmit = async () => {
|
||||
updatePayload.credentials = newCredentials
|
||||
}
|
||||
|
||||
// Kiro OAuth: persist model mapping to credentials
|
||||
if (props.account.platform === 'kiro' && props.account.type === 'oauth') {
|
||||
const currentCredentials = (updatePayload.credentials as Record<string, unknown>) ||
|
||||
((props.account.credentials as Record<string, unknown>) || {})
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
const modelMapping = buildModelMappingObject('mapping', [], modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
} else {
|
||||
delete newCredentials.model_mapping
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
}
|
||||
|
||||
// Antigravity: persist model mapping to credentials (applies to all antigravity types)
|
||||
// Antigravity 只支持映射模式
|
||||
if (props.account.platform === 'antigravity') {
|
||||
|
||||
@@ -696,6 +696,7 @@ const getOAuthKey = (key: string) => {
|
||||
if (props.platform === 'openai') return `admin.accounts.oauth.openai.${key}`
|
||||
if (props.platform === 'gemini') return `admin.accounts.oauth.gemini.${key}`
|
||||
if (props.platform === 'antigravity') return `admin.accounts.oauth.antigravity.${key}`
|
||||
if (props.platform === 'kiro') return `admin.accounts.oauth.kiro.${key}`
|
||||
return `admin.accounts.oauth.${key}`
|
||||
}
|
||||
|
||||
@@ -726,6 +727,8 @@ const sessionTokenInput = ref('')
|
||||
const codexSessionInput = ref('')
|
||||
const showHelpDialog = ref(false)
|
||||
const oauthState = ref('')
|
||||
const oauthCallbackPath = ref('')
|
||||
const oauthLoginOption = ref('')
|
||||
const projectId = ref('')
|
||||
|
||||
// Computed: show method selection when either cookie or refresh token option is enabled
|
||||
@@ -765,10 +768,10 @@ watch(inputMethod, (newVal) => {
|
||||
emit('update:inputMethod', newVal)
|
||||
})
|
||||
|
||||
// Auto-extract code from callback URL (OpenAI/Gemini/Antigravity)
|
||||
// Auto-extract code from callback URL (OpenAI/Gemini/Antigravity/Kiro)
|
||||
// e.g., http://localhost:8085/callback?code=xxx...&state=...
|
||||
watch(authCodeInput, (newVal) => {
|
||||
if (props.platform !== 'openai' && props.platform !== 'gemini' && props.platform !== 'antigravity') return
|
||||
if (props.platform !== 'openai' && props.platform !== 'gemini' && props.platform !== 'antigravity' && props.platform !== 'kiro') return
|
||||
|
||||
const trimmed = newVal.trim()
|
||||
// Check if it looks like a URL with code parameter
|
||||
@@ -778,7 +781,11 @@ watch(authCodeInput, (newVal) => {
|
||||
const url = new URL(trimmed)
|
||||
const code = url.searchParams.get('code')
|
||||
const stateParam = url.searchParams.get('state')
|
||||
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
|
||||
if (props.platform === 'kiro') {
|
||||
oauthCallbackPath.value = url.pathname || ''
|
||||
oauthLoginOption.value = url.searchParams.get('login_option') || ''
|
||||
}
|
||||
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity' || props.platform === 'kiro') && stateParam) {
|
||||
oauthState.value = stateParam
|
||||
}
|
||||
if (code && code !== trimmed) {
|
||||
@@ -789,7 +796,13 @@ watch(authCodeInput, (newVal) => {
|
||||
// If URL parsing fails, try regex extraction
|
||||
const match = trimmed.match(/[?&]code=([^&]+)/)
|
||||
const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
|
||||
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
|
||||
if (props.platform === 'kiro') {
|
||||
const pathMatch = trimmed.match(/^https?:\/\/[^/]+(\/[^?]*)/)
|
||||
oauthCallbackPath.value = pathMatch?.[1] || oauthCallbackPath.value
|
||||
const loginOptionMatch = trimmed.match(/[?&]login_option=([^&]+)/)
|
||||
oauthLoginOption.value = loginOptionMatch?.[1] || oauthLoginOption.value
|
||||
}
|
||||
if ((props.platform === 'openai' || props.platform === 'gemini' || props.platform === 'antigravity' || props.platform === 'kiro') && stateMatch && stateMatch[1]) {
|
||||
oauthState.value = stateMatch[1]
|
||||
}
|
||||
if (match && match[1] && match[1] !== trimmed) {
|
||||
@@ -841,6 +854,8 @@ const handleImportCodexSession = () => {
|
||||
defineExpose({
|
||||
authCode: authCodeInput,
|
||||
oauthState,
|
||||
oauthCallbackPath,
|
||||
oauthLoginOption,
|
||||
projectId,
|
||||
sessionKey: sessionKeyInput,
|
||||
refreshToken: refreshTokenInput,
|
||||
@@ -850,6 +865,8 @@ defineExpose({
|
||||
reset: () => {
|
||||
authCodeInput.value = ''
|
||||
oauthState.value = ''
|
||||
oauthCallbackPath.value = ''
|
||||
oauthLoginOption.value = ''
|
||||
projectId.value = ''
|
||||
sessionKeyInput.value = ''
|
||||
refreshTokenInput.value = ''
|
||||
|
||||
@@ -13,6 +13,15 @@ vi.mock('vue-i18n', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
i18n: {
|
||||
global: {
|
||||
t: (key: string) => key
|
||||
}
|
||||
},
|
||||
getLocale: () => 'en'
|
||||
}))
|
||||
|
||||
function makeAccount(overrides: Partial<Account>): Account {
|
||||
return {
|
||||
id: 1,
|
||||
@@ -35,6 +44,12 @@ function makeAccount(overrides: Partial<Account>): Account {
|
||||
overload_until: null,
|
||||
temp_unschedulable_until: null,
|
||||
temp_unschedulable_reason: null,
|
||||
kiro_quota_state: null,
|
||||
kiro_quota_reason: null,
|
||||
kiro_quota_reset_at: null,
|
||||
kiro_runtime_state: null,
|
||||
kiro_runtime_reason: null,
|
||||
kiro_runtime_reset_at: null,
|
||||
session_window_start: null,
|
||||
session_window_end: null,
|
||||
session_window_status: null,
|
||||
@@ -159,4 +174,96 @@ describe('AccountStatusIndicator', () => {
|
||||
// AICredits 积分耗尽状态应显示
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
|
||||
})
|
||||
|
||||
it('Kiro 运行时冷却在状态列复用限流展示', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 5,
|
||||
name: 'kiro-cooldown',
|
||||
platform: 'kiro',
|
||||
kiro_runtime_state: 'cooldown',
|
||||
kiro_runtime_reason: 'rate_limit_exceeded',
|
||||
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedAutoResume')
|
||||
expect(wrapper.text()).toContain('429')
|
||||
})
|
||||
|
||||
it('Kiro suspended 在状态列显示为 forbidden', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 6,
|
||||
name: 'kiro-suspended',
|
||||
platform: 'kiro',
|
||||
kiro_runtime_state: 'suspended',
|
||||
kiro_runtime_reason: 'account_suspended',
|
||||
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.forbidden')
|
||||
})
|
||||
|
||||
it('Kiro overage active 在状态列仍显示正常状态', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 7,
|
||||
name: 'kiro-overage-active',
|
||||
platform: 'kiro',
|
||||
kiro_quota_state: 'overage_active',
|
||||
kiro_quota_reason: 'overages_enabled',
|
||||
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.active')
|
||||
expect(wrapper.text()).not.toContain('admin.accounts.status.overageActive')
|
||||
})
|
||||
|
||||
it('Kiro overage exhausted 在状态列显示危险徽章', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 8,
|
||||
name: 'kiro-overage-exhausted',
|
||||
platform: 'kiro',
|
||||
kiro_quota_state: 'overage_exhausted',
|
||||
kiro_quota_reason: 'overage disabled after quota exhaustion',
|
||||
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.overageExhausted')
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -25,6 +25,15 @@ vi.mock('vue-i18n', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
i18n: {
|
||||
global: {
|
||||
t: (key: string) => key
|
||||
}
|
||||
},
|
||||
getLocale: () => 'en'
|
||||
}))
|
||||
|
||||
function makeAccount(overrides: Partial<Account>): Account {
|
||||
return {
|
||||
id: 1,
|
||||
@@ -47,6 +56,12 @@ function makeAccount(overrides: Partial<Account>): Account {
|
||||
overload_until: null,
|
||||
temp_unschedulable_until: null,
|
||||
temp_unschedulable_reason: null,
|
||||
kiro_quota_state: null,
|
||||
kiro_quota_reason: null,
|
||||
kiro_quota_reset_at: null,
|
||||
kiro_runtime_state: null,
|
||||
kiro_runtime_reason: null,
|
||||
kiro_runtime_reset_at: null,
|
||||
session_window_start: null,
|
||||
session_window_end: null,
|
||||
session_window_status: null,
|
||||
@@ -530,6 +545,234 @@ describe('AccountUsageCell', () => {
|
||||
expect(wrapper.text()).toContain('7d|100|106540000')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会用 passive source 拉取并展示 credits 额度', async () => {
|
||||
const account = makeAccount({
|
||||
id: 3001,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
|
||||
getUsage.mockResolvedValue({
|
||||
source: 'passive',
|
||||
kiro_subscription_name: 'KIRO PRO+',
|
||||
kiro_overages_enabled: true,
|
||||
kiro_credit: {
|
||||
current_usage: 125,
|
||||
usage_limit: 2000,
|
||||
percentage_used: 6.25,
|
||||
},
|
||||
kiro_bonus: {
|
||||
current_usage: 25,
|
||||
usage_limit: 500,
|
||||
percentage_used: 5,
|
||||
days_remaining: 7,
|
||||
},
|
||||
kiro_overage: {
|
||||
current_overages: 2,
|
||||
overage_charges: 0.08,
|
||||
currency_symbol: '$',
|
||||
currency_code: 'USD',
|
||||
},
|
||||
kiro_reset_at: '2099-03-13T12:00:00Z',
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(getUsage).toHaveBeenCalledWith(3001, 'passive')
|
||||
expect(wrapper.emitted('kiroUsageMeta')?.[0]).toEqual([
|
||||
{
|
||||
plan_type: 'KIRO PRO+',
|
||||
kiro_overages_enabled: true
|
||||
}
|
||||
])
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroCredits')
|
||||
expect(wrapper.text()).toContain('125 / 2.0K')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroBonus')
|
||||
expect(wrapper.text()).toContain('25 / 500')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroDaysLeft')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroReset')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroOverage 2 ($0.08)')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会展示运行时冷却状态', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
source: 'passive',
|
||||
kiro_runtime_state: 'cooldown',
|
||||
kiro_runtime_reason: 'rate_limit_exceeded',
|
||||
kiro_runtime_reset_at: '2099-03-13T12:00:00Z',
|
||||
kiro_credit: {
|
||||
current_usage: 10,
|
||||
usage_limit: 100,
|
||||
percentage_used: 10,
|
||||
},
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3002,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedUntil')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会展示 overage active 与 exhausted 状态', async () => {
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
kiro_quota_state: 'overage_active',
|
||||
kiro_quota_reason: 'overages_enabled',
|
||||
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
|
||||
kiro_overages_enabled: true,
|
||||
kiro_credit: {
|
||||
current_usage: 2100,
|
||||
usage_limit: 2000,
|
||||
percentage_used: 100,
|
||||
},
|
||||
kiro_overage: {
|
||||
current_overages: 3,
|
||||
overage_charges: 0.12,
|
||||
currency_symbol: '$',
|
||||
},
|
||||
})
|
||||
|
||||
const activeWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3005,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(activeWrapper.text()).toContain('admin.accounts.status.overageActive')
|
||||
expect(activeWrapper.text()).not.toContain('admin.accounts.status.overageActiveUntil')
|
||||
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
kiro_quota_state: 'overage_exhausted',
|
||||
kiro_quota_reason: 'usage API error: overage exhausted',
|
||||
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
|
||||
error: 'usage API error: kiro usage request failed (status 429): {"message":"overage exhausted"}',
|
||||
})
|
||||
|
||||
const exhaustedWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3006,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhausted')
|
||||
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会展示 profile 异常和 usage forbidden 徽章', async () => {
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
error_code: 'forbidden',
|
||||
error: 'usage API error: kiro usage request failed (status 400): {"message":"profileArn is required for this request."}',
|
||||
})
|
||||
|
||||
const profileWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3003,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(profileWrapper.text()).toContain('admin.accounts.usageError')
|
||||
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
error_code: 'forbidden',
|
||||
error: 'usage API error: kiro usage request failed (status 403): {"message":"User is not authorized to access this feature."}',
|
||||
})
|
||||
|
||||
const forbiddenWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3004,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(forbiddenWrapper.text()).toContain('admin.accounts.forbidden')
|
||||
})
|
||||
|
||||
it('Key 账号会展示 today stats 徽章并带 A/U 提示', async () => {
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
import { mount } from '@vue/test-utils'
|
||||
|
||||
vi.mock('@/stores/app', () => ({
|
||||
useAppStore: () => ({
|
||||
showSuccess: vi.fn(),
|
||||
showError: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useClipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copied: { value: false },
|
||||
copyToClipboard: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
import OAuthAuthorizationFlow from '../OAuthAuthorizationFlow.vue'
|
||||
|
||||
describe('OAuthAuthorizationFlow', () => {
|
||||
it('extracts code, state, and callback metadata from a full Kiro callback URL', async () => {
|
||||
const wrapper = mount(OAuthAuthorizationFlow, {
|
||||
props: {
|
||||
addMethod: 'oauth',
|
||||
platform: 'kiro',
|
||||
authUrl: 'https://example.com/authorize',
|
||||
sessionId: 'session-1'
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const textarea = wrapper.get('textarea')
|
||||
await textarea.setValue('http://localhost:49153/oauth/callback?code=abc123&state=state456&login_option=github')
|
||||
await nextTick()
|
||||
|
||||
expect((textarea.element as HTMLTextAreaElement).value).toBe('abc123')
|
||||
expect((wrapper.vm as any).oauthState).toBe('state456')
|
||||
expect((wrapper.vm as any).oauthCallbackPath).toBe('/oauth/callback')
|
||||
expect((wrapper.vm as any).oauthLoginOption).toBe('github')
|
||||
})
|
||||
})
|
||||
@@ -489,7 +489,8 @@ const platformOptions = [
|
||||
{ value: 'anthropic', label: 'Anthropic' },
|
||||
{ value: 'openai', label: 'OpenAI' },
|
||||
{ value: 'gemini', label: 'Gemini' },
|
||||
{ value: 'antigravity', label: 'Antigravity' }
|
||||
{ value: 'antigravity', label: 'Antigravity' },
|
||||
{ value: 'kiro', label: 'Kiro' }
|
||||
]
|
||||
|
||||
// Load rules when dialog opens
|
||||
|
||||
@@ -25,7 +25,7 @@ const updateType = (value: string | number | boolean | null) => { emit('update:f
|
||||
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
||||
const updatePrivacyMode = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, privacy_mode: value }) }
|
||||
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }])
|
||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'kiro', label: 'Kiro' }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
|
||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }, { value: 'unschedulable', label: t('admin.accounts.status.unschedulable') }])
|
||||
const privacyOpts = computed(() => [
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
<BaseDialog
|
||||
:show="show"
|
||||
:title="t('admin.accounts.reAuthorizeAccount')"
|
||||
width="normal"
|
||||
:width="isKiro ? 'wide' : 'normal'"
|
||||
@close="handleClose"
|
||||
>
|
||||
<div v-if="account" class="space-y-4">
|
||||
<!-- Account Info -->
|
||||
<div
|
||||
class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700"
|
||||
>
|
||||
<div class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
|
||||
<div class="flex items-center gap-3">
|
||||
<div
|
||||
:class="[
|
||||
@@ -18,6 +15,8 @@
|
||||
? 'from-green-500 to-green-600'
|
||||
: isGemini
|
||||
? 'from-blue-500 to-blue-600'
|
||||
: isKiro
|
||||
? 'from-amber-500 to-amber-600'
|
||||
: isAntigravity
|
||||
? 'from-purple-500 to-purple-600'
|
||||
: 'from-orange-500 to-orange-600'
|
||||
@@ -26,15 +25,15 @@
|
||||
<Icon name="sparkles" size="md" class="text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block font-semibold text-gray-900 dark:text-white">{{
|
||||
account.name
|
||||
}}</span>
|
||||
<span class="block font-semibold text-gray-900 dark:text-white">{{ account.name }}</span>
|
||||
<span class="text-sm text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
isOpenAI
|
||||
? t('admin.accounts.openaiAccount')
|
||||
: isGemini
|
||||
? t('admin.accounts.geminiAccount')
|
||||
: isKiro
|
||||
? t('admin.accounts.kiroAccount')
|
||||
: isAntigravity
|
||||
? t('admin.accounts.antigravityAccount')
|
||||
: t('admin.accounts.claudeCodeAccount')
|
||||
@@ -44,7 +43,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Add Method Selection (Claude only) -->
|
||||
<fieldset v-if="isAnthropic" class="border-0 p-0">
|
||||
<legend class="input-label">{{ t('admin.accounts.oauth.authMethod') }}</legend>
|
||||
<div class="mt-2 flex gap-4">
|
||||
@@ -55,9 +53,7 @@
|
||||
value="oauth"
|
||||
class="mr-2 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{
|
||||
t('admin.accounts.types.oauth')
|
||||
}}</span>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.types.oauth') }}</span>
|
||||
</label>
|
||||
<label class="flex cursor-pointer items-center">
|
||||
<input
|
||||
@@ -66,14 +62,11 @@
|
||||
value="setup-token"
|
||||
class="mr-2 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{
|
||||
t('admin.accounts.setupTokenLongLived')
|
||||
}}</span>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.setupTokenLongLived') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</fieldset>
|
||||
|
||||
<!-- Gemini OAuth Type Display (read-only) -->
|
||||
<div v-if="isGemini" class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
|
||||
<div class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
|
||||
@@ -116,7 +109,187 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="isKiro" class="rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-700/40 dark:bg-amber-900/20">
|
||||
<div class="mb-3 text-sm font-medium text-amber-900 dark:text-amber-100">
|
||||
{{ t('admin.accounts.oauth.kiro.authModeTitle') }}
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroAccountType = 'oauth'"
|
||||
:class="kiroModeClass('oauth')"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroAccountType === 'oauth'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="key" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.oauthTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.oauthSubtitle') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroAccountType = 'idc'"
|
||||
:class="kiroModeClass('idc')"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroAccountType === 'idc'
|
||||
? 'bg-blue-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="cloud" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.idcTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.idcSubtitle') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroAccountType = 'import'"
|
||||
:class="kiroModeClass('import')"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroAccountType === 'import'
|
||||
? 'bg-slate-700 text-white dark:bg-slate-500'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="download" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.importTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.importSubtitle') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="kiroAccountType === 'oauth'" class="mt-3 space-y-3">
|
||||
<div class="text-xs text-amber-800 dark:text-amber-200">
|
||||
{{ t('admin.accounts.oauth.kiro.oauthSubtitle') }}
|
||||
</div>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroOAuthProvider = 'google'"
|
||||
:class="kiroProviderClass('google')"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroOAuthProvider === 'google'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="user" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.googleTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.googleDesc') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="kiroOAuthProvider = 'github'"
|
||||
:class="kiroProviderClass('github')"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
kiroOAuthProvider === 'github'
|
||||
? 'bg-slate-700 text-white dark:bg-slate-500'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="terminal" size="sm" />
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.accounts.oauth.kiro.githubTitle') }}
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.oauth.kiro.githubDesc') }}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="kiroAccountType === 'idc'" class="mt-3 grid gap-3 md:grid-cols-2">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.idcStartUrlLabel') }}</label>
|
||||
<input
|
||||
v-model="kiroIDCStartUrl"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.oauth.kiro.startUrlPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.regionLabel') }}</label>
|
||||
<input
|
||||
v-model="kiroIDCRegion"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="t('admin.accounts.oauth.kiro.regionPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="isKiroImportMode" class="mt-3 space-y-3">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.tokenJsonLabel') }}</label>
|
||||
<textarea
|
||||
v-model="kiroTokenJson"
|
||||
rows="7"
|
||||
class="input font-mono text-xs"
|
||||
placeholder='{"accessToken":"...","refreshToken":"..."}'
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.oauth.kiro.tokenJsonHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.kiro.deviceRegistrationLabel') }}</label>
|
||||
<textarea
|
||||
v-model="kiroDeviceRegistrationJson"
|
||||
rows="4"
|
||||
class="input font-mono text-xs"
|
||||
placeholder='{"clientId":"...","clientSecret":"..."}'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<OAuthAuthorizationFlow
|
||||
v-if="!isKiroImportMode"
|
||||
ref="oauthFlowRef"
|
||||
:add-method="addMethod"
|
||||
:auth-url="currentAuthUrl"
|
||||
@@ -128,12 +301,11 @@
|
||||
:show-cookie-option="isAnthropic"
|
||||
:allow-multiple="false"
|
||||
:method-label="t('admin.accounts.inputMethod')"
|
||||
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
||||
:platform="oauthPlatform"
|
||||
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
||||
@generate-url="handleGenerateUrl"
|
||||
@cookie-auth="handleCookieAuth"
|
||||
/>
|
||||
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
@@ -142,7 +314,16 @@
|
||||
{{ t('common.cancel') }}
|
||||
</button>
|
||||
<button
|
||||
v-if="isManualInputMethod"
|
||||
v-if="isKiroImportMode"
|
||||
type="button"
|
||||
:disabled="currentLoading || !kiroTokenJson.trim()"
|
||||
class="btn btn-primary"
|
||||
@click="handleKiroImport"
|
||||
>
|
||||
{{ currentLoading ? t('admin.accounts.oauth.verifying') : t('admin.accounts.oauth.kiro.importAndUpdate') }}
|
||||
</button>
|
||||
<button
|
||||
v-else-if="isManualInputMethod"
|
||||
type="button"
|
||||
:disabled="!canExchangeCode"
|
||||
class="btn btn-primary"
|
||||
@@ -161,18 +342,14 @@
|
||||
r="10"
|
||||
stroke="currentColor"
|
||||
stroke-width="4"
|
||||
></circle>
|
||||
/>
|
||||
<path
|
||||
class="opacity-75"
|
||||
fill="currentColor"
|
||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||
></path>
|
||||
/>
|
||||
</svg>
|
||||
{{
|
||||
currentLoading
|
||||
? t('admin.accounts.oauth.verifying')
|
||||
: t('admin.accounts.oauth.completeAuth')
|
||||
}}
|
||||
{{ currentLoading ? t('admin.accounts.oauth.verifying') : t('admin.accounts.oauth.completeAuth') }}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
@@ -180,28 +357,29 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import { computed, ref, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import {
|
||||
useAccountOAuth,
|
||||
type AddMethod,
|
||||
type AuthInputMethod
|
||||
} from '@/composables/useAccountOAuth'
|
||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
|
||||
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
|
||||
import type { Account } from '@/types'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import OAuthAuthorizationFlow from '@/components/account/OAuthAuthorizationFlow.vue'
|
||||
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
|
||||
import {
|
||||
type AddMethod,
|
||||
type AuthInputMethod,
|
||||
useAccountOAuth
|
||||
} from '@/composables/useAccountOAuth'
|
||||
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
|
||||
import { useKiroOAuth } from '@/composables/useKiroOAuth'
|
||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import type { Account, AccountPlatform } from '@/types'
|
||||
|
||||
// Type for exposed OAuthAuthorizationFlow component
|
||||
// Note: defineExpose automatically unwraps refs, so we use the unwrapped types
|
||||
interface OAuthFlowExposed {
|
||||
authCode: string
|
||||
oauthState: string
|
||||
oauthCallbackPath: string
|
||||
oauthLoginOption: string
|
||||
projectId: string
|
||||
sessionKey: string
|
||||
inputMethod: AuthInputMethod
|
||||
@@ -222,77 +400,96 @@ const emit = defineEmits<{
|
||||
const appStore = useAppStore()
|
||||
const { t } = useI18n()
|
||||
|
||||
// OAuth composables
|
||||
const claudeOAuth = useAccountOAuth()
|
||||
const openaiOAuth = useOpenAIOAuth()
|
||||
const geminiOAuth = useGeminiOAuth()
|
||||
const antigravityOAuth = useAntigravityOAuth()
|
||||
const kiroOAuth = useKiroOAuth()
|
||||
|
||||
// Refs
|
||||
const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
|
||||
|
||||
// State
|
||||
const addMethod = ref<AddMethod>('oauth')
|
||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist')
|
||||
const kiroAccountType = ref<'oauth' | 'idc' | 'import'>('oauth')
|
||||
const kiroOAuthProvider = ref<'google' | 'github'>('google')
|
||||
const kiroIDCStartUrl = ref('https://view.awsapps.com/start')
|
||||
const kiroIDCRegion = ref('us-east-1')
|
||||
const kiroTokenJson = ref('')
|
||||
const kiroDeviceRegistrationJson = ref('')
|
||||
|
||||
// Computed - check platform
|
||||
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
||||
const isOpenAILike = computed(() => isOpenAI.value)
|
||||
const isGemini = computed(() => props.account?.platform === 'gemini')
|
||||
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
||||
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
||||
const isKiro = computed(() => props.account?.platform === 'kiro')
|
||||
|
||||
const oauthPlatform = computed<AccountPlatform>(() => {
|
||||
if (isOpenAI.value) return 'openai'
|
||||
if (isGemini.value) return 'gemini'
|
||||
if (isKiro.value) return 'kiro'
|
||||
if (isAntigravity.value) return 'antigravity'
|
||||
return 'anthropic'
|
||||
})
|
||||
|
||||
// Computed - current OAuth state based on platform
|
||||
const currentAuthUrl = computed(() => {
|
||||
if (isOpenAILike.value) return openaiOAuth.authUrl.value
|
||||
if (isGemini.value) return geminiOAuth.authUrl.value
|
||||
if (isKiro.value) return kiroOAuth.authUrl.value
|
||||
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
||||
return claudeOAuth.authUrl.value
|
||||
})
|
||||
|
||||
const currentSessionId = computed(() => {
|
||||
if (isOpenAILike.value) return openaiOAuth.sessionId.value
|
||||
if (isGemini.value) return geminiOAuth.sessionId.value
|
||||
if (isKiro.value) return kiroOAuth.sessionId.value
|
||||
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
||||
return claudeOAuth.sessionId.value
|
||||
})
|
||||
|
||||
const currentLoading = computed(() => {
|
||||
if (isOpenAILike.value) return openaiOAuth.loading.value
|
||||
if (isGemini.value) return geminiOAuth.loading.value
|
||||
if (isKiro.value) return kiroOAuth.loading.value
|
||||
if (isAntigravity.value) return antigravityOAuth.loading.value
|
||||
return claudeOAuth.loading.value
|
||||
})
|
||||
|
||||
const currentError = computed(() => {
|
||||
if (isOpenAILike.value) return openaiOAuth.error.value
|
||||
if (isGemini.value) return geminiOAuth.error.value
|
||||
if (isKiro.value) return kiroOAuth.error.value
|
||||
if (isAntigravity.value) return antigravityOAuth.error.value
|
||||
return claudeOAuth.error.value
|
||||
})
|
||||
|
||||
// Computed
|
||||
const isKiroImportMode = computed(() => isKiro.value && kiroAccountType.value === 'import')
|
||||
|
||||
const isManualInputMethod = computed(() => {
|
||||
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option)
|
||||
return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
||||
return isOpenAILike.value || isGemini.value || isKiro.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
||||
})
|
||||
|
||||
const canExchangeCode = computed(() => {
|
||||
if (isKiroImportMode.value) {
|
||||
return false
|
||||
}
|
||||
const authCode = oauthFlowRef.value?.authCode || ''
|
||||
const sessionId = currentSessionId.value
|
||||
const loading = currentLoading.value
|
||||
return authCode.trim() && sessionId && !loading
|
||||
return !!(authCode.trim() && currentSessionId.value && !currentLoading.value)
|
||||
})
|
||||
|
||||
// Watchers
|
||||
watch(
|
||||
() => props.show,
|
||||
(newVal) => {
|
||||
if (newVal && props.account) {
|
||||
// Initialize addMethod based on current account type (Claude only)
|
||||
if (
|
||||
isAnthropic.value &&
|
||||
(props.account.type === 'oauth' || props.account.type === 'setup-token')
|
||||
) {
|
||||
if (!newVal || !props.account) {
|
||||
resetState()
|
||||
return
|
||||
}
|
||||
|
||||
if (isAnthropic.value && (props.account.type === 'oauth' || props.account.type === 'setup-token')) {
|
||||
addMethod.value = props.account.type as AddMethod
|
||||
}
|
||||
|
||||
if (isGemini.value) {
|
||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||
geminiOAuthType.value =
|
||||
@@ -302,42 +499,123 @@ watch(
|
||||
? 'ai_studio'
|
||||
: 'code_assist'
|
||||
}
|
||||
} else {
|
||||
resetState()
|
||||
|
||||
if (isKiro.value) {
|
||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||
const authMethod = typeof creds.auth_method === 'string' ? creds.auth_method : ''
|
||||
const provider = String(creds.provider || '').toLowerCase()
|
||||
kiroIDCStartUrl.value = typeof creds.start_url === 'string' && creds.start_url ? creds.start_url : 'https://view.awsapps.com/start'
|
||||
kiroIDCRegion.value = typeof creds.region === 'string' && creds.region ? creds.region : 'us-east-1'
|
||||
kiroAccountType.value = authMethod === 'idc' ? 'idc' : 'oauth'
|
||||
kiroOAuthProvider.value = provider === 'github' ? 'github' : 'google'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// Methods
|
||||
const resetState = () => {
|
||||
addMethod.value = 'oauth'
|
||||
geminiOAuthType.value = 'code_assist'
|
||||
kiroAccountType.value = 'oauth'
|
||||
kiroOAuthProvider.value = 'google'
|
||||
kiroIDCStartUrl.value = 'https://view.awsapps.com/start'
|
||||
kiroIDCRegion.value = 'us-east-1'
|
||||
kiroTokenJson.value = ''
|
||||
kiroDeviceRegistrationJson.value = ''
|
||||
claudeOAuth.resetState()
|
||||
openaiOAuth.resetState()
|
||||
geminiOAuth.resetState()
|
||||
antigravityOAuth.resetState()
|
||||
kiroOAuth.resetState()
|
||||
oauthFlowRef.value?.reset()
|
||||
}
|
||||
|
||||
const kiroModeClass = (mode: typeof kiroAccountType.value) => [
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroAccountType.value === mode
|
||||
? mode === 'idc'
|
||||
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
|
||||
: mode === 'import'
|
||||
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
|
||||
: 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: mode === 'idc'
|
||||
? 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
|
||||
: mode === 'import'
|
||||
? 'border-gray-200 hover:border-slate-300 dark:border-dark-600 dark:hover:border-slate-700'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]
|
||||
|
||||
const kiroProviderClass = (provider: typeof kiroOAuthProvider.value) => [
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
kiroOAuthProvider.value === provider
|
||||
? provider === 'github'
|
||||
? 'border-slate-500 bg-slate-50 dark:bg-slate-900/20'
|
||||
: 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: provider === 'github'
|
||||
? 'border-amber-200 hover:border-slate-300 dark:border-amber-700/40 dark:hover:border-slate-700'
|
||||
: 'border-amber-200 hover:border-amber-300 dark:border-amber-700/40 dark:hover:border-amber-600'
|
||||
]
|
||||
|
||||
const handleClose = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const buildUpdatedCredentials = (next: Record<string, unknown>) => ({
|
||||
...((props.account?.credentials || {}) as Record<string, unknown>),
|
||||
...next
|
||||
})
|
||||
|
||||
const updateAccountCredentials = async (payload: {
|
||||
type: 'oauth' | 'setup-token'
|
||||
credentials: Record<string, unknown>
|
||||
extra?: Record<string, unknown>
|
||||
}) => {
|
||||
if (!props.account) return
|
||||
|
||||
await adminAPI.accounts.update(props.account.id, payload)
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
}
|
||||
|
||||
const handleGenerateUrl = async () => {
|
||||
if (!props.account) return
|
||||
|
||||
if (isOpenAILike.value) {
|
||||
await openaiOAuth.generateAuthUrl(props.account.proxy_id)
|
||||
} else if (isGemini.value) {
|
||||
return
|
||||
}
|
||||
|
||||
if (isGemini.value) {
|
||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
||||
const projectId = geminiOAuthType.value === 'code_assist' ? oauthFlowRef.value?.projectId : undefined
|
||||
await geminiOAuth.generateAuthUrl(props.account.proxy_id, projectId, geminiOAuthType.value, tierId)
|
||||
} else if (isAntigravity.value) {
|
||||
await antigravityOAuth.generateAuthUrl(props.account.proxy_id)
|
||||
} else {
|
||||
await claudeOAuth.generateAuthUrl(addMethod.value, props.account.proxy_id)
|
||||
return
|
||||
}
|
||||
|
||||
if (isKiro.value) {
|
||||
if (kiroAccountType.value === 'idc') {
|
||||
await kiroOAuth.generateIDCAuthUrl({
|
||||
proxyId: props.account.proxy_id,
|
||||
startUrl: kiroIDCStartUrl.value,
|
||||
region: kiroIDCRegion.value
|
||||
})
|
||||
return
|
||||
}
|
||||
await kiroOAuth.generateAuthUrl(
|
||||
props.account.proxy_id,
|
||||
kiroOAuthProvider.value === 'github' ? 'Github' : 'Google'
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if (isAntigravity.value) {
|
||||
await antigravityOAuth.generateAuthUrl(props.account.proxy_id)
|
||||
return
|
||||
}
|
||||
|
||||
await claudeOAuth.generateAuthUrl(addMethod.value, props.account.proxy_id)
|
||||
}
|
||||
|
||||
const handleExchangeCode = async () => {
|
||||
@@ -347,53 +625,37 @@ const handleExchangeCode = async () => {
|
||||
if (!authCode.trim()) return
|
||||
|
||||
if (isOpenAILike.value) {
|
||||
// OpenAI OAuth flow
|
||||
const oauthClient = openaiOAuth
|
||||
const sessionId = oauthClient.sessionId.value
|
||||
const sessionId = openaiOAuth.sessionId.value
|
||||
if (!sessionId) return
|
||||
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||
|
||||
const stateToUse = (oauthFlowRef.value?.oauthState || openaiOAuth.oauthState.value || '').trim()
|
||||
if (!stateToUse) {
|
||||
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(oauthClient.error.value)
|
||||
openaiOAuth.error.value = t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(openaiOAuth.error.value)
|
||||
return
|
||||
}
|
||||
|
||||
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||
authCode.trim(),
|
||||
sessionId,
|
||||
stateToUse,
|
||||
props.account.proxy_id
|
||||
)
|
||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(authCode.trim(), sessionId, stateToUse, props.account.proxy_id)
|
||||
if (!tokenInfo) return
|
||||
|
||||
// Build credentials and extra info
|
||||
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||
const extra = oauthClient.buildExtraInfo(tokenInfo)
|
||||
|
||||
try {
|
||||
// Update account with new credentials
|
||||
await adminAPI.accounts.update(props.account.id, {
|
||||
type: 'oauth', // OpenAI OAuth is always 'oauth' type
|
||||
credentials,
|
||||
extra
|
||||
await updateAccountCredentials({
|
||||
type: 'oauth',
|
||||
credentials: buildUpdatedCredentials(openaiOAuth.buildCredentials(tokenInfo)),
|
||||
extra: openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||
})
|
||||
|
||||
// Clear error status after successful re-authorization
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(oauthClient.error.value)
|
||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(openaiOAuth.error.value)
|
||||
}
|
||||
} else if (isGemini.value) {
|
||||
return
|
||||
}
|
||||
|
||||
if (isGemini.value) {
|
||||
const sessionId = geminiOAuth.sessionId.value
|
||||
if (!sessionId) return
|
||||
|
||||
const stateFromInput = oauthFlowRef.value?.oauthState || ''
|
||||
const stateToUse = stateFromInput || geminiOAuth.state.value
|
||||
const stateToUse = oauthFlowRef.value?.oauthState || geminiOAuth.state.value
|
||||
if (!stateToUse) return
|
||||
|
||||
const tokenInfo = await geminiOAuth.exchangeAuthCode({
|
||||
@@ -402,32 +664,58 @@ const handleExchangeCode = async () => {
|
||||
state: stateToUse,
|
||||
proxyId: props.account.proxy_id,
|
||||
oauthType: geminiOAuthType.value,
|
||||
tierId: typeof (props.account.credentials as any)?.tier_id === 'string' ? ((props.account.credentials as any).tier_id as string) : undefined
|
||||
tierId: typeof (props.account.credentials as any)?.tier_id === 'string'
|
||||
? ((props.account.credentials as any).tier_id as string)
|
||||
: undefined
|
||||
})
|
||||
if (!tokenInfo) return
|
||||
|
||||
const credentials = geminiOAuth.buildCredentials(tokenInfo)
|
||||
|
||||
try {
|
||||
await adminAPI.accounts.update(props.account.id, {
|
||||
await updateAccountCredentials({
|
||||
type: 'oauth',
|
||||
credentials
|
||||
credentials: buildUpdatedCredentials(geminiOAuth.buildCredentials(tokenInfo))
|
||||
})
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(geminiOAuth.error.value)
|
||||
}
|
||||
} else if (isAntigravity.value) {
|
||||
// Antigravity OAuth flow
|
||||
return
|
||||
}
|
||||
|
||||
if (isKiro.value) {
|
||||
const sessionId = kiroOAuth.sessionId.value
|
||||
if (!sessionId) return
|
||||
|
||||
const stateToUse = oauthFlowRef.value?.oauthState || kiroOAuth.state.value
|
||||
if (!stateToUse) return
|
||||
|
||||
const tokenInfo = await kiroOAuth.exchangeAuthCode({
|
||||
code: authCode.trim(),
|
||||
sessionId,
|
||||
state: stateToUse,
|
||||
callbackPath: oauthFlowRef.value?.oauthCallbackPath || '',
|
||||
loginOption: oauthFlowRef.value?.oauthLoginOption || '',
|
||||
proxyId: props.account.proxy_id
|
||||
})
|
||||
if (!tokenInfo) return
|
||||
|
||||
try {
|
||||
await updateAccountCredentials({
|
||||
type: 'oauth',
|
||||
credentials: buildUpdatedCredentials(kiroOAuth.buildCredentials(tokenInfo))
|
||||
})
|
||||
} catch (error: any) {
|
||||
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(kiroOAuth.error.value)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (isAntigravity.value) {
|
||||
const sessionId = antigravityOAuth.sessionId.value
|
||||
if (!sessionId) return
|
||||
|
||||
const stateFromInput = oauthFlowRef.value?.oauthState || ''
|
||||
const stateToUse = stateFromInput || antigravityOAuth.state.value
|
||||
const stateToUse = oauthFlowRef.value?.oauthState || antigravityOAuth.state.value
|
||||
if (!stateToUse) return
|
||||
|
||||
const tokenInfo = await antigravityOAuth.exchangeAuthCode({
|
||||
@@ -438,23 +726,18 @@ const handleExchangeCode = async () => {
|
||||
})
|
||||
if (!tokenInfo) return
|
||||
|
||||
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
|
||||
|
||||
try {
|
||||
await adminAPI.accounts.update(props.account.id, {
|
||||
await updateAccountCredentials({
|
||||
type: 'oauth',
|
||||
credentials
|
||||
credentials: buildUpdatedCredentials(antigravityOAuth.buildCredentials(tokenInfo))
|
||||
})
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(antigravityOAuth.error.value)
|
||||
}
|
||||
} else {
|
||||
// Claude OAuth flow
|
||||
return
|
||||
}
|
||||
|
||||
const sessionId = claudeOAuth.sessionId.value
|
||||
if (!sessionId) return
|
||||
|
||||
@@ -474,32 +757,41 @@ const handleExchangeCode = async () => {
|
||||
...proxyConfig
|
||||
})
|
||||
|
||||
const extra = claudeOAuth.buildExtraInfo(tokenInfo)
|
||||
|
||||
// Update account with new credentials and type
|
||||
await adminAPI.accounts.update(props.account.id, {
|
||||
type: addMethod.value, // Update type based on selected method
|
||||
credentials: tokenInfo,
|
||||
extra
|
||||
await updateAccountCredentials({
|
||||
type: addMethod.value,
|
||||
credentials: buildUpdatedCredentials(tokenInfo),
|
||||
extra: claudeOAuth.buildExtraInfo(tokenInfo)
|
||||
})
|
||||
|
||||
// Clear error status after successful re-authorization
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(claudeOAuth.error.value)
|
||||
} finally {
|
||||
claudeOAuth.loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleKiroImport = async () => {
|
||||
if (!props.account || !isKiroImportMode.value || !kiroTokenJson.value.trim()) return
|
||||
|
||||
const tokenInfo = await kiroOAuth.importToken(
|
||||
kiroTokenJson.value,
|
||||
kiroDeviceRegistrationJson.value || undefined
|
||||
)
|
||||
if (!tokenInfo) return
|
||||
|
||||
try {
|
||||
await updateAccountCredentials({
|
||||
type: 'oauth',
|
||||
credentials: buildUpdatedCredentials(kiroOAuth.buildCredentials(tokenInfo))
|
||||
})
|
||||
} catch (error: any) {
|
||||
kiroOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
appStore.showError(kiroOAuth.error.value)
|
||||
}
|
||||
}
|
||||
|
||||
const handleCookieAuth = async (sessionKey: string) => {
|
||||
if (!props.account || isOpenAILike.value) return
|
||||
if (!props.account || isOpenAILike.value || isKiro.value) return
|
||||
|
||||
claudeOAuth.loading.value = true
|
||||
claudeOAuth.error.value = ''
|
||||
@@ -517,24 +809,13 @@ const handleCookieAuth = async (sessionKey: string) => {
|
||||
...proxyConfig
|
||||
})
|
||||
|
||||
const extra = claudeOAuth.buildExtraInfo(tokenInfo)
|
||||
|
||||
// Update account with new credentials and type
|
||||
await adminAPI.accounts.update(props.account.id, {
|
||||
type: addMethod.value, // Update type based on selected method
|
||||
credentials: tokenInfo,
|
||||
extra
|
||||
await updateAccountCredentials({
|
||||
type: addMethod.value,
|
||||
credentials: buildUpdatedCredentials(tokenInfo),
|
||||
extra: claudeOAuth.buildExtraInfo(tokenInfo)
|
||||
})
|
||||
|
||||
// Clear error status after successful re-authorization
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
claudeOAuth.error.value =
|
||||
error.response?.data?.detail || t('admin.accounts.oauth.cookieAuthFailed')
|
||||
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.cookieAuthFailed')
|
||||
} finally {
|
||||
claudeOAuth.loading.value = false
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ const labelClass = computed(() => {
|
||||
}
|
||||
|
||||
// 正常状态或无天数:根据平台显示主题色
|
||||
if (props.platform === 'anthropic') {
|
||||
if (props.platform === 'anthropic' || props.platform === 'kiro') {
|
||||
return `${base} bg-orange-200/60 text-orange-800 dark:bg-orange-800/40 dark:text-orange-300`
|
||||
}
|
||||
if (props.platform === 'openai') {
|
||||
@@ -129,7 +129,7 @@ const labelClass = computed(() => {
|
||||
|
||||
// Badge color based on platform and subscription type
|
||||
const badgeClass = computed(() => {
|
||||
if (props.platform === 'anthropic') {
|
||||
if (props.platform === 'anthropic' || props.platform === 'kiro') {
|
||||
// Claude: orange theme
|
||||
return isSubscription.value
|
||||
? 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
|
||||
@@ -91,6 +91,8 @@ const ratePillClass = computed(() => {
|
||||
return 'bg-green-50 text-green-700 dark:bg-green-900/20 dark:text-green-400'
|
||||
case 'gemini':
|
||||
return 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400'
|
||||
case 'kiro':
|
||||
return 'bg-amber-50 text-amber-700 dark:bg-amber-900/20 dark:text-amber-400'
|
||||
default: // antigravity and others
|
||||
return 'bg-violet-50 text-violet-700 dark:bg-violet-900/20 dark:text-violet-400'
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<template>
|
||||
<!-- Claude/Anthropic logo -->
|
||||
<svg v-if="platform === 'anthropic'" :class="sizeClass" viewBox="0 0 16 16" fill="currentColor">
|
||||
<svg v-if="platform === 'anthropic' || platform === 'kiro'" :class="sizeClass" viewBox="0 0 16 16" fill="currentColor">
|
||||
<path
|
||||
d="m3.127 10.604 3.135-1.76.053-.153-.053-.085H6.11l-.525-.032-1.791-.048-1.554-.065-1.505-.08-.38-.081L0 7.832l.036-.234.32-.214.455.04 1.009.069 1.513.105 1.097.064 1.626.17h.259l.036-.105-.089-.065-.068-.064-1.566-1.062-1.695-1.121-.887-.646-.48-.327-.243-.306-.104-.67.435-.48.585.04.15.04.593.456 1.267.981 1.654 1.218.242.202.097-.068.012-.049-.109-.181-.9-1.626-.96-1.655-.428-.686-.113-.411a2 2 0 0 1-.068-.484l.496-.674L4.446 0l.662.089.279.242.411.94.666 1.48 1.033 2.014.302.597.162.553.06.17h.105v-.097l.085-1.134.157-1.392.154-1.792.052-.504.25-.605.497-.327.387.186.319.456-.045.294-.19 1.23-.37 1.93-.243 1.29h.142l.161-.16.654-.868 1.097-1.372.484-.545.565-.601.363-.287h.686l.505.751-.226.775-.707.895-.585.759-.839 1.13-.524.904.048.072.125-.012 1.897-.403 1.024-.186 1.223-.21.553.258.06.263-.218.536-1.307.323-1.533.307-2.284.54-.028.02.032.04 1.029.098.44.024h1.077l2.005.15.525.346.315.424-.053.323-.807.411-3.631-.863-.872-.218h-.12v.073l.726.71 1.331 1.202 1.667 1.55.084.383-.214.302-.226-.032-1.464-1.101-.565-.497-1.28-1.077h-.084v.113l.295.432 1.557 2.34.08.718-.112.234-.404.141-.444-.08-.911-1.28-.94-1.44-.759-1.291-.093.053-.448 4.821-.21.246-.484.186-.403-.307-.214-.496.214-.98.258-1.28.21-1.016.19-1.263.112-.42-.008-.028-.092.012-.953 1.307-1.448 1.957-1.146 1.227-.274.109-.477-.247.045-.44.266-.39 1.586-2.018.956-1.25.617-.723-.004-.105h-.036l-4.212 2.736-.75.096-.324-.302.04-.496.154-.162 1.267-.871z"
|
||||
/>
|
||||
|
||||
@@ -31,10 +31,18 @@
|
||||
</span>
|
||||
</div>
|
||||
<!-- Row 2: Plan type + Privacy mode (only if either exists) -->
|
||||
<div v-if="planLabel || privacyBadge" class="inline-flex items-center overflow-hidden rounded-md">
|
||||
<div v-if="planLabel || privacyBadge || overagesBadge" class="inline-flex items-center overflow-hidden rounded-md">
|
||||
<span v-if="planLabel" :class="['inline-flex items-center gap-1 px-1.5 py-1', planBadgeClass]">
|
||||
<span>{{ planLabel }}</span>
|
||||
</span>
|
||||
<span
|
||||
v-if="overagesBadge"
|
||||
:class="['inline-flex items-center gap-1 px-1.5 py-1', overagesBadge.class]"
|
||||
:title="overagesBadge.title"
|
||||
>
|
||||
<Icon name="sparkles" size="xs" />
|
||||
<span>{{ overagesBadge.label }}</span>
|
||||
</span>
|
||||
<span
|
||||
v-if="privacyBadge"
|
||||
:class="['inline-flex items-center gap-1 px-1.5 py-1', privacyBadge.class]"
|
||||
@@ -66,6 +74,7 @@ interface Props {
|
||||
platform: AccountPlatform
|
||||
type: AccountType
|
||||
planType?: string
|
||||
overagesEnabled?: boolean
|
||||
privacyMode?: string
|
||||
subscriptionExpiresAt?: string
|
||||
}
|
||||
@@ -76,6 +85,7 @@ const platformLabel = computed(() => {
|
||||
if (props.platform === 'anthropic') return 'Anthropic'
|
||||
if (props.platform === 'openai') return 'OpenAI'
|
||||
if (props.platform === 'antigravity') return 'Antigravity'
|
||||
if (props.platform === 'kiro') return 'Kiro'
|
||||
return 'Gemini'
|
||||
})
|
||||
|
||||
@@ -126,6 +136,9 @@ const platformClass = computed(() => {
|
||||
if (props.platform === 'antigravity') {
|
||||
return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
}
|
||||
if (props.platform === 'kiro') {
|
||||
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
}
|
||||
return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
})
|
||||
|
||||
@@ -139,6 +152,9 @@ const typeClass = computed(() => {
|
||||
if (props.platform === 'antigravity') {
|
||||
return 'bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
}
|
||||
if (props.platform === 'kiro') {
|
||||
return 'bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
}
|
||||
return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
})
|
||||
|
||||
@@ -149,6 +165,15 @@ const planBadgeClass = computed(() => {
|
||||
return typeClass.value
|
||||
})
|
||||
|
||||
const overagesBadge = computed(() => {
|
||||
if (props.platform !== 'kiro' || !props.overagesEnabled) return null
|
||||
return {
|
||||
label: t('admin.accounts.status.overageActive'),
|
||||
title: t('admin.accounts.usageWindow.kiroOverage'),
|
||||
class: 'bg-amber-100 text-amber-700 dark:bg-amber-900/30 dark:text-amber-300'
|
||||
}
|
||||
})
|
||||
|
||||
// Subscription expiration label (non-free only)
|
||||
const expiresLabel = computed(() => {
|
||||
if (!props.subscriptionExpiresAt || !props.planType) return ''
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import { mount } from '@vue/test-utils'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import PlatformTypeBadge from '../PlatformTypeBadge.vue'
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key === 'admin.accounts.status.overageActive' ? 'Overage' : key
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
describe('PlatformTypeBadge', () => {
|
||||
it('shows Kiro overages tag next to the plan tag when enabled', () => {
|
||||
const wrapper = mount(PlatformTypeBadge, {
|
||||
props: {
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
planType: 'KIRO PRO+',
|
||||
overagesEnabled: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('KIRO PRO+')
|
||||
expect(wrapper.text()).toContain('Overage')
|
||||
})
|
||||
|
||||
it('does not show overages tag for non-Kiro accounts', () => {
|
||||
const wrapper = mount(PlatformTypeBadge, {
|
||||
props: {
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
planType: 'Pro',
|
||||
overagesEnabled: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).not.toContain('Overage')
|
||||
})
|
||||
})
|
||||
@@ -4,7 +4,12 @@ vi.mock('@/api/admin/accounts', () => ({
|
||||
getAntigravityDefaultModelMapping: vi.fn()
|
||||
}))
|
||||
|
||||
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
|
||||
import {
|
||||
buildModelMappingObject,
|
||||
fetchKiroDefaultMappings,
|
||||
getModelsByPlatform,
|
||||
getPresetMappingsByPlatform
|
||||
} from '../useModelWhitelist'
|
||||
|
||||
describe('useModelWhitelist', () => {
|
||||
it('openai 模型列表包含 GPT-5.4 官方快照', () => {
|
||||
@@ -51,6 +56,58 @@ describe('useModelWhitelist', () => {
|
||||
expect(models.indexOf('gemini-2.5-flash-image')).toBeLessThan(models.indexOf('gemini-2.5-flash-lite'))
|
||||
})
|
||||
|
||||
it('kiro 模型列表不暴露旧的 -agentic / -chat 后缀', () => {
|
||||
const models = getModelsByPlatform('kiro')
|
||||
|
||||
expect(models).toContain('claude-sonnet-4-6')
|
||||
expect(models).toContain('claude-sonnet-4-6-thinking')
|
||||
expect(models).not.toContain('claude-sonnet-4-6-chat')
|
||||
expect(models.every((model) => !model.endsWith('-agentic') && !model.endsWith('-chat'))).toBe(true)
|
||||
})
|
||||
|
||||
it('kiro 模型列表只保留 Claude 模型', () => {
|
||||
const models = getModelsByPlatform('kiro')
|
||||
|
||||
expect(models).toEqual([
|
||||
'claude-opus-4-6',
|
||||
'claude-opus-4-6-thinking',
|
||||
'claude-sonnet-4-6',
|
||||
'claude-sonnet-4-6-thinking',
|
||||
'claude-opus-4-5-20251101',
|
||||
'claude-opus-4-5-20251101-thinking',
|
||||
'claude-sonnet-4-5-20250929',
|
||||
'claude-sonnet-4-5-20250929-thinking',
|
||||
'claude-haiku-4-5-20251001',
|
||||
'claude-haiku-4-5-20251001-thinking'
|
||||
])
|
||||
expect(models.every(model => model.startsWith('claude-'))).toBe(true)
|
||||
expect(models.some(model => model.endsWith('-agentic'))).toBe(false)
|
||||
expect(models.some(model => model.endsWith('-chat'))).toBe(false)
|
||||
expect(models).not.toContain('kiro-auto')
|
||||
expect(models).not.toContain('claude-opus-4-5')
|
||||
expect(models).not.toContain('claude-sonnet-4-5')
|
||||
expect(models).not.toContain('claude-sonnet-4')
|
||||
expect(models).not.toContain('claude-3-5-sonnet-20241022')
|
||||
expect(models).not.toContain('claude-3-5-haiku-20241022')
|
||||
expect(models).not.toContain('claude-haiku-4-5')
|
||||
expect(models).not.toContain('gpt-4o')
|
||||
expect(models).not.toContain('gpt-4')
|
||||
expect(models).not.toContain('gpt-4-turbo')
|
||||
expect(models).not.toContain('gpt-3.5-turbo')
|
||||
expect(models).not.toContain('deepseek-3-2')
|
||||
expect(models).not.toContain('minimax-m2-1')
|
||||
expect(models).not.toContain('qwen3-coder-next')
|
||||
})
|
||||
|
||||
it('claude 模型列表包含 dated 和 thinking 兼容别名', () => {
|
||||
const models = getModelsByPlatform('claude')
|
||||
|
||||
expect(models).toContain('claude-opus-4-6-thinking')
|
||||
expect(models).toContain('claude-opus-4-5-20251101-thinking')
|
||||
expect(models).toContain('claude-sonnet-4-20250514-thinking')
|
||||
expect(models).toContain('claude-haiku-4-5-20251001-thinking')
|
||||
})
|
||||
|
||||
it('whitelist 模式会忽略通配符条目', () => {
|
||||
const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
|
||||
expect(mapping).toEqual({
|
||||
@@ -73,4 +130,68 @@ describe('useModelWhitelist', () => {
|
||||
'gpt-5.4-mini': 'gpt-5.4-mini'
|
||||
})
|
||||
})
|
||||
|
||||
it('kiro 预设映射只暴露 Claude 入口', () => {
|
||||
const mappings = getPresetMappingsByPlatform('kiro')
|
||||
const mappingTargets = mappings.map(item => item.to)
|
||||
|
||||
expect(mappings.map(({ from, to }) => ({ from, to }))).toEqual([
|
||||
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
|
||||
])
|
||||
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
|
||||
expect(mappingTargets.every(model => model.startsWith('claude-'))).toBe(true)
|
||||
expect(mappingTargets.some(model => model.endsWith('-agentic'))).toBe(false)
|
||||
expect(mappingTargets.some(model => model.endsWith('-chat'))).toBe(false)
|
||||
expect(mappingTargets).not.toContain('kiro-auto')
|
||||
expect(mappingTargets.some(model => model.startsWith('kiro-'))).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-opus-4-5')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-sonnet-4-5')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-sonnet-4')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-3-5-sonnet-20241022')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-3-5-haiku-20241022')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-haiku-4-5')).toBe(false)
|
||||
expect(mappingTargets).not.toContain('gpt-4o')
|
||||
expect(mappingTargets).not.toContain('gpt-4')
|
||||
expect(mappingTargets).not.toContain('gpt-4-turbo')
|
||||
expect(mappingTargets).not.toContain('gpt-3.5-turbo')
|
||||
expect(mappingTargets).not.toContain('deepseek-3.2')
|
||||
expect(mappingTargets).not.toContain('minimax-m2.1')
|
||||
expect(mappingTargets).not.toContain('qwen3-coder-next')
|
||||
})
|
||||
|
||||
it('kiro 默认映射会在前端填充所有可精确定价模型', async () => {
|
||||
const mappings = await fetchKiroDefaultMappings()
|
||||
|
||||
expect(mappings).toEqual(expect.arrayContaining([
|
||||
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
|
||||
]))
|
||||
expect(mappings).toHaveLength(10)
|
||||
expect(mappings.every(item => !item.from.startsWith('kiro-'))).toBe(true)
|
||||
expect(mappings.every(item => !item.to.startsWith('kiro-'))).toBe(true)
|
||||
expect(mappings.every(item => !item.from.endsWith('-agentic'))).toBe(true)
|
||||
expect(mappings.every(item => !item.to.endsWith('-agentic'))).toBe(true)
|
||||
expect(mappings.every(item => !item.from.endsWith('-chat'))).toBe(true)
|
||||
expect(mappings.every(item => !item.to.endsWith('-chat'))).toBe(true)
|
||||
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
|
||||
expect(mappings.every(item => item.to.startsWith('claude-'))).toBe(true)
|
||||
expect(mappings.some(item => item.to === 'claude-opus-4-7')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user