package service import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "sync" kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro" ) const kiroMaxWebSearchIterations = 5 var ( errKiroWebSearchFallback = errors.New("kiro web search fallback") kiroWebSearchDescCache sync.Map ) type kiroWebSearchExecution struct { ResponseBody []byte Usage ClaudeUsage RequestID string } type kiroWebSearchHTTPError struct { Response *http.Response } type kiroStreamChunkCollector struct { chunks [][]byte } func (e *kiroWebSearchHTTPError) Error() string { if e == nil || e.Response == nil { return "kiro web search http error" } return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode) } func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) { if len(p) > 0 { w.chunks = append(w.chunks, append([]byte(nil), p...)) } return len(p), nil } func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) { collector := &kiroStreamChunkCollector{} result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens) if err != nil { return nil, nil, err } return collector.chunks, result, nil } func writeSSEChunks(w io.Writer, chunks [][]byte) error { for _, chunk := range chunks { if len(chunk) == 0 { continue } if _, err := w.Write(chunk); err != nil { return err } } return nil } func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error { if strings.TrimSpace(msgID) == "" { msgID = "msg_" + kiropkg.GenerateToolUseID() } if strings.TrimSpace(model) == "" { model = "kiro" } payload, err := json.Marshal(map[string]any{ "type": "message_start", "message": map[string]any{ "id": msgID, "type": "message", "role": "assistant", "model": model, "content": []any{}, "stop_reason": nil, "stop_sequence": nil, "usage": map[string]any{ "input_tokens": inputTokens, "output_tokens": 0, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0, }, }, }) if err != nil { return err } _, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n") return err } func (s *GatewayService) streamKiroWebSearchAsAnthropic( ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer, ) error { query := kiropkg.ExtractSearchQuery(anthropicBody) if strings.TrimSpace(query) == "" { return errKiroWebSearchFallback } currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody) if err != nil { currentBody = anthropicBody } currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID() nextContentBlockIndex := 0 if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil { return err } for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ { s.prefetchKiroWebSearchDescription(ctx, account, token) results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query) if strings.TrimSpace(nextToken) != "" { token = nextToken } if mcpErr != nil { results = nil } if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil { return err } nextContentBlockIndex += 2 currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results) if err != nil { return errKiroWebSearchFallback } resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers) if err != nil { return err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return &kiroWebSearchHTTPError{Response: resp} } chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) { defer func() { _ = resp.Body.Close() }() return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens) }() if streamErr != nil { return streamErr } analysis := kiropkg.AnalyzeBufferedStream(chunks) if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations { filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex) if err := writeSSEChunks(w, filtered); err != nil { return err } if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex { nextContentBlockIndex = maxIndex + 1 } query = analysis.WebSearchQuery if strings.TrimSpace(analysis.WebSearchToolUseID) == "" { currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID() } else { currentToolUseID = analysis.WebSearchToolUseID } continue } for _, chunk := range chunks { adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex) if !shouldForward { continue } if _, err := w.Write(adjusted); err != nil { return err } } return nil } return fmt.Errorf("kiro web search exceeded max iterations") } func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) { query := kiropkg.ExtractSearchQuery(anthropicBody) if strings.TrimSpace(query) == "" { return nil, errKiroWebSearchFallback } currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody) if err != nil { currentBody = anthropicBody } inputTokens := estimateKiroInputTokens(anthropicBody) currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID() searches := make([]kiropkg.SearchIndicator, 0, 2) requestID := "" for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ { s.prefetchKiroWebSearchDescription(ctx, account, token) results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query) if strings.TrimSpace(nextToken) != "" { token = nextToken } if mcpErr != nil { results = nil } searches = append(searches, kiropkg.SearchIndicator{ ToolUseID: currentToolUseID, Query: query, Results: results, }) currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results) if err != nil { return nil, errKiroWebSearchFallback } resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers) if err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, &kiroWebSearchHTTPError{Response: resp} } parseResult, parseErr := func() (*kiropkg.ParseResult, error) { defer func() { _ = resp.Body.Close() }() return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel) }() if parseErr != nil { return nil, parseErr } if requestID == "" { requestID = buildKiroRequestID(resp) } nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody) if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations { finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches) if injectErr == nil { parseResult.ResponseBody = finalBody } return &kiroWebSearchExecution{ ResponseBody: parseResult.ResponseBody, Usage: kiroUsageToClaude(parseResult.Usage, inputTokens), RequestID: requestID, }, nil } query = nextQuery if strings.TrimSpace(nextToolUseID) == "" { nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID() } currentToolUseID = nextToolUseID } return nil, fmt.Errorf("kiro web search exceeded max iterations") } func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) { endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account)) if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok { if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" { kiropkg.SetCachedWebSearchDescription(desc) } return } reqBody, _ := json.Marshal(kiropkg.MCPRequest{ ID: "tools_list", JSONRPC: "2.0", Method: "tools/list", }) resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token) if err != nil || resp == nil { return } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return } body, err := io.ReadAll(resp.Body) if err != nil { return } var result kiropkg.MCPResponse if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { return } for _, tool := range result.Result.Tools { if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" { kiroWebSearchDescCache.Store(endpoint, tool.Description) kiropkg.SetCachedWebSearchDescription(tool.Description) return } } } func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) { reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query)) if err != nil { return nil, token, err } endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account)) resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token) if err != nil { return nil, nextToken, err } if resp == nil { return nil, nextToken, fmt.Errorf("kiro web search returned nil response") } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return nil, nextToken, err } if resp.StatusCode != http.StatusOK { return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } var parsed kiropkg.MCPResponse if err := json.Unmarshal(body, &parsed); err != nil { return nil, nextToken, err } if parsed.Error != nil { msg := "unknown error" if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" { msg = strings.TrimSpace(*parsed.Error.Message) } code := 0 if parsed.Error.Code != nil { code = *parsed.Error.Code } return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg) } return kiropkg.ParseSearchResults(&parsed), nextToken, nil } func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest { return kiropkg.MCPRequest{ ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()), JSONRPC: "2.0", Method: "tools/call", Params: map[string]interface{}{ "name": "web_search", "arguments": map[string]interface{}{ "query": query, "_meta": map[string]interface{}{ "_isValid": true, "_activePath": []string{"query"}, "_completedPaths": [][]string{{"query"}}, }, }, }, } } func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) { currentToken := token accountKey := buildKiroAccountKey(account) proxyURL := kiroProxyURL(account) tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account) for attempt := 0; attempt < 3; attempt++ { if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil { if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil { return nil, currentToken, failoverErr } return nil, currentToken, err } req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account) if err != nil { return nil, currentToken, err } resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile) if err != nil { return nil, currentToken, err } if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { respBody, readErr := io.ReadAll(resp.Body) _ = resp.Body.Close() if readErr != nil { return nil, currentToken, readErr } if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) { if _, err := s.markKiroSuspended(ctx, accountKey); err != nil { return nil, currentToken, err } resp.Body = io.NopCloser(strings.NewReader(string(respBody))) return resp, currentToken, nil } if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) { resp.Body = io.NopCloser(strings.NewReader(string(respBody))) return resp, currentToken, nil } if s.kiroTokenProvider == nil { resp.Body = io.NopCloser(strings.NewReader(string(respBody))) return resp, currentToken, nil } refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account) if refreshErr != nil { resp.Body = io.NopCloser(strings.NewReader(string(respBody))) return resp, currentToken, nil } currentToken = refreshedToken accountKey = buildKiroAccountKey(account) if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil { return nil, currentToken, sleepErr } continue } if resp.StatusCode == http.StatusTooManyRequests { if _, err := s.markKiro429(ctx, accountKey); err != nil { _ = resp.Body.Close() return nil, currentToken, err } } if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 { if attempt < 2 { _ = resp.Body.Close() if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil { return nil, currentToken, sleepErr } continue } } if resp.StatusCode >= 200 && resp.StatusCode < 300 { if err := s.markKiroSuccess(ctx, accountKey); err != nil { _ = resp.Body.Close() return nil, currentToken, err } } return resp, currentToken, nil } return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted") }