feat: support Gemini chat completions gateway
Release Image / image (push) Successful in 2m58s

This commit is contained in:
kone
2026-06-05 10:37:47 +08:00
parent 2a00019d81
commit ac1e273009
4 changed files with 292 additions and 2 deletions
+1 -1
View File
@@ -1 +1 @@
0.1.141 0.1.142
@@ -219,7 +219,16 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
if channelMapping.Mapped { if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
} }
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) var result *service.ForwardResult
if account.Platform == service.PlatformGemini {
if h.geminiCompatService == nil {
err = errors.New("gemini compatibility service not configured")
} else {
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
}
} else {
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
@@ -14,11 +14,13 @@ import (
"math" "math"
mathrand "math/rand" mathrand "math/rand"
"net/http" "net/http"
"net/http/httptest"
"regexp" "regexp"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
@@ -1090,6 +1092,223 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil }, nil
} }
// ForwardAsChatCompletions accepts an OpenAI Chat Completions request, converts
// it through the existing Responses/Anthropic compatibility chain, forwards it
// to Gemini, then translates the captured Anthropic-compatible response back to
// Chat Completions format.
func (s *GeminiMessagesCompatService) ForwardAsChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
var ccReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &ccReq); err != nil {
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
if strings.TrimSpace(ccReq.Model) == "" {
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return nil, errors.New("model is required")
}
originalModel := ccReq.Model
clientStream := ccReq.Stream
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
if err != nil {
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
if err != nil {
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// ChatCompletionsToResponses intentionally forces upstream streaming for
// OpenAI/Codex paths. Gemini's Claude-compat forwarder writes the client
// protocol directly, so preserve the client's requested stream mode here.
anthropicReq.Stream = clientStream
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
writeGeminiChatCompletionsError(c, http.StatusInternalServerError, "api_error", "Failed to build upstream request")
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
captureC, recorder := newCapturedGinContext(ctx, c)
result, err := s.Forward(ctx, captureC, account, anthropicBody)
if err != nil {
var failoverErr *UpstreamFailoverError
if errors.As(err, &failoverErr) {
return nil, err
}
writeCapturedClaudeErrorAsChatCompletions(c, recorder)
return nil, err
}
if recorder.Code >= 400 {
writeCapturedClaudeErrorAsChatCompletions(c, recorder)
return nil, fmt.Errorf("gemini chat completions upstream failed with status %d", recorder.Code)
}
if clientStream {
if err := writeCapturedAnthropicStreamAsChatCompletions(c, recorder, originalModel, includeUsage); err != nil {
return nil, err
}
} else {
if err := writeCapturedAnthropicMessageAsChatCompletions(c, recorder, originalModel); err != nil {
return nil, err
}
}
if result == nil {
result = &ForwardResult{}
}
result.Model = originalModel
result.Stream = clientStream
result.Duration = time.Since(startTime)
return result, nil
}
func newCapturedGinContext(ctx context.Context, src *gin.Context) (*gin.Context, *httptest.ResponseRecorder) {
recorder := httptest.NewRecorder()
captureC, _ := gin.CreateTestContext(recorder)
if src != nil {
captureC.Params = src.Params
if src.Request != nil {
captureC.Request = src.Request.Clone(ctx)
}
}
if captureC.Request == nil {
captureC.Request, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/messages", nil)
}
return captureC, recorder
}
func writeGeminiChatCompletionsError(c *gin.Context, status int, errType, message string) {
if c == nil || c.Writer.Written() {
return
}
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func writeCapturedClaudeErrorAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder) {
if c == nil || c.Writer.Written() {
return
}
status := recorder.Code
if status < 400 {
status = http.StatusBadGateway
}
body := recorder.Body.Bytes()
errType := strings.TrimSpace(gjson.GetBytes(body, "error.type").String())
if errType == "" {
errType = "server_error"
}
message := strings.TrimSpace(gjson.GetBytes(body, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.GetBytes(body, "message").String())
}
if message == "" {
message = "Upstream request failed"
}
writeGeminiChatCompletionsError(c, status, errType, message)
}
func writeCapturedAnthropicMessageAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder, originalModel string) error {
var anthropicResp apicompat.AnthropicResponse
if err := json.Unmarshal(recorder.Body.Bytes(), &anthropicResp); err != nil {
writeGeminiChatCompletionsError(c, http.StatusBadGateway, "server_error", "Failed to parse upstream response")
return fmt.Errorf("parse captured anthropic response: %w", err)
}
responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp)
ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel)
c.JSON(http.StatusOK, ccResp)
return nil
}
func writeCapturedAnthropicStreamAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder, originalModel string, includeUsage bool) error {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
anthState := apicompat.NewAnthropicEventToResponsesState()
anthState.Model = originalModel
ccState := apicompat.NewResponsesEventToChatState()
ccState.Model = originalModel
ccState.IncludeUsage = includeUsage
writeChunk := func(chunk apicompat.ChatCompletionsChunk) error {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return err
}
_, err = io.WriteString(c.Writer, sse)
return err
}
scanner := bufio.NewScanner(bytes.NewReader(recorder.Body.Bytes()))
scanner.Buffer(make([]byte, 0, 64*1024), defaultMaxLineSize)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(dataLine[6:]), &event); err != nil {
continue
}
responsesEvents := apicompat.AnthropicEventToResponsesEvents(&event, anthState)
for _, resEvt := range responsesEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
if err := writeChunk(chunk); err != nil {
return err
}
}
}
c.Writer.Flush()
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("read captured anthropic stream: %w", err)
}
for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) {
for _, chunk := range apicompat.ResponsesEventToChatChunks(&resEvt, ccState) {
if err := writeChunk(chunk); err != nil {
return err
}
}
}
for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) {
if err := writeChunk(chunk); err != nil {
return err
}
}
if _, err := io.WriteString(c.Writer, "data: [DONE]\n\n"); err != nil {
return err
}
c.Writer.Flush()
return nil
}
func isGeminiSignatureRelatedError(respBody []byte) bool { func isGeminiSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" { if msg == "" {
@@ -261,6 +261,68 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:") require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
} }
func TestGeminiMessagesCompatServiceForwardAsChatCompletions_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"gemini-cc-1"}},
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"OK"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":1}}`)),
},
}
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
account := &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-key",
},
}
body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"Return OK"}],"max_tokens":64}`)
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gemini-2.5-flash", result.Model)
require.Equal(t, "gemini-2.5-flash", result.UpstreamModel)
require.False(t, result.Stream)
require.Equal(t, 1, httpStub.calls)
require.NotNil(t, httpStub.lastReq)
require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:generateContent")
var out struct {
Object string `json:"object"`
Model string `json:"model"`
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &out))
require.Equal(t, "chat.completion", out.Object)
require.Equal(t, "gemini-2.5-flash", out.Model)
require.Len(t, out.Choices, 1)
require.Equal(t, "assistant", out.Choices[0].Message.Role)
require.Equal(t, "OK", out.Choices[0].Message.Content)
require.Equal(t, "stop", out.Choices[0].FinishReason)
require.Equal(t, 7, out.Usage.PromptTokens)
require.Equal(t, 1, out.Usage.CompletionTokens)
require.Equal(t, 8, out.Usage.TotalTokens)
}
func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) { func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
w := httptest.NewRecorder() w := httptest.NewRecorder()