feat: complete kiro platform support

This commit is contained in:
nianzs
2026-04-29 18:20:46 +08:00
parent fcaa8ea86a
commit b09bcb6a3c
39 changed files with 2063 additions and 164 deletions
@@ -57,6 +57,8 @@ func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-sonnet-4",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"gpt-4o",
"gpt-4",
"deepseek-3-2",
@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"slices"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 44,
Name: "kiro-oauth",
Platform: service.PlatformKiro,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
}
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 47,
Name: "kiro-oauth-mapped",
Platform: service.PlatformKiro,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4.6",
"custom-model": "custom-upstream-model",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 2)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
require.True(t, slices.Contains(ids, "custom-model"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
}
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 45,
Name: "kiro-apikey",
Platform: service.PlatformKiro,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4.6",
"custom-model": "custom-upstream-model",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 2)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
require.True(t, slices.Contains(ids, "custom-model"))
}
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 46,
Name: "kiro-apikey-defaults",
Platform: service.PlatformKiro,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
}
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -120,7 +120,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -0,0 +1,16 @@
package admin
import (
"testing"
"github.com/gin-gonic/gin/binding"
"github.com/stretchr/testify/require"
)
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
require.NoError(t, binding.Validator.ValidateStruct(createReq))
updateReq := UpdateGroupRequest{Platform: "kiro"}
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
}
@@ -36,11 +36,12 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformKiro = "kiro"
)
// AllPlatforms 返回所有支持的平台列表
func AllPlatforms() []string {
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
}
// Validate 验证规则配置的有效性
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
service.PlatformOpenAI: 1,
service.PlatformGemini: 1,
service.PlatformAntigravity: 2,
service.PlatformKiro: 1,
}
for platform, minCount := range requiredByPlatform {
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
requestedModel: "any-model",
expected: true,
},
{
name: "kiro no mapping falls back to default whitelist",
platform: PlatformKiro,
credentials: nil,
requestedModel: "claude-sonnet-4-6",
expected: true,
},
{
name: "kiro no mapping rejects model outside default whitelist",
platform: PlatformKiro,
credentials: nil,
requestedModel: "auto",
expected: false,
},
// 精确匹配
{
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
{
name: "kiro no mapping uses default upstream mapping",
platform: PlatformKiro,
credentials: nil,
requestedModel: "claude-sonnet-4-6",
expected: "claude-sonnet-4.6",
},
// 精确匹配
{
+17 -2
View File
@@ -194,6 +194,13 @@ func (s *BillingService) initFallbackPricing() {
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
@@ -268,13 +275,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
return s.fallbackPrices["claude-3-opus"]
}
if strings.Contains(modelLower, "sonnet") {
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
switch {
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
return s.fallbackPrices["claude-sonnet-4.6"]
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
return s.fallbackPrices["claude-sonnet-4.5"]
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
return s.fallbackPrices["claude-sonnet-4"]
}
return s.fallbackPrices["claude-3-5-sonnet"]
}
if strings.Contains(modelLower, "haiku") {
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
switch {
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
return s.fallbackPrices["claude-haiku-4.5"]
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
return s.fallbackPrices["claude-3-5-haiku"]
}
return s.fallbackPrices["claude-3-haiku"]
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
@@ -105,44 +109,63 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
var resp *http.Response
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
} else {
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
@@ -102,44 +106,63 @@ func (s *GatewayService) ForwardAsResponses(
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
var resp *http.Response
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
} else {
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
nil,
)
}
+15 -1
View File
@@ -912,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
injectMetadata bool
metadataUserID string
stripSystemCacheControl bool
preserveToolChoice bool
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
@@ -1126,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
modified = true
}
}
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
if !gjson.GetBytes(out, "max_tokens").Exists() {
@@ -4518,7 +4525,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
normalizeOpts := claudeOAuthNormalizeOptions{
stripSystemCacheControl: !systemRewritten,
preserveToolChoice: account.Platform == PlatformKiro,
}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
@@ -8968,6 +8978,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
return nil
}
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
return nil
}
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
// Anthropic API Key keeps the existing account-level override:
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
// Kiro OAuth uses channel-level switch only.
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
@@ -62,22 +64,37 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
return false
}
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
if account == nil {
return false
default: // "default" → follow channel config
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
switch {
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default:
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
}
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
default:
return false
}
}
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(platform)
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
"usage": map[string]int{
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
}
return flushSSEJSON(w, "message_start", evt)
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
"name": toolNameWebSearch, "input": map[string]any{},
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
inputJSON, err := json.Marshal(map[string]string{"query": query})
if err != nil {
return fmt.Errorf("marshal query: %w", err)
}
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
}); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
blocks := make([]map[string]any, 0, len(results))
for _, r := range results {
block := map[string]string{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
block := map[string]any{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
"encrypted_content": r.Snippet,
"page_age": nil,
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
@@ -13,6 +14,31 @@ import (
"github.com/stretchr/testify/require"
)
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
rec := httptest.NewRecorder()
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, `"cache_creation_input_tokens":0`)
require.Contains(t, body, `"cache_read_input_tokens":0`)
}
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
rec := httptest.NewRecorder()
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, `event: content_block_start`)
require.Contains(t, body, `"type":"server_tool_use"`)
require.Contains(t, body, `"input":{}`)
require.Contains(t, body, `event: content_block_delta`)
require.Contains(t, body, `"type":"input_json_delta"`)
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
require.Contains(t, body, `event: content_block_stop`)
}
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
require.Nil(t, blocks[1]["page_age"])
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
require.Equal(t, "", blocks[0]["encrypted_content"])
require.Nil(t, blocks[0]["page_age"])
}
// --- buildTextSummary ---
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
}
}
func newKiroOAuthAccount() *Account {
return &Account{
ID: 2,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
// nil channelService + default mode → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 11,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
},
}
channelSvc := newChannelServiceWithCache(77, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newKiroOAuthAccount()
groupID := int64(77)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 12,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
},
}
channelSvc := newChannelServiceWithCache(78, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newKiroOAuthAccount()
groupID := int64(78)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newKiroOAuthAccount()
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
@@ -0,0 +1,623 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/stretchr/testify/require"
)
type kiroUsageCooldownStore struct {
state *kirocooldown.State
err error
}
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
return nil
}
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
return s.state, s.err
}
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
return false, nil
}
func kiroFloatPtr(v float64) *float64 {
return &v
}
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
}
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.billingService = NewBillingService(svc.cfg, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-sonnet-4-6": {
InputCostPerToken: 2.5e-6,
OutputCostPerToken: 10e-6,
},
},
})
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_kiro_billing_normalized",
Model: "claude-sonnet-4-6",
UpstreamModel: "claude-sonnet-4.6",
Usage: OpenAIUsage{
InputTokens: 20,
OutputTokens: 10,
},
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
account := Account{
ID: 701,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
require.Equal(t, "*/*", r.Header.Get("Accept"))
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
"overageConfiguration": {"overageStatus":"ENABLED"},
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
"usageBreakdownList": [{
"currency":"USD",
"currentOveragesWithPrecision":2,
"currentUsageWithPrecision":125,
"freeTrialInfo":{
"currentUsageWithPrecision":25,
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
"freeTrialStatus":"ACTIVE",
"usageLimitWithPrecision":500
},
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
"overageCharges":0.08,
"resourceType":"CREDIT",
"usageLimitWithPrecision":2000
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, "active", usage.Source)
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
require.True(t, usage.KiroOveragesEnabled)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
require.NotNil(t, usage.KiroBonus)
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
require.NotNil(t, usage.KiroOverage)
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
require.NotNil(t, usage.KiroResetAt)
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
require.NotNil(t, usage.KiroQuotaResetAt)
}
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
account := Account{
ID: 702,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":300,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer successServer.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, firstUsage)
require.NotNil(t, firstUsage.KiroCredit)
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
}))
defer failingServer.Close()
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
require.Empty(t, usage.Error)
require.Empty(t, usage.ErrorCode)
}
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
account := Account{
ID: 703,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "BuilderId",
"auth_method": "idc",
"region": "us-east-1",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Empty(t, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":42,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
}
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
account := Account{
ID: 707,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"region": "us-east-1",
"start_url": "https://d-example.awsapps.com/start",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":64,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
}
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
account := Account{
ID: 709,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"api_region": "eu-west-1",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
gotRegions := make([]string, 0, 2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":11,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(region string) string {
gotRegions = append(gotRegions, region)
return server.URL
}
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, []string{"eu-west-1"}, gotRegions)
}
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
account := Account{
ID: 710,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
gotRegions := make([]string, 0, 2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Empty(t, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":7,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(region string) string {
gotRegions = append(gotRegions, region)
return server.URL
}
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
}
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
account := Account{
ID: 704,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
SetKiroCooldownStore(&kiroUsageCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(90 * time.Second),
Remaining: 90 * time.Second,
},
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":42,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.Equal(t, "cooldown", usage.KiroRuntimeState)
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
require.NotNil(t, usage.KiroRuntimeResetAt)
}
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
StatusCode: http.StatusBadRequest,
Body: `{"message":"profileArn is required for this request."}`,
})
require.Equal(t, errorCodeForbidden, info.ErrorCode)
require.False(t, info.NeedsReauth)
}
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
StatusCode: http.StatusTooManyRequests,
Body: `{"message":"overage exhausted for this billing window"}`,
})
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
}
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
account := Account{
ID: 708,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, firstUsage)
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, secondUsage)
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
require.Equal(t, 1, requestCount)
}
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
NextDateReset: "2099-03-13T12:00:00Z",
OverageConfiguration: kiroOverageConfiguration{
OverageStatus: "DISABLED",
},
UsageBreakdownList: []kiroUsageBreakdown{
{
ResourceType: "CREDIT",
CurrentUsageWithPrecision: kiroFloatPtr(2000),
UsageLimitWithPrecision: kiroFloatPtr(2000),
CurrentOveragesWithPrecision: kiroFloatPtr(0),
},
},
})
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
require.NotNil(t, info.KiroQuotaResetAt)
}
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
SetKiroCooldownStore(&kiroUsageCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(2 * time.Minute),
Remaining: 2 * time.Minute,
},
})
account := &Account{
ID: 705,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "kiro-access-token"},
}
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
require.Equal(t, "cooldown", account.KiroRuntimeState)
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
require.NotNil(t, account.KiroRuntimeResetAt)
}
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
account := Account{
ID: 706,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"nextDateReset":"2099-03-13T12:00:00Z",
"overageConfiguration":{"overageStatus":"ENABLED"},
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":2000,
"currentOveragesWithPrecision":4,
"overageCharges":0.2,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
_, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
target := &Account{
ID: account.ID,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "kiro-access-token"},
}
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
require.NotNil(t, target.KiroQuotaResetAt)
}
@@ -0,0 +1,163 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
account := Account{
Type: AccountTypeAPIKey,
Platform: PlatformKiro,
Credentials: map[string]any{},
}
require.Empty(t, account.GetBaseURL())
}
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
svc := &GatewayService{}
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
require.Equal(t, 25*time.Second, got)
}
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamKeepaliveInterval: 10,
KiroStreamKeepaliveInterval: 25,
},
},
}
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
}
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
svc := NewBillingService(&config.Config{}, nil)
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
}
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
tests := []struct {
name string
requestedModel string
upstreamModel string
want string
}{
{
name: "kiro claude sonnet 4.6 uses pricing key format",
requestedModel: "claude-sonnet-4-6",
upstreamModel: "claude-sonnet-4.6",
want: "claude-sonnet-4-6",
},
{
name: "falls back to upstream when requested model empty",
requestedModel: "",
upstreamModel: "claude-haiku-4-5",
want: "claude-haiku-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
})
}
}
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.billingService = NewBillingService(svc.cfg, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-sonnet-4-6": {
InputCostPerToken: 2.5e-6,
OutputCostPerToken: 10e-6,
},
},
})
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_kiro_billing_normalized",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
Model: "claude-sonnet-4-6",
UpstreamModel: "claude-sonnet-4.6",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_kiro_auto_fallback_cost",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
Model: "auto",
UpstreamModel: "auto",
Duration: time.Second,
},
APIKey: &APIKey{ID: 601, Quota: 100},
User: &User{ID: 701},
Account: &Account{ID: 801, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
@@ -10,6 +10,7 @@ import (
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
@@ -144,6 +145,60 @@ func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
}
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
account := &Account{
ID: 8,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
body := []byte(`{
"model":"claude-opus-4.6",
"messages":[{"role":"user","content":"hello"}]
}`)
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-opus-4.6",
"kiro-access-token",
"claude-opus-4-6-thinking",
nil,
)
require.NoError(t, err)
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
}
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
account := &Account{
ID: 9,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
body := []byte(`{
"model":"claude-opus-4.6",
"messages":[{"role":"user","content":"hello"}]
}`)
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-opus-4.6",
"kiro-access-token",
"claude-opus-4-6",
nil,
)
require.NoError(t, err)
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
require.NotContains(t, systemContent, "<thinking_mode>")
}
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
account := &Account{
Credentials: map[string]any{
+25 -9
View File
@@ -81,7 +81,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
}
if parsed.Stream {
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -146,7 +146,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
}
if isOnlyWebSearchToolInBody(body) {
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
switch {
case errors.Is(webSearchErr, errKiroWebSearchFallback):
case webSearchErr == nil:
@@ -194,7 +194,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
}
inputTokens := estimateKiroInputTokens(body)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -253,7 +253,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
}, nil
}
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, 0, err
@@ -268,7 +268,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
headers := make(http.Header)
headers.Set("Content-Type", "text/event-stream")
go func() {
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
@@ -282,7 +282,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
}, inputTokens, nil
}
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -318,7 +318,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
}, inputTokens, nil
}
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
var requestCtx kiropkg.KiroRequestContext
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
@@ -329,7 +329,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
modelID := kiropkg.MapModel(mappedModel)
currentToken := token
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
if err != nil {
return nil, requestCtx, err
}
@@ -432,7 +432,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
if err != nil {
return nil, requestCtx, err
}
@@ -501,9 +501,25 @@ func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropic
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
profileArn := resolveKiroPayloadProfileArn(account)
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
}
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
requestModel = strings.TrimSpace(requestModel)
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
return anthropicBody
}
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
return anthropicBody
}
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
return next
}
return anthropicBody
}
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
if s == nil || s.accountRepo == nil || account == nil {
return
@@ -283,7 +283,7 @@ func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
},
}
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
@@ -323,7 +323,7 @@ func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testi
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
require.Len(t, upstream.requests, 1)
@@ -461,7 +461,7 @@ func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailov
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
@@ -501,7 +501,7 @@ func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
@@ -550,7 +550,7 @@ func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedu
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
require.Equal(t, 1, repo.setErrorCalls)
+4 -4
View File
@@ -102,7 +102,7 @@ func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens in
}
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
) error {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
@@ -141,7 +141,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
return errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return err
}
@@ -190,7 +190,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
return fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return nil, errKiroWebSearchFallback
@@ -227,7 +227,7 @@ func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Acco
return nil, errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
t.Run(mode, func(t *testing.T) {
t.Parallel()
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
privacyCalls := 0
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
privacyCalls++
+3
View File
@@ -271,6 +271,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
+32 -2
View File
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-2.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model)
model = canonicalModelNameForPricing(model)
model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/")
model = strings.TrimPrefix(model, "publishers/google/models/")
@@ -625,7 +625,31 @@ func normalizeModelNameForPricing(model string) string {
}
model = strings.TrimLeft(model, "/")
return model
return canonicalModelNameForPricing(model)
}
func canonicalModelNameForPricing(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {
return ""
}
switch model {
case "claude-opus-4.5":
return "claude-opus-4-5"
case "claude-opus-4.6":
return "claude-opus-4-6"
case "claude-opus-4.7":
return "claude-opus-4-7"
case "claude-sonnet-4.5":
return "claude-sonnet-4-5"
case "claude-sonnet-4.6":
return "claude-sonnet-4-6"
case "claude-haiku-4.5":
return "claude-haiku-4-5"
default:
return model
}
}
func lastSegment(model string) string {
@@ -671,8 +695,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
@@ -710,6 +736,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
}
case strings.Contains(model, "sonnet"):
switch {
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
fallbackName = "sonnet-4.6"
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
fallbackName = "sonnet-4.5"
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
@@ -719,6 +747,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
}
case strings.Contains(model, "haiku"):
switch {
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
fallbackName = "haiku-4.5"
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
fallbackName = "haiku-3.5"
default:
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
if len(groupIDs) == 0 {
return nil
}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
var firstErr error
for _, platform := range platforms {
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
buckets := make([]SchedulerBucket, 0)
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
for _, platform := range platforms {
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 5,
Platform: PlatformGemini,
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 6,
Platform: PlatformGemini,
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 7,
Platform: PlatformGemini,
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 9,
Platform: PlatformGemini,
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 11,
Platform: PlatformGemini,
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 12,
Platform: PlatformGemini,
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
until := time.Now().Add(10 * time.Minute)
account := &Account{
ID: 15,
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 16,
Platform: tt.platform,
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 18,
Platform: PlatformOpenAI,
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)