feat: add OpenAI image generation controls
This commit is contained in:
@@ -92,6 +92,9 @@ type CreateGroupRequest struct {
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
@@ -129,6 +132,9 @@ type UpdateGroupRequest struct {
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
AllowImageGeneration *bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent *bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
@@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: req.ImageRateMultiplier,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: req.ImageRateMultiplier,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
|
||||
@@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
AllowImageGeneration: g.AllowImageGeneration,
|
||||
ImageRateIndependent: g.ImageRateIndependent,
|
||||
ImageRateMultiplier: g.ImageRateMultiplier,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
|
||||
@@ -94,9 +94,12 @@ type Group struct {
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type imageConcurrencyLimiter struct {
|
||||
mu sync.Mutex
|
||||
notify chan struct{}
|
||||
limit int
|
||||
active int
|
||||
waiting int
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) {
|
||||
return l.acquire(context.Background(), enabled, limit, false, 0, 0)
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
|
||||
return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting)
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
|
||||
if !enabled || limit <= 0 {
|
||||
return nil, true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if wait {
|
||||
if timeout <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
ctx = waitCtx
|
||||
}
|
||||
if maxWaiting < 0 {
|
||||
maxWaiting = 0
|
||||
}
|
||||
for {
|
||||
release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting)
|
||||
if acquired {
|
||||
return release, acquired
|
||||
}
|
||||
if !wait || notify == nil {
|
||||
return nil, false
|
||||
}
|
||||
if !l.waitForSlot(ctx, notify) {
|
||||
if waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
if waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.notify == nil {
|
||||
l.notify = make(chan struct{})
|
||||
}
|
||||
if l.enabled != enabled || l.limit != limit {
|
||||
l.enabled = enabled
|
||||
l.limit = limit
|
||||
}
|
||||
if l.active < l.limit {
|
||||
l.active++
|
||||
return l.releaseFunc(), true, nil, nil
|
||||
}
|
||||
if !wait {
|
||||
return nil, false, nil, nil
|
||||
}
|
||||
if maxWaiting > 0 && l.waiting >= maxWaiting {
|
||||
return nil, false, nil, nil
|
||||
}
|
||||
l.waiting++
|
||||
return nil, false, l.waiterReleaseFunc(), l.notify
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool {
|
||||
select {
|
||||
case <-notify:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) releaseFunc() func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
l.mu.Lock()
|
||||
if l.active > 0 {
|
||||
l.active--
|
||||
}
|
||||
if l.notify != nil {
|
||||
close(l.notify)
|
||||
l.notify = make(chan struct{})
|
||||
}
|
||||
l.mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
l.mu.Lock()
|
||||
if l.waiting > 0 {
|
||||
l.waiting--
|
||||
}
|
||||
l.mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
|
||||
release, acquired := limiter.TryAcquire(false, 1)
|
||||
|
||||
require.True(t, acquired)
|
||||
require.Nil(t, release)
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
|
||||
release, acquired := limiter.TryAcquire(true, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
secondRelease, secondAcquired := limiter.TryAcquire(true, 1)
|
||||
require.False(t, secondAcquired)
|
||||
require.Nil(t, secondRelease)
|
||||
|
||||
release()
|
||||
thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1)
|
||||
require.True(t, thirdAcquired)
|
||||
require.NotNil(t, thirdRelease)
|
||||
thirdRelease()
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
acquiredCh := make(chan func(), 1)
|
||||
go func() {
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, waitAcquired)
|
||||
acquiredCh <- waitRelease
|
||||
}()
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
release()
|
||||
|
||||
select {
|
||||
case waitRelease := <-acquiredCh:
|
||||
require.NotNil(t, waitRelease)
|
||||
waitRelease()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for image concurrency slot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1)
|
||||
|
||||
require.False(t, waitAcquired)
|
||||
require.Nil(t, waitRelease)
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
waitingStarted := make(chan struct{})
|
||||
waitingDone := make(chan struct{})
|
||||
go func() {
|
||||
close(waitingStarted)
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
if waitAcquired && waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
close(waitingDone)
|
||||
}()
|
||||
<-waitingStarted
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
|
||||
require.False(t, overflowAcquired)
|
||||
require.Nil(t, overflowRelease)
|
||||
release()
|
||||
<-waitingDone
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
},
|
||||
},
|
||||
},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
blockedRelease, blocked := h.acquireImageGenerationSlot(c, false)
|
||||
|
||||
require.False(t, blocked)
|
||||
require.Nil(t, blockedRelease)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
||||
groupID := int64(1)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
User: &service.User{ID: 20},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
errorPassthroughService: nil,
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
}}},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
rec.Body.Reset()
|
||||
rec.Code = 0
|
||||
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := `{"model":"gpt-5.4","input":"write code"}`
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
||||
groupID := int64(1)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
User: &service.User{ID: 20},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
}}},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
rec.Body.Reset()
|
||||
rec.Code = 0
|
||||
|
||||
h.Responses(c)
|
||||
|
||||
require.NotEqual(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
@@ -187,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
@@ -242,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
|
||||
@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
@@ -69,6 +70,7 @@ func NewOpenAIGatewayHandler(
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
@@ -187,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
var imageReleaseFunc func()
|
||||
if imageIntent {
|
||||
var imageAcquired bool
|
||||
imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !imageAcquired {
|
||||
return
|
||||
}
|
||||
if imageReleaseFunc != nil {
|
||||
defer imageReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
@@ -318,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
@@ -383,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -701,52 +730,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai_messages.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
@@ -757,16 +794,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -1114,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
|
||||
@@ -1257,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil || result == nil {
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
@@ -1440,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
h.submitMandatoryUsageRecordTask(task)
|
||||
return
|
||||
}
|
||||
h.submitUsageRecordTask(task)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
|
||||
return
|
||||
}
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.usage"),
|
||||
).Warn("openai.usage_record_task_mandatory_sync_fallback")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.usage"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("openai.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) {
|
||||
if h == nil || h.cfg == nil || h.imageLimiter == nil {
|
||||
return nil, true
|
||||
}
|
||||
imageConcurrency := h.cfg.Gateway.ImageConcurrency
|
||||
wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait
|
||||
release, acquired := h.imageLimiter.Acquire(
|
||||
c.Request.Context(),
|
||||
imageConcurrency.Enabled,
|
||||
imageConcurrency.MaxConcurrentRequests,
|
||||
wait,
|
||||
time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second,
|
||||
imageConcurrency.MaxWaitingRequests,
|
||||
)
|
||||
if acquired {
|
||||
return release, true
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
|
||||
@@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
zap.String("capability", string(parsed.RequiredCapability)),
|
||||
)
|
||||
|
||||
if !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if imageReleaseFunc != nil {
|
||||
defer imageReleaseFunc()
|
||||
}
|
||||
|
||||
if parsed.Multipart {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
|
||||
} else {
|
||||
@@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
if result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai.images.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
@@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
if parsed.Multipart {
|
||||
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
|
||||
}
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
upstreamModel := ""
|
||||
if result != nil {
|
||||
upstreamModel = result.UpstreamModel
|
||||
}
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.images"),
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
groupID := int64(111)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 222,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: false,
|
||||
},
|
||||
User: &service.User{ID: 333},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}},
|
||||
}
|
||||
|
||||
h.Images(c)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage())
|
||||
}
|
||||
@@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) {
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
block := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
pool.Submit(func(ctx context.Context) {
|
||||
close(block)
|
||||
<-release
|
||||
})
|
||||
<-block
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) {
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
block := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
pool.Submit(func(ctx context.Context) {
|
||||
close(block)
|
||||
<-release
|
||||
})
|
||||
<-block
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user