Files
sub2api/backend/internal/service/kiro_websearch.go
T
2026-05-17 06:45:35 +08:00

459 lines
14 KiB
Go

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
const kiroMaxWebSearchIterations = 5
var (
errKiroWebSearchFallback = errors.New("kiro web search fallback")
kiroWebSearchDescCache sync.Map
)
type kiroWebSearchExecution struct {
ResponseBody []byte
Usage ClaudeUsage
RequestID string
}
type kiroWebSearchHTTPError struct {
Response *http.Response
}
type kiroStreamChunkCollector struct {
chunks [][]byte
}
func (e *kiroWebSearchHTTPError) Error() string {
if e == nil || e.Response == nil {
return "kiro web search http error"
}
return fmt.Sprintf("kiro web search http error: %d", e.Response.StatusCode)
}
func (w *kiroStreamChunkCollector) Write(p []byte) (int, error) {
if len(p) > 0 {
w.chunks = append(w.chunks, append([]byte(nil), p...))
}
return len(p), nil
}
func bufferKiroAnthropicStream(ctx context.Context, body io.Reader, mappedModel string, inputTokens int) ([][]byte, *kiropkg.StreamResult, error) {
collector := &kiroStreamChunkCollector{}
result, err := kiropkg.StreamEventStreamAsAnthropic(ctx, body, collector, mappedModel, inputTokens)
if err != nil {
return nil, nil, err
}
return collector.chunks, result, nil
}
func writeSSEChunks(w io.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if len(chunk) == 0 {
continue
}
if _, err := w.Write(chunk); err != nil {
return err
}
}
return nil
}
func writeAnthropicMessageStart(w io.Writer, msgID, model string, inputTokens int) error {
if strings.TrimSpace(msgID) == "" {
msgID = "msg_" + kiropkg.GenerateToolUseID()
}
if strings.TrimSpace(model) == "" {
model = "kiro"
}
payload, err := json.Marshal(map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": inputTokens,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
})
if err != nil {
return err
}
_, err = io.WriteString(w, "event: message_start\ndata: "+string(payload)+"\n\n")
return err
}
func (s *GatewayService) streamKiroWebSearchAsAnthropic(
ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, inputTokens int, headers http.Header, w io.Writer,
) error {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return errKiroWebSearchFallback
}
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
if err != nil {
currentBody = anthropicBody
}
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
nextContentBlockIndex := 0
if err := writeAnthropicMessageStart(w, "", mappedModel, inputTokens); err != nil {
return err
}
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
s.prefetchKiroWebSearchDescription(ctx, account, token)
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
if strings.TrimSpace(nextToken) != "" {
token = nextToken
}
if mcpErr != nil {
results = nil
}
if err := writeSSEChunks(w, kiropkg.GenerateSearchIndicatorEvents(query, currentToolUseID, results, nextContentBlockIndex)); err != nil {
return err
}
nextContentBlockIndex += 2
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
if err != nil {
return errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return &kiroWebSearchHTTPError{Response: resp}
}
chunks, _, streamErr := func() ([][]byte, *kiropkg.StreamResult, error) {
defer func() { _ = resp.Body.Close() }()
return bufferKiroAnthropicStream(ctx, resp.Body, mappedModel, inputTokens)
}()
if streamErr != nil {
return streamErr
}
analysis := kiropkg.AnalyzeBufferedStream(chunks)
if analysis.HasWebSearchToolUse && strings.TrimSpace(analysis.WebSearchQuery) != "" && iteration+1 < kiroMaxWebSearchIterations {
filtered := kiropkg.FilterChunksForClient(chunks, analysis.WebSearchToolUseIndex, nextContentBlockIndex)
if err := writeSSEChunks(w, filtered); err != nil {
return err
}
if maxIndex := kiropkg.MaxContentBlockIndex(filtered); maxIndex >= nextContentBlockIndex {
nextContentBlockIndex = maxIndex + 1
}
query = analysis.WebSearchQuery
if strings.TrimSpace(analysis.WebSearchToolUseID) == "" {
currentToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
} else {
currentToolUseID = analysis.WebSearchToolUseID
}
continue
}
for _, chunk := range chunks {
adjusted, shouldForward := kiropkg.AdjustSSEChunk(chunk, nextContentBlockIndex)
if !shouldForward {
continue
}
if _, err := w.Write(adjusted); err != nil {
return err
}
}
return nil
}
return fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) executeKiroWebSearch(ctx context.Context, account *Account, anthropicBody []byte, mappedModel, requestModel, token string, headers http.Header) (*kiroWebSearchExecution, error) {
query := kiropkg.ExtractSearchQuery(anthropicBody)
if strings.TrimSpace(query) == "" {
return nil, errKiroWebSearchFallback
}
currentBody, err := kiropkg.ReplaceWebSearchToolDescription(anthropicBody)
if err != nil {
currentBody = anthropicBody
}
inputTokens := estimateKiroInputTokens(anthropicBody)
currentToolUseID := "srvtoolu_" + kiropkg.GenerateToolUseID()
searches := make([]kiropkg.SearchIndicator, 0, 2)
requestID := ""
for iteration := 0; iteration < kiroMaxWebSearchIterations; iteration++ {
s.prefetchKiroWebSearchDescription(ctx, account, token)
results, nextToken, mcpErr := s.callKiroWebSearchMCP(ctx, account, token, query)
if strings.TrimSpace(nextToken) != "" {
token = nextToken
}
if mcpErr != nil {
results = nil
}
searches = append(searches, kiropkg.SearchIndicator{
ToolUseID: currentToolUseID,
Query: query,
Results: results,
})
currentBody, err = kiropkg.InjectToolResultsClaude(currentBody, currentToolUseID, query, results)
if err != nil {
return nil, errKiroWebSearchFallback
}
resp, _, err := s.executeKiroUpstream(ctx, account, currentBody, mappedModel, requestModel, token, headers)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &kiroWebSearchHTTPError{Response: resp}
}
parseResult, parseErr := func() (*kiropkg.ParseResult, error) {
defer func() { _ = resp.Body.Close() }()
return kiropkg.ParseNonStreamingEventStream(resp.Body, mappedModel)
}()
if parseErr != nil {
return nil, parseErr
}
if requestID == "" {
requestID = buildKiroRequestID(resp)
}
nextToolUseID, nextQuery, hasNext := kiropkg.ExtractWebSearchToolUseFromResponse(parseResult.ResponseBody)
if !hasNext || strings.TrimSpace(nextQuery) == "" || iteration+1 >= kiroMaxWebSearchIterations {
finalBody, injectErr := kiropkg.InjectSearchIndicatorsInResponse(parseResult.ResponseBody, searches)
if injectErr == nil {
parseResult.ResponseBody = finalBody
}
return &kiroWebSearchExecution{
ResponseBody: parseResult.ResponseBody,
Usage: kiroUsageToClaude(parseResult.Usage, inputTokens),
RequestID: requestID,
}, nil
}
query = nextQuery
if strings.TrimSpace(nextToolUseID) == "" {
nextToolUseID = "srvtoolu_" + kiropkg.GenerateToolUseID()
}
currentToolUseID = nextToolUseID
}
return nil, fmt.Errorf("kiro web search exceeded max iterations")
}
func (s *GatewayService) prefetchKiroWebSearchDescription(ctx context.Context, account *Account, token string) {
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
if cached, ok := kiroWebSearchDescCache.Load(endpoint); ok {
if desc, ok := cached.(string); ok && strings.TrimSpace(desc) != "" {
kiropkg.SetCachedWebSearchDescription(desc)
}
return
}
reqBody, _ := json.Marshal(kiropkg.MCPRequest{
ID: "tools_list",
JSONRPC: "2.0",
Method: "tools/list",
})
resp, _, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
if err != nil || resp == nil {
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return
}
var result kiropkg.MCPResponse
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
return
}
for _, tool := range result.Result.Tools {
if strings.EqualFold(tool.Name, "web_search") && strings.TrimSpace(tool.Description) != "" {
kiroWebSearchDescCache.Store(endpoint, tool.Description)
kiropkg.SetCachedWebSearchDescription(tool.Description)
return
}
}
}
func (s *GatewayService) callKiroWebSearchMCP(ctx context.Context, account *Account, token, query string) (*kiropkg.WebSearchResults, string, error) {
reqBody, err := json.Marshal(buildKiroWebSearchMCPRequest(query))
if err != nil {
return nil, token, err
}
endpoint := kiropkg.BuildMcpEndpoint(kiroAPIRegion(account))
resp, nextToken, err := s.doKiroMCPJSONRequest(ctx, account, endpoint, reqBody, token)
if err != nil {
return nil, nextToken, err
}
if resp == nil {
return nil, nextToken, fmt.Errorf("kiro web search returned nil response")
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nextToken, err
}
if resp.StatusCode != http.StatusOK {
return nil, nextToken, fmt.Errorf("kiro mcp status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed kiropkg.MCPResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, nextToken, err
}
if parsed.Error != nil {
msg := "unknown error"
if parsed.Error.Message != nil && strings.TrimSpace(*parsed.Error.Message) != "" {
msg = strings.TrimSpace(*parsed.Error.Message)
}
code := 0
if parsed.Error.Code != nil {
code = *parsed.Error.Code
}
return nil, nextToken, fmt.Errorf("kiro mcp error %d: %s", code, msg)
}
return kiropkg.ParseSearchResults(&parsed), nextToken, nil
}
func buildKiroWebSearchMCPRequest(query string) kiropkg.MCPRequest {
return kiropkg.MCPRequest{
ID: fmt.Sprintf("web_search_%s", kiropkg.GenerateToolUseID()),
JSONRPC: "2.0",
Method: "tools/call",
Params: map[string]any{
"name": "web_search",
"arguments": map[string]any{
"query": query,
"_meta": map[string]any{
"_isValid": true,
"_activePath": []string{"query"},
"_completedPaths": [][]string{{"query"}},
},
},
},
}
}
func (s *GatewayService) doKiroMCPJSONRequest(ctx context.Context, account *Account, endpoint string, payload []byte, token string) (*http.Response, string, error) {
currentToken := token
accountKey := buildKiroAccountKey(account)
proxyURL := kiroProxyURL(account)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
for attempt := 0; attempt < 3; attempt++ {
if err := s.checkAndWaitKiroCooldown(ctx, accountKey); err != nil {
if failoverErr := asKiroCooldownFailoverError(err); failoverErr != nil {
return nil, currentToken, failoverErr
}
return nil, currentToken, err
}
req, err := newKiroJSONRequest(ctx, endpoint, payload, currentToken, accountKey, buildKiroMachineID(account), "", account)
if err != nil {
return nil, currentToken, err
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return nil, currentToken, err
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, currentToken, readErr
}
if resp.StatusCode == http.StatusForbidden && isKiroSuspendedBody(respBody) {
if _, err := s.markKiroSuspended(ctx, accountKey); err != nil {
return nil, currentToken, err
}
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
if resp.StatusCode == http.StatusForbidden && !isKiroTokenErrorBody(respBody) {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
if s.kiroTokenProvider == nil {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
refreshedToken, refreshErr := s.kiroTokenProvider.ForceRefreshAccessToken(ctx, account)
if refreshErr != nil {
resp.Body = io.NopCloser(strings.NewReader(string(respBody)))
return resp, currentToken, nil
}
currentToken = refreshedToken
accountKey = buildKiroAccountKey(account)
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, currentToken, sleepErr
}
continue
}
if resp.StatusCode == http.StatusTooManyRequests {
if _, err := s.markKiro429(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, currentToken, err
}
}
if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode >= 500 {
if attempt < 2 {
_ = resp.Body.Close()
if sleepErr := sleepKiroRetry(ctx, attempt); sleepErr != nil {
return nil, currentToken, sleepErr
}
continue
}
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
if err := s.markKiroSuccess(ctx, accountKey); err != nil {
_ = resp.Body.Close()
return nil, currentToken, err
}
}
return resp, currentToken, nil
}
return nil, currentToken, fmt.Errorf("kiro mcp request retries exhausted")
}