Merge pull request #1242 from Ethan0x0000/feat/anthropic-openai-endpoint-compat
支持 Anthropic Responses / Chat Completions 兼容端点并完善会话一致性与错误可观测性
This commit is contained in:
@@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
|
||||
// POST /v1/chat/completions
|
||||
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
|
||||
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
|
||||
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// Claude Code only restriction
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group is restricted to Claude Code clients (/v1/messages only)")
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
// 1. Acquire user concurrency slot
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.chatCompletionsErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request for session hash
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
|
||||
if parsedReq == nil {
|
||||
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
||||
}
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default:
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 4. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleCCFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
}
|
||||
h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.cc.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Record usage
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.cc.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
|
||||
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
|
||||
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
if streamStarted {
|
||||
return
|
||||
}
|
||||
statusCode := http.StatusBadGateway
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
@@ -0,0 +1,295 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
|
||||
// POST /v1/responses
|
||||
// This converts Responses API requests to Anthropic format, forwards to Anthropic
|
||||
// upstream, and converts responses back to Responses format.
|
||||
func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.responses",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream using gjson (like OpenAI handler)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// Claude Code only restriction:
|
||||
// /v1/responses is never a Claude Code endpoint.
|
||||
// When claude_code_only is enabled, this endpoint is rejected.
|
||||
// The existing service-layer checkClaudeCodeRestriction handles degradation
|
||||
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
|
||||
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group is restricted to Claude Code clients (/v1/messages only)")
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
// 1. Acquire user concurrency slot
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.responsesErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request for session hash
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
|
||||
if parsedReq == nil {
|
||||
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
||||
}
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default:
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 4. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// Can't failover if streaming content already sent
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleResponsesFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
}
|
||||
h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.responses.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Record usage
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.responses.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// responsesErrorResponse writes an error in OpenAI Responses API format.
|
||||
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
|
||||
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
if streamStarted {
|
||||
return // Can't write error after stream started
|
||||
}
|
||||
statusCode := http.StatusBadGateway
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
@@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
@@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
@@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
@@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
|
||||
@@ -27,6 +27,9 @@ const (
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
opsUpstreamModelKey = "ops_upstream_model"
|
||||
opsRequestTypeKey = "ops_request_type"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
@@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
}
|
||||
}
|
||||
|
||||
// setOpsEndpointContext stores upstream model and request type for ops error logging.
|
||||
// Called by handlers after model mapping and request type determination.
|
||||
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
|
||||
c.Set(opsUpstreamModelKey, upstreamModel)
|
||||
}
|
||||
c.Set(opsRequestTypeKey, requestType)
|
||||
}
|
||||
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
@@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: stream,
|
||||
Stream: stream,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
|
||||
RequestedModel: modelName,
|
||||
UpstreamModel: func() string {
|
||||
if v, ok := c.Get(opsUpstreamModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestType: func() *int16 {
|
||||
if v, ok := c.Get(opsRequestTypeKey); ok {
|
||||
switch t := v.(type) {
|
||||
case int16:
|
||||
return &t
|
||||
case int:
|
||||
v16 := int16(t)
|
||||
return &v16
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: "upstream",
|
||||
@@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: stream,
|
||||
Stream: stream,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
|
||||
RequestedModel: modelName,
|
||||
UpstreamModel: func() string {
|
||||
if v, ok := c.Get(opsUpstreamModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestType: func() *int16 {
|
||||
if v, ok := c.Get(opsRequestTypeKey); ok {
|
||||
switch t := v.(type) {
|
||||
case int16:
|
||||
return &t
|
||||
case int:
|
||||
v16 := int16(t)
|
||||
return &v16
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: phase,
|
||||
|
||||
@@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream
|
||||
|
||||
v, ok := c.Get(opsUpstreamModelKey)
|
||||
require.True(t, ok)
|
||||
vStr, ok := v.(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "claude-3-5-sonnet-20241022", vStr)
|
||||
|
||||
rt, ok := c.Get(opsRequestTypeKey)
|
||||
require.True(t, ok)
|
||||
rtVal, ok := rt.(int16)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int16(2), rtVal)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
setOpsEndpointContext(c, "", int16(1))
|
||||
|
||||
_, ok := c.Get(opsUpstreamModelKey)
|
||||
require.False(t, ok, "empty upstream model should not be stored")
|
||||
|
||||
rt, ok := c.Get(opsRequestTypeKey)
|
||||
require.True(t, ok)
|
||||
rtVal, ok := rt.(int16)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int16(1), rtVal)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_NilContext(t *testing.T) {
|
||||
require.NotPanics(t, func() {
|
||||
setOpsEndpointContext(nil, "model", int16(1))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
|
||||
|
||||
platform := ""
|
||||
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
|
||||
Reference in New Issue
Block a user