diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index e164af0b..6be7166f 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.141 +0.1.142 diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index c6b73190..be5fa548 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -219,7 +219,16 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) + var result *service.ForwardResult + if account.Platform == service.PlatformGemini { + if h.geminiCompatService == nil { + err = errors.New("gemini compatibility service not configured") + } else { + result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody) + } + } else { + result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) + } if accountReleaseFunc != nil { accountReleaseFunc() diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index ea0c0d7d..5f4a4f0a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -14,11 +14,13 @@ import ( "math" mathrand "math/rand" "net/http" + "net/http/httptest" "regexp" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" @@ -1090,6 +1092,223 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex }, nil } +// ForwardAsChatCompletions accepts an OpenAI Chat Completions request, converts +// it through the existing Responses/Anthropic compatibility chain, forwards it +// to Gemini, then translates the captured Anthropic-compatible response back to +// Chat Completions format. +func (s *GeminiMessagesCompatService) ForwardAsChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + var ccReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &ccReq); err != nil { + writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + if strings.TrimSpace(ccReq.Model) == "" { + writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return nil, errors.New("model is required") + } + + originalModel := ccReq.Model + clientStream := ccReq.Stream + includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage + + responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq) + if err != nil { + writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq) + if err != nil { + writeGeminiChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + // ChatCompletionsToResponses intentionally forces upstream streaming for + // OpenAI/Codex paths. Gemini's Claude-compat forwarder writes the client + // protocol directly, so preserve the client's requested stream mode here. + anthropicReq.Stream = clientStream + + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + writeGeminiChatCompletionsError(c, http.StatusInternalServerError, "api_error", "Failed to build upstream request") + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + captureC, recorder := newCapturedGinContext(ctx, c) + result, err := s.Forward(ctx, captureC, account, anthropicBody) + if err != nil { + var failoverErr *UpstreamFailoverError + if errors.As(err, &failoverErr) { + return nil, err + } + writeCapturedClaudeErrorAsChatCompletions(c, recorder) + return nil, err + } + + if recorder.Code >= 400 { + writeCapturedClaudeErrorAsChatCompletions(c, recorder) + return nil, fmt.Errorf("gemini chat completions upstream failed with status %d", recorder.Code) + } + + if clientStream { + if err := writeCapturedAnthropicStreamAsChatCompletions(c, recorder, originalModel, includeUsage); err != nil { + return nil, err + } + } else { + if err := writeCapturedAnthropicMessageAsChatCompletions(c, recorder, originalModel); err != nil { + return nil, err + } + } + + if result == nil { + result = &ForwardResult{} + } + result.Model = originalModel + result.Stream = clientStream + result.Duration = time.Since(startTime) + return result, nil +} + +func newCapturedGinContext(ctx context.Context, src *gin.Context) (*gin.Context, *httptest.ResponseRecorder) { + recorder := httptest.NewRecorder() + captureC, _ := gin.CreateTestContext(recorder) + if src != nil { + captureC.Params = src.Params + if src.Request != nil { + captureC.Request = src.Request.Clone(ctx) + } + } + if captureC.Request == nil { + captureC.Request, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/messages", nil) + } + return captureC, recorder +} + +func writeGeminiChatCompletionsError(c *gin.Context, status int, errType, message string) { + if c == nil || c.Writer.Written() { + return + } + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func writeCapturedClaudeErrorAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder) { + if c == nil || c.Writer.Written() { + return + } + status := recorder.Code + if status < 400 { + status = http.StatusBadGateway + } + body := recorder.Body.Bytes() + errType := strings.TrimSpace(gjson.GetBytes(body, "error.type").String()) + if errType == "" { + errType = "server_error" + } + message := strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) + if message == "" { + message = strings.TrimSpace(gjson.GetBytes(body, "message").String()) + } + if message == "" { + message = "Upstream request failed" + } + writeGeminiChatCompletionsError(c, status, errType, message) +} + +func writeCapturedAnthropicMessageAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder, originalModel string) error { + var anthropicResp apicompat.AnthropicResponse + if err := json.Unmarshal(recorder.Body.Bytes(), &anthropicResp); err != nil { + writeGeminiChatCompletionsError(c, http.StatusBadGateway, "server_error", "Failed to parse upstream response") + return fmt.Errorf("parse captured anthropic response: %w", err) + } + + responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp) + ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel) + c.JSON(http.StatusOK, ccResp) + return nil +} + +func writeCapturedAnthropicStreamAsChatCompletions(c *gin.Context, recorder *httptest.ResponseRecorder, originalModel string, includeUsage bool) error { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + anthState := apicompat.NewAnthropicEventToResponsesState() + anthState.Model = originalModel + ccState := apicompat.NewResponsesEventToChatState() + ccState.Model = originalModel + ccState.IncludeUsage = includeUsage + + writeChunk := func(chunk apicompat.ChatCompletionsChunk) error { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + return err + } + _, err = io.WriteString(c.Writer, sse) + return err + } + + scanner := bufio.NewScanner(bytes.NewReader(recorder.Body.Bytes())) + scanner.Buffer(make([]byte, 0, 64*1024), defaultMaxLineSize) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(dataLine[6:]), &event); err != nil { + continue + } + + responsesEvents := apicompat.AnthropicEventToResponsesEvents(&event, anthState) + for _, resEvt := range responsesEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + if err := writeChunk(chunk); err != nil { + return err + } + } + } + c.Writer.Flush() + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("read captured anthropic stream: %w", err) + } + + for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) { + for _, chunk := range apicompat.ResponsesEventToChatChunks(&resEvt, ccState) { + if err := writeChunk(chunk); err != nil { + return err + } + } + } + for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) { + if err := writeChunk(chunk); err != nil { + return err + } + } + if _, err := io.WriteString(c.Writer, "data: [DONE]\n\n"); err != nil { + return err + } + c.Writer.Flush() + return nil +} + func isGeminiSignatureRelatedError(respBody []byte) bool { msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) if msg == "" { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index c2adf45d..bf72fdf1 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -261,6 +261,68 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:") } +func TestGeminiMessagesCompatServiceForwardAsChatCompletions_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"gemini-cc-1"}}, + Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"OK"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":1}}`)), + }, + } + svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}} + account := &Account{ + ID: 1, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-key", + }, + } + body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"Return OK"}],"max_tokens":64}`) + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, "gemini-2.5-flash", result.UpstreamModel) + require.False(t, result.Stream) + require.Equal(t, 1, httpStub.calls) + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:generateContent") + + var out struct { + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &out)) + require.Equal(t, "chat.completion", out.Object) + require.Equal(t, "gemini-2.5-flash", out.Model) + require.Len(t, out.Choices, 1) + require.Equal(t, "assistant", out.Choices[0].Message.Role) + require.Equal(t, "OK", out.Choices[0].Message.Content) + require.Equal(t, "stop", out.Choices[0].FinishReason) + require.Equal(t, 7, out.Usage.PromptTokens) + require.Equal(t, 1, out.Usage.CompletionTokens) + require.Equal(t, 8, out.Usage.TotalTokens) +} + func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder()