feat: complete kiro platform support
This commit is contained in:
@@ -57,6 +57,8 @@ func TestDefaultKiroModelMapping_MatchesKiroReferenceModels(t *testing.T) {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
"deepseek-3-2",
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -103,3 +104,156 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "kiro-oauth",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/44/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 47,
|
||||
Name: "kiro-oauth-mapped",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/47/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "kiro-apikey",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"custom-model": "custom-upstream-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/45/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 2)
|
||||
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-sonnet-4-6"))
|
||||
require.True(t, slices.Contains(ids, "custom-model"))
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_KiroAPIKeyWithoutMappingFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 46,
|
||||
Name: "kiro-apikey-defaults",
|
||||
Platform: service.PlatformKiro,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/46/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
ids := make([]string, 0, len(resp.Data))
|
||||
for _, model := range resp.Data {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
require.True(t, slices.Contains(ids, "claude-opus-4-6"))
|
||||
require.False(t, slices.Contains(ids, "claude-opus-4-7"))
|
||||
require.False(t, slices.Contains(ids, "kiro-claude-opus-4-7"))
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -120,7 +120,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity kiro"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGroupRequestValidationAcceptsKiroPlatform(t *testing.T) {
|
||||
createReq := CreateGroupRequest{Name: "kiro-default", Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(createReq))
|
||||
|
||||
updateReq := UpdateGroupRequest{Platform: "kiro"}
|
||||
require.NoError(t, binding.Validator.ValidateStruct(updateReq))
|
||||
}
|
||||
@@ -36,11 +36,12 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformKiro = "kiro"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity, PlatformKiro}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
|
||||
@@ -19,6 +19,7 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
service.PlatformKiro: 1,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
|
||||
@@ -151,6 +151,20 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
requestedModel: "any-model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping falls back to default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping rejects model outside default whitelist",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "auto",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
@@ -244,6 +258,13 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
{
|
||||
name: "kiro no mapping uses default upstream mapping",
|
||||
platform: PlatformKiro,
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
expected: "claude-sonnet-4.6",
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
|
||||
@@ -194,6 +194,13 @@ func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
|
||||
s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
|
||||
|
||||
// Claude Sonnet 4.5/4.6 当前与 Sonnet 4 同价
|
||||
s.fallbackPrices["claude-sonnet-4.5"] = s.fallbackPrices["claude-sonnet-4"]
|
||||
s.fallbackPrices["claude-sonnet-4.6"] = s.fallbackPrices["claude-sonnet-4.5"]
|
||||
|
||||
// Claude Haiku 4.5 当前与 Claude 3.5 Haiku 同价
|
||||
s.fallbackPrices["claude-haiku-4.5"] = s.fallbackPrices["claude-3-5-haiku"]
|
||||
|
||||
// Gemini 3.1 Pro
|
||||
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-6, // $2 per MTok
|
||||
@@ -268,13 +275,21 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6"):
|
||||
return s.fallbackPrices["claude-sonnet-4.6"]
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-sonnet-4.5"]
|
||||
case strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3"):
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
switch {
|
||||
case strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5"):
|
||||
return s.fallbackPrices["claude-haiku-4.5"]
|
||||
case strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5"):
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
|
||||
@@ -61,7 +61,11 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -105,6 +109,24 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -126,7 +148,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -144,6 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 12. Handle error response with failover
|
||||
|
||||
@@ -58,7 +58,11 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
if account.Platform == PlatformKiro {
|
||||
if next := account.GetMappedModel(originalModel); next != "" {
|
||||
mappedModel = next
|
||||
}
|
||||
} else if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
@@ -102,6 +106,24 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
var resp *http.Response
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
resp, _, err = s.openKiroAnthropicStreamResponse(ctx, account, anthropicBody, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
} else {
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -123,7 +145,7 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -141,6 +163,7 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 12. Handle error response with failover
|
||||
|
||||
@@ -44,6 +44,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -912,6 +912,7 @@ type claudeOAuthNormalizeOptions struct {
|
||||
injectMetadata bool
|
||||
metadataUserID string
|
||||
stripSystemCacheControl bool
|
||||
preserveToolChoice bool
|
||||
}
|
||||
|
||||
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
|
||||
@@ -1126,6 +1127,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !opts.preserveToolChoice && gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
|
||||
if !gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
@@ -4518,7 +4525,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
|
||||
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
|
||||
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{
|
||||
stripSystemCacheControl: !systemRewritten,
|
||||
preserveToolChoice: account.Platform == PlatformKiro,
|
||||
}
|
||||
if s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil && fp != nil {
|
||||
@@ -8968,6 +8978,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
if account.Platform == PlatformKiro && account.Type == AccountTypeOAuth {
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "Token counting is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射:
|
||||
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
|
||||
|
||||
@@ -49,8 +49,10 @@ func getWebSearchManager() *websearch.Manager {
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
|
||||
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → platform-specific policy.
|
||||
// Anthropic API Key keeps the existing account-level override:
|
||||
// "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
// Kiro OAuth uses channel-level switch only.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
@@ -62,13 +64,29 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
return false
|
||||
}
|
||||
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case account.Platform == PlatformAnthropic && account.Type == AccountTypeAPIKey:
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
return false
|
||||
default: // "default" → follow channel config
|
||||
default:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
}
|
||||
case account.Platform == PlatformKiro && account.Type == AccountTypeOAuth:
|
||||
return s.isChannelWebSearchEmulationEnabled(ctx, groupID, account.Platform)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) isChannelWebSearchEmulationEnabled(ctx context.Context, groupID *int64, platform string) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
@@ -76,8 +94,7 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(account.Platform)
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(platform)
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
@@ -249,7 +266,12 @@ func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||
"message": map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
return flushSSEJSON(w, "message_start", evt)
|
||||
@@ -260,12 +282,26 @@ func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
"name": toolNameWebSearch, "input": map[string]any{},
|
||||
},
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
|
||||
return err
|
||||
}
|
||||
inputJSON, err := json.Marshal(map[string]string{"query": query})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal query: %w", err)
|
||||
}
|
||||
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
@@ -362,16 +398,15 @@ func writeWebSearchNonStreamResponse(
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||
blocks := make([]map[string]string, 0, len(results))
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]any {
|
||||
blocks := make([]map[string]any, 0, len(results))
|
||||
for _, r := range results {
|
||||
block := map[string]string{
|
||||
block := map[string]any{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
}
|
||||
if r.Snippet != "" {
|
||||
block["page_content"] = r.Snippet
|
||||
"encrypted_content": r.Snippet,
|
||||
"page_age": nil,
|
||||
}
|
||||
if r.PageAge != "" {
|
||||
block["page_age"] = r.PageAge
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +14,31 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteSSEMessageStart_IncludesCacheUsageFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEMessageStart(rec, "msg_test", "claude-sonnet-4-5")
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"cache_creation_input_tokens":0`)
|
||||
require.Contains(t, body, `"cache_read_input_tokens":0`)
|
||||
}
|
||||
|
||||
func TestWriteSSEServerToolUse_UsesInputJSONDelta(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
err := writeSSEServerToolUse(rec, "srvtoolu_test", "golang concurrency", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `event: content_block_start`)
|
||||
require.Contains(t, body, `"type":"server_tool_use"`)
|
||||
require.Contains(t, body, `"input":{}`)
|
||||
require.Contains(t, body, `event: content_block_delta`)
|
||||
require.Contains(t, body, `"type":"input_json_delta"`)
|
||||
require.Contains(t, body, `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, body, `event: content_block_stop`)
|
||||
}
|
||||
|
||||
// --- isOnlyWebSearchToolInBody ---
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||
@@ -111,12 +137,12 @@ func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||
require.Equal(t, "snippet a", blocks[0]["encrypted_content"])
|
||||
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||
// Second result has no PageAge
|
||||
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||
_, hasPageAge := blocks[1]["page_age"]
|
||||
require.False(t, hasPageAge)
|
||||
require.Equal(t, "snippet b", blocks[1]["encrypted_content"])
|
||||
require.Nil(t, blocks[1]["page_age"])
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
@@ -126,8 +152,8 @@ func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
|
||||
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||
_, hasContent := blocks[0]["page_content"]
|
||||
require.False(t, hasContent)
|
||||
require.Equal(t, "", blocks[0]["encrypted_content"])
|
||||
require.Nil(t, blocks[0]["page_age"])
|
||||
}
|
||||
|
||||
// --- buildTextSummary ---
|
||||
@@ -165,6 +191,14 @@ func newAnthropicAPIKeyAccount(mode string) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
func newKiroOAuthAccount() *Account {
|
||||
return &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
|
||||
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
@@ -378,3 +412,75 @@ func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
|
||||
// nil channelService + default mode → returns false
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelEnabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 11,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: true},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(77, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(77)
|
||||
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroChannelDisabledFallsBack(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 12,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformKiro: false},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(78, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
groupID := int64(78)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_KiroRequiresChannelConfig(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
|
||||
account := newKiroOAuthAccount()
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,623 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/kirocooldown"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type kiroUsageCooldownStore struct {
|
||||
state *kirocooldown.State
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ReserveRequest(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuccess(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) Mark429(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) MarkSuspended(context.Context, string) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) GetState(context.Context, string) (*kirocooldown.State, error) {
|
||||
return s.state, s.err
|
||||
}
|
||||
|
||||
func (s *kiroUsageCooldownStore) ClearEarliestTransientCooldown(context.Context, []string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func kiroFloatPtr(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Kiro(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"kiro": true},
|
||||
},
|
||||
}
|
||||
|
||||
require.True(t, c.IsWebSearchEmulationEnabled("kiro"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_kiro_billing_normalized",
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroMapsCredits(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
resetAt := time.Now().Add(10 * 24 * time.Hour).Unix()
|
||||
bonusExpiry := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, "arn:aws:codewhisperer:us-east-1:123456789012:profile/SOCIAL", r.URL.Query().Get("profileArn"))
|
||||
require.Equal(t, kiroUsageOrigin, r.URL.Query().Get("origin"))
|
||||
require.Equal(t, kiroUsageResourceType, r.URL.Query().Get("resourceType"))
|
||||
require.Equal(t, "Bearer kiro-access-token", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "*/*", r.Header.Get("Accept"))
|
||||
require.True(t, strings.Contains(r.Header.Get("User-Agent"), "KiroIDE-"))
|
||||
require.True(t, strings.Contains(r.Header.Get("X-Amz-User-Agent"), "KiroIDE-"))
|
||||
require.Equal(t, "vibe", r.Header.Get("x-amzn-kiro-agent-mode"))
|
||||
require.Equal(t, "true", r.Header.Get("x-amzn-codewhisperer-optout"))
|
||||
require.NotEmpty(t, r.Header.Get("Amz-Sdk-Invocation-Id"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageConfiguration": {"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+","type":"Q_DEVELOPER_STANDALONE_PRO_PLUS"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentOveragesWithPrecision":2,
|
||||
"currentUsageWithPrecision":125,
|
||||
"freeTrialInfo":{
|
||||
"currentUsageWithPrecision":25,
|
||||
"freeTrialExpiry":` + strconv.FormatInt(bonusExpiry, 10) + `,
|
||||
"freeTrialStatus":"ACTIVE",
|
||||
"usageLimitWithPrecision":500
|
||||
},
|
||||
"nextDateReset": ` + strconv.FormatInt(resetAt, 10) + `,
|
||||
"overageCharges":0.08,
|
||||
"resourceType":"CREDIT",
|
||||
"usageLimitWithPrecision":2000
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, "active", usage.Source)
|
||||
require.Equal(t, "KIRO PRO+", usage.KiroSubscriptionName)
|
||||
require.Equal(t, "Q_DEVELOPER_STANDALONE_PRO_PLUS", usage.KiroSubscriptionType)
|
||||
require.True(t, usage.KiroOveragesEnabled)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 125.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Equal(t, 2000.0, usage.KiroCredit.UsageLimit)
|
||||
require.InDelta(t, 6.25, usage.KiroCredit.PercentageUsed, 0.001)
|
||||
require.NotNil(t, usage.KiroBonus)
|
||||
require.Equal(t, 25.0, usage.KiroBonus.CurrentUsage)
|
||||
require.Equal(t, 500.0, usage.KiroBonus.UsageLimit)
|
||||
require.NotNil(t, usage.KiroOverage)
|
||||
require.Equal(t, "$", usage.KiroOverage.CurrencySymbol)
|
||||
require.Equal(t, 2.0, usage.KiroOverage.CurrentOverages)
|
||||
require.Equal(t, 0.08, usage.KiroOverage.OverageCharges)
|
||||
require.NotNil(t, usage.KiroResetAt)
|
||||
require.Equal(t, kiroQuotaStateOverageActive, usage.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", usage.KiroQuotaReason)
|
||||
require.NotNil(t, usage.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroActiveUsesCachedSnapshotWithinTTL(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 702,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":300,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer successServer.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return successServer.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.NotNil(t, firstUsage.KiroCredit)
|
||||
require.Equal(t, 300.0, firstUsage.KiroCredit.CurrentUsage)
|
||||
|
||||
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"message":"temporary failure"}`, http.StatusInternalServerError)
|
||||
}))
|
||||
defer failingServer.Close()
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return failingServer.URL }
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 300.0, usage.KiroCredit.CurrentUsage)
|
||||
require.Empty(t, usage.Error)
|
||||
require.Empty(t, usage.ErrorCode)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroBuilderIDWithoutProfileArnOmitsProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 703,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "BuilderId",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 42.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroEnterpriseUsesCredentialProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 707,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "us-east-1",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REALENTERPRISE"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":64,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.KiroCredit)
|
||||
require.Equal(t, 64.0, usage.KiroCredit.CurrentUsage)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroUsesAPIRegionForUsageRequest(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 709,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"api_region": "eu-west-1",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
const resolvedProfileArn = "arn:aws:codewhisperer:eu-west-1:123456789012:profile/REALAPIREGION"
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Equal(t, resolvedProfileArn, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":11,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{"eu-west-1"}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroOmitsProfileArnAndUsesDefaultRegionWithoutAPIRegionOrProfileArn(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 710,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "AWS",
|
||||
"auth_method": "idc",
|
||||
"region": "ap-northeast-2",
|
||||
"start_url": "https://d-example.awsapps.com/start",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
gotRegions := make([]string, 0, 2)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/getUsageLimits", r.URL.Path)
|
||||
require.Empty(t, r.URL.Query().Get("profileArn"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":7,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(region string) string {
|
||||
gotRegions = append(gotRegions, region)
|
||||
return server.URL
|
||||
}
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, []string{kiroDefaultRegion}, gotRegions)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroIncludesRuntimeCooldownState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 704,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(90 * time.Second),
|
||||
Remaining: 90 * time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":42,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
usage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cooldown", usage.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, usage.KiroRuntimeReason)
|
||||
require.NotNil(t, usage.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesProfileError(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: `{"message":"profileArn is required for this request."}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeForbidden, info.ErrorCode)
|
||||
require.False(t, info.NeedsReauth)
|
||||
}
|
||||
|
||||
func TestBuildKiroDegradedUsage_ClassifiesOverageExhausted(t *testing.T) {
|
||||
info := buildKiroDegradedUsage(&kiroUsageHTTPError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: `{"message":"overage exhausted for this billing window"}`,
|
||||
})
|
||||
|
||||
require.Equal(t, errorCodeNetworkError, info.ErrorCode)
|
||||
require.Equal(t, kiroQuotaStateOverageExhausted, info.KiroQuotaState)
|
||||
require.Contains(t, info.KiroQuotaReason, "overage exhausted")
|
||||
}
|
||||
|
||||
func TestAccountUsageService_GetUsage_KiroCachesErrorSnapshotWhenRefreshFailsWithoutPriorSuccess(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 708,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
requestCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
http.Error(w, `{"message":"FEATURE_NOT_SUPPORTED","reason":"FEATURE_NOT_SUPPORTED"}`, http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
firstUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstUsage)
|
||||
require.Equal(t, errorCodeForbidden, firstUsage.ErrorCode)
|
||||
|
||||
secondUsage, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, secondUsage)
|
||||
require.Equal(t, errorCodeForbidden, secondUsage.ErrorCode)
|
||||
require.Equal(t, 1, requestCount)
|
||||
}
|
||||
|
||||
func TestMapKiroUsageToInfo_CreditsExhaustedWithoutOverages(t *testing.T) {
|
||||
info := mapKiroUsageToInfo(&kiroUsageLimitsResponse{
|
||||
NextDateReset: "2099-03-13T12:00:00Z",
|
||||
OverageConfiguration: kiroOverageConfiguration{
|
||||
OverageStatus: "DISABLED",
|
||||
},
|
||||
UsageBreakdownList: []kiroUsageBreakdown{
|
||||
{
|
||||
ResourceType: "CREDIT",
|
||||
CurrentUsageWithPrecision: kiroFloatPtr(2000),
|
||||
UsageLimitWithPrecision: kiroFloatPtr(2000),
|
||||
CurrentOveragesWithPrecision: kiroFloatPtr(0),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, kiroQuotaStateCreditsExhausted, info.KiroQuotaState)
|
||||
require.Equal(t, "credits_exhausted", info.KiroQuotaReason)
|
||||
require.NotNil(t, info.KiroQuotaResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeState(t *testing.T) {
|
||||
svc := NewAccountUsageService(nil, nil, nil, nil, nil, NewUsageCache(), nil, nil).
|
||||
SetKiroCooldownStore(&kiroUsageCooldownStore{
|
||||
state: &kirocooldown.State{
|
||||
Active: true,
|
||||
Reason: kirocooldown.CooldownReason429,
|
||||
CooldownUntil: time.Now().Add(2 * time.Minute),
|
||||
Remaining: 2 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
account := &Account{
|
||||
ID: 705,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), account)
|
||||
require.Equal(t, "cooldown", account.KiroRuntimeState)
|
||||
require.Equal(t, kirocooldown.CooldownReason429, account.KiroRuntimeReason)
|
||||
require.NotNil(t, account.KiroRuntimeResetAt)
|
||||
}
|
||||
|
||||
func TestAccountUsageService_EnrichAccountWithKiroRuntimeStateIncludesCachedQuotaState(t *testing.T) {
|
||||
account := Account{
|
||||
ID: 706,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "kiro-access-token",
|
||||
"provider": "Github",
|
||||
"auth_method": "social",
|
||||
},
|
||||
}
|
||||
repo := &stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
svc := NewAccountUsageService(repo, nil, nil, nil, nil, NewUsageCache(), nil, nil)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"nextDateReset":"2099-03-13T12:00:00Z",
|
||||
"overageConfiguration":{"overageStatus":"ENABLED"},
|
||||
"subscriptionInfo": {"subscriptionTitle":"KIRO PRO+"},
|
||||
"usageBreakdownList": [{
|
||||
"currency":"USD",
|
||||
"currentUsageWithPrecision":2000,
|
||||
"currentOveragesWithPrecision":4,
|
||||
"overageCharges":0.2,
|
||||
"usageLimitWithPrecision":2000,
|
||||
"resourceType":"CREDIT"
|
||||
}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
prevResolver := resolveKiroRuntimeEndpoint
|
||||
resolveKiroRuntimeEndpoint = func(_ string) string { return server.URL }
|
||||
defer func() { resolveKiroRuntimeEndpoint = prevResolver }()
|
||||
|
||||
_, err := svc.GetUsage(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
target := &Account{
|
||||
ID: account.ID,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"access_token": "kiro-access-token"},
|
||||
}
|
||||
svc.EnrichAccountWithKiroRuntimeState(context.Background(), target)
|
||||
|
||||
require.Equal(t, kiroQuotaStateOverageActive, target.KiroQuotaState)
|
||||
require.Equal(t, "overages_enabled", target.KiroQuotaReason)
|
||||
require.NotNil(t, target.KiroQuotaResetAt)
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBaseURL_KiroAPIKeyWithoutBaseURLReturnsEmpty(t *testing.T) {
|
||||
account := Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformKiro,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
require.Empty(t, account.GetBaseURL())
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveDefaultsTo25Seconds(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
got := svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro})
|
||||
|
||||
require.Equal(t, 25*time.Second, got)
|
||||
}
|
||||
|
||||
func TestGatewayServiceKiroStreamKeepaliveUsesKiroSpecificConfig(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamKeepaliveInterval: 10,
|
||||
KiroStreamKeepaliveInterval: 25,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, 25*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformKiro}))
|
||||
require.Equal(t, 10*time.Second, svc.streamKeepaliveIntervalForAccount(&Account{Platform: PlatformAnthropic}))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_KiroHaiku45UsesDedicatedFallback(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, nil)
|
||||
|
||||
pricing, err := svc.GetModelPricing("claude-haiku-4-5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestForwardResultBillingModel_NormalizesKiroModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
upstreamModel string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "kiro claude sonnet 4.6 uses pricing key format",
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
upstreamModel: "claude-sonnet-4.6",
|
||||
want: "claude-sonnet-4-6",
|
||||
},
|
||||
{
|
||||
name: "falls back to upstream when requested model empty",
|
||||
requestedModel: "",
|
||||
upstreamModel: "claude-haiku-4-5",
|
||||
want: "claude-haiku-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, forwardResultBillingModel(tt.requestedModel, tt.upstreamModel))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_NormalizesKiroBillingModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.billingService = NewBillingService(svc.cfg, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-sonnet-4-6": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
OutputCostPerToken: 10e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("claude-sonnet-4-6", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_billing_normalized",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "claude-sonnet-4-6",
|
||||
UpstreamModel: "claude-sonnet-4.6",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4-6", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "claude-sonnet-4.6", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_KiroUnknownPricingFallsBackToConservativeCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost(kiroConservativeFallbackBillingModel, UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_kiro_auto_fallback_cost",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Model: "auto",
|
||||
UpstreamModel: "auto",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 601, Quota: 100},
|
||||
User: &User{ID: 701},
|
||||
Account: &Account{ID: 801, Platform: PlatformKiro},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildKiroAccountKeyIgnoresAccessToken(t *testing.T) {
|
||||
@@ -144,6 +145,60 @@ func TestBuildKiroPayloadForAccountPropagatesThinkingHeaders(t *testing.T) {
|
||||
require.Contains(t, string(payload), "\\u003cthinking_mode\\u003eenabled\\u003c/thinking_mode\\u003e")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountPreservesThinkingAliasAfterMapping(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6-thinking",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "claude-opus-4.6", gjson.GetBytes(payload, "conversationState.currentMessage.userInputMessage.modelId").String())
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.Contains(t, systemContent, "<thinking_mode>adaptive</thinking_mode>")
|
||||
require.Contains(t, systemContent, "<thinking_effort>high</thinking_effort>")
|
||||
}
|
||||
|
||||
func TestBuildKiroPayloadForAccountDoesNotEnableThinkingForNonThinkingAlias(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformKiro,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4.6",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
payload, err := buildKiroPayloadForAccount(
|
||||
context.Background(),
|
||||
account,
|
||||
body,
|
||||
"claude-opus-4.6",
|
||||
"kiro-access-token",
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemContent := gjson.GetBytes(payload, "conversationState.history.0.userInputMessage.content").String()
|
||||
require.NotContains(t, systemContent, "<thinking_mode>")
|
||||
}
|
||||
|
||||
func TestKiroAPIRegionPrefersAPIRegionOverProfileARN(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
|
||||
@@ -81,7 +81,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
}
|
||||
|
||||
if parsed.Stream {
|
||||
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, c.Request.Header)
|
||||
resp, _, err := s.openKiroAnthropicStreamResponse(ctx, account, body, mappedModel, originalModel, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -146,7 +146,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
return nil, fmt.Errorf("kiro requires oauth token, got %s", tokenType)
|
||||
}
|
||||
if isOnlyWebSearchToolInBody(body) {
|
||||
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, token, c.Request.Header)
|
||||
webSearchResult, webSearchErr := s.executeKiroWebSearch(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
switch {
|
||||
case errors.Is(webSearchErr, errKiroWebSearchFallback):
|
||||
case webSearchErr == nil:
|
||||
@@ -194,7 +194,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
}
|
||||
|
||||
inputTokens := estimateKiroInputTokens(body)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, token, c.Request.Header)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, body, mappedModel, originalModel, token, c.Request.Header)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -253,7 +253,7 @@ func (s *GatewayService) forwardKiroMessages(ctx context.Context, c *gin.Context
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel string, headers http.Header) (*http.Response, int, error) {
|
||||
func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel string, headers http.Header) (*http.Response, int, error) {
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -268,7 +268,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "text/event-stream")
|
||||
go func() {
|
||||
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, token, inputTokens, headers, pw)
|
||||
streamErr := s.streamKiroWebSearchAsAnthropic(ctx, account, anthropicBody, mappedModel, requestModel, token, inputTokens, headers, pw)
|
||||
if streamErr != nil {
|
||||
_ = pw.CloseWithError(streamErr)
|
||||
return
|
||||
@@ -282,7 +282,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, token, headers)
|
||||
resp, requestCtx, err := s.executeKiroUpstream(ctx, account, anthropicBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -318,7 +318,7 @@ func (s *GatewayService) openKiroAnthropicStreamResponse(ctx context.Context, ac
|
||||
}, inputTokens, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||
func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*http.Response, kiropkg.KiroRequestContext, error) {
|
||||
var requestCtx kiropkg.KiroRequestContext
|
||||
if err := s.checkAndWaitKiroCooldown(ctx, buildKiroAccountKey(account)); err != nil {
|
||||
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
|
||||
@@ -329,7 +329,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
|
||||
|
||||
modelID := kiropkg.MapModel(mappedModel)
|
||||
currentToken := token
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
|
||||
buildResult, err := buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func (s *GatewayService) executeKiroUpstream(ctx context.Context, account *Accou
|
||||
if refreshErr == nil && strings.TrimSpace(refreshedToken) != "" {
|
||||
currentToken = refreshedToken
|
||||
accountKey = buildKiroAccountKey(account)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, mappedModel, headers)
|
||||
buildResult, err = buildKiroPayloadForAccountWithRepo(ctx, s.accountRepo, account, anthropicBody, modelID, currentToken, requestModel, headers)
|
||||
if err != nil {
|
||||
return nil, requestCtx, err
|
||||
}
|
||||
@@ -501,9 +501,25 @@ func buildKiroPayloadForAccount(ctx context.Context, account *Account, anthropic
|
||||
|
||||
func buildKiroPayloadForAccountWithRepo(ctx context.Context, repo AccountRepository, account *Account, anthropicBody []byte, modelID, token, requestModel string, headers http.Header) (*kiropkg.KiroBuildResult, error) {
|
||||
profileArn := resolveKiroPayloadProfileArn(account)
|
||||
anthropicBody = prepareKiroPayloadBodyForRequestModel(anthropicBody, requestModel)
|
||||
return kiropkg.BuildKiroPayloadWithContext(anthropicBody, modelID, profileArn, "AI_EDITOR", headers)
|
||||
}
|
||||
|
||||
func prepareKiroPayloadBodyForRequestModel(anthropicBody []byte, requestModel string) []byte {
|
||||
requestModel = strings.TrimSpace(requestModel)
|
||||
if requestModel == "" || !strings.Contains(strings.ToLower(requestModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
bodyModel := strings.TrimSpace(gjson.GetBytes(anthropicBody, "model").String())
|
||||
if bodyModel == "" || strings.EqualFold(bodyModel, requestModel) || strings.Contains(strings.ToLower(bodyModel), "thinking") {
|
||||
return anthropicBody
|
||||
}
|
||||
if next, ok := setJSONValueBytes(anthropicBody, "model", requestModel); ok {
|
||||
return next
|
||||
}
|
||||
return anthropicBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) markKiroAuthTemporarilyUnavailable(ctx context.Context, account *Account, statusCode int, body string) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
|
||||
@@ -283,7 +283,7 @@ func TestExecuteKiroUpstreamCooldownReturnsFailoverError(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "token", nil)
|
||||
_, _, err := svc.executeKiroUpstream(context.Background(), &Account{ID: 42}, []byte(`{}`), "claude-sonnet-4-6", "claude-sonnet-4-6", "token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
@@ -323,7 +323,7 @@ func TestExecuteKiroUpstreamInvalidModelDoesNotRefreshProfileArnOrRetry(t *testi
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "test-token", nil)
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-opus-4-6", "claude-opus-4-6", "test-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
require.Len(t, upstream.requests, 1)
|
||||
@@ -461,7 +461,7 @@ func TestExecuteKiroUpstreamMonthlyRequestCountRateLimitsUntilNextMonthAndFailov
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
@@ -501,7 +501,7 @@ func TestExecuteKiroUpstreamPlain402FailoversWithoutTempUnschedule(t *testing.T)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "test-token", nil)
|
||||
_, _, err = svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "test-token", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
@@ -550,7 +550,7 @@ func TestExecuteKiroUpstreamInvalidGrantForceRefreshSetsErrorWithoutTempUnschedu
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "stale-token", nil)
|
||||
resp, _, err := svc.executeKiroUpstream(context.Background(), account, payloadBytes, "claude-sonnet-4-6", "claude-sonnet-4-6", "stale-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
|
||||
@@ -102,7 +102,7 @@ func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens in
|
||||
}
|
||||
|
||||
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
|
||||
) error {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
@@ -141,7 +141,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
return errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -190,7 +190,7 @@ func (s *GatewayService) streamKiroWebSearchAsAnthropic(
|
||||
return fmt.Errorf("kiro web search exceeded max iterations")
|
||||
}
|
||||
|
||||
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
|
||||
query := kiropkg.ExtractSearchQuery(anthropicBody)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil, errKiroWebSearchFallback
|
||||
@@ -227,7 +227,7 @@ func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Acco
|
||||
return nil, errKiroWebSearchFallback
|
||||
}
|
||||
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, token, headers)
|
||||
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testi
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
privacyCalls := 0
|
||||
service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) {
|
||||
privacyCalls++
|
||||
|
||||
@@ -271,6 +271,9 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
|
||||
out := *ev
|
||||
|
||||
out.Platform = strings.TrimSpace(out.Platform)
|
||||
out.RequestedModel = truncateString(strings.TrimSpace(out.RequestedModel), 128)
|
||||
out.MappedModel = truncateString(strings.TrimSpace(out.MappedModel), 128)
|
||||
out.KiroModelID = truncateString(strings.TrimSpace(out.KiroModelID), 128)
|
||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||
|
||||
|
||||
@@ -612,7 +612,7 @@ func normalizeModelNameForPricing(model string) string {
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-2.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = canonicalModelNameForPricing(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
model = strings.TrimPrefix(model, "publishers/google/models/")
|
||||
@@ -625,8 +625,32 @@ func normalizeModelNameForPricing(model string) string {
|
||||
}
|
||||
|
||||
model = strings.TrimLeft(model, "/")
|
||||
return canonicalModelNameForPricing(model)
|
||||
}
|
||||
|
||||
func canonicalModelNameForPricing(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch model {
|
||||
case "claude-opus-4.5":
|
||||
return "claude-opus-4-5"
|
||||
case "claude-opus-4.6":
|
||||
return "claude-opus-4-6"
|
||||
case "claude-opus-4.7":
|
||||
return "claude-opus-4-7"
|
||||
case "claude-sonnet-4.5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "claude-sonnet-4.6":
|
||||
return "claude-sonnet-4-6"
|
||||
case "claude-haiku-4.5":
|
||||
return "claude-haiku-4-5"
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func lastSegment(model string) string {
|
||||
if idx := strings.LastIndex(model, "/"); idx != -1 {
|
||||
@@ -671,8 +695,10 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
{name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
|
||||
{name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
|
||||
{name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
|
||||
{name: "sonnet-4.6", match: []string{"claude-sonnet-4-6", "claude-sonnet-4.6"}},
|
||||
{name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
|
||||
{name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
|
||||
{name: "haiku-4.5", match: []string{"claude-haiku-4-5", "claude-haiku-4.5"}},
|
||||
{name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
|
||||
{name: "sonnet-3", match: []string{"claude-3-sonnet"}},
|
||||
{name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
|
||||
@@ -710,6 +736,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "sonnet"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
|
||||
fallbackName = "sonnet-4.6"
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "sonnet-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
@@ -719,6 +747,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
case strings.Contains(model, "haiku"):
|
||||
switch {
|
||||
case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
|
||||
fallbackName = "haiku-4.5"
|
||||
case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
|
||||
fallbackName = "haiku-3.5"
|
||||
default:
|
||||
|
||||
@@ -481,7 +481,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
|
||||
@@ -783,7 +783,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity, PlatformKiro}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
@@ -154,7 +154,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Platform: PlatformGemini,
|
||||
@@ -180,7 +180,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformGemini,
|
||||
@@ -207,7 +207,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -235,7 +235,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformGemini,
|
||||
@@ -263,7 +263,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformOpenAI, // OpenAI OAuth 账户
|
||||
@@ -290,7 +290,7 @@ func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
resetAt := time.Now().Add(30 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
@@ -325,7 +325,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformGemini,
|
||||
@@ -354,7 +354,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformGemini,
|
||||
@@ -381,7 +381,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -408,7 +408,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Platform: PlatformAntigravity,
|
||||
@@ -436,7 +436,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 15,
|
||||
@@ -479,7 +479,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Platform: tt.platform,
|
||||
@@ -504,7 +504,7 @@ func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedul
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
account := &Account{
|
||||
ID: 18,
|
||||
Platform: PlatformOpenAI,
|
||||
@@ -585,7 +585,7 @@ func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, in
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
@@ -720,7 +720,7 @@ func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
|
||||
@@ -89,6 +89,10 @@ security:
|
||||
enabled: false
|
||||
# Allowed upstream hosts for API proxying
|
||||
# 允许代理的上游 API 主机列表
|
||||
# If you enable Kiro OAuth / IDC, also allow Kiro auth and AWS SSO OIDC hosts.
|
||||
# 如果启用 Kiro OAuth / IDC,请同时放行 Kiro 鉴权域名和 AWS SSO OIDC 域名。
|
||||
# If you enable Kiro runtime forwarding, also allow the corresponding AWS Q API endpoint.
|
||||
# 如果启用 Kiro 运行时转发,还需要放行对应的 AWS Q API 域名。
|
||||
upstream_hosts:
|
||||
- "api.openai.com"
|
||||
- "api.anthropic.com"
|
||||
@@ -97,6 +101,11 @@ security:
|
||||
- "api.minimaxi.com"
|
||||
- "generativelanguage.googleapis.com"
|
||||
- "cloudcode-pa.googleapis.com"
|
||||
- "prod.us-east-1.auth.desktop.kiro.dev"
|
||||
- "oidc.us-east-1.amazonaws.com"
|
||||
- "oidc.*.amazonaws.com"
|
||||
- "device.sso.*.amazonaws.com"
|
||||
- "q.*.amazonaws.com"
|
||||
- "*.openai.azure.com"
|
||||
# Allowed hosts for pricing data download
|
||||
# 允许下载定价数据的主机列表
|
||||
@@ -340,6 +349,9 @@ gateway:
|
||||
# Stream keepalive interval (seconds), 0=disable
|
||||
# 流式 keepalive 间隔(秒),0=禁用
|
||||
stream_keepalive_interval: 10
|
||||
# Kiro stream keepalive interval (seconds), 0=use default 25
|
||||
# Kiro 流式 keepalive 间隔(秒),0=使用默认 25
|
||||
kiro_stream_keepalive_interval: 25
|
||||
# SSE max line size in bytes (default: 40MB)
|
||||
# SSE 单行最大字节数(默认 40MB)
|
||||
max_line_size: 41943040
|
||||
|
||||
@@ -558,6 +558,13 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getKiroDefaultModelMapping(): Promise<Record<string, string>> {
|
||||
const { data } = await apiClient.get<Record<string, string>>(
|
||||
'/admin/accounts/kiro/default-model-mapping'
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OpenAI token using refresh token
|
||||
* @param refreshToken - The refresh token
|
||||
|
||||
@@ -13,6 +13,15 @@ vi.mock('vue-i18n', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
i18n: {
|
||||
global: {
|
||||
t: (key: string) => key
|
||||
}
|
||||
},
|
||||
getLocale: () => 'en'
|
||||
}))
|
||||
|
||||
function makeAccount(overrides: Partial<Account>): Account {
|
||||
return {
|
||||
id: 1,
|
||||
@@ -35,6 +44,12 @@ function makeAccount(overrides: Partial<Account>): Account {
|
||||
overload_until: null,
|
||||
temp_unschedulable_until: null,
|
||||
temp_unschedulable_reason: null,
|
||||
kiro_quota_state: null,
|
||||
kiro_quota_reason: null,
|
||||
kiro_quota_reset_at: null,
|
||||
kiro_runtime_state: null,
|
||||
kiro_runtime_reason: null,
|
||||
kiro_runtime_reset_at: null,
|
||||
session_window_start: null,
|
||||
session_window_end: null,
|
||||
session_window_status: null,
|
||||
@@ -159,4 +174,96 @@ describe('AccountStatusIndicator', () => {
|
||||
// AICredits 积分耗尽状态应显示
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
|
||||
})
|
||||
|
||||
it('Kiro 运行时冷却在状态列复用限流展示', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 5,
|
||||
name: 'kiro-cooldown',
|
||||
platform: 'kiro',
|
||||
kiro_runtime_state: 'cooldown',
|
||||
kiro_runtime_reason: 'rate_limit_exceeded',
|
||||
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedAutoResume')
|
||||
expect(wrapper.text()).toContain('429')
|
||||
})
|
||||
|
||||
it('Kiro suspended 在状态列显示为 forbidden', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 6,
|
||||
name: 'kiro-suspended',
|
||||
platform: 'kiro',
|
||||
kiro_runtime_state: 'suspended',
|
||||
kiro_runtime_reason: 'account_suspended',
|
||||
kiro_runtime_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.forbidden')
|
||||
})
|
||||
|
||||
it('Kiro overage active 在状态列仍显示正常状态', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 7,
|
||||
name: 'kiro-overage-active',
|
||||
platform: 'kiro',
|
||||
kiro_quota_state: 'overage_active',
|
||||
kiro_quota_reason: 'overages_enabled',
|
||||
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.active')
|
||||
expect(wrapper.text()).not.toContain('admin.accounts.status.overageActive')
|
||||
})
|
||||
|
||||
it('Kiro overage exhausted 在状态列显示危险徽章', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 8,
|
||||
name: 'kiro-overage-exhausted',
|
||||
platform: 'kiro',
|
||||
kiro_quota_state: 'overage_exhausted',
|
||||
kiro_quota_reason: 'overage disabled after quota exhaustion',
|
||||
kiro_quota_reset_at: '2099-03-15T00:00:00Z'
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.overageExhausted')
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -25,6 +25,15 @@ vi.mock('vue-i18n', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
i18n: {
|
||||
global: {
|
||||
t: (key: string) => key
|
||||
}
|
||||
},
|
||||
getLocale: () => 'en'
|
||||
}))
|
||||
|
||||
function makeAccount(overrides: Partial<Account>): Account {
|
||||
return {
|
||||
id: 1,
|
||||
@@ -47,6 +56,12 @@ function makeAccount(overrides: Partial<Account>): Account {
|
||||
overload_until: null,
|
||||
temp_unschedulable_until: null,
|
||||
temp_unschedulable_reason: null,
|
||||
kiro_quota_state: null,
|
||||
kiro_quota_reason: null,
|
||||
kiro_quota_reset_at: null,
|
||||
kiro_runtime_state: null,
|
||||
kiro_runtime_reason: null,
|
||||
kiro_runtime_reset_at: null,
|
||||
session_window_start: null,
|
||||
session_window_end: null,
|
||||
session_window_status: null,
|
||||
@@ -530,6 +545,234 @@ describe('AccountUsageCell', () => {
|
||||
expect(wrapper.text()).toContain('7d|100|106540000')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会用 passive source 拉取并展示 credits 额度', async () => {
|
||||
const account = makeAccount({
|
||||
id: 3001,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
|
||||
getUsage.mockResolvedValue({
|
||||
source: 'passive',
|
||||
kiro_subscription_name: 'KIRO PRO+',
|
||||
kiro_overages_enabled: true,
|
||||
kiro_credit: {
|
||||
current_usage: 125,
|
||||
usage_limit: 2000,
|
||||
percentage_used: 6.25,
|
||||
},
|
||||
kiro_bonus: {
|
||||
current_usage: 25,
|
||||
usage_limit: 500,
|
||||
percentage_used: 5,
|
||||
days_remaining: 7,
|
||||
},
|
||||
kiro_overage: {
|
||||
current_overages: 2,
|
||||
overage_charges: 0.08,
|
||||
currency_symbol: '$',
|
||||
currency_code: 'USD',
|
||||
},
|
||||
kiro_reset_at: '2099-03-13T12:00:00Z',
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(getUsage).toHaveBeenCalledWith(3001, 'passive')
|
||||
expect(wrapper.emitted('kiroUsageMeta')?.[0]).toEqual([
|
||||
{
|
||||
plan_type: 'KIRO PRO+',
|
||||
kiro_overages_enabled: true
|
||||
}
|
||||
])
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroCredits')
|
||||
expect(wrapper.text()).toContain('125 / 2.0K')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroBonus')
|
||||
expect(wrapper.text()).toContain('25 / 500')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroDaysLeft')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroReset')
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.kiroOverage 2 ($0.08)')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会展示运行时冷却状态', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
source: 'passive',
|
||||
kiro_runtime_state: 'cooldown',
|
||||
kiro_runtime_reason: 'rate_limit_exceeded',
|
||||
kiro_runtime_reset_at: '2099-03-13T12:00:00Z',
|
||||
kiro_credit: {
|
||||
current_usage: 10,
|
||||
usage_limit: 100,
|
||||
percentage_used: 10,
|
||||
},
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3002,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimited')
|
||||
expect(wrapper.text()).toContain('admin.accounts.status.rateLimitedUntil')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会展示 overage active 与 exhausted 状态', async () => {
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
kiro_quota_state: 'overage_active',
|
||||
kiro_quota_reason: 'overages_enabled',
|
||||
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
|
||||
kiro_overages_enabled: true,
|
||||
kiro_credit: {
|
||||
current_usage: 2100,
|
||||
usage_limit: 2000,
|
||||
percentage_used: 100,
|
||||
},
|
||||
kiro_overage: {
|
||||
current_overages: 3,
|
||||
overage_charges: 0.12,
|
||||
currency_symbol: '$',
|
||||
},
|
||||
})
|
||||
|
||||
const activeWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3005,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(activeWrapper.text()).toContain('admin.accounts.status.overageActive')
|
||||
expect(activeWrapper.text()).not.toContain('admin.accounts.status.overageActiveUntil')
|
||||
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
kiro_quota_state: 'overage_exhausted',
|
||||
kiro_quota_reason: 'usage API error: overage exhausted',
|
||||
kiro_quota_reset_at: '2099-03-13T12:00:00Z',
|
||||
error: 'usage API error: kiro usage request failed (status 429): {"message":"overage exhausted"}',
|
||||
})
|
||||
|
||||
const exhaustedWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3006,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhausted')
|
||||
expect(exhaustedWrapper.text()).toContain('admin.accounts.status.overageExhaustedUntil')
|
||||
})
|
||||
|
||||
it('Kiro OAuth 会展示 profile 异常和 usage forbidden 徽章', async () => {
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
error_code: 'forbidden',
|
||||
error: 'usage API error: kiro usage request failed (status 400): {"message":"profileArn is required for this request."}',
|
||||
})
|
||||
|
||||
const profileWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3003,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(profileWrapper.text()).toContain('admin.accounts.usageError')
|
||||
|
||||
getUsage.mockResolvedValueOnce({
|
||||
source: 'passive',
|
||||
error_code: 'forbidden',
|
||||
error: 'usage API error: kiro usage request failed (status 403): {"message":"User is not authorized to access this feature."}',
|
||||
})
|
||||
|
||||
const forbiddenWrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3004,
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
extra: {},
|
||||
credentials: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(forbiddenWrapper.text()).toContain('admin.accounts.forbidden')
|
||||
})
|
||||
|
||||
it('Key 账号会展示 today stats 徽章并带 A/U 提示', async () => {
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
import { mount } from '@vue/test-utils'
|
||||
|
||||
vi.mock('@/stores/app', () => ({
|
||||
useAppStore: () => ({
|
||||
showSuccess: vi.fn(),
|
||||
showError: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useClipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copied: { value: false },
|
||||
copyToClipboard: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
import OAuthAuthorizationFlow from '../OAuthAuthorizationFlow.vue'
|
||||
|
||||
describe('OAuthAuthorizationFlow', () => {
|
||||
it('extracts code, state, and callback metadata from a full Kiro callback URL', async () => {
|
||||
const wrapper = mount(OAuthAuthorizationFlow, {
|
||||
props: {
|
||||
addMethod: 'oauth',
|
||||
platform: 'kiro',
|
||||
authUrl: 'https://example.com/authorize',
|
||||
sessionId: 'session-1'
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const textarea = wrapper.get('textarea')
|
||||
await textarea.setValue('http://localhost:49153/oauth/callback?code=abc123&state=state456&login_option=github')
|
||||
await nextTick()
|
||||
|
||||
expect((textarea.element as HTMLTextAreaElement).value).toBe('abc123')
|
||||
expect((wrapper.vm as any).oauthState).toBe('state456')
|
||||
expect((wrapper.vm as any).oauthCallbackPath).toBe('/oauth/callback')
|
||||
expect((wrapper.vm as any).oauthLoginOption).toBe('github')
|
||||
})
|
||||
})
|
||||
@@ -489,7 +489,8 @@ const platformOptions = [
|
||||
{ value: 'anthropic', label: 'Anthropic' },
|
||||
{ value: 'openai', label: 'OpenAI' },
|
||||
{ value: 'gemini', label: 'Gemini' },
|
||||
{ value: 'antigravity', label: 'Antigravity' }
|
||||
{ value: 'antigravity', label: 'Antigravity' },
|
||||
{ value: 'kiro', label: 'Kiro' }
|
||||
]
|
||||
|
||||
// Load rules when dialog opens
|
||||
|
||||
@@ -115,7 +115,7 @@ const labelClass = computed(() => {
|
||||
}
|
||||
|
||||
// 正常状态或无天数:根据平台显示主题色
|
||||
if (props.platform === 'anthropic') {
|
||||
if (props.platform === 'anthropic' || props.platform === 'kiro') {
|
||||
return `${base} bg-orange-200/60 text-orange-800 dark:bg-orange-800/40 dark:text-orange-300`
|
||||
}
|
||||
if (props.platform === 'openai') {
|
||||
@@ -129,7 +129,7 @@ const labelClass = computed(() => {
|
||||
|
||||
// Badge color based on platform and subscription type
|
||||
const badgeClass = computed(() => {
|
||||
if (props.platform === 'anthropic') {
|
||||
if (props.platform === 'anthropic' || props.platform === 'kiro') {
|
||||
// Claude: orange theme
|
||||
return isSubscription.value
|
||||
? 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
|
||||
@@ -91,6 +91,8 @@ const ratePillClass = computed(() => {
|
||||
return 'bg-green-50 text-green-700 dark:bg-green-900/20 dark:text-green-400'
|
||||
case 'gemini':
|
||||
return 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400'
|
||||
case 'kiro':
|
||||
return 'bg-amber-50 text-amber-700 dark:bg-amber-900/20 dark:text-amber-400'
|
||||
default: // antigravity and others
|
||||
return 'bg-violet-50 text-violet-700 dark:bg-violet-900/20 dark:text-violet-400'
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import { mount } from '@vue/test-utils'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import PlatformTypeBadge from '../PlatformTypeBadge.vue'
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key === 'admin.accounts.status.overageActive' ? 'Overage' : key
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
describe('PlatformTypeBadge', () => {
|
||||
it('shows Kiro overages tag next to the plan tag when enabled', () => {
|
||||
const wrapper = mount(PlatformTypeBadge, {
|
||||
props: {
|
||||
platform: 'kiro',
|
||||
type: 'oauth',
|
||||
planType: 'KIRO PRO+',
|
||||
overagesEnabled: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('KIRO PRO+')
|
||||
expect(wrapper.text()).toContain('Overage')
|
||||
})
|
||||
|
||||
it('does not show overages tag for non-Kiro accounts', () => {
|
||||
const wrapper = mount(PlatformTypeBadge, {
|
||||
props: {
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
planType: 'Pro',
|
||||
overagesEnabled: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).not.toContain('Overage')
|
||||
})
|
||||
})
|
||||
@@ -4,7 +4,12 @@ vi.mock('@/api/admin/accounts', () => ({
|
||||
getAntigravityDefaultModelMapping: vi.fn()
|
||||
}))
|
||||
|
||||
import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
|
||||
import {
|
||||
buildModelMappingObject,
|
||||
fetchKiroDefaultMappings,
|
||||
getModelsByPlatform,
|
||||
getPresetMappingsByPlatform
|
||||
} from '../useModelWhitelist'
|
||||
|
||||
describe('useModelWhitelist', () => {
|
||||
it('openai 模型列表包含 GPT-5.4 官方快照', () => {
|
||||
@@ -59,6 +64,49 @@ describe('useModelWhitelist', () => {
|
||||
expect(models.every((model) => !model.endsWith('-agentic') && !model.endsWith('-chat'))).toBe(true)
|
||||
})
|
||||
|
||||
it('kiro 模型列表只保留 Claude 模型', () => {
|
||||
const models = getModelsByPlatform('kiro')
|
||||
|
||||
expect(models).toEqual([
|
||||
'claude-opus-4-6',
|
||||
'claude-opus-4-6-thinking',
|
||||
'claude-sonnet-4-6',
|
||||
'claude-sonnet-4-6-thinking',
|
||||
'claude-opus-4-5-20251101',
|
||||
'claude-opus-4-5-20251101-thinking',
|
||||
'claude-sonnet-4-5-20250929',
|
||||
'claude-sonnet-4-5-20250929-thinking',
|
||||
'claude-haiku-4-5-20251001',
|
||||
'claude-haiku-4-5-20251001-thinking'
|
||||
])
|
||||
expect(models.every(model => model.startsWith('claude-'))).toBe(true)
|
||||
expect(models.some(model => model.endsWith('-agentic'))).toBe(false)
|
||||
expect(models.some(model => model.endsWith('-chat'))).toBe(false)
|
||||
expect(models).not.toContain('kiro-auto')
|
||||
expect(models).not.toContain('claude-opus-4-5')
|
||||
expect(models).not.toContain('claude-sonnet-4-5')
|
||||
expect(models).not.toContain('claude-sonnet-4')
|
||||
expect(models).not.toContain('claude-3-5-sonnet-20241022')
|
||||
expect(models).not.toContain('claude-3-5-haiku-20241022')
|
||||
expect(models).not.toContain('claude-haiku-4-5')
|
||||
expect(models).not.toContain('gpt-4o')
|
||||
expect(models).not.toContain('gpt-4')
|
||||
expect(models).not.toContain('gpt-4-turbo')
|
||||
expect(models).not.toContain('gpt-3.5-turbo')
|
||||
expect(models).not.toContain('deepseek-3-2')
|
||||
expect(models).not.toContain('minimax-m2-1')
|
||||
expect(models).not.toContain('qwen3-coder-next')
|
||||
})
|
||||
|
||||
it('claude 模型列表包含 dated 和 thinking 兼容别名', () => {
|
||||
const models = getModelsByPlatform('claude')
|
||||
|
||||
expect(models).toContain('claude-opus-4-6-thinking')
|
||||
expect(models).toContain('claude-opus-4-5-20251101-thinking')
|
||||
expect(models).toContain('claude-sonnet-4-20250514-thinking')
|
||||
expect(models).toContain('claude-haiku-4-5-20251001-thinking')
|
||||
})
|
||||
|
||||
it('whitelist 模式会忽略通配符条目', () => {
|
||||
const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
|
||||
expect(mapping).toEqual({
|
||||
@@ -81,4 +129,68 @@ describe('useModelWhitelist', () => {
|
||||
'gpt-5.4-mini': 'gpt-5.4-mini'
|
||||
})
|
||||
})
|
||||
|
||||
it('kiro 预设映射只暴露 Claude 入口', () => {
|
||||
const mappings = getPresetMappingsByPlatform('kiro')
|
||||
const mappingTargets = mappings.map(item => item.to)
|
||||
|
||||
expect(mappings.map(({ from, to }) => ({ from, to }))).toEqual([
|
||||
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
|
||||
])
|
||||
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
|
||||
expect(mappingTargets.every(model => model.startsWith('claude-'))).toBe(true)
|
||||
expect(mappingTargets.some(model => model.endsWith('-agentic'))).toBe(false)
|
||||
expect(mappingTargets.some(model => model.endsWith('-chat'))).toBe(false)
|
||||
expect(mappingTargets).not.toContain('kiro-auto')
|
||||
expect(mappingTargets.some(model => model.startsWith('kiro-'))).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-opus-4-5')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-sonnet-4-5')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-sonnet-4')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-3-5-sonnet-20241022')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-3-5-haiku-20241022')).toBe(false)
|
||||
expect(mappings.some(item => item.from === 'claude-haiku-4-5')).toBe(false)
|
||||
expect(mappingTargets).not.toContain('gpt-4o')
|
||||
expect(mappingTargets).not.toContain('gpt-4')
|
||||
expect(mappingTargets).not.toContain('gpt-4-turbo')
|
||||
expect(mappingTargets).not.toContain('gpt-3.5-turbo')
|
||||
expect(mappingTargets).not.toContain('deepseek-3.2')
|
||||
expect(mappingTargets).not.toContain('minimax-m2.1')
|
||||
expect(mappingTargets).not.toContain('qwen3-coder-next')
|
||||
})
|
||||
|
||||
it('kiro 默认映射会在前端填充所有可精确定价模型', async () => {
|
||||
const mappings = await fetchKiroDefaultMappings()
|
||||
|
||||
expect(mappings).toEqual(expect.arrayContaining([
|
||||
{ from: 'claude-opus-4-6', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-opus-4-6-thinking', to: 'claude-opus-4.6' },
|
||||
{ from: 'claude-sonnet-4-6', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-sonnet-4-6-thinking', to: 'claude-sonnet-4.6' },
|
||||
{ from: 'claude-opus-4-5-20251101', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-opus-4-5-20251101-thinking', to: 'claude-opus-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-sonnet-4-5-20250929-thinking', to: 'claude-sonnet-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4.5' },
|
||||
{ from: 'claude-haiku-4-5-20251001-thinking', to: 'claude-haiku-4.5' }
|
||||
]))
|
||||
expect(mappings).toHaveLength(10)
|
||||
expect(mappings.every(item => !item.from.startsWith('kiro-'))).toBe(true)
|
||||
expect(mappings.every(item => !item.to.startsWith('kiro-'))).toBe(true)
|
||||
expect(mappings.every(item => !item.from.endsWith('-agentic'))).toBe(true)
|
||||
expect(mappings.every(item => !item.to.endsWith('-agentic'))).toBe(true)
|
||||
expect(mappings.every(item => !item.from.endsWith('-chat'))).toBe(true)
|
||||
expect(mappings.every(item => !item.to.endsWith('-chat'))).toBe(true)
|
||||
expect(mappings.every(item => item.from.startsWith('claude-'))).toBe(true)
|
||||
expect(mappings.every(item => item.to.startsWith('claude-'))).toBe(true)
|
||||
expect(mappings.some(item => item.to === 'claude-opus-4-7')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -23,13 +23,21 @@ export const claudeModels = [
|
||||
'claude-3-5-sonnet-20241022', 'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-haiku-20241022',
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-sonnet-4-20250514-thinking', 'claude-opus-4-20250514-thinking',
|
||||
'claude-sonnet-4-20250514', 'claude-opus-4-20250514',
|
||||
'claude-opus-4-1-20250805',
|
||||
'claude-sonnet-4-5-thinking', 'claude-sonnet-4-5-20250929-thinking',
|
||||
'claude-haiku-4-5-thinking', 'claude-haiku-4-5-20251001-thinking',
|
||||
'claude-opus-4-5-thinking', 'claude-opus-4-5-20251101-thinking',
|
||||
'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001',
|
||||
'claude-opus-4-5-20251101',
|
||||
'claude-opus-4-6-thinking',
|
||||
'claude-opus-4-6',
|
||||
'claude-opus-4-7-thinking',
|
||||
'claude-opus-4-7',
|
||||
'claude-sonnet-4-6'
|
||||
'claude-sonnet-4-6-thinking',
|
||||
'claude-sonnet-4-6',
|
||||
'claude-2.1', 'claude-2.0', 'claude-instant-1.2'
|
||||
]
|
||||
|
||||
// Google Gemini
|
||||
|
||||
@@ -324,8 +324,8 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Web Search Emulation (Anthropic only, hidden when global disabled) -->
|
||||
<div v-if="section.platform === 'anthropic' && webSearchGlobalEnabled" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
||||
<!-- Web Search Emulation (supported platforms only, hidden when global disabled) -->
|
||||
<div v-if="supportsWebSearchEmulation(section.platform) && webSearchGlobalEnabled" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="text-xs font-medium text-gray-700 dark:text-gray-300">
|
||||
@@ -718,7 +718,7 @@ const form = reactive({
|
||||
let abortController: AbortController | null = null
|
||||
|
||||
// ── Platform config ──
|
||||
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
|
||||
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity', 'kiro']
|
||||
|
||||
// ── Helpers ──
|
||||
function formatDate(value: string): string {
|
||||
@@ -758,6 +758,10 @@ function getGroupsForPlatform(platform: GroupPlatform): AdminGroup[] {
|
||||
return allGroups.value.filter(g => g.platform === platform)
|
||||
}
|
||||
|
||||
function supportsWebSearchEmulation(platform: GroupPlatform): boolean {
|
||||
return platform === 'anthropic' || platform === 'kiro'
|
||||
}
|
||||
|
||||
// ── Group helpers ──
|
||||
const groupToChannelMap = computed(() => {
|
||||
const map = new Map<number, Channel>()
|
||||
@@ -1037,7 +1041,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
|
||||
const wsEmulation: Record<string, boolean> = {}
|
||||
for (const section of form.platforms) {
|
||||
if (!section.enabled) continue
|
||||
if (section.platform === 'anthropic') {
|
||||
if (supportsWebSearchEmulation(section.platform)) {
|
||||
wsEmulation[section.platform] = !!section.web_search_emulation
|
||||
}
|
||||
}
|
||||
|
||||
@@ -967,7 +967,8 @@ const platformFilterOptions = computed(() => [
|
||||
{ value: 'anthropic', label: 'Anthropic' },
|
||||
{ value: 'openai', label: 'OpenAI' },
|
||||
{ value: 'gemini', label: 'Gemini' },
|
||||
{ value: 'antigravity', label: 'Antigravity' }
|
||||
{ value: 'antigravity', label: 'Antigravity' },
|
||||
{ value: 'kiro', label: 'Kiro' }
|
||||
])
|
||||
|
||||
// Group options for assign (only subscription type groups)
|
||||
|
||||
@@ -111,7 +111,8 @@ const platformOptions = computed(() => [
|
||||
{ value: 'openai', label: 'OpenAI' },
|
||||
{ value: 'anthropic', label: 'Anthropic' },
|
||||
{ value: 'gemini', label: 'Gemini' },
|
||||
{ value: 'antigravity', label: 'Antigravity' }
|
||||
{ value: 'antigravity', label: 'Antigravity' },
|
||||
{ value: 'kiro', label: 'Kiro' }
|
||||
])
|
||||
|
||||
const timeRangeOptions = computed(() => [
|
||||
|
||||
Reference in New Issue
Block a user