release: prepare v0.1.134
This commit is contained in:
@@ -1 +1 @@
|
||||
0.1.133
|
||||
0.1.134
|
||||
|
||||
@@ -146,6 +146,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
kiroOAuthService := service.NewKiroOAuthService(proxyRepository)
|
||||
kiroTokenProvider := service.ProvideKiroTokenProvider(accountRepository, geminiTokenCache, kiroOAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
@@ -154,7 +155,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, openAITokenProvider, kiroTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
@@ -189,7 +190,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, kiroTokenProvider, kiroCooldownStore, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
|
||||
@@ -1862,6 +1862,76 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetKiroUpstreamModels handles getting upstream Kiro models with the account credentials/proxy.
|
||||
// GET /api/v1/admin/accounts/:id/kiro/upstream-models
|
||||
func (h *AccountHandler) GetKiroUpstreamModels(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
if account.Platform != service.PlatformKiro {
|
||||
response.BadRequest(c, "Account is not a Kiro account")
|
||||
return
|
||||
}
|
||||
if h.accountTestService == nil {
|
||||
response.InternalError(c, "Kiro account service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
models, err := h.accountTestService.FetchKiroUpstreamModels(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrKiroModelListUnsupported) {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.InternalError(c, "Failed to fetch Kiro upstream model list: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// GetOpenAIUpstreamModels handles getting upstream OpenAI models with the account credentials/proxy.
|
||||
// GET /api/v1/admin/accounts/:id/openai/upstream-models
|
||||
func (h *AccountHandler) GetOpenAIUpstreamModels(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
if account.Platform != service.PlatformOpenAI {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
return
|
||||
}
|
||||
if h.accountTestService == nil {
|
||||
response.InternalError(c, "OpenAI account service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
models, err := h.accountTestService.FetchOpenAIUpstreamModels(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOpenAIModelListUnsupported) {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.InternalError(c, "Failed to fetch OpenAI upstream model list: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// GetAvailableModels handles getting available models for an account
|
||||
// GET /api/v1/admin/accounts/:id/models
|
||||
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
@@ -129,6 +129,7 @@ func (h *KiroOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
type KiroImportTokenRequest struct {
|
||||
TokenJSON string `json:"token_json" binding:"required"`
|
||||
DeviceRegistrationJSON string `json:"device_registration_json"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
|
||||
@@ -140,6 +141,7 @@ func (h *KiroOAuthHandler) ImportToken(c *gin.Context) {
|
||||
tokenInfo, err := h.kiroOAuthService.ImportToken(&service.KiroImportTokenInput{
|
||||
TokenJSON: req.TokenJSON,
|
||||
DeviceRegistrationJSON: req.DeviceRegistrationJSON,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "导入 Kiro Token 失败: "+err.Error())
|
||||
|
||||
@@ -424,6 +424,24 @@ func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*Token
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func ParseImportedRefreshToken(tokenJSON string) (*TokenData, bool, error) {
|
||||
var token TokenData
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse kiro token: %w", err)
|
||||
}
|
||||
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
||||
if token.AuthMethod == "" {
|
||||
token.AuthMethod = "social"
|
||||
}
|
||||
if token.Provider == "" && token.AuthMethod == "social" {
|
||||
token.Provider = string(SocialProviderGoogle)
|
||||
}
|
||||
if strings.TrimSpace(token.RefreshToken) == "" || strings.TrimSpace(token.AccessToken) != "" {
|
||||
return &token, false, nil
|
||||
}
|
||||
return &token, true, nil
|
||||
}
|
||||
|
||||
func getOIDCEndpoint(region string) string {
|
||||
if strings.TrimSpace(oidcEndpointOverride) != "" {
|
||||
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
|
||||
|
||||
@@ -54,3 +54,35 @@ func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
|
||||
t.Fatalf("fresh session should remain after pruning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImportedRefreshTokenAcceptsRefreshTokenOnlyPayload(t *testing.T) {
|
||||
token, refreshOnly, err := ParseImportedRefreshToken(`{"refreshToken":"rt","provider":"Google"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseImportedRefreshToken() error = %v", err)
|
||||
}
|
||||
if !refreshOnly {
|
||||
t.Fatalf("refreshOnly = false, want true")
|
||||
}
|
||||
if token.RefreshToken != "rt" {
|
||||
t.Fatalf("refresh token = %q, want rt", token.RefreshToken)
|
||||
}
|
||||
if token.Provider != "Google" {
|
||||
t.Fatalf("provider = %q, want Google", token.Provider)
|
||||
}
|
||||
if token.AuthMethod != "social" {
|
||||
t.Fatalf("auth method = %q, want social", token.AuthMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImportedRefreshTokenKeepsFullTokenAsNonRefreshOnly(t *testing.T) {
|
||||
token, refreshOnly, err := ParseImportedRefreshToken(`{"accessToken":"at","refreshToken":"rt"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseImportedRefreshToken() error = %v", err)
|
||||
}
|
||||
if refreshOnly {
|
||||
t.Fatalf("refreshOnly = true, want false")
|
||||
}
|
||||
if token.Provider != "Google" {
|
||||
t.Fatalf("provider = %q, want default Google", token.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,7 +408,7 @@ LEFT JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = $1
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
AND ual.action IN ('accrue', 'signup_reward')
|
||||
WHERE ua.inviter_id = $1
|
||||
GROUP BY ua.user_id, u.email, u.username, ua.created_at
|
||||
ORDER BY ua.created_at DESC
|
||||
|
||||
@@ -306,6 +306,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/openai/upstream-models", h.Admin.Account.GetOpenAIUpstreamModels)
|
||||
accounts.GET("/:id/kiro/upstream-models", h.Admin.Account.GetKiroUpstreamModels)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.GET("/data", h.Admin.Account.ExportData)
|
||||
|
||||
@@ -67,6 +67,7 @@ type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
kiroTokenProvider *KiroTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
@@ -79,6 +80,7 @@ func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
kiroTokenProvider *KiroTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
@@ -89,6 +91,7 @@ func NewAccountTestService(
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
kiroTokenProvider: kiroTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var ErrKiroModelListUnsupported = errors.New("kiro upstream model list requires an OAuth/access-token account")
|
||||
|
||||
type kiroAvailableModelsResponse struct {
|
||||
AvailableModels []kiroAvailableModelItem `json:"availableModels"`
|
||||
AvailableModelsSnake []kiroAvailableModelItem `json:"available_models"`
|
||||
Models []kiroAvailableModelItem `json:"models"`
|
||||
NextToken string `json:"nextToken"`
|
||||
NextTokenSnake string `json:"next_token"`
|
||||
}
|
||||
|
||||
type kiroAvailableModelItem struct {
|
||||
ModelID string `json:"modelId"`
|
||||
ModelIDSnake string `json:"model_id"`
|
||||
ID string `json:"id"`
|
||||
ModelName string `json:"modelName"`
|
||||
ModelNameSnake string `json:"model_name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
DisplayNameSnake string `json:"display_name"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *AccountTestService) FetchKiroUpstreamModels(ctx context.Context, account *Account) ([]kiropkg.Model, error) {
|
||||
if account == nil {
|
||||
return nil, errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformKiro {
|
||||
return nil, fmt.Errorf("not a kiro account")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if s == nil || s.kiroTokenProvider == nil {
|
||||
return nil, errors.New("kiro token provider not configured")
|
||||
}
|
||||
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get kiro access token failed: %w", err)
|
||||
}
|
||||
token = strings.TrimSpace(accessToken)
|
||||
}
|
||||
if token == "" {
|
||||
return nil, ErrKiroModelListUnsupported
|
||||
}
|
||||
|
||||
return requestKiroAvailableModels(ctx, account, kiroAPIRegion(account), strings.TrimSpace(account.GetCredential("profile_arn")), token)
|
||||
}
|
||||
|
||||
func requestKiroAvailableModels(ctx context.Context, account *Account, region, profileArn, token string) ([]kiropkg.Model, error) {
|
||||
endpoint := resolveKiroRuntimeEndpoint(region)
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: accountProxyURL(account),
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kiro model list client failed: %w", err)
|
||||
}
|
||||
|
||||
var all []kiroAvailableModelItem
|
||||
nextToken := ""
|
||||
for {
|
||||
resp, err := requestKiroAvailableModelsPage(ctx, client, account, endpoint, profileArn, token, nextToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, resp.items()...)
|
||||
|
||||
nextToken = strings.TrimSpace(resp.NextToken)
|
||||
if nextToken == "" {
|
||||
nextToken = strings.TrimSpace(resp.NextTokenSnake)
|
||||
}
|
||||
if nextToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return mapKiroAvailableModels(all), nil
|
||||
}
|
||||
|
||||
func requestKiroAvailableModelsPage(ctx context.Context, client *http.Client, account *Account, endpoint, profileArn, token, nextToken string) (*kiroAvailableModelsResponse, error) {
|
||||
reqURL, err := url.Parse(endpoint + "/ListAvailableModels")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build kiro model list url failed: %w", err)
|
||||
}
|
||||
q := reqURL.Query()
|
||||
q.Set("origin", kiroUsageOrigin)
|
||||
q.Set("maxResults", "50")
|
||||
if profileArn != "" {
|
||||
q.Set("profileArn", profileArn)
|
||||
}
|
||||
if nextToken != "" {
|
||||
q.Set("nextToken", nextToken)
|
||||
}
|
||||
reqURL.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kiro model list request failed: %w", err)
|
||||
}
|
||||
applyKiroModelListHeaders(req, account, token)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro model list request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read kiro model list response failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
|
||||
}
|
||||
|
||||
var parsed kiroAvailableModelsResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("decode kiro model list response failed: %w", err)
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func applyKiroModelListHeaders(req *http.Request, account *Account, token string) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
accountKey := buildKiroAccountKey(account)
|
||||
machineID := buildKiroMachineID(account)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
|
||||
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
|
||||
if account != nil {
|
||||
applyKiroConditionalHeaders(req, account)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *kiroAvailableModelsResponse) items() []kiroAvailableModelItem {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case len(r.AvailableModels) > 0:
|
||||
return r.AvailableModels
|
||||
case len(r.AvailableModelsSnake) > 0:
|
||||
return r.AvailableModelsSnake
|
||||
default:
|
||||
return r.Models
|
||||
}
|
||||
}
|
||||
|
||||
func mapKiroAvailableModels(items []kiroAvailableModelItem) []kiropkg.Model {
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
models := make([]kiropkg.Model, 0, len(items))
|
||||
for _, item := range items {
|
||||
id := firstNonEmptyKiroModelField(item.ModelID, item.ModelIDSnake, item.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
displayName := firstNonEmptyKiroModelField(item.ModelName, item.ModelNameSnake, item.DisplayName, item.DisplayNameSnake, item.Name, id)
|
||||
models = append(models, kiropkg.Model{ID: id, Type: "model", DisplayName: displayName})
|
||||
}
|
||||
sort.Slice(models, func(i, j int) bool { return models[i].ID < models[j].ID })
|
||||
return models
|
||||
}
|
||||
|
||||
func firstNonEmptyKiroModelField(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -96,6 +96,7 @@ type KiroRefreshTokenInput struct {
|
||||
type KiroImportTokenInput struct {
|
||||
TokenJSON string
|
||||
DeviceRegistrationJSON string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
|
||||
@@ -284,6 +285,28 @@ func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Acc
|
||||
}
|
||||
|
||||
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
|
||||
tokenFromRefresh, refreshOnly, err := kiropkg.ParseImportedRefreshToken(input.TokenJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshOnly {
|
||||
token, err := s.RefreshToken(context.Background(), &KiroRefreshTokenInput{
|
||||
RefreshToken: tokenFromRefresh.RefreshToken,
|
||||
AuthMethod: tokenFromRefresh.AuthMethod,
|
||||
Provider: tokenFromRefresh.Provider,
|
||||
ClientID: tokenFromRefresh.ClientID,
|
||||
ClientSecret: tokenFromRefresh.ClientSecret,
|
||||
StartURL: tokenFromRefresh.StartURL,
|
||||
Region: tokenFromRefresh.Region,
|
||||
ProfileArn: tokenFromRefresh.ProfileArn,
|
||||
ProxyID: input.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var ErrOpenAIModelListUnsupported = errors.New("openai upstream model list requires an OAuth access token or API key")
|
||||
|
||||
type openAIModelsResponse struct {
|
||||
Data []openaiModelListItem `json:"data"`
|
||||
}
|
||||
|
||||
type openaiModelListItem struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
func (s *AccountTestService) FetchOpenAIUpstreamModels(ctx context.Context, account *Account) ([]openaipkg.Model, error) {
|
||||
if account == nil {
|
||||
return nil, errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI {
|
||||
return nil, fmt.Errorf("not an openai account")
|
||||
}
|
||||
|
||||
token, baseURL, err := s.resolveOpenAIModelListAuth(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return requestOpenAIAvailableModels(ctx, account, baseURL, token)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) resolveOpenAIModelListAuth(ctx context.Context, account *Account) (token, baseURL string, err error) {
|
||||
if account.IsOpenAIOAuth() {
|
||||
if s == nil || s.openAITokenProvider == nil {
|
||||
token = strings.TrimSpace(account.GetOpenAIAccessToken())
|
||||
} else {
|
||||
token, err = s.openAITokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get openai access token failed: %w", err)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return "", "", ErrOpenAIModelListUnsupported
|
||||
}
|
||||
return token, "https://api.openai.com", nil
|
||||
}
|
||||
|
||||
if account.IsOpenAIApiKey() {
|
||||
token = strings.TrimSpace(account.GetOpenAIApiKey())
|
||||
if token == "" {
|
||||
return "", "", ErrOpenAIModelListUnsupported
|
||||
}
|
||||
baseURL = account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
if s != nil {
|
||||
normalized, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
baseURL = normalized
|
||||
}
|
||||
return token, baseURL, nil
|
||||
}
|
||||
|
||||
return "", "", ErrOpenAIModelListUnsupported
|
||||
}
|
||||
|
||||
func requestOpenAIAvailableModels(ctx context.Context, account *Account, baseURL, token string) ([]openaipkg.Model, error) {
|
||||
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
modelsURL := baseURL
|
||||
if !strings.HasSuffix(modelsURL, "/models") {
|
||||
if strings.HasSuffix(modelsURL, "/v1") {
|
||||
modelsURL += "/models"
|
||||
} else {
|
||||
modelsURL += "/v1/models"
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai model list request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
if userAgent := strings.TrimSpace(account.GetOpenAIUserAgent()); userAgent != "" {
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
}
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: accountProxyURL(account),
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: isLoopbackEndpoint(modelsURL),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai model list client failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai model list request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read openai model list response failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
|
||||
}
|
||||
|
||||
var parsed openAIModelsResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("decode openai model list response failed: %w", err)
|
||||
}
|
||||
|
||||
models := make([]openaipkg.Model, 0, len(parsed.Data))
|
||||
for _, item := range parsed.Data {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, openaipkg.Model{
|
||||
ID: id,
|
||||
Object: firstNonEmptyOpenAIModelField(item.Object, "model"),
|
||||
Created: item.Created,
|
||||
OwnedBy: item.OwnedBy,
|
||||
Type: "model",
|
||||
DisplayName: id,
|
||||
})
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func firstNonEmptyOpenAIModelField(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -227,6 +227,116 @@ func TestForwardAsChatCompletions_RequestErrorRetriesBeforeSuccess(t *testing.T)
|
||||
require.Contains(t, events[0].Message, "connection reset by peer")
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_ClosedNetworkConnectionRetriesBeforeSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_closed_network_retry","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &sequentialHTTPUpstreamRecorder{
|
||||
errs: []error{
|
||||
errors.New("Post \"https://chatgpt.com/backend-api/codex/responses\": use of closed network connection"),
|
||||
nil,
|
||||
},
|
||||
responses: []*http.Response{{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_closed_network_retry"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Len(t, upstream.requests, 2)
|
||||
|
||||
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, "request_error", events[0].Kind)
|
||||
require.Contains(t, events[0].Message, "use of closed network connection")
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_TLSBadRecordMACRetriesBeforeSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_tls_retry","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &sequentialHTTPUpstreamRecorder{
|
||||
errs: []error{
|
||||
errors.New("Post \"https://chatgpt.com/backend-api/codex/responses\": local error: tls: bad record MAC"),
|
||||
nil,
|
||||
},
|
||||
responses: []*http.Response{{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_tls_retry"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Len(t, upstream.requests, 2)
|
||||
|
||||
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, "request_error", events[0].Kind)
|
||||
require.Contains(t, strings.ToLower(events[0].Message), "tls: bad record mac")
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_RequestErrorExhaustionReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -121,8 +121,10 @@ func isRetryableOpenAIHTTPRequestError(err error) bool {
|
||||
"connection refused",
|
||||
"unexpected eof",
|
||||
"server closed idle connection",
|
||||
"use of closed network connection",
|
||||
"broken pipe",
|
||||
"connection aborted",
|
||||
"tls: bad record mac",
|
||||
"tls: use of closed connection",
|
||||
"http2: client connection lost",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user