Files
sub2api/backend/internal/pkg/kiro/websearch_test.go
2026-04-30 14:02:05 +08:00

139 lines
5.3 KiB
Go

package kiro
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
body := []byte(`{
"tools":[{"type":"web_search_20250305","description":"old"}],
"messages":[{"role":"user","content":"golang"}]
}`)
updated, err := ReplaceWebSearchToolDescription(body)
require.NoError(t, err)
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
}
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
body := []byte(`{
"messages":[{"role":"user","content":"what is golang"}]
}`)
results := &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev"},
},
}
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
require.NoError(t, err)
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
}
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
response := []byte(`{
"content":[
{"type":"text","text":"let me search"},
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
]
}`)
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
require.True(t, ok)
require.Equal(t, "srvtoolu_next", toolUseID)
require.Equal(t, "golang concurrency", query)
}
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
response := []byte(`{
"id":"msg_1",
"type":"message",
"role":"assistant",
"model":"kiro",
"content":[{"type":"text","text":"final"}],
"stop_reason":"end_turn",
"usage":{"input_tokens":1,"output_tokens":1}
}`)
snippet := "result snippet"
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
{
ToolUseID: "srvtoolu_test",
Query: "golang",
Results: &WebSearchResults{
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
},
},
})
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(updated, &decoded))
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
}
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
resp := &MCPResponse{
Result: &struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
}{
Content: []struct {
Type string `json:"type"`
Text string `json:"text"`
}{
{
Type: "text",
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
},
},
},
}
results := ParseSearchResults(resp)
require.NotNil(t, results)
require.Len(t, results.Results, 1)
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
require.Equal(t, "doc-1", *results.Results[0].ID)
require.Equal(t, "go.dev", *results.Results[0].Domain)
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
require.True(t, *results.Results[0].PublicDomain)
}
func TestSearchGuidanceText_IsStructured(t *testing.T) {
guidance := searchGuidanceText()
require.Contains(t, guidance, "<search_guidance>")
require.Contains(t, guidance, "Current date:")
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
require.Contains(t, guidance, "Rephrasing in English for better coverage")
}