revert: completely remove all Sora functionality

This commit is contained in:
erio
2026-04-05 17:11:01 +08:00
parent dbb248df52
commit 62e80c602d
136 changed files with 256 additions and 24221 deletions
+1 -2
View File
@@ -28,8 +28,7 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora'
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
// FindByExtraField 根据 extra 字段中的键值对查找账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
@@ -13,18 +13,14 @@ import (
"log"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
)
// TestEvent represents a SSE event for account testing
@@ -71,13 +62,8 @@ type AccountTestService struct {
httpUpstream HTTPUpstream
cfg *config.Config
tlsFPProfileService *TLSFingerprintProfileService
soraTestGuardMu sync.Mutex
soraTestLastRun map[int64]time.Time
soraTestCooldown time.Duration
}
const defaultSoraTestCooldown = 10 * time.Second
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
@@ -94,8 +80,6 @@ func NewAccountTestService(
httpUpstream: httpUpstream,
cfg: cfg,
tlsFPProfileService: tlsFPProfileService,
soraTestLastRun: make(map[int64]time.Time),
soraTestCooldown: defaultSoraTestCooldown,
}
}
@@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt)
}
if account.Platform == PlatformSora {
return s.testSoraAccountConnection(c, account)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -634,697 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
type soraProbeStep struct {
Name string `json:"name"`
Status string `json:"status"`
HTTPStatus int `json:"http_status,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
type soraProbeSummary struct {
Status string `json:"status"`
Steps []soraProbeStep `json:"steps"`
}
type soraProbeRecorder struct {
steps []soraProbeStep
}
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
r.steps = append(r.steps, soraProbeStep{
Name: name,
Status: status,
HTTPStatus: httpStatus,
ErrorCode: strings.TrimSpace(errorCode),
Message: strings.TrimSpace(message),
})
}
func (r *soraProbeRecorder) finalize() soraProbeSummary {
meSuccess := false
partial := false
for _, step := range r.steps {
if step.Name == "me" {
meSuccess = strings.EqualFold(step.Status, "success")
continue
}
if strings.EqualFold(step.Status, "failed") {
partial = true
}
}
status := "success"
if !meSuccess {
status = "failed"
} else if partial {
status = "partial_success"
}
return soraProbeSummary{
Status: status,
Steps: append([]soraProbeStep(nil), r.steps...),
}
}
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
if rec == nil {
return
}
summary := rec.finalize()
code := ""
for _, step := range summary.Steps {
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
code = step.ErrorCode
break
}
}
s.sendEvent(c, TestEvent{
Type: "sora_test_result",
Status: summary.Status,
Code: code,
Data: summary,
})
}
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
if accountID <= 0 {
return 0, true
}
s.soraTestGuardMu.Lock()
defer s.soraTestGuardMu.Unlock()
if s.soraTestLastRun == nil {
s.soraTestLastRun = make(map[int64]time.Time)
}
cooldown := s.soraTestCooldown
if cooldown <= 0 {
cooldown = defaultSoraTestCooldown
}
now := time.Now()
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
elapsed := now.Sub(lastRun)
if elapsed < cooldown {
return cooldown - elapsed, false
}
}
s.soraTestLastRun[accountID] = now
return 0, true
}
func ceilSeconds(d time.Duration) int {
if d <= 0 {
return 1
}
sec := int(d / time.Second)
if d%time.Second != 0 {
sec++
}
if sec < 1 {
sec = 1
}
return sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, msg)
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
// 使用 Sora 客户端标准请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
soraTLSProfile := s.resolveSoraTLSProfile()
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
if err != nil {
recorder.addStep("me", "failed", 0, "network_error", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
switch {
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
case strings.TrimSpace(upstreamMessage) != "":
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
default:
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
}
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
if err := json.Unmarshal(body, &meResp); err != nil {
// 能收到 200 就说明 token 有效
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
} else {
// 尝试提取用户名或邮箱信息
info := "Sora connection OK"
if name, ok := meResp["name"].(string); ok && name != "" {
info = fmt.Sprintf("Sora connection OK - User: %s", name)
} else if email, ok := meResp["email"].(string); ok && email != "" {
info = fmt.Sprintf("Sora connection OK - Email: %s", email)
}
s.sendEvent(c, TestEvent{Type: "content", Text: info})
}
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
if err == nil {
subReq.Header.Set("Authorization", "Bearer "+authToken)
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
subReq.Header.Set("Accept", "application/json")
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
if subErr != nil {
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
}
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) testSora2Capabilities(
c *gin.Context,
ctx context.Context,
account *Account,
authToken string,
proxyURL string,
tlsProfile *tlsfingerprint.Profile,
recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
tlsProfile,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
if inviteStatus == http.StatusUnauthorized {
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraBootstrapURL,
proxyURL,
tlsProfile,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
if recorder != nil {
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
tlsProfile,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
} else if recorder != nil {
code := ""
msg := ""
if bootstrapErr != nil {
code = "network_error"
msg = bootstrapErr.Error()
}
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
}
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
}
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraRemainingURL,
proxyURL,
tlsProfile,
)
if remainingErr != nil {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
}
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
}
}
func (s *AccountTestService) fetchSoraTestEndpoint(
ctx context.Context,
account *Account,
authToken string,
url string,
proxyURL string,
tlsProfile *tlsfingerprint.Profile,
) (int, http.Header, []byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return 0, nil, nil, err
}
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return 0, nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return resp.StatusCode, resp.Header, nil, readErr
}
return resp.StatusCode, resp.Header, body, nil
}
func parseSoraSubscriptionSummary(body []byte) string {
var subResp struct {
Data []struct {
Plan struct {
ID string `json:"id"`
Title string `json:"title"`
} `json:"plan"`
EndTS string `json:"end_ts"`
} `json:"data"`
}
if err := json.Unmarshal(body, &subResp); err != nil {
return ""
}
if len(subResp.Data) == 0 {
return ""
}
first := subResp.Data[0]
parts := make([]string, 0, 3)
if first.Plan.Title != "" {
parts = append(parts, first.Plan.Title)
}
if first.Plan.ID != "" {
parts = append(parts, first.Plan.ID)
}
if first.EndTS != "" {
parts = append(parts, "end="+first.EndTS)
}
if len(parts) == 0 {
return ""
}
return "Subscription: " + strings.Join(parts, " | ")
}
func parseSoraInviteSummary(body []byte) string {
var inviteResp struct {
InviteCode string `json:"invite_code"`
RedeemedCount int64 `json:"redeemed_count"`
TotalCount int64 `json:"total_count"`
}
if err := json.Unmarshal(body, &inviteResp); err != nil {
return ""
}
parts := []string{"Sora2: supported"}
if inviteResp.InviteCode != "" {
parts = append(parts, "invite="+inviteResp.InviteCode)
}
if inviteResp.TotalCount > 0 {
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
}
return strings.Join(parts, " | ")
}
func parseSoraRemainingSummary(body []byte) string {
var remainingResp struct {
RateLimitAndCreditBalance struct {
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
RateLimitReached bool `json:"rate_limit_reached"`
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
} `json:"rate_limit_and_credit_balance"`
}
if err := json.Unmarshal(body, &remainingResp); err != nil {
return ""
}
info := remainingResp.RateLimitAndCreditBalance
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
if info.RateLimitReached {
parts = append(parts, "rate_limited=true")
}
if info.AccessResetsInSeconds > 0 {
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
}
return strings.Join(parts, " | ")
}
func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
// Sora TLS fingerprint enabled — use built-in default profile
return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
}
return nil // disabled
}
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
if headers == nil {
return "unknown"
}
candidates := []string{
"x-openai-public-ip",
"x-envoy-external-address",
"cf-connecting-ip",
"x-forwarded-for",
}
for _, key := range candidates {
if value := strings.TrimSpace(headers.Get(key)); value != "" {
return value
}
}
return "unknown"
}
func sanitizeProxyURLForLog(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
u, err := url.Parse(raw)
if err != nil {
return "<invalid_proxy_url>"
}
if u.User != nil {
u.User = nil
}
return u.String()
}
func endpointPathForLog(endpoint string) string {
parsed, err := url.Parse(strings.TrimSpace(endpoint))
if err != nil || parsed.Path == "" {
return endpoint
}
return parsed.Path
}
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
accountID := int64(0)
platform := ""
proxyID := "none"
if account != nil {
accountID = account.ID
platform = account.Platform
if account.ProxyID != nil {
proxyID = fmt.Sprintf("%d", *account.ProxyID)
}
}
cfRay := extractCloudflareRayID(headers, body)
if cfRay == "" {
cfRay = "unknown"
}
log.Printf(
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
accountID,
platform,
endpoint,
endpointPathForLog(endpoint),
proxyID,
sanitizeProxyURLForLog(proxyURL),
cfRay,
extractSoraEgressIPHint(headers),
)
}
func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
@@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
ctx, recorder := newTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
@@ -6,6 +6,7 @@ import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -14,6 +15,14 @@ import (
"github.com/stretchr/testify/require"
)
func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
@@ -34,7 +43,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
@@ -68,7 +77,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newSoraTestContext()
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
@@ -1,320 +0,0 @@
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, profile != nil)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestSanitizeProxyURLForLog(t *testing.T) {
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
require.Equal(t, "", sanitizeProxyURLForLog(""))
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
}
func TestExtractSoraEgressIPHint(t *testing.T) {
h := make(http.Header)
h.Set("x-openai-public-ip", "203.0.113.10")
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
h2 := make(http.Header)
h2.Set("x-envoy-external-address", "198.51.100.9")
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
}
+5 -96
View File
@@ -15,7 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/httputil"
)
// AdminService interface defines admin management operations
@@ -111,7 +111,6 @@ type CreateUserInput struct {
Balance float64
Concurrency int
AllowedGroups []int64
SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
@@ -126,7 +125,6 @@ type UpdateUserInput struct {
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
@@ -143,11 +141,6 @@ type CreateGroupInput struct {
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
@@ -158,8 +151,6 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
@@ -184,11 +175,6 @@ type UpdateGroupInput struct {
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
@@ -199,8 +185,6 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
@@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{
http.StatusOK: {},
},
},
{
Target: "sora",
URL: "https://sora.chatgpt.com/backend/me",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
}
const (
@@ -448,7 +424,6 @@ type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
@@ -473,7 +448,6 @@ func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
@@ -492,7 +466,6 @@ func NewAdminService(
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
@@ -582,7 +555,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
@@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组
if input.FallbackGroupID != nil {
@@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
SoraImagePrice360: soraImagePrice360,
SoraImagePrice540: soraImagePrice540,
SoraVideoPricePerRequest: soraVideoPrice,
SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
@@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
if input.SoraImagePrice360 != nil {
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
}
if input.SoraImagePrice540 != nil {
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
}
if input.SoraVideoPricePerRequest != nil {
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
}
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, err
}
// 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
if account.Platform == PlatformSora && s.soraAccountRepo != nil {
soraUpdates := map[string]any{
"access_token": account.GetCredential("access_token"),
"refresh_token": account.GetCredential("refresh_token"),
}
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
// 只记录警告日志,不阻塞账号创建
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
}
}
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
@@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
@@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
body = body[:proxyQualityMaxBodyBytes]
}
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
// Cloudflare challenge 检测
if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge"
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
item.Message = "Sora 命中 Cloudflare challenge"
item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
item.Message = "命中 Cloudflare challenge"
return item
}
@@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
require.Contains(t, result.Summary, "挑战 1 项")
}
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
@@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
defer server.Close()
target := proxyQualityTarget{
Target: "sora",
Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
@@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
@@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
@@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
@@ -808,14 +808,6 @@ type ImagePriceConfig struct {
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
}
// SoraPriceConfig Sora 按次计费配置
type SoraPriceConfig struct {
ImagePrice360 *float64
ImagePrice540 *float64
VideoPricePerRequest *float64
VideoPricePerRequestHD *float64
}
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
@@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
}
}
// CalculateSoraImageCost 计算 Sora 图片按次费用
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
unitPrice := 0.0
if groupConfig != nil {
switch imageSize {
case "540":
if groupConfig.ImagePrice540 != nil {
unitPrice = *groupConfig.ImagePrice540
}
default:
if groupConfig.ImagePrice360 != nil {
unitPrice = *groupConfig.ImagePrice360
}
}
}
totalCost := unitPrice * float64(imageCount)
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
unitPrice := 0.0
if groupConfig != nil {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
if groupConfig.VideoPricePerRequestHD != nil {
unitPrice = *groupConfig.VideoPricePerRequestHD
}
}
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
unitPrice = *groupConfig.VideoPricePerRequest
}
}
totalCost := unitPrice
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格
@@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
func TestCalculateSoraVideoCost(t *testing.T) {
svc := newTestBillingService()
price := 0.5
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
}
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
svc := newTestBillingService()
hdPrice := 1.0
normalPrice := 0.5
cfg := &SoraPriceConfig{
VideoPricePerRequest: &normalPrice,
VideoPricePerRequestHD: &hdPrice,
}
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
@@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) {
require.Contains(t, err.Error(), "not initialized")
}
func TestCalculateSoraImageCost(t *testing.T) {
svc := newTestBillingService()
price360 := 0.05
price540 := 0.08
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
}
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
// 使用空的 fallback prices 让 GetModelPricing 失败
svc := &BillingService{
@@ -24,7 +24,6 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
PlatformSora = domain.PlatformSora
)
// Account type constants
@@ -107,7 +106,6 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
@@ -199,27 +197,6 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
// =========================
// Claude Code Version Check
// =========================
+45 -426
View File
@@ -60,13 +60,6 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -510,10 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType string // image / video / prompt
MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
@@ -1341,6 +1330,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
@@ -1349,12 +1343,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded
}
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
@@ -1442,24 +1430,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
var stickyCacheMissReason string
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
if s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
if rpmPass { // 粘性会话窗口费用+RPM 检查
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择
} else {
if s.debugModelRoutingEnabled() {
@@ -1473,49 +1456,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
if stickyCacheMissReason == "" {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
stickyCacheMissReason = "session_limit"
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
// 会话限制已满,继续到负载感知选择
} else {
stickyCacheMissReason = "wait_queue_full"
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} else if !gatePass {
stickyCacheMissReason = "gate_check"
} else {
stickyCacheMissReason = "rpm_red"
}
// 记录粘性缓存未命中的结构化日志
if stickyCacheMissReason != "" {
baseRPM := stickyAccount.GetBaseRPM()
var currentRPM int
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
currentRPM = count
}
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
stickyAccountID, shortSessionHash(sessionHash))
}
}
}
@@ -1621,7 +1582,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
@@ -1637,7 +1597,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
@@ -1652,10 +1611,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return &AccountSelectionResult{
Account: account,
@@ -1971,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if platform == PlatformSora {
return s.listSoraSchedulableAccounts(ctx, groupID)
}
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
@@ -2070,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
const useMixed = false
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
} else if groupID != nil {
accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform != PlatformSora {
continue
}
if !s.isSoraAccountSchedulable(&acc) {
continue
}
filtered = append(filtered, acc)
}
slog.Debug("account_scheduling_list_sora",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
@@ -2141,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
return s.soraUnschedulableReason(account) == ""
}
func (s *GatewayService) soraUnschedulableReason(account *Account) string {
if account == nil {
return "account_nil"
}
if account.Status != StatusActive {
return fmt.Sprintf("status=%s", account.Status)
}
if !account.Schedulable {
return "schedulable=false"
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
}
return ""
}
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil {
return false
}
if account.Platform == PlatformSora {
return s.isSoraAccountSchedulable(account)
}
return account.IsSchedulable()
}
@@ -2175,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil {
return false
}
if account.Platform == PlatformSora {
if !s.isSoraAccountSchedulable(account) {
return false
}
return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
}
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
}
@@ -2795,12 +2673,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
@@ -2824,7 +2696,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2872,12 +2744,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2983,12 +2849,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -3055,12 +2915,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
@@ -3128,12 +2982,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3203,7 +3051,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -3227,7 +3075,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -3240,12 +3087,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3357,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs,
stats.SampleRateLimitIDs,
)
if platform == PlatformSora {
s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
}
return stats
}
@@ -3417,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
}
if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable"
if acc.Platform == PlatformSora {
detail = s.soraUnschedulableReason(acc)
}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
@@ -3444,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
func (s *GatewayService) logSoraSelectionFailureDetails(
ctx context.Context,
groupID *int64,
sessionHash string,
requestedModel string,
accounts []Account,
excludedIDs map[int64]struct{},
allowMixedScheduling bool,
) {
const maxLines = 30
logged := 0
for i := range accounts {
if logged >= maxLines {
break
}
acc := &accounts[i]
diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
if diagnosis.Category == "eligible" {
continue
}
detail := diagnosis.Detail
if detail == "" {
detail = "-"
}
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
acc.ID,
acc.Platform,
diagnosis.Category,
detail,
)
logged++
}
if len(accounts) > maxLines {
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
len(accounts),
logged,
)
}
}
// GetAccessToken 获取账号凭证
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
@@ -3573,13 +3358,14 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
if account.Platform == PlatformSora {
return s.isSoraModelSupportedByAccount(account, requestedModel)
}
if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
}
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
@@ -3588,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
if account == nil {
return false
}
if strings.TrimSpace(requestedModel) == "" {
return true
}
// 先走原始精确/通配符匹配。
mapping := account.GetModelMapping()
if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
return true
}
aliases := buildSoraModelAliases(requestedModel)
if len(aliases) == 0 {
return false
}
hasSoraSelector := false
for pattern := range mapping {
if !isSoraModelSelector(pattern) {
continue
}
hasSoraSelector = true
if matchPatternAnyAlias(pattern, aliases) {
return true
}
}
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
// 此时不应误拦截 Sora 模型请求。
if !hasSoraSelector {
return true
}
return false
}
func matchPatternAnyAlias(pattern string, aliases []string) bool {
normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
if normalizedPattern == "" {
return false
}
for _, alias := range aliases {
if matchWildcard(normalizedPattern, alias) {
return true
}
}
return false
}
func isSoraModelSelector(pattern string) bool {
p := strings.ToLower(strings.TrimSpace(pattern))
if p == "" {
return false
}
switch {
case strings.HasPrefix(p, "sora"),
strings.HasPrefix(p, "gpt-image"),
strings.HasPrefix(p, "prompt-enhance"),
strings.HasPrefix(p, "sy_"):
return true
}
return p == "video" || p == "image"
}
func buildSoraModelAliases(requestedModel string) []string {
modelID := strings.ToLower(strings.TrimSpace(requestedModel))
if modelID == "" {
return nil
}
aliases := make([]string, 0, 8)
addAlias := func(value string) {
v := strings.ToLower(strings.TrimSpace(value))
if v == "" {
return
}
for _, existing := range aliases {
if existing == v {
return
}
}
aliases = append(aliases, v)
}
addAlias(modelID)
cfg, ok := GetSoraModelConfig(modelID)
if ok {
addAlias(cfg.Model)
switch cfg.Type {
case "video":
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case "image":
addAlias("image")
addAlias("gpt-image")
case "prompt_enhance":
addAlias("prompt-enhance")
}
return aliases
}
switch {
case strings.HasPrefix(modelID, "sora"):
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case strings.HasPrefix(modelID, "gpt-image"):
addAlias("image")
addAlias("gpt-image")
case strings.HasPrefix(modelID, "prompt-enhance"):
addAlias("prompt-enhance")
default:
return nil
}
return aliases
}
func soraVideoFamilyAlias(modelID string) string {
switch {
case strings.HasPrefix(modelID, "sora2pro-hd"):
return "sora2pro-hd"
case strings.HasPrefix(modelID, "sora2pro"):
return "sora2pro"
case strings.HasPrefix(modelID, "sora2"):
return "sora2"
default:
return ""
}
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
@@ -7434,6 +7083,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey
User *User
Account *Account
@@ -7745,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct {
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
// ParsedRequest(可选,仅 Claude 路径传入)
ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支(image/video/prompt
// - MediaType 字段写入使用日志
EnableClaudePath bool
@@ -7776,6 +7424,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
ParsedRequest: input.ParsedRequest,
EnableClaudePath: true,
})
}
@@ -7841,8 +7490,6 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
@@ -7944,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
@@ -7963,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -8163,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if isSoraMedia {
return nil
}
var mode string
switch {
case cost != nil && cost.BillingMode != "":
@@ -8183,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil
}
@@ -8293,6 +7899,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return ch.BillingModelSource == BillingModelSourceUpstream
}
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
if groupID == nil {
return false
}
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
return false
}
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
@@ -9,35 +9,35 @@ import (
func TestCollectSelectionFailureStats(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
model := "gpt-5.4"
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
accounts := []Account{
// excluded
{
ID: 1,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
// unschedulable
{
ID: 2,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: false,
},
// platform filtered
{
ID: 3,
Platform: PlatformOpenAI,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
},
// model unsupported
{
ID: 4,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
@@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// model rate limited
{
ID: 5,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
@@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// eligible
{
ID: 6,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
}
excluded := map[int64]struct{}{1: {}}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false)
if stats.Total != 6 {
t.Fatalf("total=%d want=6", stats.Total)
@@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
}
}
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) {
svc := &GatewayService{}
acc := &Account{
ID: 7,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: false,
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false)
if diagnosis.Category != "unschedulable" {
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
}
if diagnosis.Detail != "schedulable=false" {
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
if diagnosis.Detail != "generic_unschedulable" {
t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail)
}
}
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
model := "gpt-5.4"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
acc := &Account{
ID: 8,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
@@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
},
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false)
if diagnosis.Category != "model_rate_limited" {
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
}
@@ -1,79 +0,0 @@
package service
import "testing"
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when model_mapping is empty")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-4o": "gpt-4o",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sora2": "sora2",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sy_8": "sy_8",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-image": "gpt-image",
},
},
}
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
}
}
@@ -1,89 +0,0 @@
package service
import (
"context"
"testing"
"time"
)
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
svc := &GatewayService{}
now := time.Now()
past := now.Add(-1 * time.Minute)
future := now.Add(5 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
AutoPauseOnExpired: true,
ExpiresAt: &past,
OverloadUntil: &future,
RateLimitResetAt: &future,
}
if !svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
}
}
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(5 * time.Minute)
acc := &Account{
Platform: PlatformAnthropic,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
}
if svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected non-sora account to keep generic schedulable checks")
}
}
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
globalResetAt := time.Now().Add(2 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &globalResetAt,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
t.Fatalf("expected sora account to be blocked by model scope rate limit")
}
}
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(3 * time.Minute)
accounts := []Account{
{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
},
}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
if stats.Unschedulable != 0 || stats.Eligible != 1 {
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
}
}
-21
View File
@@ -26,15 +26,6 @@ type Group struct {
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置(阶段 1)
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
// Sora 存储配额
SoraStorageQuotaBytes int64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
@@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
}
}
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
switch imageSize {
case "360":
return g.SoraImagePrice360
case "540":
return g.SoraImagePrice540
default:
return g.SoraImagePrice360
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
@@ -3,30 +3,15 @@ package service
import (
"context"
"crypto/subtle"
"encoding/json"
"io"
"log/slog"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
type soraSessionChunk struct {
index int
value string
}
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil {
@@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if resp.StatusCode != http.StatusOK {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var sessionResp struct {
AccessToken string `json:"accessToken"`
Expires string `json:"expires"`
User struct {
Email string `json:"email"`
Name string `json:"name"`
} `json:"user"`
}
if err := json.Unmarshal(body, &sessionResp); err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
}
if strings.TrimSpace(sessionResp.AccessToken) == "" {
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
}
expiresAt := time.Now().Add(time.Hour).Unix()
if strings.TrimSpace(sessionResp.Expires) != "" {
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
expiresAt = parsed.Unix()
}
}
expiresIn := expiresAt - time.Now().Unix()
if expiresIn < 0 {
expiresIn = 0
}
return &OpenAITokenInfo{
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
ClientID: openai.SoraClientID,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
func normalizeSoraSessionTokenInput(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
if len(matches) == 0 {
return sanitizeSessionToken(trimmed)
}
chunkMatches := make([]soraSessionChunk, 0, len(matches))
singleValues := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) < 3 {
continue
}
value := sanitizeSessionToken(match[2])
if value == "" {
continue
}
if strings.TrimSpace(match[1]) == "" {
singleValues = append(singleValues, value)
continue
}
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
if err != nil || idx < 0 {
continue
}
chunkMatches = append(chunkMatches, soraSessionChunk{
index: idx,
value: value,
})
}
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
return merged
}
if len(singleValues) > 0 {
return singleValues[len(singleValues)-1]
}
return ""
}
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
if len(chunks) == 0 {
return ""
}
byIndex := make(map[int]string, len(chunks))
for _, chunk := range chunks {
byIndex[chunk.index] = chunk.value
}
if _, ok := byIndex[0]; !ok {
return ""
}
if requireComplete {
for idx := 0; idx <= requiredMaxIndex; idx++ {
if _, ok := byIndex[idx]; !ok {
return ""
}
}
}
orderedIndexes := make([]int, 0, len(byIndex))
for idx := range byIndex {
orderedIndexes = append(orderedIndexes, idx)
}
sort.Ints(orderedIndexes)
var builder strings.Builder
for _, idx := range orderedIndexes {
if _, err := builder.WriteString(byIndex[idx]); err != nil {
return ""
}
}
return sanitizeSessionToken(builder.String())
}
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
if len(chunks) == 0 {
return ""
}
requiredMaxIndex := 0
for _, chunk := range chunks {
if chunk.index > requiredMaxIndex {
requiredMaxIndex = chunk.index
}
}
groupStarts := make([]int, 0, len(chunks))
for idx, chunk := range chunks {
if chunk.index == 0 {
groupStarts = append(groupStarts, idx)
}
}
if len(groupStarts) == 0 {
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
for i := len(groupStarts) - 1; i >= 0; i-- {
start := groupStarts[i]
end := len(chunks)
if i+1 < len(groupStarts) {
end = groupStarts[i+1]
}
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
return merged
}
}
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
func sanitizeSessionToken(raw string) string {
token := strings.TrimSpace(raw)
token = strings.Trim(token, "\"'`")
token = strings.TrimSuffix(token, ";")
return strings.TrimSpace(token)
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
// RefreshAccountToken refreshes token for an OpenAI OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
if account.Platform != PlatformOpenAI {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
@@ -609,10 +389,5 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
}
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
return openai.OAuthPlatformSora
default:
return openai.OAuthPlatformOpenAI
}
return openai.OAuthPlatformOpenAI
}
@@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
require.NoError(t, err)
require.NotEmpty(t, result.AuthURL)
require.NotEmpty(t, result.SessionID)
parsed, err := url.Parse(result.AuthURL)
require.NoError(t, err)
q := parsed.Query()
require.Equal(t, openai.ClientID, q.Get("client_id"))
require.Empty(t, q.Get("codex_cli_simplified_flow"))
session, ok := svc.sessionStore.Get(result.SessionID)
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
@@ -1,173 +0,0 @@
package service
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at-token", info.AccessToken)
require.Equal(t, "demo@example.com", info.Email)
require.Greater(t, info.ExpiresAt, int64(0))
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"set-cookie",
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
@@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
// OpenAITokenCache token cache interface.
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
@@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai/sora oauth account")
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
@@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
refreshFailed = true
} else {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
} else if result.Refreshed {
p.metrics.refreshSuccess.Add(1)
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
} else if result.Refreshed {
p.metrics.refreshSuccess.Add(1)
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai/sora oauth account")
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai/sora oauth account")
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
-618
View File
@@ -22,8 +22,6 @@ import (
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
"default subscription group must exist and be subscription type",
@@ -104,7 +102,6 @@ type SettingService struct {
defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
onS3Update func() // Callback when Sora S3 settings are updated
version string // Application version
}
@@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
@@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
@@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
s.onS3Update = callback
}
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
@@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
@@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
@@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
@@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
@@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
@@ -1584,606 +1569,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
type soraS3ProfilesStore struct {
ActiveProfileID string `json:"active_profile_id"`
Items []soraS3ProfileStoreItem `json:"items"`
}
type soraS3ProfileStoreItem struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
profiles, err := s.ListSoraS3Profiles(ctx)
if err != nil {
return nil, err
}
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if activeProfile == nil {
return &SoraS3Settings{}, nil
}
return &SoraS3Settings{
Enabled: activeProfile.Enabled,
Endpoint: activeProfile.Endpoint,
Region: activeProfile.Region,
Bucket: activeProfile.Bucket,
AccessKeyID: activeProfile.AccessKeyID,
SecretAccessKey: activeProfile.SecretAccessKey,
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
Prefix: activeProfile.Prefix,
ForcePathStyle: activeProfile.ForcePathStyle,
CDNURL: activeProfile.CDNURL,
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
}, nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
now := time.Now().UTC().Format(time.RFC3339)
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
if activeIndex < 0 {
activeID := "default"
if hasSoraS3ProfileID(store.Items, activeID) {
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
}
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: activeID,
Name: "Default",
UpdatedAt: now,
})
store.ActiveProfileID = activeID
activeIndex = len(store.Items) - 1
}
active := store.Items[activeIndex]
active.Enabled = settings.Enabled
active.Endpoint = strings.TrimSpace(settings.Endpoint)
active.Region = strings.TrimSpace(settings.Region)
active.Bucket = strings.TrimSpace(settings.Bucket)
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
active.Prefix = strings.TrimSpace(settings.Prefix)
active.ForcePathStyle = settings.ForcePathStyle
active.CDNURL = strings.TrimSpace(settings.CDNURL)
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
if settings.SecretAccessKey != "" {
active.SecretAccessKey = settings.SecretAccessKey
}
active.UpdatedAt = now
store.Items[activeIndex] = active
return s.persistSoraS3ProfilesStore(ctx, store)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
return convertSoraS3ProfilesStore(store), nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
profileID := strings.TrimSpace(profile.ProfileID)
if profileID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
if hasSoraS3ProfileID(store.Items, profileID) {
return nil, ErrSoraS3ProfileExists
}
now := time.Now().UTC().Format(time.RFC3339)
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: profileID,
Name: name,
Enabled: profile.Enabled,
Endpoint: strings.TrimSpace(profile.Endpoint),
Region: strings.TrimSpace(profile.Region),
Bucket: strings.TrimSpace(profile.Bucket),
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
SecretAccessKey: profile.SecretAccessKey,
Prefix: strings.TrimSpace(profile.Prefix),
ForcePathStyle: profile.ForcePathStyle,
CDNURL: strings.TrimSpace(profile.CDNURL),
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
})
if setActive || store.ActiveProfileID == "" {
store.ActiveProfileID = profileID
}
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
created := findSoraS3ProfileByID(profiles.Items, profileID)
if created == nil {
return nil, ErrSoraS3ProfileNotFound
}
return created, nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
target := store.Items[targetIndex]
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
target.Name = name
target.Enabled = profile.Enabled
target.Endpoint = strings.TrimSpace(profile.Endpoint)
target.Region = strings.TrimSpace(profile.Region)
target.Bucket = strings.TrimSpace(profile.Bucket)
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
target.Prefix = strings.TrimSpace(profile.Prefix)
target.ForcePathStyle = profile.ForcePathStyle
target.CDNURL = strings.TrimSpace(profile.CDNURL)
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
if profile.SecretAccessKey != "" {
target.SecretAccessKey = profile.SecretAccessKey
}
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
store.Items[targetIndex] = target
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
updated := findSoraS3ProfileByID(profiles.Items, targetID)
if updated == nil {
return nil, ErrSoraS3ProfileNotFound
}
return updated, nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return ErrSoraS3ProfileNotFound
}
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
if store.ActiveProfileID == targetID {
store.ActiveProfileID = ""
if len(store.Items) > 0 {
store.ActiveProfileID = store.Items[0].ProfileID
}
}
return s.persistSoraS3ProfilesStore(ctx, store)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
store.ActiveProfileID = targetID
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if active == nil {
return nil, ErrSoraS3ProfileNotFound
}
return active, nil
}
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
if err == nil {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return &soraS3ProfilesStore{}, nil
}
var store soraS3ProfilesStore
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
normalized := normalizeSoraS3ProfilesStore(store)
return &normalized, nil
}
if !errors.Is(err, ErrSettingNotFound) {
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
}
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, legacyErr
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
if store == nil {
return fmt.Errorf("sora s3 profiles store cannot be nil")
}
normalized := normalizeSoraS3ProfilesStore(*store)
data, err := json.Marshal(normalized)
if err != nil {
return fmt.Errorf("marshal sora s3 profiles: %w", err)
}
updates := map[string]string{
SettingKeySoraS3Profiles: string(data),
}
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
if active == nil {
updates[SettingKeySoraS3Enabled] = "false"
updates[SettingKeySoraS3Endpoint] = ""
updates[SettingKeySoraS3Region] = ""
updates[SettingKeySoraS3Bucket] = ""
updates[SettingKeySoraS3AccessKeyID] = ""
updates[SettingKeySoraS3Prefix] = ""
updates[SettingKeySoraS3ForcePathStyle] = "false"
updates[SettingKeySoraS3CDNURL] = ""
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
updates[SettingKeySoraS3SecretAccessKey] = ""
} else {
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
}
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return err
}
if s.onUpdate != nil {
s.onUpdate()
}
if s.onS3Update != nil {
s.onS3Update()
}
return nil
}
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
keys := []string{
SettingKeySoraS3Enabled,
SettingKeySoraS3Endpoint,
SettingKeySoraS3Region,
SettingKeySoraS3Bucket,
SettingKeySoraS3AccessKeyID,
SettingKeySoraS3SecretAccessKey,
SettingKeySoraS3Prefix,
SettingKeySoraS3ForcePathStyle,
SettingKeySoraS3CDNURL,
SettingKeySoraDefaultStorageQuotaBytes,
}
values, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
}
result := &SoraS3Settings{
Enabled: values[SettingKeySoraS3Enabled] == "true",
Endpoint: values[SettingKeySoraS3Endpoint],
Region: values[SettingKeySoraS3Region],
Bucket: values[SettingKeySoraS3Bucket],
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
Prefix: values[SettingKeySoraS3Prefix],
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
CDNURL: values[SettingKeySoraS3CDNURL],
}
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
result.DefaultStorageQuotaBytes = v
}
return result, nil
}
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
seen := make(map[string]struct{}, len(store.Items))
normalized := soraS3ProfilesStore{
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
}
now := time.Now().UTC().Format(time.RFC3339)
for idx := range store.Items {
item := store.Items[idx]
item.ProfileID = strings.TrimSpace(item.ProfileID)
if item.ProfileID == "" {
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
}
if _, exists := seen[item.ProfileID]; exists {
continue
}
seen[item.ProfileID] = struct{}{}
item.Name = strings.TrimSpace(item.Name)
if item.Name == "" {
item.Name = item.ProfileID
}
item.Endpoint = strings.TrimSpace(item.Endpoint)
item.Region = strings.TrimSpace(item.Region)
item.Bucket = strings.TrimSpace(item.Bucket)
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
item.Prefix = strings.TrimSpace(item.Prefix)
item.CDNURL = strings.TrimSpace(item.CDNURL)
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
if item.UpdatedAt == "" {
item.UpdatedAt = now
}
normalized.Items = append(normalized.Items, item)
}
if len(normalized.Items) == 0 {
normalized.ActiveProfileID = ""
return normalized
}
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
return normalized
}
normalized.ActiveProfileID = normalized.Items[0].ProfileID
return normalized
}
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
if store == nil {
return &SoraS3ProfileList{}
}
items := make([]SoraS3Profile, 0, len(store.Items))
for idx := range store.Items {
item := store.Items[idx]
items = append(items, SoraS3Profile{
ProfileID: item.ProfileID,
Name: item.Name,
IsActive: item.ProfileID == store.ActiveProfileID,
Enabled: item.Enabled,
Endpoint: item.Endpoint,
Region: item.Region,
Bucket: item.Bucket,
AccessKeyID: item.AccessKeyID,
SecretAccessKey: item.SecretAccessKey,
SecretAccessKeyConfigured: item.SecretAccessKey != "",
Prefix: item.Prefix,
ForcePathStyle: item.ForcePathStyle,
CDNURL: item.CDNURL,
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
UpdatedAt: item.UpdatedAt,
})
}
return &SoraS3ProfileList{
ActiveProfileID: store.ActiveProfileID,
Items: items,
}
}
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
for idx := range items {
if items[idx].ProfileID == profileID {
return idx
}
}
return -1
}
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
return findSoraS3ProfileIndex(items, profileID) >= 0
}
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
if settings == nil {
return true
}
if settings.Enabled {
return false
}
if strings.TrimSpace(settings.Endpoint) != "" {
return false
}
if strings.TrimSpace(settings.Region) != "" {
return false
}
if strings.TrimSpace(settings.Bucket) != "" {
return false
}
if strings.TrimSpace(settings.AccessKeyID) != "" {
return false
}
if settings.SecretAccessKey != "" {
return false
}
if strings.TrimSpace(settings.Prefix) != "" {
return false
}
if strings.TrimSpace(settings.CDNURL) != "" {
return false
}
return settings.DefaultStorageQuotaBytes == 0
}
func maxInt64(value int64, min int64) int64 {
if value < min {
return min
}
return value
}
-42
View File
@@ -41,7 +41,6 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
@@ -107,7 +106,6 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
@@ -116,46 +114,6 @@ type PublicSettings struct {
Version string
}
// SoraS3Settings Sora S3 存储配置
type SoraS3Settings struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type SoraS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type SoraS3ProfileList struct {
ActiveProfileID string `json:"active_profile_id"`
Items []SoraS3Profile `json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
@@ -1,40 +0,0 @@
package service
import "context"
// SoraAccountRepository Sora 账号扩展表仓储接口
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
//
// 设计说明:
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
// - Token 刷新时需要同时更新两个表以保持数据一致性
type SoraAccountRepository interface {
// Upsert 创建或更新 Sora 账号扩展信息
// accountID: 关联的 accounts.id
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
//
// 如果记录不存在则创建,存在则更新。
// 用于:
// 1. 创建 Sora 账号时初始化扩展表
// 2. Token 刷新时同步更新扩展表
Upsert(ctx context.Context, accountID int64, updates map[string]any) error
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
// 返回 nil, nil 表示记录不存在(非错误)
GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
// Delete 删除 Sora 账号扩展信息
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
Delete(ctx context.Context, accountID int64) error
}
// SoraAccount Sora 账号扩展信息
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
type SoraAccount struct {
AccountID int64 // 关联的 accounts.id
AccessToken string // OAuth access_token
RefreshToken string // OAuth refresh_token
SessionToken string // Session token(可选,用于 ST→AT 兜底)
}
-117
View File
@@ -1,117 +0,0 @@
package service
import (
"context"
"fmt"
"net/http"
)
// SoraClient 定义直连 Sora 的任务操作接口。
type SoraClient interface {
Enabled() bool
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
DeletePost(ctx context.Context, account *Account, postID string) error
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
}
// SoraImageRequest 图片生成请求参数
type SoraImageRequest struct {
Prompt string
Width int
Height int
MediaID string
}
// SoraVideoRequest 视频生成请求参数
type SoraVideoRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
VideoCount int
MediaID string
RemixTargetID string
CameoIDs []string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type SoraStoryboardRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
MediaID string
}
// SoraImageTaskStatus 图片任务状态
type SoraImageTaskStatus struct {
ID string
Status string
ProgressPct float64
URLs []string
ErrorMsg string
}
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
ID string
Status string
ProgressPct int
URLs []string
GenerationID string
ErrorMsg string
}
// SoraCameoStatus 角色处理中间态
type SoraCameoStatus struct {
Status string
StatusMessage string
DisplayNameHint string
UsernameHint string
ProfileAssetURL string
InstructionSetHint any
InstructionSet any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type SoraCharacterFinalizeRequest struct {
CameoID string
Username string
DisplayName string
ProfileAssetPointer string
InstructionSet any
}
// SoraUpstreamError 上游错误
type SoraUpstreamError struct {
StatusCode int
Message string
Headers http.Header
Body []byte
}
func (e *SoraUpstreamError) Error() string {
if e == nil {
return "sora upstream error"
}
if e.Message != "" {
return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
}
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
}
File diff suppressed because it is too large Load Diff
@@ -1,564 +0,0 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
var _ SoraClient = (*stubSoraClientForPoll)(nil)
type stubSoraClientForPoll struct {
imageStatus *SoraImageTaskStatus
videoStatus *SoraVideoTaskStatus
imageCalls int
videoCalls int
enhanced string
enhanceErr error
storyboard bool
videoReq SoraVideoRequest
parseErr error
postCalls int
deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
return "", nil
}
func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
s.videoReq = req
return "task-video", nil
}
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
s.storyboard = true
return "task-video", nil
}
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
return &SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
return nil
}
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
return nil
}
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
s.postCalls++
return "s_post", nil
}
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
s.deleteCalls++
return nil
}
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
if s.parseErr != nil {
return "", s.parseErr
}
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
}
return "enhanced prompt", s.enhanceErr
}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++
return s.imageStatus, nil
}
func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
s.videoCalls++
return s.videoStatus, nil
}
func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
client := &stubSoraClientForPoll{
imageStatus: &SoraImageTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/a.png"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
require.NoError(t, err)
require.Equal(t, []string{"https://example.com/a.png"}, urls)
require.Equal(t, 1, client.imageCalls)
}
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
client := &stubSoraClientForPoll{
enhanced: "cinematic prompt",
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
},
},
}
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, "prompt-enhance-short-10s", result.Model)
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
}
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 3, client.videoReq.VideoCount)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, 0, client.videoCalls)
}
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
parseErr: errors.New("parse failed"),
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 0, client.deleteCalls)
}
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 1, client.deleteCalls)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "failed",
ErrorMsg: "reject",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
service := NewSoraGatewayService(client, nil, nil, cfg)
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
cfg := &config.Config{
Gateway: config.GatewayConfig{
SoraMediaSigningKey: "test-key",
SoraMediaSignedURLTTLSeconds: 600,
},
}
service := NewSoraGatewayService(nil, nil, nil, cfg)
url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
require.Contains(t, url, "/sora/media-signed")
require.Contains(t, url, "expires=")
require.Contains(t, url, "sig=")
}
func TestNormalizeSoraMediaURLs_Empty(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
result := svc.normalizeSoraMediaURLs(nil)
require.Empty(t, result)
result = svc.normalizeSoraMediaURLs([]string{})
require.Empty(t, result)
}
func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"}
result := svc.normalizeSoraMediaURLs(urls)
require.Equal(t, urls, result)
}
func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) {
cfg := &config.Config{}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"}
result := svc.normalizeSoraMediaURLs(urls)
require.Len(t, result, 2)
require.Contains(t, result[0], "/sora/media")
require.Contains(t, result[1], "/sora/media")
}
func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"}
result := svc.normalizeSoraMediaURLs(urls)
require.Len(t, result, 2)
}
func TestBuildSoraContent_Image(t *testing.T) {
content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"})
require.Contains(t, content, "![image](https://a.com/1.png)")
require.Contains(t, content, "![image](https://a.com/2.png)")
}
func TestBuildSoraContent_Video(t *testing.T) {
content := buildSoraContent("video", []string{"https://a.com/v.mp4"})
require.Contains(t, content, "<video src='https://a.com/v.mp4'")
}
func TestBuildSoraContent_VideoEmpty(t *testing.T) {
content := buildSoraContent("video", nil)
require.Empty(t, content)
}
func TestBuildSoraContent_Prompt(t *testing.T) {
content := buildSoraContent("prompt", nil)
require.Empty(t, content)
}
func TestSoraImageSizeFromModel(t *testing.T) {
require.Equal(t, "360", soraImageSizeFromModel("gpt-image"))
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-landscape"))
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-portrait"))
require.Equal(t, "540", soraImageSizeFromModel("something-landscape"))
require.Equal(t, "360", soraImageSizeFromModel("unknown-model"))
}
func TestFirstMediaURL(t *testing.T) {
require.Equal(t, "", firstMediaURL(nil))
require.Equal(t, "", firstMediaURL([]string{}))
require.Equal(t, "a", firstMediaURL([]string{"a", "b"}))
}
func TestSoraProErrorMessage(t *testing.T) {
require.Contains(t, soraProErrorMessage("sora2pro-hd", ""), "Pro-HD")
require.Contains(t, soraProErrorMessage("sora2pro", ""), "Pro")
require.Empty(t, soraProErrorMessage("sora-basic", ""))
}
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
body := rec.Body.String()
require.Contains(t, body, "event: error\n")
require.Contains(t, body, "data: [DONE]\n\n")
lines := strings.Split(body, "\n")
require.GreaterOrEqual(t, len(lines), 2)
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "))
data := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
errObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
}
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
sourceHeaders := http.Header{}
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
err := svc.handleSoraRequestError(
context.Background(),
&Account{ID: 1, Platform: PlatformSora},
&SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "forbidden",
Headers: sourceHeaders,
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
},
"sora2-landscape-10s",
nil,
false,
)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.NotNil(t, failoverErr.ResponseHeaders)
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
sourceHeaders.Set("cf-ray", "mutated-after-return")
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
}
func TestShouldFailoverUpstreamError(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc.shouldFailoverUpstreamError(401))
require.True(t, svc.shouldFailoverUpstreamError(404))
require.True(t, svc.shouldFailoverUpstreamError(429))
require.True(t, svc.shouldFailoverUpstreamError(500))
require.True(t, svc.shouldFailoverUpstreamError(502))
require.False(t, svc.shouldFailoverUpstreamError(200))
require.False(t, svc.shouldFailoverUpstreamError(400))
}
func TestWithSoraTimeout_NilService(t *testing.T) {
var svc *SoraGatewayService
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
require.NotNil(t, ctx)
require.Nil(t, cancel)
}
func TestWithSoraTimeout_ZeroTimeout(t *testing.T) {
cfg := &config.Config{}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
require.NotNil(t, ctx)
require.Nil(t, cancel)
}
func TestWithSoraTimeout_PositiveTimeout(t *testing.T) {
cfg := &config.Config{
Gateway: config.GatewayConfig{
SoraRequestTimeoutSeconds: 30,
},
}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
require.NotNil(t, ctx)
require.NotNil(t, cancel)
cancel()
}
func TestPollInterval(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 5,
},
},
}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
require.Equal(t, 5*time.Second, svc.pollInterval())
// 默认值
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc2.pollInterval() > 0)
}
func TestPollMaxAttempts(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
MaxPollAttempts: 100,
},
},
}
svc := NewSoraGatewayService(nil, nil, nil, cfg)
require.Equal(t, 100, svc.pollMaxAttempts())
// 默认值
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc2.pollMaxAttempts() > 0)
}
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
require.Error(t, err)
}
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
encoded := "data:image/png;base64,aGVsbG8="
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
require.NoError(t, err)
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
require.Error(t, err)
require.Nil(t, data)
}
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
body := map[string]any{
"watermark_free": float64(1),
"watermark_fallback_on_failure": float64(0),
}
opts := parseSoraWatermarkOptions(body)
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
func TestParseSoraVideoCount(t *testing.T) {
require.Equal(t, 1, parseSoraVideoCount(nil))
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
}
@@ -1,532 +0,0 @@
//nolint:unused
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
const soraRewriteBufferLimit = 2048
type soraStreamingResult struct {
mediaType string
mediaURLs []string
imageCount int
imageSize string
firstTokenMs *int
}
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if c != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
}
}
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
if s.rateLimitService == nil || account == nil || resp == nil {
return
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
upstreamMsg = msg
}
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if c != nil {
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
c.JSON(resp.StatusCode, responsePayload)
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
if len(respBody) > 0 {
var payload map[string]any
if err := json.Unmarshal(respBody, &payload); err == nil {
if errObj, ok := payload["error"].(map[string]any); ok {
if overrideMessage != "" {
errObj["message"] = overrideMessage
}
payload["error"] = errObj
return payload
}
}
}
return map[string]any{
"error": map[string]any{
"type": "upstream_error",
"message": overrideMessage,
},
}
}
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
if resp == nil {
return nil, errors.New("empty response")
}
if clientStream {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
}
w := c.Writer
flusher, _ := w.(http.Flusher)
contentBuilder := strings.Builder{}
var firstTokenMs *int
var upstreamError error
rewriteBuffer := ""
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
sendLine := func(line string) error {
if !clientStream {
return nil
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return err
}
if flusher != nil {
flusher.Flush()
}
return nil
}
for scanner.Scan() {
line := scanner.Text()
if soraSSEDataRe.MatchString(line) {
data := soraSSEDataRe.ReplaceAllString(line, "")
if data == "[DONE]" {
if rewriteBuffer != "" {
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
if err != nil {
return nil, err
}
if flushLine != "" {
if flushContent != "" {
if _, err := contentBuilder.WriteString(flushContent); err != nil {
return nil, err
}
}
if err := sendLine(flushLine); err != nil {
return nil, err
}
}
rewriteBuffer = ""
}
if err := sendLine("data: [DONE]"); err != nil {
return nil, err
}
break
}
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
if errEvent != nil && upstreamError == nil {
upstreamError = errEvent
}
if contentDelta != "" {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
return nil, err
}
}
if err := sendLine(updatedLine); err != nil {
return nil, err
}
continue
}
if err := sendLine(line); err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
if clientStream {
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
if flusher != nil {
flusher.Flush()
}
}
return nil, err
}
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
if clientStream {
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
if flusher != nil {
flusher.Flush()
}
}
return nil, err
}
content := contentBuilder.String()
mediaType, mediaURLs := s.extractSoraMedia(content)
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
mediaType = "prompt"
}
imageSize := ""
imageCount := 0
if mediaType == "image" {
imageSize = soraImageSizeFromModel(originalModel)
imageCount = len(mediaURLs)
}
if upstreamError != nil && !clientStream {
if c != nil {
c.JSON(http.StatusBadGateway, map[string]any{
"error": map[string]any{
"type": "upstream_error",
"message": upstreamError.Error(),
},
})
}
return nil, upstreamError
}
if !clientStream {
response := buildSoraNonStreamResponse(content, originalModel)
if len(mediaURLs) > 0 {
response["media_url"] = mediaURLs[0]
if len(mediaURLs) > 1 {
response["media_urls"] = mediaURLs
}
}
c.JSON(http.StatusOK, response)
}
return &soraStreamingResult{
mediaType: mediaType,
mediaURLs: mediaURLs,
imageCount: imageCount,
imageSize: imageSize,
firstTokenMs: firstTokenMs,
}, nil
}
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
if strings.TrimSpace(data) == "" {
return "data: ", "", nil
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return "data: " + data, "", nil
}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return "data: " + data, "", errors.New(msg)
}
}
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
payload["model"] = originalModel
}
contentDelta, updated := extractSoraContent(payload)
if updated {
var rewritten string
if rewriteBuffer != nil {
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
} else {
rewritten = s.rewriteSoraContent(contentDelta)
}
if rewritten != contentDelta {
applySoraContent(payload, rewritten)
contentDelta = rewritten
}
}
updatedData, err := jsonMarshalRaw(payload)
if err != nil {
return "data: " + data, contentDelta, nil
}
return "data: " + string(updatedData), contentDelta, nil
}
func extractSoraContent(payload map[string]any) (string, bool) {
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
return "", false
}
choice, ok := choices[0].(map[string]any)
if !ok {
return "", false
}
if delta, ok := choice["delta"].(map[string]any); ok {
if content, ok := delta["content"].(string); ok {
return content, true
}
}
if message, ok := choice["message"].(map[string]any); ok {
if content, ok := message["content"].(string); ok {
return content, true
}
}
return "", false
}
func applySoraContent(payload map[string]any, content string) {
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
return
}
choice, ok := choices[0].(map[string]any)
if !ok {
return
}
if delta, ok := choice["delta"].(map[string]any); ok {
delta["content"] = content
choice["delta"] = delta
return
}
if message, ok := choice["message"].(map[string]any); ok {
message["content"] = content
choice["message"] = message
}
}
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
if buffer == nil {
return s.rewriteSoraContent(contentDelta)
}
if contentDelta == "" && *buffer == "" {
return ""
}
combined := *buffer + contentDelta
rewritten := s.rewriteSoraContent(combined)
bufferStart := s.findSoraRewriteBufferStart(rewritten)
if bufferStart < 0 {
*buffer = ""
return rewritten
}
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
bufferStart = len(rewritten) - soraRewriteBufferLimit
}
output := rewritten[:bufferStart]
*buffer = rewritten[bufferStart:]
return output
}
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
minIndex := -1
start := 0
for {
idx := strings.Index(content[start:], "![")
if idx < 0 {
break
}
idx += start
if !hasSoraImageMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + 2
}
lower := strings.ToLower(content)
start = 0
for {
idx := strings.Index(lower[start:], "<video")
if idx < 0 {
break
}
idx += start
if !hasSoraVideoMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + len("<video")
}
return minIndex
}
func hasSoraImageMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func hasSoraVideoMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
if content == "" {
return content
}
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
sub := soraImageMarkdownRe.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
rewritten := s.rewriteSoraURL(sub[1])
if rewritten == sub[1] {
return match
}
return strings.Replace(match, sub[1], rewritten, 1)
})
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
sub := soraVideoHTMLRe.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
rewritten := s.rewriteSoraURL(sub[1])
if rewritten == sub[1] {
return match
}
return strings.Replace(match, sub[1], rewritten, 1)
})
return content
}
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
if buffer == "" {
return "", "", nil
}
rewritten := s.rewriteSoraContent(buffer)
payload := map[string]any{
"choices": []any{
map[string]any{
"delta": map[string]any{
"content": rewritten,
},
"index": 0,
},
},
}
if originalModel != "" {
payload["model"] = originalModel
}
updatedData, err := jsonMarshalRaw(payload)
if err != nil {
return "", "", err
}
return "data: " + string(updatedData), rewritten, nil
}
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
path := parsed.Path
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
return raw
}
return s.buildSoraMediaURL(path, parsed.RawQuery)
}
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
if content == "" {
return "", nil
}
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
return "video", []string{match[1]}
}
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
if len(imageMatches) == 0 {
return "", nil
}
urls := make([]string, 0, len(imageMatches))
for _, match := range imageMatches {
if len(match) > 1 {
urls = append(urls, match[1])
}
}
return "image", urls
}
func isSoraPromptEnhanceModel(model string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
}
@@ -1,63 +0,0 @@
package service
import (
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type SoraGeneration struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
MediaType string `json:"media_type"` // video / image
Status string `json:"status"` // pending / generating / completed / failed / cancelled
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
FileSizeBytes int64 `json:"file_size_bytes"`
StorageType string `json:"storage_type"` // s3 / local / upstream / none
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
UpstreamTaskID string `json:"upstream_task_id"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const (
SoraGenStatusPending = "pending"
SoraGenStatusGenerating = "generating"
SoraGenStatusCompleted = "completed"
SoraGenStatusFailed = "failed"
SoraGenStatusCancelled = "cancelled"
)
// Sora 存储类型常量
const (
SoraStorageTypeS3 = "s3"
SoraStorageTypeLocal = "local"
SoraStorageTypeUpstream = "upstream"
SoraStorageTypeNone = "none"
)
// SoraGenerationListParams 查询生成记录的参数。
type SoraGenerationListParams struct {
UserID int64
Status string // 可选筛选
StorageType string // 可选筛选
MediaType string // 可选筛选
Page int
PageSize int
}
// SoraGenerationRepository 生成记录持久化接口。
type SoraGenerationRepository interface {
Create(ctx context.Context, gen *SoraGeneration) error
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
Update(ctx context.Context, gen *SoraGeneration) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
}
@@ -1,332 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
)
const soraGenerationActiveLimit = 3
type soraGenerationRepoAtomicCreator interface {
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
}
type soraGenerationRepoConditionalUpdater interface {
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type SoraGenerationService struct {
genRepo SoraGenerationRepository
s3Storage *SoraS3Storage
quotaService *SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func NewSoraGenerationService(
genRepo SoraGenerationRepository,
s3Storage *SoraS3Storage,
quotaService *SoraQuotaService,
) *SoraGenerationService {
return &SoraGenerationService{
genRepo: genRepo,
s3Storage: s3Storage,
quotaService: quotaService,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
gen := &SoraGeneration{
UserID: userID,
APIKeyID: apiKeyID,
Model: model,
Prompt: prompt,
MediaType: mediaType,
Status: SoraGenStatusPending,
StorageType: SoraStorageTypeNone,
}
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
if err := atomicCreator.CreatePendingWithLimit(
ctx,
gen,
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
soraGenerationActiveLimit,
); err != nil {
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
return nil, err
}
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
if err := s.genRepo.Create(ctx, gen); err != nil {
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
// MarkGenerating 标记为生成中。
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusGenerating
gen.UpstreamTaskID = upstreamTaskID
return s.genRepo.Update(ctx, gen)
}
// MarkCompleted 标记为已完成。
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusCompleted
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkFailed 标记为失败。
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusFailed
gen.ErrorMessage = errMsg
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkCancelled 标记为已取消。
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationNotActive
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationNotActive
}
gen.Status = SoraGenStatusCancelled
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
func (s *SoraGenerationService) UpdateStorageForCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusCompleted {
return ErrSoraGenerationStateConflict
}
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
return s.genRepo.Update(ctx, gen)
}
// GetByID 获取记录详情(含权限校验)。
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if gen.UserID != userID {
return nil, fmt.Errorf("无权访问此生成记录")
}
return gen, nil
}
// List 查询生成记录列表(分页 + 筛选)。
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
return s.genRepo.List(ctx, params)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.UserID != userID {
return fmt.Errorf("无权删除此生成记录")
}
// 清理 S3 文件
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
}
}
// 释放配额(S3/本地均释放)
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
}
}
return s.genRepo.Delete(ctx, id)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
return nil
}
if len(gen.S3ObjectKeys) == 0 {
return nil
}
urls := make([]string, len(gen.S3ObjectKeys))
var wg sync.WaitGroup
var firstErr error
var errMu sync.Mutex
for idx, key := range gen.S3ObjectKeys {
wg.Add(1)
go func(i int, objectKey string) {
defer wg.Done()
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
if err != nil {
errMu.Lock()
if firstErr == nil {
firstErr = err
}
errMu.Unlock()
return
}
urls[i] = url
}(idx, key)
}
wg.Wait()
if firstErr != nil {
return firstErr
}
gen.MediaURL = urls[0]
gen.MediaURLs = urls
return nil
}
@@ -1,881 +0,0 @@
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// ==================== Stub: SoraGenerationRepository ====================
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
type stubGenRepo struct {
gens map[int64]*SoraGeneration
nextID int64
createErr error
getErr error
updateErr error
deleteErr error
listErr error
countErr error
countValue int64
}
func newStubGenRepo() *stubGenRepo {
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
}
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
if r.createErr != nil {
return r.createErr
}
gen.ID = r.nextID
gen.CreatedAt = time.Now()
r.nextID++
r.gens[gen.ID] = gen
return nil
}
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
if r.getErr != nil {
return nil, r.getErr
}
if gen, ok := r.gens[id]; ok {
return gen, nil
}
return nil, fmt.Errorf("not found")
}
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
if r.updateErr != nil {
return r.updateErr
}
r.gens[gen.ID] = gen
return nil
}
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
if r.deleteErr != nil {
return r.deleteErr
}
delete(r.gens, id)
return nil
}
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if r.listErr != nil {
return nil, 0, r.listErr
}
var result []*SoraGeneration
for _, gen := range r.gens {
if gen.UserID != params.UserID {
continue
}
if params.Status != "" && gen.Status != params.Status {
continue
}
if params.StorageType != "" && gen.StorageType != params.StorageType {
continue
}
if params.MediaType != "" && gen.MediaType != params.MediaType {
continue
}
result = append(result, gen)
}
return result, int64(len(result)), nil
}
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
if r.countValue > 0 {
return r.countValue, nil
}
var count int64
statusSet := make(map[string]struct{})
for _, s := range statuses {
statusSet[s] = struct{}{}
}
for _, gen := range r.gens {
if gen.UserID == userID {
if _, ok := statusSet[gen.Status]; ok {
count++
}
}
}
return count, nil
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var _ UserRepository = (*stubUserRepoForQuota)(nil)
type stubUserRepoForQuota struct {
users map[int64]*User
updateErr error
}
func newStubUserRepoForQuota() *stubUserRepoForQuota {
return &stubUserRepoForQuota{users: make(map[int64]*User)}
}
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
if u, ok := r.users[id]; ok {
return u, nil
}
return nil, fmt.Errorf("user not found")
}
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
if r.updateErr != nil {
return r.updateErr
}
r.users[user.ID] = user
return nil
}
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
return nil, nil
}
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubUserRepoForQuota) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
storage := &SoraS3Storage{}
storage.cfg = &SoraS3Settings{
Enabled: true,
Bucket: "test-bucket",
CDNURL: cdnURL,
}
// 需要 non-nil client 使 getClient 命中缓存
storage.client = s3.New(s3.Options{})
return storage
}
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
func newS3StorageFailingDelete() *SoraS3Storage {
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
}
// ==================== CreatePending ====================
func TestCreatePending_Success(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
require.NoError(t, err)
require.Equal(t, int64(1), gen.ID)
require.Equal(t, int64(1), gen.UserID)
require.Equal(t, "sora2-landscape-10s", gen.Model)
require.Equal(t, "一只猫跳舞", gen.Prompt)
require.Equal(t, "video", gen.MediaType)
require.Equal(t, SoraGenStatusPending, gen.Status)
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
require.Nil(t, gen.APIKeyID)
}
func TestCreatePending_WithAPIKeyID(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
apiKeyID := int64(42)
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
require.NoError(t, err)
require.NotNil(t, gen.APIKeyID)
require.Equal(t, int64(42), *gen.APIKeyID)
}
func TestCreatePending_RepoError(t *testing.T) {
repo := newStubGenRepo()
repo.createErr = fmt.Errorf("db write error")
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
require.Error(t, err)
require.Nil(t, gen)
require.Contains(t, err.Error(), "create generation")
}
// ==================== MarkGenerating ====================
func TestMarkGenerating_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
require.NoError(t, err)
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
}
func TestMarkGenerating_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 999, "")
require.Error(t, err)
}
func TestMarkGenerating_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 1, "")
require.Error(t, err)
}
// ==================== MarkCompleted ====================
func TestMarkCompleted_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 1,
"https://cdn.example.com/video.mp4",
[]string{"https://cdn.example.com/video.mp4"},
SoraStorageTypeS3,
[]string{"sora/1/2024/01/01/uuid.mp4"},
1048576,
)
require.NoError(t, err)
gen := repo.gens[1]
require.Equal(t, SoraGenStatusCompleted, gen.Status)
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
require.Equal(t, int64(1048576), gen.FileSizeBytes)
require.NotNil(t, gen.CompletedAt)
}
func TestMarkCompleted_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
require.Error(t, err)
}
func TestMarkCompleted_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
require.Error(t, err)
}
// ==================== MarkFailed ====================
func TestMarkFailed_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
require.NoError(t, err)
gen := repo.gens[1]
require.Equal(t, SoraGenStatusFailed, gen.Status)
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
require.NotNil(t, gen.CompletedAt)
}
func TestMarkFailed_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 999, "error")
require.Error(t, err)
}
func TestMarkFailed_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 1, "err")
require.Error(t, err)
}
// ==================== MarkCancelled ====================
func TestMarkCancelled_Pending(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
require.NotNil(t, repo.gens[1].CompletedAt)
}
func TestMarkCancelled_Generating(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
}
func TestMarkCancelled_Completed(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
}
func TestMarkCancelled_Failed(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
func TestMarkCancelled_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 999)
require.Error(t, err)
}
func TestMarkCancelled_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
// ==================== GetByID ====================
func TestGetByID_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1), gen.ID)
require.Equal(t, "sora2-landscape-10s", gen.Model)
}
func TestGetByID_WrongUser(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 1, 1)
require.Error(t, err)
require.Nil(t, gen)
require.Contains(t, err.Error(), "无权访问")
}
func TestGetByID_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 999, 1)
require.Error(t, err)
require.Nil(t, gen)
}
// ==================== List ====================
func TestList_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
svc := NewSoraGenerationService(repo, nil, nil)
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, gens, 2) // 只有 userID=1 的
require.Equal(t, int64(2), total)
}
func TestList_DefaultPagination(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
require.NoError(t, err)
}
func TestList_MaxPageSize(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// pageSize > 100 → 应限制为 100
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
require.NoError(t, err)
}
func TestList_Error(t *testing.T) {
repo := newStubGenRepo()
repo.listErr = fmt.Errorf("db error")
svc := NewSoraGenerationService(repo, nil, nil)
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
require.Error(t, err)
}
// ==================== Delete ====================
func TestDelete_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDelete_WrongUser(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "无权删除")
}
func TestDelete_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 999, 1)
require.Error(t, err)
}
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // s3Storage 为 nil,跳过清理
}
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // quotaService 为 nil,跳过释放
}
func TestDelete_NonS3NoCleanup(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
}
func TestDelete_DeleteRepoError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
repo.deleteErr = fmt.Errorf("delete failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.Error(t, err)
}
// ==================== CountActiveByUser ====================
func TestCountActiveByUser_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
svc := NewSoraGenerationService(repo, nil, nil)
count, err := svc.CountActiveByUser(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(2), count)
}
func TestCountActiveByUser_NoActive(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
count, err := svc.CountActiveByUser(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(0), count)
}
func TestCountActiveByUser_Error(t *testing.T) {
repo := newStubGenRepo()
repo.countErr = fmt.Errorf("db error")
svc := NewSoraGenerationService(repo, nil, nil)
_, err := svc.CountActiveByUser(context.Background(), 1)
require.Error(t, err)
}
// ==================== ResolveMediaURLs ====================
func TestResolveMediaURLs_NilGen(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
}
func TestResolveMediaURLs_NonS3(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
}
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
}
func TestResolveMediaURLs_Local(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
}
// ==================== 状态流转完整测试 ====================
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// 1. 创建 pending
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
require.NoError(t, err)
require.Equal(t, SoraGenStatusPending, gen.Status)
// 2. 标记 generating
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
require.NoError(t, err)
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
// 3. 标记 completed
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
}
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
require.NoError(t, err)
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
}
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
err := svc.MarkCancelled(context.Background(), gen.ID)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
}
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
err := svc.MarkCancelled(context.Background(), gen.ID)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
}
// ==================== 权限隔离测试 ====================
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
// 用户 2 尝试访问用户 1 的记录
_, err := svc.GetByID(context.Background(), gen.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "无权访问")
}
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
err := svc.Delete(context.Background(), gen.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "无权删除")
}
// ==================== Delete: S3 清理 + 配额释放路径 ====================
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
// 验证 Delete 仍然成功(S3 错误只是记录日志)
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
}
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(repo, s3Storage, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // S3 清理失败不影响删除
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
// 有配额服务时,删除 S3 类型记录会释放配额
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
FileSizeBytes: 1048576, // 1MB
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
// 配额应被释放: 2MB - 1MB = 1MB
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
// S3 清理 + 配额释放同时触发
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"key1"},
FileSizeBytes: 512,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
_, exists := repo.gens[1]
require.False(t, exists)
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
// 本地存储同样需要释放配额
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeLocal,
FileSizeBytes: 1024,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
// FileSizeBytes=0 跳过配额释放
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
FileSizeBytes: 0,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
}
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
}
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{
"sora/1/2024/01/01/img1.png",
"sora/1/2024/01/01/img2.png",
"sora/1/2024/01/01/img3.png",
},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
// 主 URL 更新为第一个 key 的 CDN URL
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
// 多图 URLs 全部更新
require.Len(t, gen.MediaURLs, 3)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
}
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
require.Equal(t, "original", gen.MediaURL) // 不变
}
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
// 使用无 settingService 的 S3 StoragegetClient 会失败
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.Error(t, err) // GetAccessURL 失败应传播错误
}
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{
"sora/1/2024/01/01/img1.png",
"sora/1/2024/01/01/img2.png",
},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
}
@@ -1,120 +0,0 @@
package service
import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/robfig/cron/v3"
)
var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
// SoraMediaCleanupService 定期清理本地媒体文件
type SoraMediaCleanupService struct {
storage *SoraMediaStorage
cfg *config.Config
cron *cron.Cron
startOnce sync.Once
stopOnce sync.Once
}
func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
return &SoraMediaCleanupService{
storage: storage,
cfg: cfg,
}
}
func (s *SoraMediaCleanupService) Start() {
if s == nil || s.cfg == nil {
return
}
if !s.cfg.Sora.Storage.Cleanup.Enabled {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (disabled)")
return
}
if s.storage == nil || !s.storage.Enabled() {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (storage disabled)")
return
}
s.startOnce.Do(func() {
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
if schedule == "" {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (empty schedule)")
return
}
loc := time.Local
if strings.TrimSpace(s.cfg.Timezone) != "" {
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
loc = parsed
}
}
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
return
}
s.cron = c
s.cron.Start()
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
})
}
func (s *SoraMediaCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.cron != nil {
ctx := s.cron.Stop()
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cron stop timed out")
}
}
})
}
func (s *SoraMediaCleanupService) runCleanup() {
if s.cfg == nil || s.storage == nil {
return
}
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
if retention <= 0 {
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] skipped (retention_days=%d)", retention)
return
}
cutoff := time.Now().AddDate(0, 0, -retention)
deleted := 0
roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()}
for _, root := range roots {
if root == "" {
continue
}
_ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
return nil
}
if info.ModTime().Before(cutoff) {
if rmErr := os.Remove(p); rmErr == nil {
deleted++
}
}
return nil
})
}
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cleanup finished, deleted=%d", deleted)
}
@@ -1,207 +0,0 @@
//go:build unit
package service
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraMediaCleanupService_RunCleanup_NilCfg(t *testing.T) {
storage := &SoraMediaStorage{}
svc := &SoraMediaCleanupService{storage: storage, cfg: nil}
// 不应 panic
svc.runCleanup()
}
func TestSoraMediaCleanupService_RunCleanup_NilStorage(t *testing.T) {
cfg := &config.Config{}
svc := &SoraMediaCleanupService{storage: nil, cfg: cfg}
// 不应 panic
svc.runCleanup()
}
func TestSoraMediaCleanupService_RunCleanup_ZeroRetention(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
RetentionDays: 0,
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
// retention=0 应跳过清理
svc.runCleanup()
}
func TestSoraMediaCleanupService_Start_NilCfg(t *testing.T) {
svc := NewSoraMediaCleanupService(nil, nil)
svc.Start() // cfg == nil 时应直接返回
}
func TestSoraMediaCleanupService_Start_StorageDisabled(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
},
},
},
}
svc := NewSoraMediaCleanupService(nil, cfg)
svc.Start() // storage == nil 时应直接返回
}
func TestSoraMediaCleanupService_Start_WithTimezone(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Timezone: "Asia/Shanghai",
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "0 3 * * *",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
t.Cleanup(svc.Stop)
}
func TestSoraMediaCleanupService_Start_Disabled(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Cleanup: config.SoraStorageCleanupConfig{
Enabled: false,
},
},
},
}
svc := NewSoraMediaCleanupService(nil, cfg)
svc.Start() // 不应 panic,也不应启动 cron
}
func TestSoraMediaCleanupService_Start_NilSelf(t *testing.T) {
var svc *SoraMediaCleanupService
svc.Start() // 不应 panic
}
func TestSoraMediaCleanupService_Start_EmptySchedule(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start() // 空 schedule 不应启动
}
func TestSoraMediaCleanupService_Start_InvalidSchedule(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "invalid-cron",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start() // 无效 schedule 不应 panic
}
func TestSoraMediaCleanupService_Start_ValidSchedule(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
Schedule: "0 3 * * *",
},
},
},
}
storage := NewSoraMediaStorage(cfg)
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
t.Cleanup(svc.Stop)
}
func TestSoraMediaCleanupService_Stop_NilSelf(t *testing.T) {
var svc *SoraMediaCleanupService
svc.Stop() // 不应 panic
}
func TestSoraMediaCleanupService_Stop_WithoutStart(t *testing.T) {
svc := NewSoraMediaCleanupService(nil, &config.Config{})
svc.Stop() // cron 未启动时 Stop 不应 panic
}
func TestSoraMediaCleanupService_RunCleanup(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
RetentionDays: 1,
},
},
},
}
storage := NewSoraMediaStorage(cfg)
require.NoError(t, storage.EnsureLocalDirs())
oldImage := filepath.Join(storage.ImageRoot(), "old.png")
newVideo := filepath.Join(storage.VideoRoot(), "new.mp4")
require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644))
require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644))
oldTime := time.Now().Add(-48 * time.Hour)
require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime))
cleanup := NewSoraMediaCleanupService(storage, cfg)
cleanup.runCleanup()
require.NoFileExists(t, oldImage)
require.FileExists(t, newVideo)
}
@@ -1,48 +0,0 @@
package service
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strconv"
"strings"
)
// SignSoraMediaURL 生成 Sora 媒体临时签名
func SignSoraMediaURL(path string, query string, expires int64, key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
mac := hmac.New(sha256.New, []byte(key))
if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil {
return ""
}
if _, err := mac.Write([]byte("|")); err != nil {
return ""
}
if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil {
return ""
}
return hex.EncodeToString(mac.Sum(nil))
}
// VerifySoraMediaURL 校验 Sora 媒体签名
func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool {
signature = strings.TrimSpace(signature)
if signature == "" {
return false
}
expected := SignSoraMediaURL(path, query, expires, key)
if expected == "" {
return false
}
return hmac.Equal([]byte(signature), []byte(expected))
}
func buildSoraMediaSignPayload(path string, query string) string {
if strings.TrimSpace(query) == "" {
return path
}
return path + "?" + query
}
@@ -1,34 +0,0 @@
package service
import "testing"
func TestSoraMediaSignVerify(t *testing.T) {
key := "test-key"
path := "/tmp/abc.png"
query := "a=1&b=2"
expires := int64(1700000000)
signature := SignSoraMediaURL(path, query, expires, key)
if signature == "" {
t.Fatal("签名为空")
}
if !VerifySoraMediaURL(path, query, expires, signature, key) {
t.Fatal("签名校验失败")
}
if VerifySoraMediaURL(path, "a=1", expires, signature, key) {
t.Fatal("签名参数不同仍然通过")
}
if VerifySoraMediaURL(path, query, expires+1, signature, key) {
t.Fatal("签名过期校验未失败")
}
}
func TestSoraMediaSignWithEmptyKey(t *testing.T) {
signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "")
if signature != "" {
t.Fatalf("空密钥不应生成签名")
}
if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") {
t.Fatalf("空密钥不应通过校验")
}
}
@@ -1,381 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
)
const (
soraStorageDefaultRoot = "/app/data/sora"
)
// SoraMediaStorage 负责下载并落地 Sora 媒体
type SoraMediaStorage struct {
cfg *config.Config
root string
imageRoot string
videoRoot string
downloadTimeout time.Duration
maxDownloadBytes int64
fallbackToUpstream bool
debug bool
sem chan struct{}
ready bool
}
func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
storage := &SoraMediaStorage{cfg: cfg}
storage.refreshConfig()
if storage.Enabled() {
if err := storage.EnsureLocalDirs(); err != nil {
log.Printf("[SoraStorage] 初始化失败: %v", err)
}
}
return storage
}
func (s *SoraMediaStorage) Enabled() bool {
if s == nil || s.cfg == nil {
return false
}
return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local"
}
func (s *SoraMediaStorage) Root() string {
if s == nil {
return ""
}
return s.root
}
func (s *SoraMediaStorage) ImageRoot() string {
if s == nil {
return ""
}
return s.imageRoot
}
func (s *SoraMediaStorage) VideoRoot() string {
if s == nil {
return ""
}
return s.videoRoot
}
func (s *SoraMediaStorage) refreshConfig() {
if s == nil || s.cfg == nil {
return
}
root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath)
if root == "" {
root = soraStorageDefaultRoot
}
root = filepath.Clean(root)
if !filepath.IsAbs(root) {
if absRoot, err := filepath.Abs(root); err == nil {
root = absRoot
}
}
s.root = root
s.imageRoot = filepath.Join(root, "image")
s.videoRoot = filepath.Join(root, "video")
maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads
if maxConcurrent <= 0 {
maxConcurrent = 4
}
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
if timeoutSeconds <= 0 {
timeoutSeconds = 120
}
s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second
maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes
if maxBytes <= 0 {
maxBytes = 200 << 20
}
s.maxDownloadBytes = maxBytes
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
s.debug = s.cfg.Sora.Storage.Debug
s.sem = make(chan struct{}, maxConcurrent)
}
// EnsureLocalDirs 创建并校验本地目录
func (s *SoraMediaStorage) EnsureLocalDirs() error {
if s == nil || !s.Enabled() {
return nil
}
if err := os.MkdirAll(s.imageRoot, 0o755); err != nil {
return fmt.Errorf("create image dir: %w", err)
}
if err := os.MkdirAll(s.videoRoot, 0o755); err != nil {
return fmt.Errorf("create video dir: %w", err)
}
s.ready = true
return nil
}
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) {
if len(urls) == 0 {
return nil, nil
}
if s == nil || !s.Enabled() {
return urls, nil
}
if !s.ready {
if err := s.EnsureLocalDirs(); err != nil {
return nil, err
}
}
results := make([]string, 0, len(urls))
for _, raw := range urls {
relative, err := s.downloadAndStore(ctx, mediaType, raw)
if err != nil {
if s.fallbackToUpstream {
results = append(results, raw)
continue
}
return nil, err
}
results = append(results, relative)
}
return results, nil
}
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) {
if s == nil || len(paths) == 0 {
return 0, nil
}
var total int64
for _, p := range paths {
localPath, err := s.resolveLocalPath(p)
if err != nil {
continue
}
info, err := os.Stat(localPath)
if err != nil {
if os.IsNotExist(err) {
continue
}
return 0, err
}
if info.Mode().IsRegular() {
total += info.Size()
}
}
return total, nil
}
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error {
if s == nil || len(paths) == 0 {
return nil
}
var lastErr error
for _, p := range paths {
localPath, err := s.resolveLocalPath(p)
if err != nil {
continue
}
if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) {
lastErr = err
}
}
return lastErr
}
func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) {
if s == nil || strings.TrimSpace(relativePath) == "" {
return "", errors.New("empty path")
}
cleaned := path.Clean(relativePath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
return "", errors.New("not a local media path")
}
if strings.TrimSpace(s.root) == "" {
return "", errors.New("storage root not configured")
}
relative := strings.TrimPrefix(cleaned, "/")
return filepath.Join(s.root, filepath.FromSlash(relative)), nil
}
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
if strings.TrimSpace(rawURL) == "" {
return "", errors.New("empty url")
}
root := s.imageRoot
if mediaType == "video" {
root = s.videoRoot
}
if root == "" {
return "", errors.New("storage root not configured")
}
retries := 3
for attempt := 1; attempt <= retries; attempt++ {
release, err := s.acquire(ctx)
if err != nil {
return "", err
}
relative, err := s.downloadOnce(ctx, root, mediaType, rawURL)
release()
if err == nil {
return relative, nil
}
if s.debug {
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err)
}
if attempt < retries {
time.Sleep(time.Duration(attempt*attempt) * time.Second)
continue
}
return "", err
}
return "", errors.New("download retries exhausted")
}
func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", err
}
client := &http.Client{Timeout: s.downloadTimeout}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body))
}
ext := normalizeSoraFileExt(fileExtFromURL(rawURL))
if ext == "" {
ext = normalizeSoraFileExt(fileExtFromContentType(resp.Header.Get("Content-Type")))
}
if ext == "" {
ext = ".bin"
}
if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes {
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
}
storageRoot, err := os.OpenRoot(root)
if err != nil {
return "", err
}
defer func() { _ = storageRoot.Close() }()
datePath := time.Now().Format("2006/01/02")
datePathFS := filepath.FromSlash(datePath)
if err := storageRoot.MkdirAll(datePathFS, 0o755); err != nil {
return "", err
}
filename := uuid.NewString() + ext
filePath := filepath.Join(datePathFS, filename)
out, err := storageRoot.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return "", err
}
defer func() { _ = out.Close() }()
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
written, err := io.Copy(out, limited)
if err != nil {
removePartialDownload(storageRoot, filePath)
return "", err
}
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
removePartialDownload(storageRoot, filePath)
return "", fmt.Errorf("download size exceeds limit: %d", written)
}
relative := path.Join("/", mediaType, datePath, filename)
if s.debug {
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative)
}
return relative, nil
}
func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) {
if s.sem == nil {
return func() {}, nil
}
select {
case s.sem <- struct{}{}:
return func() { <-s.sem }, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func fileExtFromURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
return ""
}
ext := path.Ext(parsed.Path)
return strings.ToLower(ext)
}
func fileExtFromContentType(ct string) string {
if ct == "" {
return ""
}
if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 {
return strings.ToLower(exts[0])
}
return ""
}
func normalizeSoraFileExt(ext string) string {
ext = strings.ToLower(strings.TrimSpace(ext))
switch ext {
case ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg", ".tif", ".tiff", ".heic",
".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
return ext
default:
return ""
}
}
func removePartialDownload(root *os.Root, filePath string) {
if root == nil || strings.TrimSpace(filePath) == "" {
return
}
_ = root.Remove(filePath)
}
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
func sanitizeMediaLogURL(rawURL string) string {
parsed, err := url.Parse(rawURL)
if err != nil {
if len(rawURL) > 80 {
return rawURL[:80] + "..."
}
return rawURL
}
safe := parsed.Scheme + "://" + parsed.Host + parsed.Path
if len(safe) > 120 {
return safe[:120] + "..."
}
return safe
}
@@ -1,119 +0,0 @@
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraMediaStorage_StoreFromURLs(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("data"))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
MaxConcurrentDownloads: 1,
},
},
}
storage := NewSoraMediaStorage(cfg)
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.NoError(t, err)
require.Len(t, urls, 1)
require.True(t, strings.HasPrefix(urls[0], "/image/"))
require.True(t, strings.HasSuffix(urls[0], ".png"))
localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/")))
require.FileExists(t, localPath)
}
func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
FallbackToUpstream: true,
},
},
}
storage := NewSoraMediaStorage(cfg)
url := server.URL + "/broken.png"
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url})
require.NoError(t, err)
require.Equal(t, []string{url}, urls)
}
func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("too-large"))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
MaxDownloadBytes: 1,
},
},
}
storage := NewSoraMediaStorage(cfg)
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.Error(t, err)
}
func TestNormalizeSoraFileExt(t *testing.T) {
require.Equal(t, ".png", normalizeSoraFileExt(".PNG"))
require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4"))
require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd"))
require.Equal(t, "", normalizeSoraFileExt(".php"))
}
func TestRemovePartialDownload(t *testing.T) {
tmpDir := t.TempDir()
root, err := os.OpenRoot(tmpDir)
require.NoError(t, err)
defer func() { _ = root.Close() }()
filePath := "partial.bin"
f, err := root.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
require.NoError(t, err)
_, _ = f.WriteString("partial")
_ = f.Close()
removePartialDownload(root, filePath)
_, err = root.Stat(filePath)
require.Error(t, err)
require.True(t, os.IsNotExist(err))
}
-488
View File
@@ -1,488 +0,0 @@
package service
import (
"regexp"
"sort"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
// SoraModelConfig Sora 模型配置
type SoraModelConfig struct {
Type string
Width int
Height int
Orientation string
Frames int
Model string
Size string
RequirePro bool
// Prompt-enhance 专用参数
ExpansionLevel string
DurationS int
}
var soraModelConfigs = map[string]SoraModelConfig{
"gpt-image": {
Type: "image",
Width: 360,
Height: 360,
},
"gpt-image-landscape": {
Type: "image",
Width: 540,
Height: 360,
},
"gpt-image-portrait": {
Type: "image",
Width: 360,
Height: 540,
},
"sora2-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_8",
Size: "small",
},
"sora2-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_8",
Size: "small",
},
"sora2-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_8",
Size: "small",
},
"sora2-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_8",
Size: "small",
},
"sora2-landscape-25s": {
Type: "video",
Orientation: "landscape",
Frames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2-portrait-25s": {
Type: "video",
Orientation: "portrait",
Frames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-25s": {
Type: "video",
Orientation: "landscape",
Frames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-25s": {
Type: "video",
Orientation: "portrait",
Frames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-hd-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"prompt-enhance-short-10s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 10,
},
"prompt-enhance-short-15s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 15,
},
"prompt-enhance-short-20s": {
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 20,
},
"prompt-enhance-medium-10s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 10,
},
"prompt-enhance-medium-15s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 15,
},
"prompt-enhance-medium-20s": {
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 20,
},
"prompt-enhance-long-10s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 10,
},
"prompt-enhance-long-15s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 15,
},
"prompt-enhance-long-20s": {
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 20,
},
}
var soraModelIDs = []string{
"gpt-image",
"gpt-image-landscape",
"gpt-image-portrait",
"sora2-landscape-10s",
"sora2-portrait-10s",
"sora2-landscape-15s",
"sora2-portrait-15s",
"sora2-landscape-25s",
"sora2-portrait-25s",
"sora2pro-landscape-10s",
"sora2pro-portrait-10s",
"sora2pro-landscape-15s",
"sora2pro-portrait-15s",
"sora2pro-landscape-25s",
"sora2pro-portrait-25s",
"sora2pro-hd-landscape-10s",
"sora2pro-hd-portrait-10s",
"sora2pro-hd-landscape-15s",
"sora2pro-hd-portrait-15s",
"prompt-enhance-short-10s",
"prompt-enhance-short-15s",
"prompt-enhance-short-20s",
"prompt-enhance-medium-10s",
"prompt-enhance-medium-15s",
"prompt-enhance-medium-20s",
"prompt-enhance-long-10s",
"prompt-enhance-long-15s",
"prompt-enhance-long-20s",
}
// GetSoraModelConfig 返回 Sora 模型配置
func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
key := strings.ToLower(strings.TrimSpace(model))
cfg, ok := soraModelConfigs[key]
return cfg, ok
}
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
type SoraModelFamily struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Orientations []string `json:"orientations"`
Durations []int `json:"durations,omitempty"`
}
var (
videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`)
imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`)
soraFamilyNames = map[string]string{
"sora2": "Sora 2",
"sora2pro": "Sora 2 Pro",
"sora2pro-hd": "Sora 2 Pro HD",
"gpt-image": "GPT Image",
}
)
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
func BuildSoraModelFamilies() []SoraModelFamily {
type familyData struct {
modelType string
orientations map[string]bool
durations map[int]bool
}
families := make(map[string]*familyData)
for id, cfg := range soraModelConfigs {
if cfg.Type == "prompt_enhance" {
continue
}
var famID, orientation string
var duration int
switch cfg.Type {
case "video":
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
famID = id[:len(id)-len(m[0])]
orientation = m[1]
duration, _ = strconv.Atoi(m[2])
}
case "image":
if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
famID = id[:len(id)-len(m[0])]
orientation = m[1]
} else {
famID = id
orientation = "square"
}
}
if famID == "" {
continue
}
fd, ok := families[famID]
if !ok {
fd = &familyData{
modelType: cfg.Type,
orientations: make(map[string]bool),
durations: make(map[int]bool),
}
families[famID] = fd
}
if orientation != "" {
fd.orientations[orientation] = true
}
if duration > 0 {
fd.durations[duration] = true
}
}
// 排序:视频在前、图像在后,同类按名称排序
famIDs := make([]string, 0, len(families))
for id := range families {
famIDs = append(famIDs, id)
}
sort.Slice(famIDs, func(i, j int) bool {
fi, fj := families[famIDs[i]], families[famIDs[j]]
if fi.modelType != fj.modelType {
return fi.modelType == "video"
}
return famIDs[i] < famIDs[j]
})
result := make([]SoraModelFamily, 0, len(famIDs))
for _, famID := range famIDs {
fd := families[famID]
fam := SoraModelFamily{
ID: famID,
Name: soraFamilyNames[famID],
Type: fd.modelType,
}
if fam.Name == "" {
fam.Name = famID
}
for o := range fd.orientations {
fam.Orientations = append(fam.Orientations, o)
}
sort.Strings(fam.Orientations)
for d := range fd.durations {
fam.Durations = append(fam.Durations, d)
}
sort.Ints(fam.Durations)
result = append(result, fam)
}
return result
}
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
// 通过命名约定自动识别视频/图像模型并分组。
func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily {
type familyData struct {
modelType string
orientations map[string]bool
durations map[int]bool
}
families := make(map[string]*familyData)
for _, id := range modelIDs {
id = strings.ToLower(strings.TrimSpace(id))
if id == "" || strings.HasPrefix(id, "prompt-enhance") {
continue
}
var famID, orientation, modelType string
var duration int
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
// 视频模型: {family}-{orientation}-{duration}s
famID = id[:len(id)-len(m[0])]
orientation = m[1]
duration, _ = strconv.Atoi(m[2])
modelType = "video"
} else if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
// 图像模型(带方向): {family}-{orientation}
famID = id[:len(id)-len(m[0])]
orientation = m[1]
modelType = "image"
} else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" {
// 已知的无后缀图像模型(如 gpt-image)
famID = id
orientation = "square"
modelType = "image"
} else if strings.Contains(id, "image") {
// 未知但名称包含 image 的模型,推断为图像模型
famID = id
orientation = "square"
modelType = "image"
} else {
continue
}
if famID == "" {
continue
}
fd, ok := families[famID]
if !ok {
fd = &familyData{
modelType: modelType,
orientations: make(map[string]bool),
durations: make(map[int]bool),
}
families[famID] = fd
}
if orientation != "" {
fd.orientations[orientation] = true
}
if duration > 0 {
fd.durations[duration] = true
}
}
famIDs := make([]string, 0, len(families))
for id := range families {
famIDs = append(famIDs, id)
}
sort.Slice(famIDs, func(i, j int) bool {
fi, fj := families[famIDs[i]], families[famIDs[j]]
if fi.modelType != fj.modelType {
return fi.modelType == "video"
}
return famIDs[i] < famIDs[j]
})
result := make([]SoraModelFamily, 0, len(famIDs))
for _, famID := range famIDs {
fd := families[famID]
fam := SoraModelFamily{
ID: famID,
Name: soraFamilyNames[famID],
Type: fd.modelType,
}
if fam.Name == "" {
fam.Name = famID
}
for o := range fd.orientations {
fam.Orientations = append(fam.Orientations, o)
}
sort.Strings(fam.Orientations)
for d := range fd.durations {
fam.Durations = append(fam.Durations, d)
}
sort.Ints(fam.Durations)
result = append(result, fam)
}
return result
}
// DefaultSoraModels returns the default Sora model list.
func DefaultSoraModels(cfg *config.Config) []openai.Model {
models := make([]openai.Model, 0, len(soraModelIDs))
for _, id := range soraModelIDs {
models = append(models, openai.Model{
ID: id,
Object: "model",
OwnedBy: "openai",
Type: "model",
DisplayName: id,
})
}
if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance {
filtered := models[:0]
for _, model := range models {
if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") {
continue
}
filtered = append(filtered, model)
}
models = filtered
}
return models
}
@@ -1,257 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraQuotaService 管理 Sora 用户存储配额。
// 配额优先级:用户级 → 分组级 → 系统默认值。
type SoraQuotaService struct {
userRepo UserRepository
groupRepo GroupRepository
settingService *SettingService
}
// NewSoraQuotaService 创建配额服务实例。
func NewSoraQuotaService(
userRepo UserRepository,
groupRepo GroupRepository,
settingService *SettingService,
) *SoraQuotaService {
return &SoraQuotaService{
userRepo: userRepo,
groupRepo: groupRepo,
settingService: settingService,
}
}
// QuotaInfo 返回给客户端的配额信息。
type QuotaInfo struct {
QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制)
UsedBytes int64 `json:"used_bytes"` // 已使用
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0
QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited
Source string `json:"source,omitempty"` // 兼容旧字段
}
// ErrSoraStorageQuotaExceeded 表示配额不足。
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
// QuotaExceededError 包含配额不足的上下文信息。
type QuotaExceededError struct {
QuotaBytes int64
UsedBytes int64
}
func (e *QuotaExceededError) Error() string {
if e == nil {
return "存储配额不足"
}
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
}
type soraQuotaAtomicUserRepository interface {
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
}
// GetQuota 获取用户的存储配额信息。
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
info := &QuotaInfo{
UsedBytes: user.SoraStorageUsedBytes,
}
// 1. 用户级配额
if user.SoraStorageQuotaBytes > 0 {
info.QuotaBytes = user.SoraStorageQuotaBytes
info.QuotaSource = "user"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 2. 分组级配额(取用户可用分组中最大的配额)
if len(user.AllowedGroups) > 0 {
var maxGroupQuota int64
for _, gid := range user.AllowedGroups {
group, err := s.groupRepo.GetByID(ctx, gid)
if err != nil {
continue
}
if group.SoraStorageQuotaBytes > maxGroupQuota {
maxGroupQuota = group.SoraStorageQuotaBytes
}
}
if maxGroupQuota > 0 {
info.QuotaBytes = maxGroupQuota
info.QuotaSource = "group"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
}
// 3. 系统默认值
defaultQuota := s.getSystemDefaultQuota(ctx)
if defaultQuota > 0 {
info.QuotaBytes = defaultQuota
info.QuotaSource = "system"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 无配额限制
info.QuotaSource = "unlimited"
info.Source = info.QuotaSource
info.AvailableBytes = 0
return info, nil
}
// CheckQuota 检查用户是否有足够的存储配额。
// 返回 nil 表示配额充足或无限制。
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
// 0 表示无限制
if quota.QuotaBytes == 0 {
return nil
}
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return nil
}
// AddUsage 原子累加用量(上传成功后调用)。
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
if err != nil {
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return fmt.Errorf("update user quota usage (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota update: %w", err)
}
user.SoraStorageUsedBytes += bytes
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota usage: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
// ReleaseUsage 释放用量(删除文件后调用)。
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
if err != nil {
return fmt.Errorf("update user quota release (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota release: %w", err)
}
user.SoraStorageUsedBytes -= bytes
if user.SoraStorageUsedBytes < 0 {
user.SoraStorageUsedBytes = 0
}
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota release: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
if quotaBytes <= 0 {
return 0
}
if usedBytes >= quotaBytes {
return 0
}
return quotaBytes - usedBytes
}
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
if s.settingService == nil {
return 0
}
settings, err := s.settingService.GetSoraS3Settings(ctx)
if err != nil {
return 0
}
return settings.DefaultStorageQuotaBytes
}
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
return s.getSystemDefaultQuota(ctx)
}
// SetUserQuota 设置用户级配额(管理员操作)。
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
user, err := userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
user.SoraStorageQuotaBytes = quotaBytes
return userRepo.Update(ctx, user)
}
// ParseQuotaBytes 解析配额字符串为字节数。
func ParseQuotaBytes(s string) int64 {
v, _ := strconv.ParseInt(s, 10, 64)
return v
}
@@ -1,492 +0,0 @@
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ==================== Stub: GroupRepository (用于 SoraQuotaService) ====================
var _ GroupRepository = (*stubGroupRepoForQuota)(nil)
type stubGroupRepoForQuota struct {
groups map[int64]*Group
}
func newStubGroupRepoForQuota() *stubGroupRepoForQuota {
return &stubGroupRepoForQuota{groups: make(map[int64]*Group)}
}
func (r *stubGroupRepoForQuota) GetByID(_ context.Context, id int64) (*Group, error) {
if g, ok := r.groups[id]; ok {
return g, nil
}
return nil, fmt.Errorf("group not found")
}
func (r *stubGroupRepoForQuota) Create(context.Context, *Group) error { return nil }
func (r *stubGroupRepoForQuota) GetByIDLite(_ context.Context, id int64) (*Group, error) {
return r.GetByID(context.Background(), id)
}
func (r *stubGroupRepoForQuota) Update(context.Context, *Group) error { return nil }
func (r *stubGroupRepoForQuota) Delete(context.Context, int64) error { return nil }
func (r *stubGroupRepoForQuota) DeleteCascade(context.Context, int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepoForQuota) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepoForQuota) ListActive(context.Context) ([]Group, error) { return nil, nil }
func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([]Group, error) {
return nil, nil
}
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepoForQuota) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepoForQuota) BindAccountsToGroup(context.Context, int64, []int64) error {
return nil
}
func (r *stubGroupRepoForQuota) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
return nil
}
// ==================== Stub: SettingRepository (用于 SettingService) ====================
var _ SettingRepository = (*stubSettingRepoForQuota)(nil)
type stubSettingRepoForQuota struct {
values map[string]string
}
func newStubSettingRepoForQuota(values map[string]string) *stubSettingRepoForQuota {
if values == nil {
values = make(map[string]string)
}
return &stubSettingRepoForQuota{values: values}
}
func (r *stubSettingRepoForQuota) Get(_ context.Context, key string) (*Setting, error) {
if v, ok := r.values[key]; ok {
return &Setting{Key: key, Value: v}, nil
}
return nil, ErrSettingNotFound
}
func (r *stubSettingRepoForQuota) GetValue(_ context.Context, key string) (string, error) {
if v, ok := r.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (r *stubSettingRepoForQuota) Set(_ context.Context, key, value string) error {
r.values[key] = value
return nil
}
func (r *stubSettingRepoForQuota) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string)
for _, k := range keys {
if v, ok := r.values[k]; ok {
result[k] = v
}
}
return result, nil
}
func (r *stubSettingRepoForQuota) SetMultiple(_ context.Context, settings map[string]string) error {
for k, v := range settings {
r.values[k] = v
}
return nil
}
func (r *stubSettingRepoForQuota) GetAll(_ context.Context) (map[string]string, error) {
return r.values, nil
}
func (r *stubSettingRepoForQuota) Delete(_ context.Context, key string) error {
delete(r.values, key)
return nil
}
// ==================== GetQuota ====================
func TestGetQuota_UserLevel(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024, // 10MB
SoraStorageUsedBytes: 3 * 1024 * 1024, // 3MB
}
svc := NewSoraQuotaService(userRepo, nil, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(10*1024*1024), quota.QuotaBytes)
require.Equal(t, int64(3*1024*1024), quota.UsedBytes)
require.Equal(t, "user", quota.Source)
}
func TestGetQuota_GroupLevel(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0, // 用户级无配额
SoraStorageUsedBytes: 1024,
AllowedGroups: []int64{10, 20},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 5 * 1024 * 1024}
groupRepo.groups[20] = &Group{ID: 20, SoraStorageQuotaBytes: 20 * 1024 * 1024}
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) // 取最大值
require.Equal(t, "group", quota.Source)
}
func TestGetQuota_SystemLevel(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 512}
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(userRepo, nil, settingService)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(104857600), quota.QuotaBytes)
require.Equal(t, "system", quota.Source)
}
func TestGetQuota_NoLimit(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 0}
svc := NewSoraQuotaService(userRepo, nil, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(0), quota.QuotaBytes)
require.Equal(t, "unlimited", quota.Source)
}
func TestGetQuota_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
svc := NewSoraQuotaService(userRepo, nil, nil)
_, err := svc.GetQuota(context.Background(), 999)
require.Error(t, err)
require.Contains(t, err.Error(), "get user")
}
func TestGetQuota_GroupRepoError(t *testing.T) {
// 分组获取失败时跳过该分组(不影响整体)
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1, SoraStorageQuotaBytes: 0,
AllowedGroups: []int64{999}, // 不存在的分组
}
groupRepo := newStubGroupRepoForQuota()
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "unlimited", quota.Source) // 分组获取失败,回退到无限制
}
// ==================== CheckQuota ====================
func TestCheckQuota_Sufficient(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 3 * 1024 * 1024,
}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.CheckQuota(context.Background(), 1, 1024)
require.NoError(t, err)
}
func TestCheckQuota_Exceeded(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 10 * 1024 * 1024, // 已满
}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.CheckQuota(context.Background(), 1, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "配额不足")
}
func TestCheckQuota_NoLimit(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0, // 无限制
SoraStorageUsedBytes: 1000000000,
}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.CheckQuota(context.Background(), 1, 999999999)
require.NoError(t, err) // 无限制时始终通过
}
func TestCheckQuota_ExactBoundary(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 1024,
SoraStorageUsedBytes: 1024, // 恰好满
}
svc := NewSoraQuotaService(userRepo, nil, nil)
// 额外 0 字节不超
require.NoError(t, svc.CheckQuota(context.Background(), 1, 0))
// 额外 1 字节超出
require.Error(t, svc.CheckQuota(context.Background(), 1, 1))
}
// ==================== AddUsage ====================
func TestAddUsage_Success(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, 2048)
require.NoError(t, err)
require.Equal(t, int64(3072), userRepo.users[1].SoraStorageUsedBytes)
}
func TestAddUsage_ZeroBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, 0)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestAddUsage_NegativeBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, -100)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestAddUsage_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 999, 1024)
require.Error(t, err)
}
func TestAddUsage_UpdateError(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 0}
userRepo.updateErr = fmt.Errorf("db error")
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.AddUsage(context.Background(), 1, 1024)
require.Error(t, err)
require.Contains(t, err.Error(), "update user quota usage")
}
// ==================== ReleaseUsage ====================
func TestReleaseUsage_Success(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 3072}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 1024)
require.NoError(t, err)
require.Equal(t, int64(2048), userRepo.users[1].SoraStorageUsedBytes)
}
func TestReleaseUsage_ClampToZero(t *testing.T) {
// 释放量大于已用量时,应 clamp 到 0
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 500}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 1000)
require.NoError(t, err)
require.Equal(t, int64(0), userRepo.users[1].SoraStorageUsedBytes)
}
func TestReleaseUsage_ZeroBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 0)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestReleaseUsage_NegativeBytes(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, -50)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
}
func TestReleaseUsage_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 999, 1024)
require.Error(t, err)
}
func TestReleaseUsage_UpdateError(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
userRepo.updateErr = fmt.Errorf("db error")
svc := NewSoraQuotaService(userRepo, nil, nil)
err := svc.ReleaseUsage(context.Background(), 1, 512)
require.Error(t, err)
require.Contains(t, err.Error(), "update user quota release")
}
// ==================== GetQuotaFromSettings ====================
func TestGetQuotaFromSettings_NilSettingService(t *testing.T) {
svc := NewSoraQuotaService(nil, nil, nil)
require.Equal(t, int64(0), svc.GetQuotaFromSettings(context.Background()))
}
func TestGetQuotaFromSettings_WithSettings(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(nil, nil, settingService)
require.Equal(t, int64(52428800), svc.GetQuotaFromSettings(context.Background()))
}
// ==================== SetUserSoraQuota ====================
func TestSetUserSoraQuota_Success(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0}
err := SetUserSoraQuota(context.Background(), userRepo, 1, 10*1024*1024)
require.NoError(t, err)
require.Equal(t, int64(10*1024*1024), userRepo.users[1].SoraStorageQuotaBytes)
}
func TestSetUserSoraQuota_UserNotFound(t *testing.T) {
userRepo := newStubUserRepoForQuota()
err := SetUserSoraQuota(context.Background(), userRepo, 999, 1024)
require.Error(t, err)
}
// ==================== ParseQuotaBytes ====================
func TestParseQuotaBytes(t *testing.T) {
require.Equal(t, int64(1048576), ParseQuotaBytes("1048576"))
require.Equal(t, int64(0), ParseQuotaBytes(""))
require.Equal(t, int64(0), ParseQuotaBytes("abc"))
require.Equal(t, int64(-1), ParseQuotaBytes("-1"))
}
// ==================== 优先级完整测试 ====================
func TestQuotaPriority_UserOverridesGroup(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 5 * 1024 * 1024,
AllowedGroups: []int64{10},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "user", quota.Source) // 用户级优先
require.Equal(t, int64(5*1024*1024), quota.QuotaBytes)
}
func TestQuotaPriority_GroupOverridesSystem(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0,
AllowedGroups: []int64{10},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "group", quota.Source) // 分组级优先于系统
require.Equal(t, int64(20*1024*1024), quota.QuotaBytes)
}
func TestQuotaPriority_FallbackToSystem(t *testing.T) {
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{
ID: 1,
SoraStorageQuotaBytes: 0,
AllowedGroups: []int64{10},
}
groupRepo := newStubGroupRepoForQuota()
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 0} // 分组无配额
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
})
settingService := NewSettingService(settingRepo, &config.Config{})
svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
quota, err := svc.GetQuota(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "system", quota.Source)
require.Equal(t, int64(52428800), quota.QuotaBytes)
}
-392
View File
@@ -1,392 +0,0 @@
package service
import (
"context"
"fmt"
"io"
"net/http"
"path"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
type SoraS3Storage struct {
settingService *SettingService
mu sync.RWMutex
client *s3.Client
cfg *SoraS3Settings // 上次加载的配置快照
healthCheckedAt time.Time
healthErr error
healthTTL time.Duration
}
const defaultSoraS3HealthTTL = 30 * time.Second
// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
type UpstreamDownloadError struct {
StatusCode int
}
func (e *UpstreamDownloadError) Error() string {
if e == nil {
return "upstream download failed"
}
return fmt.Sprintf("upstream returned %d", e.StatusCode)
}
// NewSoraS3Storage 创建 S3 存储服务实例。
func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage {
return &SoraS3Storage{
settingService: settingService,
healthTTL: defaultSoraS3HealthTTL,
}
}
// Enabled 返回 S3 存储是否已启用且配置有效。
func (s *SoraS3Storage) Enabled(ctx context.Context) bool {
cfg, err := s.getConfig(ctx)
if err != nil || cfg == nil {
return false
}
return cfg.Enabled && cfg.Bucket != ""
}
// getConfig 获取当前 S3 配置(从 settings 表读取)。
func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) {
if s.settingService == nil {
return nil, fmt.Errorf("setting service not available")
}
return s.settingService.GetSoraS3Settings(ctx)
}
// getClient 获取或初始化 S3 客户端(带缓存)。
// 配置变更时调用 RefreshClient 清除缓存。
func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
s.mu.RLock()
if s.client != nil && s.cfg != nil {
client, cfg := s.client, s.cfg
s.mu.RUnlock()
return client, cfg, nil
}
s.mu.RUnlock()
return s.initClient(ctx)
}
func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
s.mu.Lock()
defer s.mu.Unlock()
// 双重检查
if s.client != nil && s.cfg != nil {
return s.client, s.cfg, nil
}
cfg, err := s.getConfig(ctx)
if err != nil {
return nil, nil, fmt.Errorf("load s3 config: %w", err)
}
if !cfg.Enabled {
return nil, nil, fmt.Errorf("sora s3 storage is disabled")
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required")
}
client, region, err := buildSoraS3Client(ctx, cfg)
if err != nil {
return nil, nil, err
}
s.client = client
s.cfg = cfg
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region)
return client, cfg, nil
}
// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
// 应在系统设置中 S3 配置变更时调用。
func (s *SoraS3Storage) RefreshClient() {
s.mu.Lock()
defer s.mu.Unlock()
s.client = nil
s.cfg = nil
s.healthCheckedAt = time.Time{}
s.healthErr = nil
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化")
}
// TestConnection 测试 S3 连接(HeadBucket)。
func (s *SoraS3Storage) TestConnection(ctx context.Context) error {
client, cfg, err := s.getClient(ctx)
if err != nil {
return err
}
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &cfg.Bucket,
})
if err != nil {
return fmt.Errorf("s3 HeadBucket failed: %w", err)
}
return nil
}
// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。
func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool {
if s == nil {
return false
}
now := time.Now()
s.mu.RLock()
lastCheck := s.healthCheckedAt
lastErr := s.healthErr
ttl := s.healthTTL
s.mu.RUnlock()
if ttl <= 0 {
ttl = defaultSoraS3HealthTTL
}
if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl {
return lastErr == nil
}
err := s.TestConnection(ctx)
s.mu.Lock()
s.healthCheckedAt = time.Now()
s.healthErr = err
s.mu.Unlock()
return err == nil
}
// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error {
if cfg == nil {
return fmt.Errorf("s3 config is required")
}
if !cfg.Enabled {
return fmt.Errorf("sora s3 storage is disabled")
}
if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required")
}
client, _, err := buildSoraS3Client(ctx, cfg)
if err != nil {
return err
}
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &cfg.Bucket,
})
if err != nil {
return fmt.Errorf("s3 HeadBucket failed: %w", err)
}
return nil
}
// GenerateObjectKey 生成 S3 object key。
// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string {
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
datePath := time.Now().Format("2006/01/02")
key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext)
if prefix != "" {
prefix = strings.TrimRight(prefix, "/") + "/"
key = prefix + key
}
return key
}
// UploadFromURL 从上游 URL 下载并流式上传到 S3。
// 返回 S3 object key。
func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) {
client, cfg, err := s.getClient(ctx)
if err != nil {
return "", 0, err
}
// 下载源文件
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
if err != nil {
return "", 0, fmt.Errorf("create download request: %w", err)
}
httpClient := &http.Client{Timeout: 5 * time.Minute}
resp, err := httpClient.Do(req)
if err != nil {
return "", 0, fmt.Errorf("download from upstream: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode}
}
// 推断文件扩展名
ext := fileExtFromURL(sourceURL)
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
if ext == "" {
ext = ".bin"
}
objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext)
// 检测 Content-Type
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream"
}
reader, writer := io.Pipe()
uploadErrCh := make(chan error, 1)
go func() {
defer close(uploadErrCh)
input := &s3.PutObjectInput{
Bucket: &cfg.Bucket,
Key: &objectKey,
Body: reader,
ContentType: &contentType,
}
if resp.ContentLength >= 0 {
input.ContentLength = &resp.ContentLength
}
_, uploadErr := client.PutObject(ctx, input)
uploadErrCh <- uploadErr
}()
written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024))
_ = writer.CloseWithError(copyErr)
uploadErr := <-uploadErrCh
if copyErr != nil {
return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr)
}
if uploadErr != nil {
return "", 0, fmt.Errorf("s3 upload: %w", uploadErr)
}
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written)
return objectKey, written, nil
}
func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) {
if cfg == nil {
return nil, "", fmt.Errorf("s3 config is required")
}
region := cfg.Region
if region == "" {
region = "us-east-1"
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
)
if err != nil {
return nil, "", fmt.Errorf("load aws config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = &cfg.Endpoint
}
if cfg.ForcePathStyle {
o.UsePathStyle = true
}
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
// 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
return client, region, nil
}
// DeleteObjects 删除一组 S3 object(遍历逐一删除)。
func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error {
if len(objectKeys) == 0 {
return nil
}
client, cfg, err := s.getClient(ctx)
if err != nil {
return err
}
var lastErr error
for _, key := range objectKeys {
k := key
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &cfg.Bucket,
Key: &k,
})
if err != nil {
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err)
lastErr = err
}
}
return lastErr
}
// GetAccessURL 获取 S3 文件的访问 URL。
// CDN URL 优先,否则生成 24h 预签名 URL。
func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) {
_, cfg, err := s.getClient(ctx)
if err != nil {
return "", err
}
// CDN URL 优先
if cfg.CDNURL != "" {
cdnBase := strings.TrimRight(cfg.CDNURL, "/")
return cdnBase + "/" + objectKey, nil
}
// 生成 24h 预签名 URL
return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour)
}
// GeneratePresignedURL 生成预签名 URL。
func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) {
client, cfg, err := s.getClient(ctx)
if err != nil {
return "", err
}
presignClient := s3.NewPresignClient(client)
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &cfg.Bucket,
Key: &objectKey,
}, s3.WithPresignExpires(ttl))
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return result.URL, nil
}
// GetMediaType 从 object key 推断媒体类型(image/video)。
func GetMediaTypeFromKey(objectKey string) string {
ext := strings.ToLower(path.Ext(objectKey))
switch ext {
case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
return "video"
default:
return "image"
}
}
@@ -1,263 +0,0 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ==================== RefreshClient ====================
func TestRefreshClient(t *testing.T) {
s := newS3StorageWithCDN("https://cdn.example.com")
require.NotNil(t, s.client)
require.NotNil(t, s.cfg)
s.RefreshClient()
require.Nil(t, s.client)
require.Nil(t, s.cfg)
}
func TestRefreshClient_AlreadyNil(t *testing.T) {
s := NewSoraS3Storage(nil)
s.RefreshClient() // 不应 panic
require.Nil(t, s.client)
require.Nil(t, s.cfg)
}
// ==================== GetMediaTypeFromKey ====================
func TestGetMediaTypeFromKey_VideoExtensions(t *testing.T) {
for _, ext := range []string{".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv"} {
require.Equal(t, "video", GetMediaTypeFromKey("path/to/file"+ext), "ext=%s", ext)
}
}
func TestGetMediaTypeFromKey_VideoUpperCase(t *testing.T) {
require.Equal(t, "video", GetMediaTypeFromKey("file.MP4"))
require.Equal(t, "video", GetMediaTypeFromKey("file.MOV"))
}
func TestGetMediaTypeFromKey_ImageExtensions(t *testing.T) {
require.Equal(t, "image", GetMediaTypeFromKey("file.png"))
require.Equal(t, "image", GetMediaTypeFromKey("file.jpg"))
require.Equal(t, "image", GetMediaTypeFromKey("file.jpeg"))
require.Equal(t, "image", GetMediaTypeFromKey("file.gif"))
require.Equal(t, "image", GetMediaTypeFromKey("file.webp"))
}
func TestGetMediaTypeFromKey_NoExtension(t *testing.T) {
require.Equal(t, "image", GetMediaTypeFromKey("file"))
require.Equal(t, "image", GetMediaTypeFromKey("path/to/file"))
}
func TestGetMediaTypeFromKey_UnknownExtension(t *testing.T) {
require.Equal(t, "image", GetMediaTypeFromKey("file.bin"))
require.Equal(t, "image", GetMediaTypeFromKey("file.xyz"))
}
// ==================== Enabled ====================
func TestEnabled_NilSettingService(t *testing.T) {
s := NewSoraS3Storage(nil)
require.False(t, s.Enabled(context.Background()))
}
func TestEnabled_ConfigDisabled(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "false",
SettingKeySoraS3Bucket: "test-bucket",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
require.False(t, s.Enabled(context.Background()))
}
func TestEnabled_ConfigEnabledWithBucket(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "my-bucket",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
require.True(t, s.Enabled(context.Background()))
}
func TestEnabled_ConfigEnabledEmptyBucket(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
require.False(t, s.Enabled(context.Background()))
}
// ==================== initClient ====================
func TestInitClient_Disabled(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "false",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
_, _, err := s.getClient(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "disabled")
}
func TestInitClient_IncompleteConfig(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "test-bucket",
// 缺少 access_key_id 和 secret_access_key
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
_, _, err := s.getClient(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "incomplete")
}
func TestInitClient_DefaultRegion(t *testing.T) {
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "test-bucket",
SettingKeySoraS3AccessKeyID: "AKID",
SettingKeySoraS3SecretAccessKey: "SECRET",
// Region 为空 → 默认 us-east-1
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
client, cfg, err := s.getClient(context.Background())
require.NoError(t, err)
require.NotNil(t, client)
require.Equal(t, "test-bucket", cfg.Bucket)
}
func TestInitClient_DoubleCheck(t *testing.T) {
// 验证双重检查锁定:第二次 getClient 命中缓存
settingRepo := newStubSettingRepoForQuota(map[string]string{
SettingKeySoraS3Enabled: "true",
SettingKeySoraS3Bucket: "test-bucket",
SettingKeySoraS3AccessKeyID: "AKID",
SettingKeySoraS3SecretAccessKey: "SECRET",
})
settingService := NewSettingService(settingRepo, &config.Config{})
s := NewSoraS3Storage(settingService)
client1, _, err1 := s.getClient(context.Background())
require.NoError(t, err1)
client2, _, err2 := s.getClient(context.Background())
require.NoError(t, err2)
require.Equal(t, client1, client2) // 同一客户端实例
}
func TestInitClient_NilSettingService(t *testing.T) {
s := NewSoraS3Storage(nil)
_, _, err := s.getClient(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "setting service not available")
}
// ==================== GenerateObjectKey ====================
func TestGenerateObjectKey_ExtWithoutDot(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("", 1, "mp4")
require.Contains(t, key, ".mp4")
require.True(t, len(key) > 0)
}
func TestGenerateObjectKey_ExtWithDot(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("", 1, ".mp4")
require.Contains(t, key, ".mp4")
// 不应出现 ..mp4
require.NotContains(t, key, "..mp4")
}
func TestGenerateObjectKey_WithPrefix(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("uploads/", 42, ".png")
require.True(t, len(key) > 0)
require.Contains(t, key, "uploads/sora/42/")
}
func TestGenerateObjectKey_PrefixWithoutTrailingSlash(t *testing.T) {
s := NewSoraS3Storage(nil)
key := s.GenerateObjectKey("uploads", 42, ".png")
require.Contains(t, key, "uploads/sora/42/")
}
// ==================== GeneratePresignedURL ====================
func TestGeneratePresignedURL_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil) // settingService=nil → getClient 失败
_, err := s.GeneratePresignedURL(context.Background(), "key", 3600)
require.Error(t, err)
}
// ==================== GetAccessURL ====================
func TestGetAccessURL_CDN(t *testing.T) {
s := newS3StorageWithCDN("https://cdn.example.com")
url, err := s.GetAccessURL(context.Background(), "sora/1/2024/01/01/video.mp4")
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", url)
}
func TestGetAccessURL_CDNTrailingSlash(t *testing.T) {
s := newS3StorageWithCDN("https://cdn.example.com/")
url, err := s.GetAccessURL(context.Background(), "key.mp4")
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/key.mp4", url)
}
func TestGetAccessURL_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
_, err := s.GetAccessURL(context.Background(), "key")
require.Error(t, err)
}
// ==================== TestConnection ====================
func TestTestConnection_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.TestConnection(context.Background())
require.Error(t, err)
}
// ==================== UploadFromURL ====================
func TestUploadFromURL_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
_, _, err := s.UploadFromURL(context.Background(), 1, "https://example.com/file.mp4")
require.Error(t, err)
}
// ==================== DeleteObjects ====================
func TestDeleteObjects_EmptyKeys(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.DeleteObjects(context.Background(), []string{})
require.NoError(t, err) // 空列表直接返回
}
func TestDeleteObjects_NilKeys(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.DeleteObjects(context.Background(), nil)
require.NoError(t, err) // nil 列表直接返回
}
func TestDeleteObjects_GetClientError(t *testing.T) {
s := NewSoraS3Storage(nil)
err := s.DeleteObjects(context.Background(), []string{"key1", "key2"})
require.Error(t, err)
}
File diff suppressed because it is too large Load Diff
@@ -1,149 +0,0 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions"
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
// 支持流式和非流式响应的直接透传。
func (s *SoraGatewayService) forwardToUpstream(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
clientStream bool,
startTime time.Time,
) (*ForwardResult, error) {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
}
baseURL := account.GetBaseURL()
if baseURL == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
}
// 校验 scheme 合法性(仅允许 http/https
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
}
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
// 构建上游请求
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
return nil, fmt.Errorf("create upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
// 透传客户端的部分请求头
for _, header := range []string{"Accept", "Accept-Encoding"} {
if v := c.GetHeader(header); v != "" {
upstreamReq.Header.Set(header, v)
}
}
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
}
}
defer func() {
_ = resp.Body.Close()
}()
// 错误响应处理
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
// 非转移错误,直接透传给客户端
c.Status(resp.StatusCode)
for key, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
if _, err := c.Writer.Write(respBody); err != nil {
return nil, fmt.Errorf("write upstream error response: %w", err)
}
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 成功响应 — 直接透传
c.Status(resp.StatusCode)
for key, values := range resp.Header {
lower := strings.ToLower(key)
// 透传内容相关头部
if lower == "content-type" || lower == "transfer-encoding" ||
lower == "cache-control" || lower == "x-request-id" {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
}
// 流式复制响应体
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, err := c.Writer.Write(buf[:n]); err != nil {
return nil, fmt.Errorf("stream upstream response write: %w", err)
}
flusher.Flush()
}
if readErr != nil {
break
}
}
} else {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return nil, fmt.Errorf("copy upstream response: %w", err)
}
}
duration := time.Since(startTime)
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Model: "", // 由调用方填充
Stream: clientStream,
Duration: duration,
}, nil
}
@@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
// Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformOpenAI, PlatformSora:
case PlatformOpenAI:
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic:
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
@@ -60,7 +60,6 @@ func NewTokenRefreshService(
}
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
claudeRefresher := NewClaudeTokenRefresher(oauthService)
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
@@ -85,18 +84,6 @@ func NewTokenRefreshService(
return s
}
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
// 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
openaiRefresher.SetSoraAccountRepo(repo)
}
}
}
// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖
func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) {
s.privacyClientFactory = factory
@@ -2,7 +2,6 @@ package service
import (
"context"
"log"
"time"
)
@@ -73,8 +72,6 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
syncLinkedSora bool
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -90,20 +87,7 @@ func (r *OpenAITokenRefresher) CacheKey(account *Account) string {
return OpenAITokenCacheKey(account)
}
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 Token 刷新时同步更新 sora_accounts 表
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo
}
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
r.syncLinkedSora = enabled
}
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
}
@@ -121,7 +105,6 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time
// Refresh 执行token刷新
// 保留原有credentials中的所有字段,只更新token相关字段
// 刷新成功后,异步同步关联的 Sora 账号
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
@@ -132,68 +115,5 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
newCredentials = MergeCredentials(account.Credentials, newCredentials)
// 异步同步关联的 Sora 账号(不阻塞主流程)
if r.accountRepo != nil && r.syncLinkedSora {
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
}
return newCredentials, nil
}
// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步)
// 该方法异步执行,失败只记录日志,不影响主流程
//
// 同步策略:
// 1. 更新 accounts.credentials(主表)
// 2. 更新 sora_accounts 扩展表(如果 soraAccountRepo 已设置)
//
// 超时控制:30 秒,防止数据库阻塞导致 goroutine 泄漏
func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
// 添加超时控制,防止 goroutine 泄漏
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// 1. 查找所有关联的 Sora 账号(限定 platform='sora'
soraAccounts, err := r.accountRepo.FindByExtraField(ctx, "linked_openai_account_id", openaiAccountID)
if err != nil {
log.Printf("[TokenSync] 查找关联 Sora 账号失败: openai_account_id=%d err=%v", openaiAccountID, err)
return
}
if len(soraAccounts) == 0 {
// 没有关联的 Sora 账号,直接返回
return
}
// 2. 同步更新每个 Sora 账号的双表数据
for _, soraAccount := range soraAccounts {
// 2.1 更新 accounts.credentials(主表)
soraAccount.Credentials["access_token"] = newCredentials["access_token"]
soraAccount.Credentials["refresh_token"] = newCredentials["refresh_token"]
if expiresAt, ok := newCredentials["expires_at"]; ok {
soraAccount.Credentials["expires_at"] = expiresAt
}
if err := r.accountRepo.Update(ctx, &soraAccount); err != nil {
log.Printf("[TokenSync] 更新 Sora accounts 表失败: sora_account_id=%d openai_account_id=%d err=%v",
soraAccount.ID, openaiAccountID, err)
continue
}
// 2.2 更新 sora_accounts 扩展表(如果仓储已设置)
if r.soraAccountRepo != nil {
soraUpdates := map[string]any{
"access_token": newCredentials["access_token"],
"refresh_token": newCredentials["refresh_token"],
}
if err := r.soraAccountRepo.Upsert(ctx, soraAccount.ID, soraUpdates); err != nil {
log.Printf("[TokenSync] 更新 sora_accounts 表失败: account_id=%d openai_account_id=%d err=%v",
soraAccount.ID, openaiAccountID, err)
// 继续处理其他账号,不中断
}
}
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
}
}
@@ -242,12 +242,6 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
accType: AccountTypeOAuth,
want: true,
},
{
name: "sora oauth - cannot refresh directly",
platform: PlatformSora,
accType: AccountTypeOAuth,
want: false,
},
{
name: "openai apikey - cannot refresh",
platform: PlatformOpenAI,
+1 -1
View File
@@ -110,7 +110,7 @@ type UsageLog struct {
ModelMappingChain *string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string
// BillingMode 计费模式:token/imagesora 路径为 nil
// BillingMode 计费模式:token/image
BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string
-4
View File
@@ -25,10 +25,6 @@ type User struct {
// map[groupID]rateMultiplier
GroupRates map[int64]float64
// Sora 存储配额
SoraStorageQuotaBytes int64 // 用户级 Sora 存储配额(0 表示使用分组或系统默认值)
SoraStorageUsedBytes int64 // Sora 存储已用量
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
-32
View File
@@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
// ProvideTokenRefreshService creates and starts TokenRefreshService
func ProvideTokenRefreshService(
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
@@ -54,8 +53,6 @@ func ProvideTokenRefreshService(
refreshAPI *OAuthRefreshAPI,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo)
// 注入 OpenAI privacy opt-out 依赖
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
@@ -281,30 +278,6 @@ func ProvideOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink {
return sink
}
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
func ProvideSoraSDKClient(
cfg *config.Config,
httpUpstream HTTPUpstream,
tokenProvider *OpenAITokenProvider,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
) *SoraSDKClient {
client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider)
client.SetAccountRepositories(accountRepo, soraAccountRepo)
return client
}
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
return svc
}
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
idempotencyCfg := DefaultIdempotencyConfig()
if cfg != nil {
@@ -425,11 +398,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementService,
NewAdminService,
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
ProvideSoraSDKClient,
wire.Bind(new(SoraClient), new(*SoraSDKClient)),
NewSoraGatewayService,
NewOpenAIGatewayService,
NewOAuthService,
NewOpenAIOAuthService,