diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 56ebc9e5..738d9f7b 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.133 +0.1.134 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f5dc748b..5dfd644a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -146,6 +146,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) + openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) kiroOAuthService := service.NewKiroOAuthService(proxyRepository) kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) @@ -154,7 +155,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) - accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, openAITokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) @@ -189,7 +190,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) - openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 37563692..7703df32 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1862,6 +1862,76 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// GetKiroUpstreamModels handles getting upstream Kiro models with the account credentials/proxy. +// GET /api/v1/admin/accounts/:id/kiro/upstream-models +func (h *AccountHandler) GetKiroUpstreamModels(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + if account.Platform != service.PlatformKiro { + response.BadRequest(c, "Account is not a Kiro account") + return + } + if h.accountTestService == nil { + response.InternalError(c, "Kiro account service not configured") + return + } + + models, err := h.accountTestService.FetchKiroUpstreamModels(c.Request.Context(), account) + if err != nil { + if errors.Is(err, service.ErrKiroModelListUnsupported) { + response.BadRequest(c, err.Error()) + return + } + response.InternalError(c, "Failed to fetch Kiro upstream model list: "+err.Error()) + return + } + response.Success(c, models) +} + +// GetOpenAIUpstreamModels handles getting upstream OpenAI models with the account credentials/proxy. +// GET /api/v1/admin/accounts/:id/openai/upstream-models +func (h *AccountHandler) GetOpenAIUpstreamModels(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + if account.Platform != service.PlatformOpenAI { + response.BadRequest(c, "Account is not an OpenAI account") + return + } + if h.accountTestService == nil { + response.InternalError(c, "OpenAI account service not configured") + return + } + + models, err := h.accountTestService.FetchOpenAIUpstreamModels(c.Request.Context(), account) + if err != nil { + if errors.Is(err, service.ErrOpenAIModelListUnsupported) { + response.BadRequest(c, err.Error()) + return + } + response.InternalError(c, "Failed to fetch OpenAI upstream model list: "+err.Error()) + return + } + response.Success(c, models) +} + // GetAvailableModels handles getting available models for an account // GET /api/v1/admin/accounts/:id/models func (h *AccountHandler) GetAvailableModels(c *gin.Context) { diff --git a/backend/internal/handler/admin/kiro_oauth_handler.go b/backend/internal/handler/admin/kiro_oauth_handler.go index fc6727b8..0b4dd4fb 100644 --- a/backend/internal/handler/admin/kiro_oauth_handler.go +++ b/backend/internal/handler/admin/kiro_oauth_handler.go @@ -129,6 +129,7 @@ func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) { type KiroImportTokenRequest struct { TokenJSON string `json:"token_json" binding:"required"` DeviceRegistrationJSON string `json:"device_registration_json"` + ProxyID *int64 `json:"proxy_id"` } func (h *KiroOAuthHandler) ImportToken(c *gin.Context) { @@ -140,6 +141,7 @@ func (h *KiroOAuthHandler) ImportToken(c *gin.Context) { tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{ TokenJSON: req.TokenJSON, DeviceRegistrationJSON: req.DeviceRegistrationJSON, + ProxyID: req.ProxyID, }) if err != nil { response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error()) diff --git a/backend/internal/pkg/kiro/oauth.go b/backend/internal/pkg/kiro/oauth.go index 2a6e1338..94615f49 100644 --- a/backend/internal/pkg/kiro/oauth.go +++ b/backend/internal/pkg/kiro/oauth.go @@ -424,6 +424,24 @@ func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*Token return &token, nil } +func ParseImportedRefreshToken(tokenJSON string) (*TokenData, bool, error) { + var token TokenData + if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil { + return nil, false, fmt.Errorf("failed to parse kiro token: %w", err) + } + token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod)) + if token.AuthMethod == "" { + token.AuthMethod = "social" + } + if token.Provider == "" && token.AuthMethod == "social" { + token.Provider = string(SocialProviderGoogle) + } + if strings.TrimSpace(token.RefreshToken) == "" || strings.TrimSpace(token.AccessToken) != "" { + return &token, false, nil + } + return &token, true, nil +} + func getOIDCEndpoint(region string) string { if strings.TrimSpace(oidcEndpointOverride) != "" { return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/") diff --git a/backend/internal/pkg/kiro/oauth_test.go b/backend/internal/pkg/kiro/oauth_test.go index b6b9b52d..dc601b65 100644 --- a/backend/internal/pkg/kiro/oauth_test.go +++ b/backend/internal/pkg/kiro/oauth_test.go @@ -54,3 +54,35 @@ func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) { t.Fatalf("fresh session should remain after pruning") } } + +func TestParseImportedRefreshTokenAcceptsRefreshTokenOnlyPayload(t *testing.T) { + token, refreshOnly, err := ParseImportedRefreshToken(`{"refreshToken":"rt","provider":"Google"}`) + if err != nil { + t.Fatalf("ParseImportedRefreshToken() error = %v", err) + } + if !refreshOnly { + t.Fatalf("refreshOnly = false, want true") + } + if token.RefreshToken != "rt" { + t.Fatalf("refresh token = %q, want rt", token.RefreshToken) + } + if token.Provider != "Google" { + t.Fatalf("provider = %q, want Google", token.Provider) + } + if token.AuthMethod != "social" { + t.Fatalf("auth method = %q, want social", token.AuthMethod) + } +} + +func TestParseImportedRefreshTokenKeepsFullTokenAsNonRefreshOnly(t *testing.T) { + token, refreshOnly, err := ParseImportedRefreshToken(`{"accessToken":"at","refreshToken":"rt"}`) + if err != nil { + t.Fatalf("ParseImportedRefreshToken() error = %v", err) + } + if refreshOnly { + t.Fatalf("refreshOnly = true, want false") + } + if token.Provider != "Google" { + t.Fatalf("provider = %q, want default Google", token.Provider) + } +} diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index 3b02c905..24bab839 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -408,7 +408,7 @@ LEFT JOIN users u ON u.id = ua.user_id LEFT JOIN user_affiliate_ledger ual ON ual.user_id = $1 AND ual.source_user_id = ua.user_id - AND ual.action = 'accrue' + AND ual.action IN ('accrue', 'signup_reward') WHERE ua.inviter_id = $1 GROUP BY ua.user_id, u.email, u.username, ua.created_at ORDER BY ua.created_at DESC diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 496eed7b..5f9df263 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -306,6 +306,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) + accounts.GET("/:id/openai/upstream-models", h.Admin.Account.GetOpenAIUpstreamModels) + accounts.GET("/:id/kiro/upstream-models", h.Admin.Account.GetKiroUpstreamModels) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.GET("/data", h.Admin.Account.ExportData) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 78c72a06..d010079c 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -67,6 +67,7 @@ type AccountTestService struct { accountRepo AccountRepository geminiTokenProvider *GeminiTokenProvider claudeTokenProvider *ClaudeTokenProvider + openAITokenProvider *OpenAITokenProvider kiroTokenProvider *KiroTokenProvider antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream @@ -79,6 +80,7 @@ func NewAccountTestService( accountRepo AccountRepository, geminiTokenProvider *GeminiTokenProvider, claudeTokenProvider *ClaudeTokenProvider, + openAITokenProvider *OpenAITokenProvider, kiroTokenProvider *KiroTokenProvider, antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, @@ -89,6 +91,7 @@ func NewAccountTestService( accountRepo: accountRepo, geminiTokenProvider: geminiTokenProvider, claudeTokenProvider: claudeTokenProvider, + openAITokenProvider: openAITokenProvider, kiroTokenProvider: kiroTokenProvider, antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, diff --git a/backend/internal/service/kiro_available_models.go b/backend/internal/service/kiro_available_models.go new file mode 100644 index 00000000..3385e437 --- /dev/null +++ b/backend/internal/service/kiro_available_models.go @@ -0,0 +1,201 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro" + "github.com/google/uuid" +) + +var ErrKiroModelListUnsupported = errors.New("kiro upstream model list requires an OAuth/access-token account") + +type kiroAvailableModelsResponse struct { + AvailableModels []kiroAvailableModelItem `json:"availableModels"` + AvailableModelsSnake []kiroAvailableModelItem `json:"available_models"` + Models []kiroAvailableModelItem `json:"models"` + NextToken string `json:"nextToken"` + NextTokenSnake string `json:"next_token"` +} + +type kiroAvailableModelItem struct { + ModelID string `json:"modelId"` + ModelIDSnake string `json:"model_id"` + ID string `json:"id"` + ModelName string `json:"modelName"` + ModelNameSnake string `json:"model_name"` + DisplayName string `json:"displayName"` + DisplayNameSnake string `json:"display_name"` + Name string `json:"name"` +} + +func (s *AccountTestService) FetchKiroUpstreamModels(ctx context.Context, account *Account) ([]kiropkg.Model, error) { + if account == nil { + return nil, errors.New("account is nil") + } + if account.Platform != PlatformKiro { + return nil, fmt.Errorf("not a kiro account") + } + + token := strings.TrimSpace(account.GetCredential("access_token")) + if account.Type == AccountTypeOAuth { + if s == nil || s.kiroTokenProvider == nil { + return nil, errors.New("kiro token provider not configured") + } + accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get kiro access token failed: %w", err) + } + token = strings.TrimSpace(accessToken) + } + if token == "" { + return nil, ErrKiroModelListUnsupported + } + + return requestKiroAvailableModels(ctx, account, kiroAPIRegion(account), strings.TrimSpace(account.GetCredential("profile_arn")), token) +} + +func requestKiroAvailableModels(ctx context.Context, account *Account, region, profileArn, token string) ([]kiropkg.Model, error) { + endpoint := resolveKiroRuntimeEndpoint(region) + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: accountProxyURL(account), + Timeout: 30 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: isLoopbackEndpoint(endpoint), + }) + if err != nil { + return nil, fmt.Errorf("create kiro model list client failed: %w", err) + } + + var all []kiroAvailableModelItem + nextToken := "" + for { + resp, err := requestKiroAvailableModelsPage(ctx, client, account, endpoint, profileArn, token, nextToken) + if err != nil { + return nil, err + } + all = append(all, resp.items()...) + + nextToken = strings.TrimSpace(resp.NextToken) + if nextToken == "" { + nextToken = strings.TrimSpace(resp.NextTokenSnake) + } + if nextToken == "" { + break + } + } + + return mapKiroAvailableModels(all), nil +} + +func requestKiroAvailableModelsPage(ctx context.Context, client *http.Client, account *Account, endpoint, profileArn, token, nextToken string) (*kiroAvailableModelsResponse, error) { + reqURL, err := url.Parse(endpoint + "/ListAvailableModels") + if err != nil { + return nil, fmt.Errorf("build kiro model list url failed: %w", err) + } + q := reqURL.Query() + q.Set("origin", kiroUsageOrigin) + q.Set("maxResults", "50") + if profileArn != "" { + q.Set("profileArn", profileArn) + } + if nextToken != "" { + q.Set("nextToken", nextToken) + } + reqURL.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("create kiro model list request failed: %w", err) + } + applyKiroModelListHeaders(req, account, token) + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("kiro model list request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read kiro model list response failed: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))} + } + + var parsed kiroAvailableModelsResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("decode kiro model list response failed: %w", err) + } + return &parsed, nil +} + +func applyKiroModelListHeaders(req *http.Request, account *Account, token string) { + if req == nil { + return + } + accountKey := buildKiroAccountKey(account) + machineID := buildKiroMachineID(account) + req.Header.Set("Accept", "*/*") + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID)) + req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID)) + req.Header.Set("x-amzn-codewhisperer-optout", "true") + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString()) + if account != nil { + applyKiroConditionalHeaders(req, account) + } +} + +func (r *kiroAvailableModelsResponse) items() []kiroAvailableModelItem { + if r == nil { + return nil + } + switch { + case len(r.AvailableModels) > 0: + return r.AvailableModels + case len(r.AvailableModelsSnake) > 0: + return r.AvailableModelsSnake + default: + return r.Models + } +} + +func mapKiroAvailableModels(items []kiroAvailableModelItem) []kiropkg.Model { + seen := make(map[string]struct{}, len(items)) + models := make([]kiropkg.Model, 0, len(items)) + for _, item := range items { + id := firstNonEmptyKiroModelField(item.ModelID, item.ModelIDSnake, item.ID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + displayName := firstNonEmptyKiroModelField(item.ModelName, item.ModelNameSnake, item.DisplayName, item.DisplayNameSnake, item.Name, id) + models = append(models, kiropkg.Model{ID: id, Type: "model", DisplayName: displayName}) + } + sort.Slice(models, func(i, j int) bool { return models[i].ID < models[j].ID }) + return models +} + +func firstNonEmptyKiroModelField(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/backend/internal/service/kiro_oauth_service.go b/backend/internal/service/kiro_oauth_service.go index 1b86639a..39ad43dd 100644 --- a/backend/internal/service/kiro_oauth_service.go +++ b/backend/internal/service/kiro_oauth_service.go @@ -96,6 +96,7 @@ type KiroRefreshTokenInput struct { type KiroImportTokenInput struct { TokenJSON string DeviceRegistrationJSON string + ProxyID *int64 } func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) { @@ -284,6 +285,28 @@ func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Acc } func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) { + tokenFromRefresh, refreshOnly, err := kiropkg.ParseImportedRefreshToken(input.TokenJSON) + if err != nil { + return nil, err + } + if refreshOnly { + token, err := s.RefreshToken(context.Background(), &KiroRefreshTokenInput{ + RefreshToken: tokenFromRefresh.RefreshToken, + AuthMethod: tokenFromRefresh.AuthMethod, + Provider: tokenFromRefresh.Provider, + ClientID: tokenFromRefresh.ClientID, + ClientSecret: tokenFromRefresh.ClientSecret, + StartURL: tokenFromRefresh.StartURL, + Region: tokenFromRefresh.Region, + ProfileArn: tokenFromRefresh.ProfileArn, + ProxyID: input.ProxyID, + }) + if err != nil { + return nil, err + } + return token, nil + } + token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON) if err != nil { return nil, err diff --git a/backend/internal/service/openai_available_models.go b/backend/internal/service/openai_available_models.go new file mode 100644 index 00000000..fb7cbc64 --- /dev/null +++ b/backend/internal/service/openai_available_models.go @@ -0,0 +1,161 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai" +) + +var ErrOpenAIModelListUnsupported = errors.New("openai upstream model list requires an OAuth access token or API key") + +type openAIModelsResponse struct { + Data []openaiModelListItem `json:"data"` +} + +type openaiModelListItem struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + +func (s *AccountTestService) FetchOpenAIUpstreamModels(ctx context.Context, account *Account) ([]openaipkg.Model, error) { + if account == nil { + return nil, errors.New("account is nil") + } + if account.Platform != PlatformOpenAI { + return nil, fmt.Errorf("not an openai account") + } + + token, baseURL, err := s.resolveOpenAIModelListAuth(ctx, account) + if err != nil { + return nil, err + } + return requestOpenAIAvailableModels(ctx, account, baseURL, token) +} + +func (s *AccountTestService) resolveOpenAIModelListAuth(ctx context.Context, account *Account) (token, baseURL string, err error) { + if account.IsOpenAIOAuth() { + if s == nil || s.openAITokenProvider == nil { + token = strings.TrimSpace(account.GetOpenAIAccessToken()) + } else { + token, err = s.openAITokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", fmt.Errorf("get openai access token failed: %w", err) + } + } + if strings.TrimSpace(token) == "" { + return "", "", ErrOpenAIModelListUnsupported + } + return token, "https://api.openai.com", nil + } + + if account.IsOpenAIApiKey() { + token = strings.TrimSpace(account.GetOpenAIApiKey()) + if token == "" { + return "", "", ErrOpenAIModelListUnsupported + } + baseURL = account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + if s != nil { + normalized, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", "", fmt.Errorf("invalid base_url: %w", err) + } + baseURL = normalized + } + return token, baseURL, nil + } + + return "", "", ErrOpenAIModelListUnsupported +} + +func requestOpenAIAvailableModels(ctx context.Context, account *Account, baseURL, token string) ([]openaipkg.Model, error) { + baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/") + if baseURL == "" { + baseURL = "https://api.openai.com" + } + modelsURL := baseURL + if !strings.HasSuffix(modelsURL, "/models") { + if strings.HasSuffix(modelsURL, "/v1") { + modelsURL += "/models" + } else { + modelsURL += "/v1/models" + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("create openai model list request failed: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + if userAgent := strings.TrimSpace(account.GetOpenAIUserAgent()); userAgent != "" { + req.Header.Set("User-Agent", userAgent) + } + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: accountProxyURL(account), + Timeout: 30 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: isLoopbackEndpoint(modelsURL), + }) + if err != nil { + return nil, fmt.Errorf("create openai model list client failed: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("openai model list request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read openai model list response failed: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))} + } + + var parsed openAIModelsResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("decode openai model list response failed: %w", err) + } + + models := make([]openaipkg.Model, 0, len(parsed.Data)) + for _, item := range parsed.Data { + id := strings.TrimSpace(item.ID) + if id == "" { + continue + } + models = append(models, openaipkg.Model{ + ID: id, + Object: firstNonEmptyOpenAIModelField(item.Object, "model"), + Created: item.Created, + OwnedBy: item.OwnedBy, + Type: "model", + DisplayName: id, + }) + } + return models, nil +} + +func firstNonEmptyOpenAIModelField(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index 5a84e37a..885d9b68 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -227,6 +227,116 @@ func TestForwardAsChatCompletions_RequestErrorRetriesBeforeSuccess(t *testing.T) require.Contains(t, events[0].Message, "connection reset by peer") } +func TestForwardAsChatCompletions_ClosedNetworkConnectionRetriesBeforeSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_closed_network_retry","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &sequentialHTTPUpstreamRecorder{ + errs: []error{ + errors.New("Post \"https://chatgpt.com/backend-api/codex/responses\": use of closed network connection"), + nil, + }, + responses: []*http.Response{{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_closed_network_retry"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }}, + } + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, rec.Code) + require.Len(t, upstream.requests, 2) + + rawEvents, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := rawEvents.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, "request_error", events[0].Kind) + require.Contains(t, events[0].Message, "use of closed network connection") +} + +func TestForwardAsChatCompletions_TLSBadRecordMACRetriesBeforeSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_tls_retry","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &sequentialHTTPUpstreamRecorder{ + errs: []error{ + errors.New("Post \"https://chatgpt.com/backend-api/codex/responses\": local error: tls: bad record MAC"), + nil, + }, + responses: []*http.Response{{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_tls_retry"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }}, + } + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, rec.Code) + require.Len(t, upstream.requests, 2) + + rawEvents, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := rawEvents.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, "request_error", events[0].Kind) + require.Contains(t, strings.ToLower(events[0].Message), "tls: bad record mac") +} + func TestForwardAsChatCompletions_RequestErrorExhaustionReturnsFailover(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_http_retry.go b/backend/internal/service/openai_http_retry.go index 6004ecba..eb936978 100644 --- a/backend/internal/service/openai_http_retry.go +++ b/backend/internal/service/openai_http_retry.go @@ -121,8 +121,10 @@ func isRetryableOpenAIHTTPRequestError(err error) bool { "connection refused", "unexpected eof", "server closed idle connection", + "use of closed network connection", "broken pipe", "connection aborted", + "tls: bad record mac", "tls: use of closed connection", "http2: client connection lost", } diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 08a0840f..b1aa0aae 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -446,6 +446,26 @@ export async function getAvailableModels(id: number): Promise { return data } +/** + * Get Kiro models from the upstream Kiro ListAvailableModels API using the account proxy. + * @param id - Account ID + * @returns List of upstream Kiro models + */ +export async function getKiroUpstreamModels(id: number): Promise { + const { data } = await apiClient.get(`/admin/accounts/${id}/kiro/upstream-models`) + return data +} + +/** + * Get OpenAI models from the upstream /v1/models API using the account proxy. + * @param id - Account ID + * @returns List of upstream OpenAI models + */ +export async function getOpenAIUpstreamModels(id: number): Promise { + const { data } = await apiClient.get(`/admin/accounts/${id}/openai/upstream-models`) + return data +} + export interface CRSPreviewAccount { crs_account_id: string kind: string @@ -667,6 +687,8 @@ export const accountsAPI = { resetTempUnschedulable, setSchedulable, getAvailableModels, + getOpenAIUpstreamModels, + getKiroUpstreamModels, generateAuthUrl, exchangeCode, refreshOpenAIToken, diff --git a/frontend/src/api/admin/kiro.ts b/frontend/src/api/admin/kiro.ts index 60751b3c..795b3b82 100644 --- a/frontend/src/api/admin/kiro.ts +++ b/frontend/src/api/admin/kiro.ts @@ -75,6 +75,7 @@ export async function refreshToken(payload: { export async function importToken(payload: { token_json: string device_registration_json?: string + proxy_id?: number }): Promise { const { data } = await apiClient.post('/admin/kiro/oauth/import-token', payload) return data diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 90064e05..1f619429 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -3217,7 +3217,7 @@
- +

{{ t('admin.accounts.oauth.kiro.tokenJsonHint') }}

@@ -5811,7 +5811,8 @@ const handleKiroImport = async () => { const tokenInfo = await kiroOAuth.importToken( kiroTokenJson.value, - kiroDeviceRegistrationJson.value || undefined + kiroDeviceRegistrationJson.value || undefined, + form.proxy_id ) if (!tokenInfo) return diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 93ebb575..30deefbc 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -74,7 +74,18 @@
- +
+ + +

@@ -160,7 +171,19 @@

- +
+ + +
- +
+ + + +
([]) const openAICompactModelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const openAIModelListLoading = ref(false) +const kiroModelListLoading = ref(false) const DEFAULT_POOL_MODE_RETRY_COUNT = 3 const MAX_POOL_MODE_RETRY_COUNT = 10 const poolModeEnabled = ref(false) @@ -2478,6 +2525,75 @@ const loadDefaultKiroModelMappings = () => { }) } +const extractModelID = (model: unknown): string => { + if (typeof model === 'string') return model.trim() + if (!model || typeof model !== 'object') return '' + + const record = model as Record + const rawID = record.modelId ?? record.model_id ?? record.id ?? record.name ?? record.model + return typeof rawID === 'string' ? rawID.trim() : '' +} + +const handleFetchKiroModelMappings = async () => { + if (!props.account || props.account.platform !== 'kiro') return + + kiroModelListLoading.value = true + try { + const models = await adminAPI.accounts.getKiroUpstreamModels(props.account.id) + const modelIDs = Array.from( + new Set( + models + .map(extractModelID) + .filter((id) => id.length > 0) + ) + ) + + if (modelIDs.length === 0) { + appStore.showError(t('admin.accounts.noModelsFetched')) + return + } + + modelRestrictionMode.value = 'mapping' + modelMappings.value = modelIDs.map((model) => ({ from: model, to: model })) + allowedModels.value = [] + appStore.showSuccess(t('admin.accounts.modelListApplied', { count: modelIDs.length })) + } catch (error: any) { + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToFetchModelList')) + } finally { + kiroModelListLoading.value = false + } +} + +const handleFetchOpenAIModels = async () => { + if (!props.account || props.account.platform !== 'openai') return + + openAIModelListLoading.value = true + try { + const models = await adminAPI.accounts.getOpenAIUpstreamModels(props.account.id) + const modelIDs = Array.from( + new Set( + models + .map(extractModelID) + .filter((id) => id.length > 0) + ) + ) + + if (modelIDs.length === 0) { + appStore.showError(t('admin.accounts.noModelsFetched')) + return + } + + modelRestrictionMode.value = 'whitelist' + allowedModels.value = modelIDs + modelMappings.value = [] + appStore.showSuccess(t('admin.accounts.modelListApplied', { count: modelIDs.length })) + } catch (error: any) { + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToFetchModelList')) + } finally { + openAIModelListLoading.value = false + } +} + const showMixedChannelWarning = ref(false) const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>( null diff --git a/frontend/src/components/admin/account/ReAuthAccountModal.vue b/frontend/src/components/admin/account/ReAuthAccountModal.vue index c5d94de1..bac16faa 100644 --- a/frontend/src/components/admin/account/ReAuthAccountModal.vue +++ b/frontend/src/components/admin/account/ReAuthAccountModal.vue @@ -775,7 +775,8 @@ const handleKiroImport = async () => { const tokenInfo = await kiroOAuth.importToken( kiroTokenJson.value, - kiroDeviceRegistrationJson.value || undefined + kiroDeviceRegistrationJson.value || undefined, + props.account.proxy_id ) if (!tokenInfo) return diff --git a/frontend/src/composables/useKiroOAuth.ts b/frontend/src/composables/useKiroOAuth.ts index 010f231d..3f2893b1 100644 --- a/frontend/src/composables/useKiroOAuth.ts +++ b/frontend/src/composables/useKiroOAuth.ts @@ -141,14 +141,16 @@ export function useKiroOAuth() { const importToken = async ( tokenJSON: string, - deviceRegistrationJSON?: string + deviceRegistrationJSON?: string, + proxyId?: number | null ): Promise => { loading.value = true error.value = '' try { return await adminAPI.kiro.importToken({ token_json: tokenJSON, - device_registration_json: deviceRegistrationJSON + device_registration_json: deviceRegistrationJSON, + proxy_id: proxyId || undefined }) } catch (err: any) { error.value = err.response?.data?.detail || t('admin.accounts.oauth.authFailed') diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index a939cbc3..92c0a005 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1028,7 +1028,7 @@ export default { columns: { email: 'Email', username: 'Username', - rebate: 'Rebate', + rebate: 'Total Earnings', joinedAt: 'Joined At' } }, @@ -3234,6 +3234,11 @@ export default { requestModel: 'Request model', actualModel: 'Actual model', addMapping: 'Add Mapping', + fetchModelList: 'Fetch Models', + fetchUpstreamModelList: 'Fetch Upstream Models', + modelListApplied: 'Replaced with {count} model(s)', + noModelsFetched: 'No available models were returned', + failedToFetchModelList: 'Failed to fetch model list', mappingExists: 'Mapping for {model} already exists', wildcardOnlyAtEnd: 'Wildcard * can only be at the end', targetNoWildcard: 'Target model cannot contain wildcard *', @@ -3643,7 +3648,7 @@ export default { regionLabel: 'Region', regionPlaceholder: 'us-east-1', tokenJsonLabel: 'Kiro Token JSON', - tokenJsonHint: 'Sign in through Kiro IDE first, then paste the contents of `~/.aws/sso/cache/kiro-auth-token.json` here.', + tokenJsonHint: 'Supports full Kiro token JSON, or a refresh-token-only payload with refreshToken + provider. The server will refresh it into usable credentials.', deviceRegistrationLabel: 'Device Registration JSON', deviceRegistrationHint: 'Optional. Only needed when the token file does not include full client details and only has `clientIdHash`.', importAndUpdate: 'Import and Update' diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index df163a07..e40691b0 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1032,7 +1032,7 @@ export default { columns: { email: '邮箱', username: '用户名', - rebate: '返利明细', + rebate: '累计收益', joinedAt: '注册时间' } }, @@ -3391,6 +3391,11 @@ export default { requestModel: '请求模型', actualModel: '实际模型', addMapping: '添加映射', + fetchModelList: '获取模型列表', + fetchUpstreamModelList: '从上游获取模型', + modelListApplied: '已覆盖为 {count} 个模型', + noModelsFetched: '未获取到可用模型', + failedToFetchModelList: '获取模型列表失败', mappingExists: '模型 {model} 的映射已存在', wildcardOnlyAtEnd: '通配符 * 只能放在末尾', targetNoWildcard: '目标模型不能包含通配符 *', @@ -3787,7 +3792,7 @@ export default { regionLabel: 'Region', regionPlaceholder: 'us-east-1', tokenJsonLabel: 'Kiro Token JSON', - tokenJsonHint: '先在 Kiro IDE 完成登录,再粘贴 `~/.aws/sso/cache/kiro-auth-token.json` 的内容。', + tokenJsonHint: '支持完整 Kiro token JSON,也支持仅粘贴 refreshToken + provider 格式,系统会自动刷新为可用凭据。', deviceRegistrationLabel: 'Device Registration JSON', deviceRegistrationHint: '可选。只有 token 文件里缺少完整客户端信息、只剩 `clientIdHash` 时才需要补充。', importAndUpdate: '导入并更新'