feat: complete kiro platform support
This commit is contained in:
@@ -57,6 +57,8 @@ func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
"deepseek-3-2",
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "kiro-oauth",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 47,
|
||||
Name: "kiro-oauth-mapped",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "kiro-apikey",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 46,
|
||||
Name: "kiro-apikey-defaults",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -120,7 +120,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
|
||||
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(createReq))
|
||||
|
||||
updateReq := UpdateGroupRequest{Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
|
||||
}
|
||||
@@ -36,11 +36,12 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformKiro = "kiro"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
|
||||
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
service.PlatformKiro: 1,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
|
||||
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
requestedModel: "any-model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping falls back to default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping rejects model outside default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "auto",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping uses default upstream mapping",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: "claude-sonnet-4.6",
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
|
||||
@@ -194,6 +194,13 @@ func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
|
||||
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
|
||||
|
||||
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
|
||||
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
|
||||
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
|
||||
|
||||
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
|
||||
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
|
||||
|
||||
// Gemini 3.1 Pro
|
||||
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-6, // $2 per MTok
|
||||
@@ -268,13 +275,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
|
||||
return s.fallbackPrices["claude-sonnet-4.6"]
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-sonnet-4.5"]
|
||||
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-haiku-4.5"]
|
||||
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
|
||||
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -105,44 +109,63 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -102,44 +106,63 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -912,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
|
||||
injectMetadata bool
|
||||
metadataUserID string
|
||||
stripSystemCacheControl bool
|
||||
preserveToolChoice bool
|
||||
}
|
||||
|
||||
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
|
||||
@@ -1126,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
|
||||
if !gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
@@ -4518,7 +4525,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
|
||||
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
|
||||
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{
|
||||
stripSystemCacheControl: !systemRewritten,
|
||||
preserveToolChoice: account.Platform == PlatformKiro,
|
||||
}
|
||||
if s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil && fp != nil {
|
||||
@@ -8968,6 +8978,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射:
|
||||
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
|
||||
|
||||
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
|
||||
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
|
||||
// Anthropic API Key keeps the existing account-level override:
|
||||
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Kiro OAuth uses channel-level switch only.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
@@ -62,22 +64,37 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
return false
|
||||
}
|
||||
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
if account == nil {
|
||||
return false
|
||||
default: // "default" → follow channel config
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(account.Platform)
|
||||
}
|
||||
|
||||
switch {
|
||||
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
return false
|
||||
default:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
}
|
||||
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(platform)
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||
"message": map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
return flushSSEJSON(w, "message_start", evt)
|
||||
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
"name": toolNameWebSearch, "input": map[string]any{},
|
||||
},
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
|
||||
return err
|
||||
}
|
||||
inputJSON, err := json.Marshal(map[string]string{"query": query})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal query: %w", err)
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||
blocks := make([]map[string]string, 0, len(results))
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
|
||||
blocks := make([]map[string]any, 0, len(results))
|
||||
for _, r := range results {
|
||||
block := map[string]string{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
}
|
||||
if r.Snippet != "" {
|
||||
block["page_content"] = r.Snippet
|
||||
block := map[string]any{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
"encrypted_content": r.Snippet,
|
||||
"page_age": nil,
|
||||
}
|
||||
if r.PageAge != "" {
|
||||
block["page_age"] = r.PageAge
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +14,31 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"cache_creation_input_tokens":0`)
|
||||
require.Contains(t, body, `"cache_read_input_tokens":0`)
|
||||
}
|
||||
|
||||
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `event: content_block_start`)
|
||||
require.Contains(t, body, `"type":"server_tool_use"`)
|
||||
require.Contains(t, body, `"input":{}`)
|
||||
require.Contains(t, body, `event: content_block_delta`)
|
||||
require.Contains(t, body, `"type":"input_json_delta"`)
|
||||
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, body, `event: content_block_stop`)
|
||||
}
|
||||
|
||||
// --- isOnlyWebSearchToolInBody ---
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
|
||||
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||
// Second result has no PageAge
|
||||
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||
_, hasPageAge := blocks[1]["page_age"]
|
||||
require.False(t, hasPageAge)
|
||||
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
|
||||
require.Nil(t, blocks[1]["page_age"])
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
|
||||
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||
_, hasContent := blocks[0]["page_content"]
|
||||
require.False(t, hasContent)
|
||||
require.Equal(t, "", blocks[0]["encrypted_content"])
|
||||
require.Nil(t, blocks[0]["page_age"])
|
||||
}
|
||||
|
||||
// --- buildTextSummary ---
|
||||
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
func newKiroOAuthAccount() *Account {
|
||||
return &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
|
||||
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
|
||||
// nil channelService + default mode → returns false
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 11,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(77, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(77)
|
||||
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 12,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(78, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(78)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,623 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type kiroUsageCooldownStore struct {
|
||||
state *kirocooldown.State
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||
return s.state, s.err
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func kiroFloatPtr(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
|
||||
},
|
||||
}
|
||||
|
||||
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_kiro_billing_normalized",
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
|
||||
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
|
||||
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
|
||||
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
|
||||
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "*/*", r.Header.Get("Accept"))
|
||||
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
|
||||
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
|
||||
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
|
||||
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageConfiguration": {"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentOveragesWithPrecision":2,
|
||||
"currentUsageWithPrecision":125,
|
||||
"freeTrialInfo":{
|
||||
"currentUsageWithPrecision":25,
|
||||
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
|
||||
"freeTrialStatus":"ACTIVE",
|
||||
"usageLimitWithPrecision":500
|
||||
},
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageCharges":0.08,
|
||||
"resourceType":"CREDIT",
|
||||
"usageLimitWithPrecision":2000
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, "active", usage.Source)
|
||||
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
|
||||
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
|
||||
require.True(t, usage.KiroOveragesEnabled)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
|
||||
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
|
||||
require.NotNil(t, usage.KiroBonus)
|
||||
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
|
||||
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
|
||||
require.NotNil(t, usage.KiroOverage)
|
||||
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
|
||||
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
|
||||
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
|
||||
require.NotNil(t, usage.KiroResetAt)
|
||||
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
|
||||
require.NotNil(t, usage.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 702,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":300,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer successServer.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.NotNil(t, firstUsage.KiroCredit)
|
||||
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
|
||||
|
||||
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
|
||||
}))
|
||||
defer failingServer.Close()
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Empty(t, usage.Error)
|
||||
require.Empty(t, usage.ErrorCode)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 703,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "BuilderId",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 707,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":64,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 709,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"api_region": "eu-west-1",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":11,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{"eu-west-1"}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 710,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":7,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 704,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(90 * time.Second),
|
||||
Remaining: 90 * time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cooldown", usage.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
|
||||
require.NotNil(t, usage.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: `{"message":"profileArn is required for this request."}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeForbidden, info.ErrorCode)
|
||||
require.False(t, info.NeedsReauth)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: `{"message":"overage exhausted for this billing window"}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
|
||||
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
|
||||
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 708,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
requestCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
|
||||
|
||||
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, secondUsage)
|
||||
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
|
||||
require.Equal(t, 1, requestCount)
|
||||
}
|
||||
|
||||
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
|
||||
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
|
||||
NextDateReset: "2099-03-13T12:00:00Z",
|
||||
OverageConfiguration: kiroOverageConfiguration{
|
||||
OverageStatus: "DISABLED",
|
||||
},
|
||||
UsageBreakdownList: []kiroUsageBreakdown{
|
||||
{
|
||||
ResourceType: "CREDIT",
|
||||
CurrentUsageWithPrecision: kiroFloatPtr(2000),
|
||||
UsageLimitWithPrecision: kiroFloatPtr(2000),
|
||||
CurrentOveragesWithPrecision: kiroFloatPtr(0),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
|
||||
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
|
||||
require.NotNil(t, info.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
|
||||
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(2 * time.Minute),
|
||||
Remaining: 2 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
account := &Account{
|
||||
ID: 705,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
|
||||
require.Equal(t, "cooldown", account.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
|
||||
require.NotNil(t, account.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 706,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset":"2099-03-13T12:00:00Z",
|
||||
"overageConfiguration":{"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":2000,
|
||||
"currentOveragesWithPrecision":4,
|
||||
"overageCharges":0.2,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
_, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
target := &Account{
|
||||
ID: account.ID,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
|
||||
|
||||
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
|
||||
require.NotNil(t, target.KiroQuotaResetAt)
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
|
||||
account := Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformKiro,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
require.Empty(t, account.GetBaseURL())
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
|
||||
|
||||
require.Equal(t, 25*time.Second, got)
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamKeepaliveInterval: 10,
|
||||
KiroStreamKeepaliveInterval: 25,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
|
||||
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, nil)
|
||||
|
||||
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
upstreamModel string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "kiro claude sonnet 4.6 uses pricing key format",
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
upstreamModel: "claude-sonnet-4.6",
|
||||
want: "claude-sonnet-4-6",
|
||||
},
|
||||
{
|
||||
name: "falls back to upstream when requested model empty",
|
||||
requestedModel: "",
|
||||
upstreamModel: "claude-haiku-4-5",
|
||||
want: "claude-haiku-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_billing_normalized",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_auto_fallback_cost",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "auto",
|
||||
UpstreamModel: "auto",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 601, Quota: 100},
|
||||
User: &User{ID: 701},
|
||||
Account: &Account{ID: 801, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
|
||||
@@ -144,6 +145,60 @@ func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
|
||||
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6-thinking",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
|
||||
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.NotContains(t, systemContent, "<thinking_mode>")
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
|
||||
@@ -81,7 +81,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
}
|
||||
|
||||
if parsed.Stream {
|
||||
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
|
||||
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -146,7 +146,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||
}
|
||||
if isOnlyWebSearchToolInBody(body) {
|
||||
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
|
||||
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
switch {
|
||||
case errors.Is(webSearchErr, errKiroWebSearchFallback):
|
||||
case webSearchErr == nil:
|
||||
@@ -194,7 +194,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(body)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -253,7 +253,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
|
||||
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -268,7 +268,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "text/event-stream")
|
||||
go func() {
|
||||
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
|
||||
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
@@ -282,7 +282,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -318,7 +318,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||
var requestCtx kiropkg.KiroRequestContext
|
||||
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
|
||||
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||
@@ -329,7 +329,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
|
||||
|
||||
modelID := kiropkg.MapModel(mappedModel)
|
||||
currentToken := token
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
|
||||
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
@@ -501,9 +501,25 @@ func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropic
|
||||
|
||||
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
|
||||
profileArn := resolveKiroPayloadProfileArn(account)
|
||||
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
|
||||
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
|
||||
}
|
||||
|
||||
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
|
||||
requestModel = strings.TrimSpace(requestModel)
|
||||
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
|
||||
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
|
||||
return next
|
||||
}
|
||||
return anthropicBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
|
||||
@@ -283,7 +283,7 @@ func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
|
||||
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
@@ -323,7 +323,7 @@ func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testi
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
@@ -461,7 +461,7 @@ func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailov
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
@@ -501,7 +501,7 @@ func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
@@ -550,7 +550,7 @@ func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedu
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
|
||||
@@ -102,7 +102,7 @@ func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens in
|
||||
}
|
||||
|
||||
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||
) error {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
@@ -141,7 +141,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
return errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -190,7 +190,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
return fmt.Errorf("kiro web search exceeded max iterations")
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil, errKiroWebSearchFallback
|
||||
@@ -227,7 +227,7 @@ func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Acco
|
||||
return nil, errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
privacyCalls := 0
|
||||
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
|
||||
privacyCalls++
|
||||
|
||||
@@ -271,6 +271,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
|
||||
out := *ev
|
||||
|
||||
out.Platform = strings.TrimSpace(out.Platform)
|
||||
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
|
||||
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
|
||||
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
|
||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||
|
||||
|
||||
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-2.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = canonicalModelNameForPricing(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
model = strings.TrimPrefix(model, "publishers/google/models/")
|
||||
@@ -625,7 +625,31 @@ func normalizeModelNameForPricing(model string) string {
|
||||
}
|
||||
|
||||
model = strings.TrimLeft(model, "/")
|
||||
return model
|
||||
return canonicalModelNameForPricing(model)
|
||||
}
|
||||
|
||||
func canonicalModelNameForPricing(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch model {
|
||||
case "claude-opus-4.5":
|
||||
return "claude-opus-4-5"
|
||||
case "claude-opus-4.6":
|
||||
return "claude-opus-4-6"
|
||||
case "claude-opus-4.7":
|
||||
return "claude-opus-4-7"
|
||||
case "claude-sonnet-4.5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "claude-sonnet-4.6":
|
||||
return "claude-sonnet-4-6"
|
||||
case "claude-haiku-4.5":
|
||||
return "claude-haiku-4-5"
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func lastSegment(model string) string {
|
||||
@@ -671,8 +695,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
|
||||
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
|
||||
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
|
||||
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
|
||||
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
|
||||
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
|
||||
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
|
||||
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
|
||||
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
|
||||
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
|
||||
@@ -710,6 +736,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "sonnet"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
|
||||
fallbackName = "sonnet-4.6"
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "sonnet-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
@@ -719,6 +747,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "haiku"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "haiku-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
fallbackName = "haiku-3.5"
|
||||
default:
|
||||
|
||||
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
|
||||
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Platform: PlatformGemini,
|
||||
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformGemini,
|
||||
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformGemini,
|
||||
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformOpenAI, // OpenAI OAuth 账户
|
||||
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
resetAt := time.Now().Add(30 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformGemini,
|
||||
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformGemini,
|
||||
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 15,
|
||||
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Platform: tt.platform,
|
||||
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 18,
|
||||
Platform: PlatformOpenAI,
|
||||
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user