feat: support Gemini chat completions gateway
This commit is contained in:
@@ -1 +1 @@
|
||||
0.1.141
|
||||
0.1.142
|
||||
|
||||
@@ -219,7 +219,16 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformGemini {
|
||||
if h.geminiCompatService == nil {
|
||||
err = errors.New("gemini compatibility service not configured")
|
||||
} else {
|
||||
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
|
||||
}
|
||||
} else {
|
||||
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
}
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
|
||||
@@ -14,11 +14,13 @@ import (
|
||||
"math"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
@@ -1090,6 +1092,223 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardAsChatCompletions accepts an OpenAI Chat Completions request, converts
|
||||
// it through the existing Responses/Anthropic compatibility chain, forwards it
|
||||
// to Gemini, then translates the captured Anthropic-compatible response back to
|
||||
// Chat Completions format.
|
||||
func (s *GeminiMessagesCompatService) ForwardAsChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var ccReq apicompat.ChatCompletionsRequest
|
||||
if err := json.Unmarshal(body, &ccReq); err != nil {
|
||||
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return nil, fmt.Errorf("parse chat completions request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(ccReq.Model) == "" {
|
||||
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
|
||||
originalModel := ccReq.Model
|
||||
clientStream := ccReq.Stream
|
||||
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
|
||||
|
||||
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
|
||||
if err != nil {
|
||||
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||
}
|
||||
|
||||
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
|
||||
if err != nil {
|
||||
writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
|
||||
}
|
||||
// ChatCompletionsToResponses intentionally forces upstream streaming for
|
||||
// OpenAI/Codex paths. Gemini's Claude-compat forwarder writes the client
|
||||
// protocol directly, so preserve the client's requested stream mode here.
|
||||
anthropicReq.Stream = clientStream
|
||||
|
||||
anthropicBody, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
writeGeminiChatCompletionsError(c, http.StatusInternalServerError, "api_error", "Failed to build upstream request")
|
||||
return nil, fmt.Errorf("marshal anthropic request: %w", err)
|
||||
}
|
||||
|
||||
captureC, recorder := newCapturedGinContext(ctx, c)
|
||||
result, err := s.Forward(ctx, captureC, account, anthropicBody)
|
||||
if err != nil {
|
||||
var failoverErr *UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
return nil, err
|
||||
}
|
||||
writeCapturedClaudeErrorAsChatCompletions(c, recorder)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if recorder.Code >= 400 {
|
||||
writeCapturedClaudeErrorAsChatCompletions(c, recorder)
|
||||
return nil, fmt.Errorf("gemini chat completions upstream failed with status %d", recorder.Code)
|
||||
}
|
||||
|
||||
if clientStream {
|
||||
if err := writeCapturedAnthropicStreamAsChatCompletions(c, recorder, originalModel, includeUsage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := writeCapturedAnthropicMessageAsChatCompletions(c, recorder, originalModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = &ForwardResult{}
|
||||
}
|
||||
result.Model = originalModel
|
||||
result.Stream = clientStream
|
||||
result.Duration = time.Since(startTime)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func newCapturedGinContext(ctx context.Context, src *gin.Context) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
recorder := httptest.NewRecorder()
|
||||
captureC, _ := gin.CreateTestContext(recorder)
|
||||
if src != nil {
|
||||
captureC.Params = src.Params
|
||||
if src.Request != nil {
|
||||
captureC.Request = src.Request.Clone(ctx)
|
||||
}
|
||||
}
|
||||
if captureC.Request == nil {
|
||||
captureC.Request, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/messages", nil)
|
||||
}
|
||||
return captureC, recorder
|
||||
}
|
||||
|
||||
func writeGeminiChatCompletionsError(c *gin.Context, status int, errType, message string) {
|
||||
if c == nil || c.Writer.Written() {
|
||||
return
|
||||
}
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writeCapturedClaudeErrorAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder) {
|
||||
if c == nil || c.Writer.Written() {
|
||||
return
|
||||
}
|
||||
status := recorder.Code
|
||||
if status < 400 {
|
||||
status = http.StatusBadGateway
|
||||
}
|
||||
body := recorder.Body.Bytes()
|
||||
errType := strings.TrimSpace(gjson.GetBytes(body, "error.type").String())
|
||||
if errType == "" {
|
||||
errType = "server_error"
|
||||
}
|
||||
message := strings.TrimSpace(gjson.GetBytes(body, "error.message").String())
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(gjson.GetBytes(body, "message").String())
|
||||
}
|
||||
if message == "" {
|
||||
message = "Upstream request failed"
|
||||
}
|
||||
writeGeminiChatCompletionsError(c, status, errType, message)
|
||||
}
|
||||
|
||||
func writeCapturedAnthropicMessageAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder, originalModel string) error {
|
||||
var anthropicResp apicompat.AnthropicResponse
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &anthropicResp); err != nil {
|
||||
writeGeminiChatCompletionsError(c, http.StatusBadGateway, "server_error", "Failed to parse upstream response")
|
||||
return fmt.Errorf("parse captured anthropic response: %w", err)
|
||||
}
|
||||
|
||||
responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp)
|
||||
ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel)
|
||||
c.JSON(http.StatusOK, ccResp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeCapturedAnthropicStreamAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder, originalModel string, includeUsage bool) error {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
anthState := apicompat.NewAnthropicEventToResponsesState()
|
||||
anthState.Model = originalModel
|
||||
ccState := apicompat.NewResponsesEventToChatState()
|
||||
ccState.Model = originalModel
|
||||
ccState.IncludeUsage = includeUsage
|
||||
|
||||
writeChunk := func(chunk apicompat.ChatCompletionsChunk) error {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.WriteString(c.Writer, sse)
|
||||
return err
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(recorder.Body.Bytes()))
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), defaultMaxLineSize)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
dataLine := scanner.Text()
|
||||
if !strings.HasPrefix(dataLine, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
var event apicompat.AnthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(dataLine[6:]), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
responsesEvents := apicompat.AnthropicEventToResponsesEvents(&event, anthState)
|
||||
for _, resEvt := range responsesEvents {
|
||||
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
|
||||
for _, chunk := range ccChunks {
|
||||
if err := writeChunk(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("read captured anthropic stream: %w", err)
|
||||
}
|
||||
|
||||
for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) {
|
||||
for _, chunk := range apicompat.ResponsesEventToChatChunks(&resEvt, ccState) {
|
||||
if err := writeChunk(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) {
|
||||
if err := writeChunk(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := io.WriteString(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isGeminiSignatureRelatedError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
|
||||
@@ -261,6 +261,68 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst
|
||||
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatServiceForwardAsChatCompletions_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
httpStub := &geminiCompatHTTPUpstreamStub{
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"x-request-id": []string{"gemini-cc-1"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"OK"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":1}}`)),
|
||||
},
|
||||
}
|
||||
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-key",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"Return OK"}],"max_tokens":64}`)
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gemini-2.5-flash", result.Model)
|
||||
require.Equal(t, "gemini-2.5-flash", result.UpstreamModel)
|
||||
require.False(t, result.Stream)
|
||||
require.Equal(t, 1, httpStub.calls)
|
||||
require.NotNil(t, httpStub.lastReq)
|
||||
require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:generateContent")
|
||||
|
||||
var out struct {
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &out))
|
||||
require.Equal(t, "chat.completion", out.Object)
|
||||
require.Equal(t, "gemini-2.5-flash", out.Model)
|
||||
require.Len(t, out.Choices, 1)
|
||||
require.Equal(t, "assistant", out.Choices[0].Message.Role)
|
||||
require.Equal(t, "OK", out.Choices[0].Message.Content)
|
||||
require.Equal(t, "stop", out.Choices[0].FinishReason)
|
||||
require.Equal(t, 7, out.Usage.PromptTokens)
|
||||
require.Equal(t, 1, out.Usage.CompletionTokens)
|
||||
require.Equal(t, 8, out.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
Reference in New Issue
Block a user