diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go index a53cc802..6fd947d8 100644 --- a/backend/internal/domain/constants_test.go +++ b/backend/internal/domain/constants_test.go @@ -57,6 +57,8 @@ func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) { "claude-opus-4-5", "claude-sonnet-4-5", "claude-sonnet-4", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", "gpt-4o", "gpt-4", "deepseek-3-2", diff --git a/backend/internal/handler/admin/account_handler_available_models_test.go b/backend/internal/handler/admin/account_handler_available_models_test.go index c5f1e2d8..6d040565 100644 --- a/backend/internal/handler/admin/account_handler_available_models_test.go +++ b/backend/internal/handler/admin/account_handler_available_models_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "slices" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau require.NotEmpty(t, resp.Data) require.NotEqual(t, "gpt-5", resp.Data[0].ID) } + +func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 44, + Name: "kiro-oauth", + Platform: service.PlatformKiro, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotEmpty(t, resp.Data) + ids := make([]string, 0, len(resp.Data)) + for _, model := range resp.Data { + ids = append(ids, model.ID) + } + require.True(t, slices.Contains(ids, "claude-opus-4-6")) + require.False(t, slices.Contains(ids, "claude-opus-4-7")) + require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7")) +} + +func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 47, + Name: "kiro-oauth-mapped", + Platform: service.PlatformKiro, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "claude-sonnet-4.6", + "custom-model": "custom-upstream-model", + }, + }, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Len(t, resp.Data, 2) + + ids := make([]string, 0, len(resp.Data)) + for _, model := range resp.Data { + ids = append(ids, model.ID) + } + require.True(t, slices.Contains(ids, "claude-sonnet-4-6")) + require.True(t, slices.Contains(ids, "custom-model")) + require.False(t, slices.Contains(ids, "claude-opus-4-7")) +} + +func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 45, + Name: "kiro-apikey", + Platform: service.PlatformKiro, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "claude-sonnet-4.6", + "custom-model": "custom-upstream-model", + }, + }, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Len(t, resp.Data, 2) + + ids := make([]string, 0, len(resp.Data)) + for _, model := range resp.Data { + ids = append(ids, model.ID) + } + require.True(t, slices.Contains(ids, "claude-sonnet-4-6")) + require.True(t, slices.Contains(ids, "custom-model")) +} + +func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 46, + Name: "kiro-apikey-defaults", + Platform: service.PlatformKiro, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotEmpty(t, resp.Data) + ids := make([]string, 0, len(resp.Data)) + for _, model := range resp.Data { + ids = append(ids, model.ID) + } + require.True(t, slices.Contains(ids, "claude-opus-4-6")) + require.False(t, slices.Contains(ids, "claude-opus-4-7")) + require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7")) +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 65e5ec78..afca3361 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -120,7 +120,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` diff --git a/backend/internal/handler/admin/group_handler_kiro_validation_test.go b/backend/internal/handler/admin/group_handler_kiro_validation_test.go new file mode 100644 index 00000000..60c1bc12 --- /dev/null +++ b/backend/internal/handler/admin/group_handler_kiro_validation_test.go @@ -0,0 +1,16 @@ +package admin + +import ( + "testing" + + "github.com/gin-gonic/gin/binding" + "github.com/stretchr/testify/require" +) + +func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) { + createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"} + require.NoError(t, binding.Validator.ValidateStruct(createReq)) + + updateReq := UpdateGroupRequest{Platform: "kiro"} + require.NoError(t, binding.Validator.ValidateStruct(updateReq)) +} diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go index 620736cd..d3b07930 100644 --- a/backend/internal/model/error_passthrough_rule.go +++ b/backend/internal/model/error_passthrough_rule.go @@ -36,11 +36,12 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" + PlatformKiro = "kiro" ) // AllPlatforms 返回所有支持的平台列表 func AllPlatforms() []string { - return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity} + return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro} } // Validate 验证规则配置的有效性 diff --git a/backend/internal/repository/simple_mode_default_groups.go b/backend/internal/repository/simple_mode_default_groups.go index 56309184..46ec6ec1 100644 --- a/backend/internal/repository/simple_mode_default_groups.go +++ b/backend/internal/repository/simple_mode_default_groups.go @@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er service.PlatformOpenAI: 1, service.PlatformGemini: 1, service.PlatformAntigravity: 2, + service.PlatformKiro: 1, } for platform, minCount := range requiredByPlatform { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index d903b940..6d249545 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) { requestedModel: "any-model", expected: true, }, + { + name: "kiro no mapping falls back to default whitelist", + platform: PlatformKiro, + credentials: nil, + requestedModel: "claude-sonnet-4-6", + expected: true, + }, + { + name: "kiro no mapping rejects model outside default whitelist", + platform: PlatformKiro, + credentials: nil, + requestedModel: "auto", + expected: false, + }, // 精确匹配 { @@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) { requestedModel: "gemini-3.1-pro-preview-customtools", expected: "gemini-3.1-pro-preview-customtools", }, + { + name: "kiro no mapping uses default upstream mapping", + platform: PlatformKiro, + credentials: nil, + requestedModel: "claude-sonnet-4-6", + expected: "claude-sonnet-4.6", + }, // 精确匹配 { diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 392b3e0b..d613cb4a 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -194,6 +194,13 @@ func (s *BillingService) initFallbackPricing() { // Claude 4.7 Opus (暂与4.6同价,待官方定价更新) s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"] + // Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价 + s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"] + s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"] + + // Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价 + s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"] + // Gemini 3.1 Pro s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{ InputPricePerToken: 2e-6, // $2 per MTok @@ -268,13 +275,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { return s.fallbackPrices["claude-3-opus"] } if strings.Contains(modelLower, "sonnet") { - if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") { + switch { + case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"): + return s.fallbackPrices["claude-sonnet-4.6"] + case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"): + return s.fallbackPrices["claude-sonnet-4.5"] + case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"): return s.fallbackPrices["claude-sonnet-4"] } return s.fallbackPrices["claude-3-5-sonnet"] } if strings.Contains(modelLower, "haiku") { - if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") { + switch { + case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"): + return s.fallbackPrices["claude-haiku-4.5"] + case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"): return s.fallbackPrices["claude-3-5-haiku"] } return s.fallbackPrices["claude-3-haiku"] diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index 7ac77f77..9991b19d 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions( // 4. Model mapping mappedModel := originalModel - if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { + if account.Platform == PlatformKiro { + if next := account.GetMappedModel(originalModel); next != "" { + mappedModel = next + } + } else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { @@ -105,44 +109,63 @@ func (s *GatewayService) ForwardAsChatCompletions( // 7. Enforce cache_control block limit anthropicBody = enforceCacheControlLimit(anthropicBody) - // 8. Get access token - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, fmt.Errorf("get access token: %w", err) - } - - // 9. Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 10. Build upstream request - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) - upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) - releaseUpstreamCtx() - if err != nil { - return nil, fmt.Errorf("build upstream request: %w", err) - } - - // 11. Send request - resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) - if err != nil { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() + var resp *http.Response + if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth { + resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + } else { + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) } - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed") - return nil, fmt.Errorf("upstream request failed: %s", safeErr) } defer func() { _ = resp.Body.Close() }() diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index 8f8a1e94..f7fd9b71 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses( // 4. Model mapping mappedModel := originalModel reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) - if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { + if account.Platform == PlatformKiro { + if next := account.GetMappedModel(originalModel); next != "" { + mappedModel = next + } + } else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { @@ -102,44 +106,63 @@ func (s *GatewayService) ForwardAsResponses( // 7. Enforce cache_control block limit anthropicBody = enforceCacheControlLimit(anthropicBody) - // 8. Get access token - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, fmt.Errorf("get access token: %w", err) - } - - // 9. Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 10. Build upstream request - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) - upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) - releaseUpstreamCtx() - if err != nil { - return nil, fmt.Errorf("build upstream request: %w", err) - } - - // 11. Send request - resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) - if err != nil { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() + var resp *http.Response + if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth { + resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + } else { + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) } - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") - return nil, fmt.Errorf("upstream request failed: %s", safeErr) } defer func() { _ = resp.Body.Close() }() diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 140bdc67..f9e3a896 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, + nil, ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3d7b1c10..c72eecfe 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -912,6 +912,7 @@ type claudeOAuthNormalizeOptions struct { injectMetadata bool metadataUserID string stripSystemCacheControl bool + preserveToolChoice bool } // sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). @@ -1126,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu modified = true } } + if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } + } // max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。 if !gjson.GetBytes(out, "max_tokens").Exists() { @@ -4518,7 +4525,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} + normalizeOpts := claudeOAuthNormalizeOptions{ + stripSystemCacheControl: !systemRewritten, + preserveToolChoice: account.Platform == PlatformKiro, + } if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { @@ -8968,6 +8978,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform") return nil } + if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform") + return nil + } // 应用模型映射: // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 diff --git a/backend/internal/service/gateway_websearch_emulation.go b/backend/internal/service/gateway_websearch_emulation.go index a42b5585..e8d5a6c4 100644 --- a/backend/internal/service/gateway_websearch_emulation.go +++ b/backend/internal/service/gateway_websearch_emulation.go @@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager { // shouldEmulateWebSearch checks whether a request should be intercepted. // -// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled. -// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel). +// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy. +// Anthropic API Key keeps the existing account-level override: +// "enabled" (force on), "disabled" (force off), "default" (follow channel). +// Kiro OAuth uses channel-level switch only. func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool { if getWebSearchManager() == nil { return false @@ -62,22 +64,37 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac return false } - mode := account.GetWebSearchEmulationMode() - switch mode { - case WebSearchModeEnabled: - return true - case WebSearchModeDisabled: + if account == nil { return false - default: // "default" → follow channel config - if groupID == nil || s.channelService == nil { - return false - } - ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) - if err != nil || ch == nil { - return false - } - return ch.IsWebSearchEmulationEnabled(account.Platform) } + + switch { + case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey: + mode := account.GetWebSearchEmulationMode() + switch mode { + case WebSearchModeEnabled: + return true + case WebSearchModeDisabled: + return false + default: + return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform) + } + case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth: + return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform) + default: + return false + } +} + +func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil || ch == nil { + return false + } + return ch.IsWebSearchEmulationEnabled(platform) } // isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool. @@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error { "message": map[string]any{ "id": msgID, "type": "message", "role": "assistant", "model": model, "content": []any{}, "stop_reason": nil, "stop_sequence": nil, - "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, + "usage": map[string]int{ + "input_tokens": 0, + "output_tokens": 0, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, }, } return flushSSEJSON(w, "message_start", evt) @@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index "type": "content_block_start", "index": index, "content_block": map[string]any{ "type": "server_tool_use", "id": toolUseID, - "name": toolNameWebSearch, "input": map[string]string{"query": query}, + "name": toolNameWebSearch, "input": map[string]any{}, }, } if err := flushSSEJSON(w, "content_block_start", start); err != nil { return err } + inputJSON, err := json.Marshal(map[string]string{"query": query}) + if err != nil { + return fmt.Errorf("marshal query: %w", err) + } + if err := flushSSEJSON(w, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": index, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": string(inputJSON), + }, + }); err != nil { + return err + } return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index}) } @@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse( // --- Helpers --- -func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string { - blocks := make([]map[string]string, 0, len(results)) +func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any { + blocks := make([]map[string]any, 0, len(results)) for _, r := range results { - block := map[string]string{ - "type": "web_search_result", - "url": r.URL, - "title": r.Title, - } - if r.Snippet != "" { - block["page_content"] = r.Snippet + block := map[string]any{ + "type": "web_search_result", + "url": r.URL, + "title": r.Title, + "encrypted_content": r.Snippet, + "page_age": nil, } if r.PageAge != "" { block["page_age"] = r.PageAge diff --git a/backend/internal/service/gateway_websearch_emulation_test.go b/backend/internal/service/gateway_websearch_emulation_test.go index de1f0014..bd82dd33 100644 --- a/backend/internal/service/gateway_websearch_emulation_test.go +++ b/backend/internal/service/gateway_websearch_emulation_test.go @@ -5,6 +5,7 @@ package service import ( "context" "encoding/json" + "net/http/httptest" "testing" "time" @@ -13,6 +14,31 @@ import ( "github.com/stretchr/testify/require" ) +func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) { + rec := httptest.NewRecorder() + err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5") + require.NoError(t, err) + + body := rec.Body.String() + require.Contains(t, body, `"cache_creation_input_tokens":0`) + require.Contains(t, body, `"cache_read_input_tokens":0`) +} + +func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) { + rec := httptest.NewRecorder() + err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0) + require.NoError(t, err) + + body := rec.Body.String() + require.Contains(t, body, `event: content_block_start`) + require.Contains(t, body, `"type":"server_tool_use"`) + require.Contains(t, body, `"input":{}`) + require.Contains(t, body, `event: content_block_delta`) + require.Contains(t, body, `"type":"input_json_delta"`) + require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`) + require.Contains(t, body, `event: content_block_stop`) +} + // --- isOnlyWebSearchToolInBody --- func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) { @@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) { require.Len(t, blocks, 2) require.Equal(t, "web_search_result", blocks[0]["type"]) require.Equal(t, "https://a.com", blocks[0]["url"]) - require.Equal(t, "snippet a", blocks[0]["page_content"]) + require.Equal(t, "snippet a", blocks[0]["encrypted_content"]) require.Equal(t, "2 days", blocks[0]["page_age"]) // Second result has no PageAge require.Equal(t, "https://b.com", blocks[1]["url"]) - _, hasPageAge := blocks[1]["page_age"] - require.False(t, hasPageAge) + require.Equal(t, "snippet b", blocks[1]["encrypted_content"]) + require.Nil(t, blocks[1]["page_age"]) } func TestBuildSearchResultBlocks_Empty(t *testing.T) { @@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) { func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) { blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}}) - _, hasContent := blocks[0]["page_content"] - require.False(t, hasContent) + require.Equal(t, "", blocks[0]["encrypted_content"]) + require.Nil(t, blocks[0]["page_age"]) } // --- buildTextSummary --- @@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account { } } +func newKiroOAuthAccount() *Account { + return &Account{ + ID: 2, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + } +} + // setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled. func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) { webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{ @@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) { // nil channelService + default mode → returns false require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) } + +func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 11, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true}, + }, + } + channelSvc := newChannelServiceWithCache(77, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newKiroOAuthAccount() + groupID := int64(77) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 12, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false}, + }, + } + channelSvc := newChannelServiceWithCache(78, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newKiroOAuthAccount() + groupID := int64(78) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + + account := newKiroOAuthAccount() + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} diff --git a/backend/internal/service/kiro_alignment_test.go b/backend/internal/service/kiro_alignment_test.go new file mode 100644 index 00000000..67ebc03b --- /dev/null +++ b/backend/internal/service/kiro_alignment_test.go @@ -0,0 +1,623 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown" + "github.com/stretchr/testify/require" +) + +type kiroUsageCooldownStore struct { + state *kirocooldown.State + err error +} + +func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) { + return 0, nil +} + +func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error { + return nil +} + +func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) { + return 0, nil +} + +func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) { + return 0, nil +} + +func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) { + return s.state, s.err +} + +func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) { + return false, nil +} + +func kiroFloatPtr(v float64) *float64 { + return &v +} + +func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{"kiro": true}, + }, + } + + require.True(t, c.IsWebSearchEmulationEnabled("kiro")) +} + +func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.billingService = NewBillingService(svc.cfg, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "claude-sonnet-4-6": { + InputCostPerToken: 2.5e-6, + OutputCostPerToken: 10e-6, + }, + }, + }) + + expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_kiro_billing_normalized", + Model: "claude-sonnet-4-6", + UpstreamModel: "claude-sonnet-4.6", + Usage: OpenAIUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30, Platform: PlatformKiro}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model) + require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12) +} + +func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) { + account := Account{ + ID: 701, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "Github", + "auth_method": "social", + "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + resetAt := time.Now().Add(10 * 24 * time.Hour).Unix() + bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/getUsageLimits", r.URL.Path) + require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn")) + require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin")) + require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType")) + require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization")) + require.Equal(t, "*/*", r.Header.Get("Accept")) + require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-")) + require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-")) + require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode")) + require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout")) + require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `, + "overageConfiguration": {"overageStatus":"ENABLED"}, + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentOveragesWithPrecision":2, + "currentUsageWithPrecision":125, + "freeTrialInfo":{ + "currentUsageWithPrecision":25, + "freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `, + "freeTrialStatus":"ACTIVE", + "usageLimitWithPrecision":500 + }, + "nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `, + "overageCharges":0.08, + "resourceType":"CREDIT", + "usageLimitWithPrecision":2000 + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, "active", usage.Source) + require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName) + require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType) + require.True(t, usage.KiroOveragesEnabled) + require.NotNil(t, usage.KiroCredit) + require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage) + require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit) + require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001) + require.NotNil(t, usage.KiroBonus) + require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage) + require.Equal(t, 500.0, usage.KiroBonus.UsageLimit) + require.NotNil(t, usage.KiroOverage) + require.Equal(t, "$", usage.KiroOverage.CurrencySymbol) + require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages) + require.Equal(t, 0.08, usage.KiroOverage.OverageCharges) + require.NotNil(t, usage.KiroResetAt) + require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState) + require.Equal(t, "overages_enabled", usage.KiroQuotaReason) + require.NotNil(t, usage.KiroQuotaResetAt) +} + +func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) { + account := Account{ + ID: 702, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "Github", + "auth_method": "social", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":300, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer successServer.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + firstUsage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, firstUsage) + require.NotNil(t, firstUsage.KiroCredit) + require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage) + + failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError) + })) + defer failingServer.Close() + resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL } + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, usage) + require.NotNil(t, usage.KiroCredit) + require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage) + require.Empty(t, usage.Error) + require.Empty(t, usage.ErrorCode) +} + +func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) { + account := Account{ + ID: 703, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "BuilderId", + "auth_method": "idc", + "region": "us-east-1", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/getUsageLimits", r.URL.Path) + require.Empty(t, r.URL.Query().Get("profileArn")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":42, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, usage) + require.NotNil(t, usage.KiroCredit) + require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage) +} + +func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) { + account := Account{ + ID: 707, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "AWS", + "auth_method": "idc", + "region": "us-east-1", + "start_url": "https://d-example.awsapps.com/start", + "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/getUsageLimits", r.URL.Path) + require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":64, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, usage) + require.NotNil(t, usage.KiroCredit) + require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage) +} + +func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) { + account := Account{ + ID: 709, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "AWS", + "auth_method": "idc", + "api_region": "eu-west-1", + "region": "ap-northeast-2", + "start_url": "https://d-example.awsapps.com/start", + "profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION" + gotRegions := make([]string, 0, 2) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/getUsageLimits", r.URL.Path) + require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":11, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(region string) string { + gotRegions = append(gotRegions, region) + return server.URL + } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, []string{"eu-west-1"}, gotRegions) +} + +func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) { + account := Account{ + ID: 710, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "AWS", + "auth_method": "idc", + "region": "ap-northeast-2", + "start_url": "https://d-example.awsapps.com/start", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + gotRegions := make([]string, 0, 2) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/getUsageLimits", r.URL.Path) + require.Empty(t, r.URL.Query().Get("profileArn")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":7, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(region string) string { + gotRegions = append(gotRegions, region) + return server.URL + } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, []string{kiroDefaultRegion}, gotRegions) +} + +func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) { + account := Account{ + ID: 704, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "Github", + "auth_method": "social", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil). + SetKiroCooldownStore(&kiroUsageCooldownStore{ + state: &kirocooldown.State{ + Active: true, + Reason: kirocooldown.CooldownReason429, + CooldownUntil: time.Now().Add(90 * time.Second), + Remaining: 90 * time.Second, + }, + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":42, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + usage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.Equal(t, "cooldown", usage.KiroRuntimeState) + require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason) + require.NotNil(t, usage.KiroRuntimeResetAt) +} + +func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) { + info := buildKiroDegradedUsage(&kiroUsageHTTPError{ + StatusCode: http.StatusBadRequest, + Body: `{"message":"profileArn is required for this request."}`, + }) + + require.Equal(t, errorCodeForbidden, info.ErrorCode) + require.False(t, info.NeedsReauth) +} + +func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) { + info := buildKiroDegradedUsage(&kiroUsageHTTPError{ + StatusCode: http.StatusTooManyRequests, + Body: `{"message":"overage exhausted for this billing window"}`, + }) + + require.Equal(t, errorCodeNetworkError, info.ErrorCode) + require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState) + require.Contains(t, info.KiroQuotaReason, "overage exhausted") +} + +func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) { + account := Account{ + ID: 708, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "Github", + "auth_method": "social", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + firstUsage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, firstUsage) + require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode) + + secondUsage, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, secondUsage) + require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode) + require.Equal(t, 1, requestCount) +} + +func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) { + info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{ + NextDateReset: "2099-03-13T12:00:00Z", + OverageConfiguration: kiroOverageConfiguration{ + OverageStatus: "DISABLED", + }, + UsageBreakdownList: []kiroUsageBreakdown{ + { + ResourceType: "CREDIT", + CurrentUsageWithPrecision: kiroFloatPtr(2000), + UsageLimitWithPrecision: kiroFloatPtr(2000), + CurrentOveragesWithPrecision: kiroFloatPtr(0), + }, + }, + }) + + require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState) + require.Equal(t, "credits_exhausted", info.KiroQuotaReason) + require.NotNil(t, info.KiroQuotaResetAt) +} + +func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) { + svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil). + SetKiroCooldownStore(&kiroUsageCooldownStore{ + state: &kirocooldown.State{ + Active: true, + Reason: kirocooldown.CooldownReason429, + CooldownUntil: time.Now().Add(2 * time.Minute), + Remaining: 2 * time.Minute, + }, + }) + + account := &Account{ + ID: 705, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{"access_token": "kiro-access-token"}, + } + + svc.EnrichAccountWithKiroRuntimeState(context.Background(), account) + require.Equal(t, "cooldown", account.KiroRuntimeState) + require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason) + require.NotNil(t, account.KiroRuntimeResetAt) +} + +func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) { + account := Account{ + ID: 706, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "kiro-access-token", + "provider": "Github", + "auth_method": "social", + }, + } + repo := &stubOpenAIAccountRepo{accounts: []Account{account}} + svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "nextDateReset":"2099-03-13T12:00:00Z", + "overageConfiguration":{"overageStatus":"ENABLED"}, + "subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"}, + "usageBreakdownList": [{ + "currency":"USD", + "currentUsageWithPrecision":2000, + "currentOveragesWithPrecision":4, + "overageCharges":0.2, + "usageLimitWithPrecision":2000, + "resourceType":"CREDIT" + }] + }`)) + })) + defer server.Close() + + prevResolver := resolveKiroRuntimeEndpoint + resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL } + defer func() { resolveKiroRuntimeEndpoint = prevResolver }() + + _, err := svc.GetUsage(context.Background(), account.ID) + require.NoError(t, err) + + target := &Account{ + ID: account.ID, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + Credentials: map[string]any{"access_token": "kiro-access-token"}, + } + svc.EnrichAccountWithKiroRuntimeState(context.Background(), target) + + require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState) + require.Equal(t, "overages_enabled", target.KiroQuotaReason) + require.NotNil(t, target.KiroQuotaResetAt) +} diff --git a/backend/internal/service/kiro_alignment_unit_test.go b/backend/internal/service/kiro_alignment_unit_test.go new file mode 100644 index 00000000..9606ea64 --- /dev/null +++ b/backend/internal/service/kiro_alignment_unit_test.go @@ -0,0 +1,163 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) { + account := Account{ + Type: AccountTypeAPIKey, + Platform: PlatformKiro, + Credentials: map[string]any{}, + } + + require.Empty(t, account.GetBaseURL()) +} + +func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) { + svc := &GatewayService{} + + got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}) + + require.Equal(t, 25*time.Second, got) +} + +func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) { + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamKeepaliveInterval: 10, + KiroStreamKeepaliveInterval: 25, + }, + }, + } + + require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})) + require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic})) +} + +func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) { + svc := NewBillingService(&config.Config{}, nil) + + pricing, err := svc.GetModelPricing("claude-haiku-4-5") + + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12) +} + +func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) { + tests := []struct { + name string + requestedModel string + upstreamModel string + want string + }{ + { + name: "kiro claude sonnet 4.6 uses pricing key format", + requestedModel: "claude-sonnet-4-6", + upstreamModel: "claude-sonnet-4.6", + want: "claude-sonnet-4-6", + }, + { + name: "falls back to upstream when requested model empty", + requestedModel: "", + upstreamModel: "claude-haiku-4-5", + want: "claude-haiku-4-5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel)) + }) + } +} + +func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.billingService = NewBillingService(svc.cfg, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "claude-sonnet-4-6": { + InputCostPerToken: 2.5e-6, + OutputCostPerToken: 10e-6, + }, + }, + }) + + expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_kiro_billing_normalized", + Usage: ClaudeUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Model: "claude-sonnet-4-6", + UpstreamModel: "claude-sonnet-4.6", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701, Platform: PlatformKiro}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model) + require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12) +} + +func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_kiro_auto_fallback_cost", + Usage: ClaudeUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Model: "auto", + UpstreamModel: "auto", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 601, Quota: 100}, + User: &User{ID: 701}, + Account: &Account{ID: 801, Platform: PlatformKiro}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12) +} diff --git a/backend/internal/service/kiro_http_helpers_test.go b/backend/internal/service/kiro_http_helpers_test.go index 8bd2c508..7185974e 100644 --- a/backend/internal/service/kiro_http_helpers_test.go +++ b/backend/internal/service/kiro_http_helpers_test.go @@ -10,6 +10,7 @@ import ( kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) { @@ -144,6 +145,60 @@ func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) { require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e") } +func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) { + account := &Account{ + ID: 8, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + } + body := []byte(`{ + "model":"claude-opus-4.6", + "messages":[{"role":"user","content":"hello"}] + }`) + + payload, err := buildKiroPayloadForAccount( + context.Background(), + account, + body, + "claude-opus-4.6", + "kiro-access-token", + "claude-opus-4-6-thinking", + nil, + ) + require.NoError(t, err) + + require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String()) + systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String() + require.Contains(t, systemContent, "adaptive") + require.Contains(t, systemContent, "high") +} + +func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) { + account := &Account{ + ID: 9, + Platform: PlatformKiro, + Type: AccountTypeOAuth, + } + body := []byte(`{ + "model":"claude-opus-4.6", + "messages":[{"role":"user","content":"hello"}] + }`) + + payload, err := buildKiroPayloadForAccount( + context.Background(), + account, + body, + "claude-opus-4.6", + "kiro-access-token", + "claude-opus-4-6", + nil, + ) + require.NoError(t, err) + + systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String() + require.NotContains(t, systemContent, "") +} + func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) { account := &Account{ Credentials: map[string]any{ diff --git a/backend/internal/service/kiro_runtime.go b/backend/internal/service/kiro_runtime.go index 78ed30b2..49a5654b 100644 --- a/backend/internal/service/kiro_runtime.go +++ b/backend/internal/service/kiro_runtime.go @@ -81,7 +81,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context } if parsed.Stream { - resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header) + resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header) if err != nil { var failoverErr *UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -146,7 +146,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType) } if isOnlyWebSearchToolInBody(body) { - webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header) + webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header) switch { case errors.Is(webSearchErr, errKiroWebSearchFallback): case webSearchErr == nil: @@ -194,7 +194,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context } inputTokens := estimateKiroInputTokens(body) - resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header) + resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header) if err != nil { var failoverErr *UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -253,7 +253,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context }, nil } -func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) { +func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) { token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { return nil, 0, err @@ -268,7 +268,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac headers := make(http.Header) headers.Set("Content-Type", "text/event-stream") go func() { - streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw) + streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw) if streamErr != nil { _ = pw.CloseWithError(streamErr) return @@ -282,7 +282,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac }, inputTokens, nil } - resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers) + resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers) if err != nil { var failoverErr *UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -318,7 +318,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac }, inputTokens, nil } -func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) { +func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) { var requestCtx kiropkg.KiroRequestContext if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil { if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil { @@ -329,7 +329,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou modelID := kiropkg.MapModel(mappedModel) currentToken := token - buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers) + buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers) if err != nil { return nil, requestCtx, err } @@ -432,7 +432,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" { currentToken = refreshedToken accountKey = buildKiroAccountKey(account) - buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers) + buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers) if err != nil { return nil, requestCtx, err } @@ -501,9 +501,25 @@ func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropic func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) { profileArn := resolveKiroPayloadProfileArn(account) + anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel) return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers) } +func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte { + requestModel = strings.TrimSpace(requestModel) + if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") { + return anthropicBody + } + bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String()) + if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") { + return anthropicBody + } + if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok { + return next + } + return anthropicBody +} + func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) { if s == nil || s.accountRepo == nil || account == nil { return diff --git a/backend/internal/service/kiro_runtime_state_test.go b/backend/internal/service/kiro_runtime_state_test.go index 8eeba068..5b26361f 100644 --- a/backend/internal/service/kiro_runtime_state_test.go +++ b/backend/internal/service/kiro_runtime_state_test.go @@ -283,7 +283,7 @@ func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) { }, } - _, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil) + _, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil) require.Error(t, err) var failoverErr *UpstreamFailoverError @@ -323,7 +323,7 @@ func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testi payloadBytes, err := json.Marshal(payload) require.NoError(t, err) - resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil) + resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, resp.StatusCode) require.Len(t, upstream.requests, 1) @@ -461,7 +461,7 @@ func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailov payloadBytes, err := json.Marshal(payload) require.NoError(t, err) - _, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil) + _, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil) require.Error(t, err) var failoverErr *UpstreamFailoverError @@ -501,7 +501,7 @@ func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T) payloadBytes, err := json.Marshal(payload) require.NoError(t, err) - _, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil) + _, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil) require.Error(t, err) var failoverErr *UpstreamFailoverError @@ -550,7 +550,7 @@ func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedu payloadBytes, err := json.Marshal(payload) require.NoError(t, err) - resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil) + resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, resp.StatusCode) require.Equal(t, 1, repo.setErrorCalls) diff --git a/backend/internal/service/kiro_websearch.go b/backend/internal/service/kiro_websearch.go index dc97e992..8e1419cd 100644 --- a/backend/internal/service/kiro_websearch.go +++ b/backend/internal/service/kiro_websearch.go @@ -102,7 +102,7 @@ func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens in } func (s *GatewayService) streamKiroWebSearchAsAnthropic( - ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer, + ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer, ) error { query := kiropkg.ExtractSearchQuery(anthropicBody) if strings.TrimSpace(query) == "" { @@ -141,7 +141,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic( return errKiroWebSearchFallback } - resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers) + resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers) if err != nil { return err } @@ -190,7 +190,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic( return fmt.Errorf("kiro web search exceeded max iterations") } -func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) { +func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) { query := kiropkg.ExtractSearchQuery(anthropicBody) if strings.TrimSpace(query) == "" { return nil, errKiroWebSearchFallback @@ -227,7 +227,7 @@ func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Acco return nil, errKiroWebSearchFallback } - resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers) + resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers) if err != nil { return nil, err } diff --git a/backend/internal/service/openai_privacy_retry_test.go b/backend/internal/service/openai_privacy_retry_test.go index 24534ea9..9f4f7aa9 100644 --- a/backend/internal/service/openai_privacy_retry_test.go +++ b/backend/internal/service/openai_privacy_retry_test.go @@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi t.Run(mode, func(t *testing.T) { t.Parallel() - service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil) + service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil) privacyCalls := 0 service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) { privacyCalls++ diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index cd3974a0..82b334f0 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -271,6 +271,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error { out := *ev out.Platform = strings.TrimSpace(out.Platform) + out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128) + out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128) + out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128) out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 91a02901..0e1301ba 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string { // - models/gemini-2.0-flash-exp // - publishers/google/models/gemini-2.5-pro // - projects/.../locations/.../publishers/google/models/gemini-2.5-pro - model = strings.TrimSpace(model) + model = canonicalModelNameForPricing(model) model = strings.TrimLeft(model, "/") model = strings.TrimPrefix(model, "models/") model = strings.TrimPrefix(model, "publishers/google/models/") @@ -625,7 +625,31 @@ func normalizeModelNameForPricing(model string) string { } model = strings.TrimLeft(model, "/") - return model + return canonicalModelNameForPricing(model) +} + +func canonicalModelNameForPricing(model string) string { + model = strings.ToLower(strings.TrimSpace(model)) + if model == "" { + return "" + } + + switch model { + case "claude-opus-4.5": + return "claude-opus-4-5" + case "claude-opus-4.6": + return "claude-opus-4-6" + case "claude-opus-4.7": + return "claude-opus-4-7" + case "claude-sonnet-4.5": + return "claude-sonnet-4-5" + case "claude-sonnet-4.6": + return "claude-sonnet-4-6" + case "claude-haiku-4.5": + return "claude-haiku-4-5" + default: + return model + } } func lastSegment(model string) string { @@ -671,8 +695,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { {name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}}, {name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}}, {name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}}, + {name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}}, {name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}}, {name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}}, + {name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}}, {name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}}, {name: "sonnet-3", match: []string{"claude-3-sonnet"}}, {name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}}, @@ -710,6 +736,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { } case strings.Contains(model, "sonnet"): switch { + case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"): + fallbackName = "sonnet-4.6" case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"): fallbackName = "sonnet-4.5" case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"): @@ -719,6 +747,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { } case strings.Contains(model, "haiku"): switch { + case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"): + fallbackName = "haiku-4.5" case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"): fallbackName = "haiku-3.5" default: diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index a68cdf0c..16bbfde3 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI if len(groupIDs) == 0 { return nil } - platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} + platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro} var firstErr error for _, platform := range platforms { if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil { @@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration { func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) { buckets := make([]SchedulerBucket, 0) - platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} + platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro} for _, platform := range platforms { buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle}) buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced}) diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 2179a85e..cd9f756a 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 5, Platform: PlatformGemini, @@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 6, Platform: PlatformGemini, @@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil) account := &Account{ ID: 7, Platform: PlatformGemini, @@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 8, Platform: PlatformAntigravity, @@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 9, Platform: PlatformGemini, @@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 10, Platform: PlatformOpenAI, // OpenAI OAuth 账户 @@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing. RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil) resetAt := time.Now().Add(30 * time.Minute) account := &Account{ ID: 17, @@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 11, Platform: PlatformGemini, @@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 12, Platform: PlatformGemini, @@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 13, Platform: PlatformAntigravity, @@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 14, Platform: PlatformAntigravity, @@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache) until := time.Now().Add(10 * time.Minute) account := &Account{ ID: 15, @@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 16, Platform: tt.platform, @@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil) account := &Account{ ID: 18, Platform: PlatformOpenAI, @@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) refreshAPI := NewOAuthRefreshAPI(repo, cache) service.SetRefreshAPI(refreshAPI) @@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil) refreshAPI := NewOAuthRefreshAPI(repo, cache) service.SetRefreshAPI(refreshAPI) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index dfc363b5..252b8093 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -89,6 +89,10 @@ security: enabled: false # Allowed upstream hosts for API proxying # 允许代理的上游 API 主机列表 + # If you enable Kiro OAuth / IDC, also allow Kiro auth and AWS SSO OIDC hosts. + # 如果启用 Kiro OAuth / IDC,请同时放行 Kiro 鉴权域名和 AWS SSO OIDC 域名。 + # If you enable Kiro runtime forwarding, also allow the corresponding AWS Q API endpoint. + # 如果启用 Kiro 运行时转发,还需要放行对应的 AWS Q API 域名。 upstream_hosts: - "api.openai.com" - "api.anthropic.com" @@ -97,6 +101,11 @@ security: - "api.minimaxi.com" - "generativelanguage.googleapis.com" - "cloudcode-pa.googleapis.com" + - "prod.us-east-1.auth.desktop.kiro.dev" + - "oidc.us-east-1.amazonaws.com" + - "oidc.*.amazonaws.com" + - "device.sso.*.amazonaws.com" + - "q.*.amazonaws.com" - "*.openai.azure.com" # Allowed hosts for pricing data download # 允许下载定价数据的主机列表 @@ -340,6 +349,9 @@ gateway: # Stream keepalive interval (seconds), 0=disable # 流式 keepalive 间隔(秒),0=禁用 stream_keepalive_interval: 10 + # Kiro stream keepalive interval (seconds), 0=use default 25 + # Kiro 流式 keepalive 间隔(秒),0=使用默认 25 + kiro_stream_keepalive_interval: 25 # SSE max line size in bytes (default: 40MB) # SSE 单行最大字节数(默认 40MB) max_line_size: 41943040 diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 8a127793..882dff67 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -558,6 +558,13 @@ export async function getAntigravityDefaultModelMapping(): Promise> { + const { data } = await apiClient.get>( + '/admin/accounts/kiro/default-model-mapping' + ) + return data +} + /** * Refresh OpenAI token using refresh token * @param refreshToken - The refresh token diff --git a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts index f758e6b0..ded857a1 100644 --- a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts +++ b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts @@ -13,6 +13,15 @@ vi.mock('vue-i18n', async () => { } }) +vi.mock('@/i18n', () => ({ + i18n: { + global: { + t: (key: string) => key + } + }, + getLocale: () => 'en' +})) + function makeAccount(overrides: Partial): Account { return { id: 1, @@ -35,6 +44,12 @@ function makeAccount(overrides: Partial): Account { overload_until: null, temp_unschedulable_until: null, temp_unschedulable_reason: null, + kiro_quota_state: null, + kiro_quota_reason: null, + kiro_quota_reset_at: null, + kiro_runtime_state: null, + kiro_runtime_reason: null, + kiro_runtime_reset_at: null, session_window_start: null, session_window_end: null, session_window_status: null, @@ -159,4 +174,96 @@ describe('AccountStatusIndicator', () => { // AICredits 积分耗尽状态应显示 expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted') }) + + it('Kiro 运行时冷却在状态列复用限流展示', () => { + const wrapper = mount(AccountStatusIndicator, { + props: { + account: makeAccount({ + id: 5, + name: 'kiro-cooldown', + platform: 'kiro', + kiro_runtime_state: 'cooldown', + kiro_runtime_reason: 'rate_limit_exceeded', + kiro_runtime_reset_at: '2099-03-15T00:00:00Z' + }) + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).toContain('admin.accounts.status.rateLimited') + expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedAutoResume') + expect(wrapper.text()).toContain('429') + }) + + it('Kiro suspended 在状态列显示为 forbidden', () => { + const wrapper = mount(AccountStatusIndicator, { + props: { + account: makeAccount({ + id: 6, + name: 'kiro-suspended', + platform: 'kiro', + kiro_runtime_state: 'suspended', + kiro_runtime_reason: 'account_suspended', + kiro_runtime_reset_at: '2099-03-15T00:00:00Z' + }) + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).toContain('admin.accounts.forbidden') + }) + + it('Kiro overage active 在状态列仍显示正常状态', () => { + const wrapper = mount(AccountStatusIndicator, { + props: { + account: makeAccount({ + id: 7, + name: 'kiro-overage-active', + platform: 'kiro', + kiro_quota_state: 'overage_active', + kiro_quota_reason: 'overages_enabled', + kiro_quota_reset_at: '2099-03-15T00:00:00Z' + }) + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).toContain('admin.accounts.status.active') + expect(wrapper.text()).not.toContain('admin.accounts.status.overageActive') + }) + + it('Kiro overage exhausted 在状态列显示危险徽章', () => { + const wrapper = mount(AccountStatusIndicator, { + props: { + account: makeAccount({ + id: 8, + name: 'kiro-overage-exhausted', + platform: 'kiro', + kiro_quota_state: 'overage_exhausted', + kiro_quota_reason: 'overage disabled after quota exhaustion', + kiro_quota_reset_at: '2099-03-15T00:00:00Z' + }) + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).toContain('admin.accounts.status.overageExhausted') + expect(wrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil') + }) }) diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts index fa4104f6..5634a605 100644 --- a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts +++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts @@ -25,6 +25,15 @@ vi.mock('vue-i18n', async () => { } }) +vi.mock('@/i18n', () => ({ + i18n: { + global: { + t: (key: string) => key + } + }, + getLocale: () => 'en' +})) + function makeAccount(overrides: Partial): Account { return { id: 1, @@ -47,6 +56,12 @@ function makeAccount(overrides: Partial): Account { overload_until: null, temp_unschedulable_until: null, temp_unschedulable_reason: null, + kiro_quota_state: null, + kiro_quota_reason: null, + kiro_quota_reset_at: null, + kiro_runtime_state: null, + kiro_runtime_reason: null, + kiro_runtime_reset_at: null, session_window_start: null, session_window_end: null, session_window_status: null, @@ -530,6 +545,234 @@ describe('AccountUsageCell', () => { expect(wrapper.text()).toContain('7d|100|106540000') }) + it('Kiro OAuth 会用 passive source 拉取并展示 credits 额度', async () => { + const account = makeAccount({ + id: 3001, + platform: 'kiro', + type: 'oauth', + extra: {}, + credentials: {} + }) + + getUsage.mockResolvedValue({ + source: 'passive', + kiro_subscription_name: 'KIRO PRO+', + kiro_overages_enabled: true, + kiro_credit: { + current_usage: 125, + usage_limit: 2000, + percentage_used: 6.25, + }, + kiro_bonus: { + current_usage: 25, + usage_limit: 500, + percentage_used: 5, + days_remaining: 7, + }, + kiro_overage: { + current_overages: 2, + overage_charges: 0.08, + currency_symbol: '$', + currency_code: 'USD', + }, + kiro_reset_at: '2099-03-13T12:00:00Z', + }) + + const wrapper = mount(AccountUsageCell, { + props: { + account + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + + expect(getUsage).toHaveBeenCalledWith(3001, 'passive') + expect(wrapper.emitted('kiroUsageMeta')?.[0]).toEqual([ + { + plan_type: 'KIRO PRO+', + kiro_overages_enabled: true + } + ]) + expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroCredits') + expect(wrapper.text()).toContain('125 / 2.0K') + expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroBonus') + expect(wrapper.text()).toContain('25 / 500') + expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroDaysLeft') + expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroReset') + expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroOverage 2 ($0.08)') + }) + + it('Kiro OAuth 会展示运行时冷却状态', async () => { + getUsage.mockResolvedValue({ + source: 'passive', + kiro_runtime_state: 'cooldown', + kiro_runtime_reason: 'rate_limit_exceeded', + kiro_runtime_reset_at: '2099-03-13T12:00:00Z', + kiro_credit: { + current_usage: 10, + usage_limit: 100, + percentage_used: 10, + }, + }) + + const wrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 3002, + platform: 'kiro', + type: 'oauth', + extra: {}, + credentials: {} + }) + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('admin.accounts.status.rateLimited') + expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedUntil') + }) + + it('Kiro OAuth 会展示 overage active 与 exhausted 状态', async () => { + getUsage.mockResolvedValueOnce({ + source: 'passive', + kiro_quota_state: 'overage_active', + kiro_quota_reason: 'overages_enabled', + kiro_quota_reset_at: '2099-03-13T12:00:00Z', + kiro_overages_enabled: true, + kiro_credit: { + current_usage: 2100, + usage_limit: 2000, + percentage_used: 100, + }, + kiro_overage: { + current_overages: 3, + overage_charges: 0.12, + currency_symbol: '$', + }, + }) + + const activeWrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 3005, + platform: 'kiro', + type: 'oauth', + extra: {}, + credentials: {} + }) + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + expect(activeWrapper.text()).toContain('admin.accounts.status.overageActive') + expect(activeWrapper.text()).not.toContain('admin.accounts.status.overageActiveUntil') + + getUsage.mockResolvedValueOnce({ + source: 'passive', + kiro_quota_state: 'overage_exhausted', + kiro_quota_reason: 'usage API error: overage exhausted', + kiro_quota_reset_at: '2099-03-13T12:00:00Z', + error: 'usage API error: kiro usage request failed (status 429): {"message":"overage exhausted"}', + }) + + const exhaustedWrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 3006, + platform: 'kiro', + type: 'oauth', + extra: {}, + credentials: {} + }) + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhausted') + expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil') + }) + + it('Kiro OAuth 会展示 profile 异常和 usage forbidden 徽章', async () => { + getUsage.mockResolvedValueOnce({ + source: 'passive', + error_code: 'forbidden', + error: 'usage API error: kiro usage request failed (status 400): {"message":"profileArn is required for this request."}', + }) + + const profileWrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 3003, + platform: 'kiro', + type: 'oauth', + extra: {}, + credentials: {} + }) + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + expect(profileWrapper.text()).toContain('admin.accounts.usageError') + + getUsage.mockResolvedValueOnce({ + source: 'passive', + error_code: 'forbidden', + error: 'usage API error: kiro usage request failed (status 403): {"message":"User is not authorized to access this feature."}', + }) + + const forbiddenWrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 3004, + platform: 'kiro', + type: 'oauth', + extra: {}, + credentials: {} + }) + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + expect(forbiddenWrapper.text()).toContain('admin.accounts.forbidden') + }) + it('Key 账号会展示 today stats 徽章并带 A/U 提示', async () => { const wrapper = mount(AccountUsageCell, { props: { diff --git a/frontend/src/components/account/__tests__/OAuthAuthorizationFlow.spec.ts b/frontend/src/components/account/__tests__/OAuthAuthorizationFlow.spec.ts new file mode 100644 index 00000000..efba0ba1 --- /dev/null +++ b/frontend/src/components/account/__tests__/OAuthAuthorizationFlow.spec.ts @@ -0,0 +1,56 @@ +import { describe, expect, it, vi } from 'vitest' +import { nextTick } from 'vue' +import { mount } from '@vue/test-utils' + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: vi.fn(), + showError: vi.fn() + }) +})) + +vi.mock('@/composables/useClipboard', () => ({ + useClipboard: () => ({ + copied: { value: false }, + copyToClipboard: vi.fn() + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +import OAuthAuthorizationFlow from '../OAuthAuthorizationFlow.vue' + +describe('OAuthAuthorizationFlow', () => { + it('extracts code, state, and callback metadata from a full Kiro callback URL', async () => { + const wrapper = mount(OAuthAuthorizationFlow, { + props: { + addMethod: 'oauth', + platform: 'kiro', + authUrl: 'https://example.com/authorize', + sessionId: 'session-1' + }, + global: { + stubs: { + Icon: true + } + } + }) + + const textarea = wrapper.get('textarea') + await textarea.setValue('http://localhost:49153/oauth/callback?code=abc123&state=state456&login_option=github') + await nextTick() + + expect((textarea.element as HTMLTextAreaElement).value).toBe('abc123') + expect((wrapper.vm as any).oauthState).toBe('state456') + expect((wrapper.vm as any).oauthCallbackPath).toBe('/oauth/callback') + expect((wrapper.vm as any).oauthLoginOption).toBe('github') + }) +}) diff --git a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue index 2ed6ded3..b4df5629 100644 --- a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue +++ b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue @@ -489,7 +489,8 @@ const platformOptions = [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, - { value: 'antigravity', label: 'Antigravity' } + { value: 'antigravity', label: 'Antigravity' }, + { value: 'kiro', label: 'Kiro' } ] // Load rules when dialog opens diff --git a/frontend/src/components/common/GroupBadge.vue b/frontend/src/components/common/GroupBadge.vue index 3303d909..e8265e79 100644 --- a/frontend/src/components/common/GroupBadge.vue +++ b/frontend/src/components/common/GroupBadge.vue @@ -115,7 +115,7 @@ const labelClass = computed(() => { } // 正常状态或无天数:根据平台显示主题色 - if (props.platform === 'anthropic') { + if (props.platform === 'anthropic' || props.platform === 'kiro') { return `${base} bg-orange-200/60 text-orange-800 dark:bg-orange-800/40 dark:text-orange-300` } if (props.platform === 'openai') { @@ -129,7 +129,7 @@ const labelClass = computed(() => { // Badge color based on platform and subscription type const badgeClass = computed(() => { - if (props.platform === 'anthropic') { + if (props.platform === 'anthropic' || props.platform === 'kiro') { // Claude: orange theme return isSubscription.value ? 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' diff --git a/frontend/src/components/common/GroupOptionItem.vue b/frontend/src/components/common/GroupOptionItem.vue index 28b5d6e3..3bbe8dc8 100644 --- a/frontend/src/components/common/GroupOptionItem.vue +++ b/frontend/src/components/common/GroupOptionItem.vue @@ -91,6 +91,8 @@ const ratePillClass = computed(() => { return 'bg-green-50 text-green-700 dark:bg-green-900/20 dark:text-green-400' case 'gemini': return 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400' + case 'kiro': + return 'bg-amber-50 text-amber-700 dark:bg-amber-900/20 dark:text-amber-400' default: // antigravity and others return 'bg-violet-50 text-violet-700 dark:bg-violet-900/20 dark:text-violet-400' } diff --git a/frontend/src/components/common/__tests__/PlatformTypeBadge.spec.ts b/frontend/src/components/common/__tests__/PlatformTypeBadge.spec.ts new file mode 100644 index 00000000..3ce11856 --- /dev/null +++ b/frontend/src/components/common/__tests__/PlatformTypeBadge.spec.ts @@ -0,0 +1,42 @@ +import { mount } from '@vue/test-utils' +import { describe, expect, it, vi } from 'vitest' +import PlatformTypeBadge from '../PlatformTypeBadge.vue' + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key === 'admin.accounts.status.overageActive' ? 'Overage' : key + }) + } +}) + +describe('PlatformTypeBadge', () => { + it('shows Kiro overages tag next to the plan tag when enabled', () => { + const wrapper = mount(PlatformTypeBadge, { + props: { + platform: 'kiro', + type: 'oauth', + planType: 'KIRO PRO+', + overagesEnabled: true + } + }) + + expect(wrapper.text()).toContain('KIRO PRO+') + expect(wrapper.text()).toContain('Overage') + }) + + it('does not show overages tag for non-Kiro accounts', () => { + const wrapper = mount(PlatformTypeBadge, { + props: { + platform: 'openai', + type: 'oauth', + planType: 'Pro', + overagesEnabled: true + } + }) + + expect(wrapper.text()).not.toContain('Overage') + }) +}) diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts index 4a17d858..79950637 100644 --- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts +++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts @@ -4,7 +4,12 @@ vi.mock('@/api/admin/accounts', () => ({ getAntigravityDefaultModelMapping: vi.fn() })) -import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist' +import { + buildModelMappingObject, + fetchKiroDefaultMappings, + getModelsByPlatform, + getPresetMappingsByPlatform +} from '../useModelWhitelist' describe('useModelWhitelist', () => { it('openai 模型列表包含 GPT-5.4 官方快照', () => { @@ -59,6 +64,49 @@ describe('useModelWhitelist', () => { expect(models.every((model) => !model.endsWith('-agentic') && !model.endsWith('-chat'))).toBe(true) }) + it('kiro 模型列表只保留 Claude 模型', () => { + const models = getModelsByPlatform('kiro') + + expect(models).toEqual([ + 'claude-opus-4-6', + 'claude-opus-4-6-thinking', + 'claude-sonnet-4-6', + 'claude-sonnet-4-6-thinking', + 'claude-opus-4-5-20251101', + 'claude-opus-4-5-20251101-thinking', + 'claude-sonnet-4-5-20250929', + 'claude-sonnet-4-5-20250929-thinking', + 'claude-haiku-4-5-20251001', + 'claude-haiku-4-5-20251001-thinking' + ]) + expect(models.every(model => model.startsWith('claude-'))).toBe(true) + expect(models.some(model => model.endsWith('-agentic'))).toBe(false) + expect(models.some(model => model.endsWith('-chat'))).toBe(false) + expect(models).not.toContain('kiro-auto') + expect(models).not.toContain('claude-opus-4-5') + expect(models).not.toContain('claude-sonnet-4-5') + expect(models).not.toContain('claude-sonnet-4') + expect(models).not.toContain('claude-3-5-sonnet-20241022') + expect(models).not.toContain('claude-3-5-haiku-20241022') + expect(models).not.toContain('claude-haiku-4-5') + expect(models).not.toContain('gpt-4o') + expect(models).not.toContain('gpt-4') + expect(models).not.toContain('gpt-4-turbo') + expect(models).not.toContain('gpt-3.5-turbo') + expect(models).not.toContain('deepseek-3-2') + expect(models).not.toContain('minimax-m2-1') + expect(models).not.toContain('qwen3-coder-next') + }) + + it('claude 模型列表包含 dated 和 thinking 兼容别名', () => { + const models = getModelsByPlatform('claude') + + expect(models).toContain('claude-opus-4-6-thinking') + expect(models).toContain('claude-opus-4-5-20251101-thinking') + expect(models).toContain('claude-sonnet-4-20250514-thinking') + expect(models).toContain('claude-haiku-4-5-20251001-thinking') + }) + it('whitelist 模式会忽略通配符条目', () => { const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], []) expect(mapping).toEqual({ @@ -81,4 +129,68 @@ describe('useModelWhitelist', () => { 'gpt-5.4-mini': 'gpt-5.4-mini' }) }) + + it('kiro 预设映射只暴露 Claude 入口', () => { + const mappings = getPresetMappingsByPlatform('kiro') + const mappingTargets = mappings.map(item => item.to) + + expect(mappings.map(({ from, to }) => ({ from, to }))).toEqual([ + { from: 'claude-opus-4-6', to: 'claude-opus-4.6' }, + { from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' }, + { from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' }, + { from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' }, + { from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' }, + { from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' }, + { from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' }, + { from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' }, + { from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' }, + { from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' } + ]) + expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true) + expect(mappingTargets.every(model => model.startsWith('claude-'))).toBe(true) + expect(mappingTargets.some(model => model.endsWith('-agentic'))).toBe(false) + expect(mappingTargets.some(model => model.endsWith('-chat'))).toBe(false) + expect(mappingTargets).not.toContain('kiro-auto') + expect(mappingTargets.some(model => model.startsWith('kiro-'))).toBe(false) + expect(mappings.some(item => item.from === 'claude-opus-4-5')).toBe(false) + expect(mappings.some(item => item.from === 'claude-sonnet-4-5')).toBe(false) + expect(mappings.some(item => item.from === 'claude-sonnet-4')).toBe(false) + expect(mappings.some(item => item.from === 'claude-3-5-sonnet-20241022')).toBe(false) + expect(mappings.some(item => item.from === 'claude-3-5-haiku-20241022')).toBe(false) + expect(mappings.some(item => item.from === 'claude-haiku-4-5')).toBe(false) + expect(mappingTargets).not.toContain('gpt-4o') + expect(mappingTargets).not.toContain('gpt-4') + expect(mappingTargets).not.toContain('gpt-4-turbo') + expect(mappingTargets).not.toContain('gpt-3.5-turbo') + expect(mappingTargets).not.toContain('deepseek-3.2') + expect(mappingTargets).not.toContain('minimax-m2.1') + expect(mappingTargets).not.toContain('qwen3-coder-next') + }) + + it('kiro 默认映射会在前端填充所有可精确定价模型', async () => { + const mappings = await fetchKiroDefaultMappings() + + expect(mappings).toEqual(expect.arrayContaining([ + { from: 'claude-opus-4-6', to: 'claude-opus-4.6' }, + { from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' }, + { from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' }, + { from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' }, + { from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' }, + { from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' }, + { from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' }, + { from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' }, + { from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' }, + { from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' } + ])) + expect(mappings).toHaveLength(10) + expect(mappings.every(item => !item.from.startsWith('kiro-'))).toBe(true) + expect(mappings.every(item => !item.to.startsWith('kiro-'))).toBe(true) + expect(mappings.every(item => !item.from.endsWith('-agentic'))).toBe(true) + expect(mappings.every(item => !item.to.endsWith('-agentic'))).toBe(true) + expect(mappings.every(item => !item.from.endsWith('-chat'))).toBe(true) + expect(mappings.every(item => !item.to.endsWith('-chat'))).toBe(true) + expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true) + expect(mappings.every(item => item.to.startsWith('claude-'))).toBe(true) + expect(mappings.some(item => item.to === 'claude-opus-4-7')).toBe(false) + }) }) diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 4e953fdc..8b6e4993 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -23,13 +23,21 @@ export const claudeModels = [ 'claude-3-5-sonnet-20241022', 'claude-3-5-sonnet-20240620', 'claude-3-5-haiku-20241022', 'claude-3-7-sonnet-20250219', + 'claude-sonnet-4-20250514-thinking', 'claude-opus-4-20250514-thinking', 'claude-sonnet-4-20250514', 'claude-opus-4-20250514', 'claude-opus-4-1-20250805', + 'claude-sonnet-4-5-thinking', 'claude-sonnet-4-5-20250929-thinking', + 'claude-haiku-4-5-thinking', 'claude-haiku-4-5-20251001-thinking', + 'claude-opus-4-5-thinking', 'claude-opus-4-5-20251101-thinking', 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001', 'claude-opus-4-5-20251101', + 'claude-opus-4-6-thinking', 'claude-opus-4-6', + 'claude-opus-4-7-thinking', 'claude-opus-4-7', - 'claude-sonnet-4-6' + 'claude-sonnet-4-6-thinking', + 'claude-sonnet-4-6', + 'claude-2.1', 'claude-2.0', 'claude-instant-1.2' ] // Google Gemini diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index e4452b98..d6c57a5b 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -324,8 +324,8 @@ - -
+ +