feat: complete kiro platform support

This commit is contained in:
nianzs
2026-04-29 18:20:46 +08:00
parent fcaa8ea86a
commit b09bcb6a3c
39 changed files with 2063 additions and 164 deletions
@@ -57,6 +57,8 @@ func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-sonnet-4",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"gpt-4o",
"gpt-4",
"deepseek-3-2",
@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"slices"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 44,
Name: "kiro-oauth",
Platform: service.PlatformKiro,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
}
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 47,
Name: "kiro-oauth-mapped",
Platform: service.PlatformKiro,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4.6",
"custom-model": "custom-upstream-model",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 2)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
require.True(t, slices.Contains(ids, "custom-model"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
}
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 45,
Name: "kiro-apikey",
Platform: service.PlatformKiro,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4.6",
"custom-model": "custom-upstream-model",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 2)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
require.True(t, slices.Contains(ids, "custom-model"))
}
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 46,
Name: "kiro-apikey-defaults",
Platform: service.PlatformKiro,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
ids := make([]string, 0, len(resp.Data))
for _, model := range resp.Data {
ids = append(ids, model.ID)
}
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
}
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -120,7 +120,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -0,0 +1,16 @@
package admin
import (
"testing"
"github.com/gin-gonic/gin/binding"
"github.com/stretchr/testify/require"
)
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
require.NoError(t, binding.Validator.ValidateStruct(createReq))
updateReq := UpdateGroupRequest{Platform: "kiro"}
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
}
@@ -36,11 +36,12 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformKiro = "kiro"
)
// AllPlatforms 返回所有支持的平台列表
func AllPlatforms() []string {
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
}
// Validate 验证规则配置的有效性
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
service.PlatformOpenAI: 1,
service.PlatformGemini: 1,
service.PlatformAntigravity: 2,
service.PlatformKiro: 1,
}
for platform, minCount := range requiredByPlatform {
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
requestedModel: "any-model",
expected: true,
},
{
name: "kiro no mapping falls back to default whitelist",
platform: PlatformKiro,
credentials: nil,
requestedModel: "claude-sonnet-4-6",
expected: true,
},
{
name: "kiro no mapping rejects model outside default whitelist",
platform: PlatformKiro,
credentials: nil,
requestedModel: "auto",
expected: false,
},
// 精确匹配
{
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
{
name: "kiro no mapping uses default upstream mapping",
platform: PlatformKiro,
credentials: nil,
requestedModel: "claude-sonnet-4-6",
expected: "claude-sonnet-4.6",
},
// 精确匹配
{
+17 -2
View File
@@ -194,6 +194,13 @@ func (s *BillingService) initFallbackPricing() {
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
@@ -268,13 +275,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
return s.fallbackPrices["claude-3-opus"]
}
if strings.Contains(modelLower, "sonnet") {
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
switch {
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
return s.fallbackPrices["claude-sonnet-4.6"]
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
return s.fallbackPrices["claude-sonnet-4.5"]
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
return s.fallbackPrices["claude-sonnet-4"]
}
return s.fallbackPrices["claude-3-5-sonnet"]
}
if strings.Contains(modelLower, "haiku") {
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
switch {
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
return s.fallbackPrices["claude-haiku-4.5"]
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
return s.fallbackPrices["claude-3-5-haiku"]
}
return s.fallbackPrices["claude-3-haiku"]
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
@@ -105,44 +109,63 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
var resp *http.Response
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
} else {
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
if account.Platform == PlatformKiro {
if next := account.GetMappedModel(originalModel); next != "" {
mappedModel = next
}
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
@@ -102,44 +106,63 @@ func (s *GatewayService) ForwardAsResponses(
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
var resp *http.Response
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
} else {
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
nil,
)
}
+15 -1
View File
@@ -912,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
injectMetadata bool
metadataUserID string
stripSystemCacheControl bool
preserveToolChoice bool
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
@@ -1126,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
modified = true
}
}
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
if !gjson.GetBytes(out, "max_tokens").Exists() {
@@ -4518,7 +4525,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
normalizeOpts := claudeOAuthNormalizeOptions{
stripSystemCacheControl: !systemRewritten,
preserveToolChoice: account.Platform == PlatformKiro,
}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
@@ -8968,6 +8978,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
return nil
}
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
return nil
}
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
// Anthropic API Key keeps the existing account-level override:
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
// Kiro OAuth uses channel-level switch only.
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
@@ -62,22 +64,37 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
return false
}
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
if account == nil {
return false
default: // "default" → follow channel config
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
switch {
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default:
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
}
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
default:
return false
}
}
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(platform)
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
"usage": map[string]int{
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
}
return flushSSEJSON(w, "message_start", evt)
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
"name": toolNameWebSearch, "input": map[string]any{},
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
inputJSON, err := json.Marshal(map[string]string{"query": query})
if err != nil {
return fmt.Errorf("marshal query: %w", err)
}
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
}); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
blocks := make([]map[string]any, 0, len(results))
for _, r := range results {
block := map[string]string{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
block := map[string]any{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
"encrypted_content": r.Snippet,
"page_age": nil,
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
@@ -13,6 +14,31 @@ import (
"github.com/stretchr/testify/require"
)
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
rec := httptest.NewRecorder()
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, `"cache_creation_input_tokens":0`)
require.Contains(t, body, `"cache_read_input_tokens":0`)
}
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
rec := httptest.NewRecorder()
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, `event: content_block_start`)
require.Contains(t, body, `"type":"server_tool_use"`)
require.Contains(t, body, `"input":{}`)
require.Contains(t, body, `event: content_block_delta`)
require.Contains(t, body, `"type":"input_json_delta"`)
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
require.Contains(t, body, `event: content_block_stop`)
}
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
require.Nil(t, blocks[1]["page_age"])
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
require.Equal(t, "", blocks[0]["encrypted_content"])
require.Nil(t, blocks[0]["page_age"])
}
// --- buildTextSummary ---
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
}
}
func newKiroOAuthAccount() *Account {
return &Account{
ID: 2,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
// nil channelService + default mode → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 11,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
},
}
channelSvc := newChannelServiceWithCache(77, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newKiroOAuthAccount()
groupID := int64(77)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 12,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
},
}
channelSvc := newChannelServiceWithCache(78, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newKiroOAuthAccount()
groupID := int64(78)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newKiroOAuthAccount()
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
@@ -0,0 +1,623 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
"github.com/stretchr/testify/require"
)
type kiroUsageCooldownStore struct {
state *kirocooldown.State
err error
}
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
return nil
}
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
return 0, nil
}
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
return s.state, s.err
}
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
return false, nil
}
func kiroFloatPtr(v float64) *float64 {
return &v
}
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
}
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.billingService = NewBillingService(svc.cfg, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-sonnet-4-6": {
InputCostPerToken: 2.5e-6,
OutputCostPerToken: 10e-6,
},
},
})
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_kiro_billing_normalized",
Model: "claude-sonnet-4-6",
UpstreamModel: "claude-sonnet-4.6",
Usage: OpenAIUsage{
InputTokens: 20,
OutputTokens: 10,
},
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
account := Account{
ID: 701,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
require.Equal(t, "*/*", r.Header.Get("Accept"))
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
"overageConfiguration": {"overageStatus":"ENABLED"},
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
"usageBreakdownList": [{
"currency":"USD",
"currentOveragesWithPrecision":2,
"currentUsageWithPrecision":125,
"freeTrialInfo":{
"currentUsageWithPrecision":25,
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
"freeTrialStatus":"ACTIVE",
"usageLimitWithPrecision":500
},
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
"overageCharges":0.08,
"resourceType":"CREDIT",
"usageLimitWithPrecision":2000
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, "active", usage.Source)
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
require.True(t, usage.KiroOveragesEnabled)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
require.NotNil(t, usage.KiroBonus)
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
require.NotNil(t, usage.KiroOverage)
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
require.NotNil(t, usage.KiroResetAt)
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
require.NotNil(t, usage.KiroQuotaResetAt)
}
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
account := Account{
ID: 702,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":300,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer successServer.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, firstUsage)
require.NotNil(t, firstUsage.KiroCredit)
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
}))
defer failingServer.Close()
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
require.Empty(t, usage.Error)
require.Empty(t, usage.ErrorCode)
}
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
account := Account{
ID: 703,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "BuilderId",
"auth_method": "idc",
"region": "us-east-1",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Empty(t, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":42,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
}
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
account := Account{
ID: 707,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"region": "us-east-1",
"start_url": "https://d-example.awsapps.com/start",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":64,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.NotNil(t, usage.KiroCredit)
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
}
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
account := Account{
ID: 709,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"api_region": "eu-west-1",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
gotRegions := make([]string, 0, 2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":11,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(region string) string {
gotRegions = append(gotRegions, region)
return server.URL
}
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, []string{"eu-west-1"}, gotRegions)
}
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
account := Account{
ID: 710,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "AWS",
"auth_method": "idc",
"region": "ap-northeast-2",
"start_url": "https://d-example.awsapps.com/start",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
gotRegions := make([]string, 0, 2)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/getUsageLimits", r.URL.Path)
require.Empty(t, r.URL.Query().Get("profileArn"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":7,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(region string) string {
gotRegions = append(gotRegions, region)
return server.URL
}
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
}
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
account := Account{
ID: 704,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
SetKiroCooldownStore(&kiroUsageCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(90 * time.Second),
Remaining: 90 * time.Second,
},
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":42,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
usage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.Equal(t, "cooldown", usage.KiroRuntimeState)
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
require.NotNil(t, usage.KiroRuntimeResetAt)
}
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
StatusCode: http.StatusBadRequest,
Body: `{"message":"profileArn is required for this request."}`,
})
require.Equal(t, errorCodeForbidden, info.ErrorCode)
require.False(t, info.NeedsReauth)
}
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
StatusCode: http.StatusTooManyRequests,
Body: `{"message":"overage exhausted for this billing window"}`,
})
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
}
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
account := Account{
ID: 708,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, firstUsage)
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
require.NotNil(t, secondUsage)
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
require.Equal(t, 1, requestCount)
}
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
NextDateReset: "2099-03-13T12:00:00Z",
OverageConfiguration: kiroOverageConfiguration{
OverageStatus: "DISABLED",
},
UsageBreakdownList: []kiroUsageBreakdown{
{
ResourceType: "CREDIT",
CurrentUsageWithPrecision: kiroFloatPtr(2000),
UsageLimitWithPrecision: kiroFloatPtr(2000),
CurrentOveragesWithPrecision: kiroFloatPtr(0),
},
},
})
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
require.NotNil(t, info.KiroQuotaResetAt)
}
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
SetKiroCooldownStore(&kiroUsageCooldownStore{
state: &kirocooldown.State{
Active: true,
Reason: kirocooldown.CooldownReason429,
CooldownUntil: time.Now().Add(2 * time.Minute),
Remaining: 2 * time.Minute,
},
})
account := &Account{
ID: 705,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "kiro-access-token"},
}
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
require.Equal(t, "cooldown", account.KiroRuntimeState)
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
require.NotNil(t, account.KiroRuntimeResetAt)
}
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
account := Account{
ID: 706,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "kiro-access-token",
"provider": "Github",
"auth_method": "social",
},
}
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"nextDateReset":"2099-03-13T12:00:00Z",
"overageConfiguration":{"overageStatus":"ENABLED"},
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
"usageBreakdownList": [{
"currency":"USD",
"currentUsageWithPrecision":2000,
"currentOveragesWithPrecision":4,
"overageCharges":0.2,
"usageLimitWithPrecision":2000,
"resourceType":"CREDIT"
}]
}`))
}))
defer server.Close()
prevResolver := resolveKiroRuntimeEndpoint
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
_, err := svc.GetUsage(context.Background(), account.ID)
require.NoError(t, err)
target := &Account{
ID: account.ID,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "kiro-access-token"},
}
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
require.NotNil(t, target.KiroQuotaResetAt)
}
@@ -0,0 +1,163 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
account := Account{
Type: AccountTypeAPIKey,
Platform: PlatformKiro,
Credentials: map[string]any{},
}
require.Empty(t, account.GetBaseURL())
}
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
svc := &GatewayService{}
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
require.Equal(t, 25*time.Second, got)
}
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamKeepaliveInterval: 10,
KiroStreamKeepaliveInterval: 25,
},
},
}
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
}
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
svc := NewBillingService(&config.Config{}, nil)
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
}
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
tests := []struct {
name string
requestedModel string
upstreamModel string
want string
}{
{
name: "kiro claude sonnet 4.6 uses pricing key format",
requestedModel: "claude-sonnet-4-6",
upstreamModel: "claude-sonnet-4.6",
want: "claude-sonnet-4-6",
},
{
name: "falls back to upstream when requested model empty",
requestedModel: "",
upstreamModel: "claude-haiku-4-5",
want: "claude-haiku-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
})
}
}
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.billingService = NewBillingService(svc.cfg, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-sonnet-4-6": {
InputCostPerToken: 2.5e-6,
OutputCostPerToken: 10e-6,
},
},
})
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_kiro_billing_normalized",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
Model: "claude-sonnet-4-6",
UpstreamModel: "claude-sonnet-4.6",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_kiro_auto_fallback_cost",
Usage: ClaudeUsage{
InputTokens: 20,
OutputTokens: 10,
},
Model: "auto",
UpstreamModel: "auto",
Duration: time.Second,
},
APIKey: &APIKey{ID: 601, Quota: 100},
User: &User{ID: 701},
Account: &Account{ID: 801, Platform: PlatformKiro},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
}
@@ -10,6 +10,7 @@ import (
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
@@ -144,6 +145,60 @@ func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
}
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
account := &Account{
ID: 8,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
body := []byte(`{
"model":"claude-opus-4.6",
"messages":[{"role":"user","content":"hello"}]
}`)
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-opus-4.6",
"kiro-access-token",
"claude-opus-4-6-thinking",
nil,
)
require.NoError(t, err)
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
}
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
account := &Account{
ID: 9,
Platform: PlatformKiro,
Type: AccountTypeOAuth,
}
body := []byte(`{
"model":"claude-opus-4.6",
"messages":[{"role":"user","content":"hello"}]
}`)
payload, err := buildKiroPayloadForAccount(
context.Background(),
account,
body,
"claude-opus-4.6",
"kiro-access-token",
"claude-opus-4-6",
nil,
)
require.NoError(t, err)
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
require.NotContains(t, systemContent, "<thinking_mode>")
}
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
account := &Account{
Credentials: map[string]any{
+25 -9
View File
@@ -81,7 +81,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
}
if parsed.Stream {
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -146,7 +146,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
}
if isOnlyWebSearchToolInBody(body) {
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
switch {
case errors.Is(webSearchErr, errKiroWebSearchFallback):
case webSearchErr == nil:
@@ -194,7 +194,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
}
inputTokens := estimateKiroInputTokens(body)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -253,7 +253,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
}, nil
}
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, 0, err
@@ -268,7 +268,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
headers := make(http.Header)
headers.Set("Content-Type", "text/event-stream")
go func() {
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
if streamErr != nil {
_ = pw.CloseWithError(streamErr)
return
@@ -282,7 +282,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
}, inputTokens, nil
}
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -318,7 +318,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
}, inputTokens, nil
}
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
var requestCtx kiropkg.KiroRequestContext
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
@@ -329,7 +329,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
modelID := kiropkg.MapModel(mappedModel)
currentToken := token
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
if err != nil {
return nil, requestCtx, err
}
@@ -432,7 +432,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
if err != nil {
return nil, requestCtx, err
}
@@ -501,9 +501,25 @@ func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropic
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
profileArn := resolveKiroPayloadProfileArn(account)
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
}
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
requestModel = strings.TrimSpace(requestModel)
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
return anthropicBody
}
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
return anthropicBody
}
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
return next
}
return anthropicBody
}
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
if s == nil || s.accountRepo == nil || account == nil {
return
@@ -283,7 +283,7 @@ func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
},
}
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
@@ -323,7 +323,7 @@ func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testi
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
require.Len(t, upstream.requests, 1)
@@ -461,7 +461,7 @@ func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailov
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
@@ -501,7 +501,7 @@ func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T)
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
@@ -550,7 +550,7 @@ func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedu
payloadBytes, err := json.Marshal(payload)
require.NoError(t, err)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
require.Equal(t, 1, repo.setErrorCalls)
+4 -4
View File
@@ -102,7 +102,7 @@ func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens in
}
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
) error {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
@@ -141,7 +141,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
return errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return err
}
@@ -190,7 +190,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
return fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return nil, errKiroWebSearchFallback
@@ -227,7 +227,7 @@ func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Acco
return nil, errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
t.Run(mode, func(t *testing.T) {
t.Parallel()
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
privacyCalls := 0
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
privacyCalls++
+3
View File
@@ -271,6 +271,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
+32 -2
View File
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-2.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model)
model = canonicalModelNameForPricing(model)
model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/")
model = strings.TrimPrefix(model, "publishers/google/models/")
@@ -625,7 +625,31 @@ func normalizeModelNameForPricing(model string) string {
}
model = strings.TrimLeft(model, "/")
return model
return canonicalModelNameForPricing(model)
}
func canonicalModelNameForPricing(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {
return ""
}
switch model {
case "claude-opus-4.5":
return "claude-opus-4-5"
case "claude-opus-4.6":
return "claude-opus-4-6"
case "claude-opus-4.7":
return "claude-opus-4-7"
case "claude-sonnet-4.5":
return "claude-sonnet-4-5"
case "claude-sonnet-4.6":
return "claude-sonnet-4-6"
case "claude-haiku-4.5":
return "claude-haiku-4-5"
default:
return model
}
}
func lastSegment(model string) string {
@@ -671,8 +695,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
@@ -710,6 +736,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
}
case strings.Contains(model, "sonnet"):
switch {
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
fallbackName = "sonnet-4.6"
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
fallbackName = "sonnet-4.5"
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
@@ -719,6 +747,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
}
case strings.Contains(model, "haiku"):
switch {
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
fallbackName = "haiku-4.5"
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
fallbackName = "haiku-3.5"
default:
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
if len(groupIDs) == 0 {
return nil
}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
var firstErr error
for _, platform := range platforms {
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
buckets := make([]SchedulerBucket, 0)
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
for _, platform := range platforms {
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 5,
Platform: PlatformGemini,
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 6,
Platform: PlatformGemini,
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 7,
Platform: PlatformGemini,
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 9,
Platform: PlatformGemini,
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 11,
Platform: PlatformGemini,
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 12,
Platform: PlatformGemini,
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
until := time.Now().Add(10 * time.Minute)
account := &Account{
ID: 15,
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 16,
Platform: tt.platform,
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 18,
Platform: PlatformOpenAI,
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
+12
View File
@@ -89,6 +89,10 @@ security:
enabled: false
# Allowed upstream hosts for API proxying
# 允许代理的上游 API 主机列表
# If you enable Kiro OAuth / IDC, also allow Kiro auth and AWS SSO OIDC hosts.
# 如果启用 Kiro OAuth / IDC,请同时放行 Kiro 鉴权域名和 AWS SSO OIDC 域名。
# If you enable Kiro runtime forwarding, also allow the corresponding AWS Q API endpoint.
# 如果启用 Kiro 运行时转发,还需要放行对应的 AWS Q API 域名。
upstream_hosts:
- "api.openai.com"
- "api.anthropic.com"
@@ -97,6 +101,11 @@ security:
- "api.minimaxi.com"
- "generativelanguage.googleapis.com"
- "cloudcode-pa.googleapis.com"
- "prod.us-east-1.auth.desktop.kiro.dev"
- "oidc.us-east-1.amazonaws.com"
- "oidc.*.amazonaws.com"
- "device.sso.*.amazonaws.com"
- "q.*.amazonaws.com"
- "*.openai.azure.com"
# Allowed hosts for pricing data download
# 允许下载定价数据的主机列表
@@ -340,6 +349,9 @@ gateway:
# Stream keepalive interval (seconds), 0=disable
# 流式 keepalive 间隔(秒),0=禁用
stream_keepalive_interval: 10
# Kiro stream keepalive interval (seconds), 0=use default 25
# Kiro 流式 keepalive 间隔(秒),0=使用默认 25
kiro_stream_keepalive_interval: 25
# SSE max line size in bytes (default: 40MB)
# SSE 单行最大字节数(默认 40MB)
max_line_size: 41943040
+7
View File
@@ -558,6 +558,13 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
return data
}
export async function getKiroDefaultModelMapping(): Promise<Record<string, string>> {
const { data } = await apiClient.get<Record<string, string>>(
'/admin/accounts/kiro/default-model-mapping'
)
return data
}
/**
* Refresh OpenAI token using refresh token
* @param refreshToken - The refresh token
@@ -13,6 +13,15 @@ vi.mock('vue-i18n', async () => {
}
})
vi.mock('@/i18n', () => ({
i18n: {
global: {
t: (key: string) => key
}
},
getLocale: () => 'en'
}))
function makeAccount(overrides: Partial<Account>): Account {
return {
id: 1,
@@ -35,6 +44,12 @@ function makeAccount(overrides: Partial<Account>): Account {
overload_until: null,
temp_unschedulable_until: null,
temp_unschedulable_reason: null,
kiro_quota_state: null,
kiro_quota_reason: null,
kiro_quota_reset_at: null,
kiro_runtime_state: null,
kiro_runtime_reason: null,
kiro_runtime_reset_at: null,
session_window_start: null,
session_window_end: null,
session_window_status: null,
@@ -159,4 +174,96 @@ describe('AccountStatusIndicator', () => {
// AICredits 积分耗尽状态应显示
expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
it('Kiro 运行时冷却在状态列复用限流展示', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 5,
name: 'kiro-cooldown',
platform: 'kiro',
kiro_runtime_state: 'cooldown',
kiro_runtime_reason: 'rate_limit_exceeded',
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedAutoResume')
expect(wrapper.text()).toContain('429')
})
it('Kiro suspended 在状态列显示为 forbidden', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 6,
name: 'kiro-suspended',
platform: 'kiro',
kiro_runtime_state: 'suspended',
kiro_runtime_reason: 'account_suspended',
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.forbidden')
})
it('Kiro overage active 在状态列仍显示正常状态', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 7,
name: 'kiro-overage-active',
platform: 'kiro',
kiro_quota_state: 'overage_active',
kiro_quota_reason: 'overages_enabled',
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.status.active')
expect(wrapper.text()).not.toContain('admin.accounts.status.overageActive')
})
it('Kiro overage exhausted 在状态列显示危险徽章', () => {
const wrapper = mount(AccountStatusIndicator, {
props: {
account: makeAccount({
id: 8,
name: 'kiro-overage-exhausted',
platform: 'kiro',
kiro_quota_state: 'overage_exhausted',
kiro_quota_reason: 'overage disabled after quota exhaustion',
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
})
},
global: {
stubs: {
Icon: true
}
}
})
expect(wrapper.text()).toContain('admin.accounts.status.overageExhausted')
expect(wrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
})
})
@@ -25,6 +25,15 @@ vi.mock('vue-i18n', async () => {
}
})
vi.mock('@/i18n', () => ({
i18n: {
global: {
t: (key: string) => key
}
},
getLocale: () => 'en'
}))
function makeAccount(overrides: Partial<Account>): Account {
return {
id: 1,
@@ -47,6 +56,12 @@ function makeAccount(overrides: Partial<Account>): Account {
overload_until: null,
temp_unschedulable_until: null,
temp_unschedulable_reason: null,
kiro_quota_state: null,
kiro_quota_reason: null,
kiro_quota_reset_at: null,
kiro_runtime_state: null,
kiro_runtime_reason: null,
kiro_runtime_reset_at: null,
session_window_start: null,
session_window_end: null,
session_window_status: null,
@@ -530,6 +545,234 @@ describe('AccountUsageCell', () => {
expect(wrapper.text()).toContain('7d|100|106540000')
})
it('Kiro OAuth 会用 passive source 拉取并展示 credits 额度', async () => {
const account = makeAccount({
id: 3001,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
getUsage.mockResolvedValue({
source: 'passive',
kiro_subscription_name: 'KIRO PRO+',
kiro_overages_enabled: true,
kiro_credit: {
current_usage: 125,
usage_limit: 2000,
percentage_used: 6.25,
},
kiro_bonus: {
current_usage: 25,
usage_limit: 500,
percentage_used: 5,
days_remaining: 7,
},
kiro_overage: {
current_overages: 2,
overage_charges: 0.08,
currency_symbol: '$',
currency_code: 'USD',
},
kiro_reset_at: '2099-03-13T12:00:00Z',
})
const wrapper = mount(AccountUsageCell, {
props: {
account
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(getUsage).toHaveBeenCalledWith(3001, 'passive')
expect(wrapper.emitted('kiroUsageMeta')?.[0]).toEqual([
{
plan_type: 'KIRO PRO+',
kiro_overages_enabled: true
}
])
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroCredits')
expect(wrapper.text()).toContain('125 / 2.0K')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroBonus')
expect(wrapper.text()).toContain('25 / 500')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroDaysLeft')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroReset')
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroOverage 2 ($0.08)')
})
it('Kiro OAuth 会展示运行时冷却状态', async () => {
getUsage.mockResolvedValue({
source: 'passive',
kiro_runtime_state: 'cooldown',
kiro_runtime_reason: 'rate_limit_exceeded',
kiro_runtime_reset_at: '2099-03-13T12:00:00Z',
kiro_credit: {
current_usage: 10,
usage_limit: 100,
percentage_used: 10,
},
})
const wrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3002,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedUntil')
})
it('Kiro OAuth 会展示 overage active 与 exhausted 状态', async () => {
getUsage.mockResolvedValueOnce({
source: 'passive',
kiro_quota_state: 'overage_active',
kiro_quota_reason: 'overages_enabled',
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
kiro_overages_enabled: true,
kiro_credit: {
current_usage: 2100,
usage_limit: 2000,
percentage_used: 100,
},
kiro_overage: {
current_overages: 3,
overage_charges: 0.12,
currency_symbol: '$',
},
})
const activeWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3005,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(activeWrapper.text()).toContain('admin.accounts.status.overageActive')
expect(activeWrapper.text()).not.toContain('admin.accounts.status.overageActiveUntil')
getUsage.mockResolvedValueOnce({
source: 'passive',
kiro_quota_state: 'overage_exhausted',
kiro_quota_reason: 'usage API error: overage exhausted',
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
error: 'usage API error: kiro usage request failed (status 429): {"message":"overage exhausted"}',
})
const exhaustedWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3006,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhausted')
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
})
it('Kiro OAuth 会展示 profile 异常和 usage forbidden 徽章', async () => {
getUsage.mockResolvedValueOnce({
source: 'passive',
error_code: 'forbidden',
error: 'usage API error: kiro usage request failed (status 400): {"message":"profileArn is required for this request."}',
})
const profileWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3003,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(profileWrapper.text()).toContain('admin.accounts.usageError')
getUsage.mockResolvedValueOnce({
source: 'passive',
error_code: 'forbidden',
error: 'usage API error: kiro usage request failed (status 403): {"message":"User is not authorized to access this feature."}',
})
const forbiddenWrapper = mount(AccountUsageCell, {
props: {
account: makeAccount({
id: 3004,
platform: 'kiro',
type: 'oauth',
extra: {},
credentials: {}
})
},
global: {
stubs: {
UsageProgressBar: true,
AccountQuotaInfo: true
}
}
})
await flushPromises()
expect(forbiddenWrapper.text()).toContain('admin.accounts.forbidden')
})
it('Key 账号会展示 today stats 徽章并带 A/U 提示', async () => {
const wrapper = mount(AccountUsageCell, {
props: {
@@ -0,0 +1,56 @@
import { describe, expect, it, vi } from 'vitest'
import { nextTick } from 'vue'
import { mount } from '@vue/test-utils'
vi.mock('@/stores/app', () => ({
useAppStore: () => ({
showSuccess: vi.fn(),
showError: vi.fn()
})
}))
vi.mock('@/composables/useClipboard', () => ({
useClipboard: () => ({
copied: { value: false },
copyToClipboard: vi.fn()
})
}))
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
import OAuthAuthorizationFlow from '../OAuthAuthorizationFlow.vue'
describe('OAuthAuthorizationFlow', () => {
it('extracts code, state, and callback metadata from a full Kiro callback URL', async () => {
const wrapper = mount(OAuthAuthorizationFlow, {
props: {
addMethod: 'oauth',
platform: 'kiro',
authUrl: 'https://example.com/authorize',
sessionId: 'session-1'
},
global: {
stubs: {
Icon: true
}
}
})
const textarea = wrapper.get('textarea')
await textarea.setValue('http://localhost:49153/oauth/callback?code=abc123&state=state456&login_option=github')
await nextTick()
expect((textarea.element as HTMLTextAreaElement).value).toBe('abc123')
expect((wrapper.vm as any).oauthState).toBe('state456')
expect((wrapper.vm as any).oauthCallbackPath).toBe('/oauth/callback')
expect((wrapper.vm as any).oauthLoginOption).toBe('github')
})
})
@@ -489,7 +489,8 @@ const platformOptions = [
{ value: 'anthropic', label: 'Anthropic' },
{ value: 'openai', label: 'OpenAI' },
{ value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' }
{ value: 'antigravity', label: 'Antigravity' },
{ value: 'kiro', label: 'Kiro' }
]
// Load rules when dialog opens
@@ -115,7 +115,7 @@ const labelClass = computed(() => {
}
// 正常状态或无天数:根据平台显示主题色
if (props.platform === 'anthropic') {
if (props.platform === 'anthropic' || props.platform === 'kiro') {
return `${base} bg-orange-200/60 text-orange-800 dark:bg-orange-800/40 dark:text-orange-300`
}
if (props.platform === 'openai') {
@@ -129,7 +129,7 @@ const labelClass = computed(() => {
// Badge color based on platform and subscription type
const badgeClass = computed(() => {
if (props.platform === 'anthropic') {
if (props.platform === 'anthropic' || props.platform === 'kiro') {
// Claude: orange theme
return isSubscription.value
? 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
@@ -91,6 +91,8 @@ const ratePillClass = computed(() => {
return 'bg-green-50 text-green-700 dark:bg-green-900/20 dark:text-green-400'
case 'gemini':
return 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400'
case 'kiro':
return 'bg-amber-50 text-amber-700 dark:bg-amber-900/20 dark:text-amber-400'
default: // antigravity and others
return 'bg-violet-50 text-violet-700 dark:bg-violet-900/20 dark:text-violet-400'
}
@@ -0,0 +1,42 @@
import { mount } from '@vue/test-utils'
import { describe, expect, it, vi } from 'vitest'
import PlatformTypeBadge from '../PlatformTypeBadge.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key === 'admin.accounts.status.overageActive' ? 'Overage' : key
})
}
})
describe('PlatformTypeBadge', () => {
it('shows Kiro overages tag next to the plan tag when enabled', () => {
const wrapper = mount(PlatformTypeBadge, {
props: {
platform: 'kiro',
type: 'oauth',
planType: 'KIRO PRO+',
overagesEnabled: true
}
})
expect(wrapper.text()).toContain('KIRO PRO+')
expect(wrapper.text()).toContain('Overage')
})
it('does not show overages tag for non-Kiro accounts', () => {
const wrapper = mount(PlatformTypeBadge, {
props: {
platform: 'openai',
type: 'oauth',
planType: 'Pro',
overagesEnabled: true
}
})
expect(wrapper.text()).not.toContain('Overage')
})
})
@@ -4,7 +4,12 @@ vi.mock('@/api/admin/accounts', () => ({
getAntigravityDefaultModelMapping: vi.fn()
}))
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
import {
buildModelMappingObject,
fetchKiroDefaultMappings,
getModelsByPlatform,
getPresetMappingsByPlatform
} from '../useModelWhitelist'
describe('useModelWhitelist', () => {
it('openai 模型列表包含 GPT-5.4 官方快照', () => {
@@ -59,6 +64,49 @@ describe('useModelWhitelist', () => {
expect(models.every((model) => !model.endsWith('-agentic') && !model.endsWith('-chat'))).toBe(true)
})
it('kiro 模型列表只保留 Claude 模型', () => {
const models = getModelsByPlatform('kiro')
expect(models).toEqual([
'claude-opus-4-6',
'claude-opus-4-6-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-6-thinking',
'claude-opus-4-5-20251101',
'claude-opus-4-5-20251101-thinking',
'claude-sonnet-4-5-20250929',
'claude-sonnet-4-5-20250929-thinking',
'claude-haiku-4-5-20251001',
'claude-haiku-4-5-20251001-thinking'
])
expect(models.every(model => model.startsWith('claude-'))).toBe(true)
expect(models.some(model => model.endsWith('-agentic'))).toBe(false)
expect(models.some(model => model.endsWith('-chat'))).toBe(false)
expect(models).not.toContain('kiro-auto')
expect(models).not.toContain('claude-opus-4-5')
expect(models).not.toContain('claude-sonnet-4-5')
expect(models).not.toContain('claude-sonnet-4')
expect(models).not.toContain('claude-3-5-sonnet-20241022')
expect(models).not.toContain('claude-3-5-haiku-20241022')
expect(models).not.toContain('claude-haiku-4-5')
expect(models).not.toContain('gpt-4o')
expect(models).not.toContain('gpt-4')
expect(models).not.toContain('gpt-4-turbo')
expect(models).not.toContain('gpt-3.5-turbo')
expect(models).not.toContain('deepseek-3-2')
expect(models).not.toContain('minimax-m2-1')
expect(models).not.toContain('qwen3-coder-next')
})
it('claude 模型列表包含 dated 和 thinking 兼容别名', () => {
const models = getModelsByPlatform('claude')
expect(models).toContain('claude-opus-4-6-thinking')
expect(models).toContain('claude-opus-4-5-20251101-thinking')
expect(models).toContain('claude-sonnet-4-20250514-thinking')
expect(models).toContain('claude-haiku-4-5-20251001-thinking')
})
it('whitelist 模式会忽略通配符条目', () => {
const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
expect(mapping).toEqual({
@@ -81,4 +129,68 @@ describe('useModelWhitelist', () => {
'gpt-5.4-mini': 'gpt-5.4-mini'
})
})
it('kiro 预设映射只暴露 Claude 入口', () => {
const mappings = getPresetMappingsByPlatform('kiro')
const mappingTargets = mappings.map(item => item.to)
expect(mappings.map(({ from, to }) => ({ from, to }))).toEqual([
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
])
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
expect(mappingTargets.every(model => model.startsWith('claude-'))).toBe(true)
expect(mappingTargets.some(model => model.endsWith('-agentic'))).toBe(false)
expect(mappingTargets.some(model => model.endsWith('-chat'))).toBe(false)
expect(mappingTargets).not.toContain('kiro-auto')
expect(mappingTargets.some(model => model.startsWith('kiro-'))).toBe(false)
expect(mappings.some(item => item.from === 'claude-opus-4-5')).toBe(false)
expect(mappings.some(item => item.from === 'claude-sonnet-4-5')).toBe(false)
expect(mappings.some(item => item.from === 'claude-sonnet-4')).toBe(false)
expect(mappings.some(item => item.from === 'claude-3-5-sonnet-20241022')).toBe(false)
expect(mappings.some(item => item.from === 'claude-3-5-haiku-20241022')).toBe(false)
expect(mappings.some(item => item.from === 'claude-haiku-4-5')).toBe(false)
expect(mappingTargets).not.toContain('gpt-4o')
expect(mappingTargets).not.toContain('gpt-4')
expect(mappingTargets).not.toContain('gpt-4-turbo')
expect(mappingTargets).not.toContain('gpt-3.5-turbo')
expect(mappingTargets).not.toContain('deepseek-3.2')
expect(mappingTargets).not.toContain('minimax-m2.1')
expect(mappingTargets).not.toContain('qwen3-coder-next')
})
it('kiro 默认映射会在前端填充所有可精确定价模型', async () => {
const mappings = await fetchKiroDefaultMappings()
expect(mappings).toEqual(expect.arrayContaining([
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
]))
expect(mappings).toHaveLength(10)
expect(mappings.every(item => !item.from.startsWith('kiro-'))).toBe(true)
expect(mappings.every(item => !item.to.startsWith('kiro-'))).toBe(true)
expect(mappings.every(item => !item.from.endsWith('-agentic'))).toBe(true)
expect(mappings.every(item => !item.to.endsWith('-agentic'))).toBe(true)
expect(mappings.every(item => !item.from.endsWith('-chat'))).toBe(true)
expect(mappings.every(item => !item.to.endsWith('-chat'))).toBe(true)
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
expect(mappings.every(item => item.to.startsWith('claude-'))).toBe(true)
expect(mappings.some(item => item.to === 'claude-opus-4-7')).toBe(false)
})
})
@@ -23,13 +23,21 @@ export const claudeModels = [
'claude-3-5-sonnet-20241022', 'claude-3-5-sonnet-20240620',
'claude-3-5-haiku-20241022',
'claude-3-7-sonnet-20250219',
'claude-sonnet-4-20250514-thinking', 'claude-opus-4-20250514-thinking',
'claude-sonnet-4-20250514', 'claude-opus-4-20250514',
'claude-opus-4-1-20250805',
'claude-sonnet-4-5-thinking', 'claude-sonnet-4-5-20250929-thinking',
'claude-haiku-4-5-thinking', 'claude-haiku-4-5-20251001-thinking',
'claude-opus-4-5-thinking', 'claude-opus-4-5-20251101-thinking',
'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001',
'claude-opus-4-5-20251101',
'claude-opus-4-6-thinking',
'claude-opus-4-6',
'claude-opus-4-7-thinking',
'claude-opus-4-7',
'claude-sonnet-4-6'
'claude-sonnet-4-6-thinking',
'claude-sonnet-4-6',
'claude-2.1', 'claude-2.0', 'claude-instant-1.2'
]
// Google Gemini
+8 -4
View File
@@ -324,8 +324,8 @@
</div>
</div>
<!-- Web Search Emulation (Anthropic only, hidden when global disabled) -->
<div v-if="section.platform === 'anthropic' && webSearchGlobalEnabled" class="border-t border-gray-200 pt-3 dark:border-dark-600">
<!-- Web Search Emulation (supported platforms only, hidden when global disabled) -->
<div v-if="supportsWebSearchEmulation(section.platform) && webSearchGlobalEnabled" class="border-t border-gray-200 pt-3 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="text-xs font-medium text-gray-700 dark:text-gray-300">
@@ -718,7 +718,7 @@ const form = reactive({
let abortController: AbortController | null = null
// ── Platform config ──
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity', 'kiro']
// ── Helpers ──
function formatDate(value: string): string {
@@ -758,6 +758,10 @@ function getGroupsForPlatform(platform: GroupPlatform): AdminGroup[] {
return allGroups.value.filter(g => g.platform === platform)
}
function supportsWebSearchEmulation(platform: GroupPlatform): boolean {
return platform === 'anthropic' || platform === 'kiro'
}
// ── Group helpers ──
const groupToChannelMap = computed(() => {
const map = new Map<number, Channel>()
@@ -1037,7 +1041,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
const wsEmulation: Record<string, boolean> = {}
for (const section of form.platforms) {
if (!section.enabled) continue
if (section.platform === 'anthropic') {
if (supportsWebSearchEmulation(section.platform)) {
wsEmulation[section.platform] = !!section.web_search_emulation
}
}
@@ -967,7 +967,8 @@ const platformFilterOptions = computed(() => [
{ value: 'anthropic', label: 'Anthropic' },
{ value: 'openai', label: 'OpenAI' },
{ value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' }
{ value: 'antigravity', label: 'Antigravity' },
{ value: 'kiro', label: 'Kiro' }
])
// Group options for assign (only subscription type groups)
@@ -111,7 +111,8 @@ const platformOptions = computed(() => [
{ value: 'openai', label: 'OpenAI' },
{ value: 'anthropic', label: 'Anthropic' },
{ value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' }
{ value: 'antigravity', label: 'Antigravity' },
{ value: 'kiro', label: 'Kiro' }
])
const timeRangeOptions = computed(() => [