228 lines
7.9 KiB
Go
228 lines
7.9 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// swapMonitorHTTPClient 临时替换 monitorHTTPClient 为不带 SSRF 校验的普通 client,
|
||
// 让 httptest (127.0.0.1) 能连通。测试结束后恢复。
|
||
func swapMonitorHTTPClient(t *testing.T) {
|
||
t.Helper()
|
||
orig := monitorHTTPClient
|
||
monitorHTTPClient = &http.Client{Timeout: 5 * time.Second}
|
||
t.Cleanup(func() { monitorHTTPClient = orig })
|
||
}
|
||
|
||
// captureHandler 把每次收到的请求 body 和 headers 存起来,测试断言用。
|
||
type captureHandler struct {
|
||
lastBody map[string]any
|
||
lastHeaders http.Header
|
||
respondText string // 写到 Anthropic content[0].text 里(校验用)
|
||
status int
|
||
}
|
||
|
||
func (h *captureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
h.lastHeaders = r.Header.Clone()
|
||
defer func() { _ = r.Body.Close() }()
|
||
var parsed map[string]any
|
||
_ = json.NewDecoder(r.Body).Decode(&parsed)
|
||
h.lastBody = parsed
|
||
|
||
respondText := h.respondText
|
||
if respondText == autoChallengeResponse {
|
||
respondText = solveMonitorChallengeFromBody(parsed)
|
||
}
|
||
|
||
if h.status == 0 {
|
||
h.status = 200
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(h.status)
|
||
// 构造 Anthropic 格式的响应:content[0].text = h.respondText
|
||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||
"content": []map[string]any{
|
||
{"type": "text", "text": respondText},
|
||
},
|
||
})
|
||
}
|
||
|
||
const autoChallengeResponse = "__auto_challenge__"
|
||
|
||
var testChallengeLineRegex = regexp.MustCompile(`Q:\s*(\d+)\s*([+-])\s*(\d+)\s*=\s*\?`)
|
||
|
||
func solveMonitorChallengeFromBody(body map[string]any) string {
|
||
msgs, _ := body["messages"].([]any)
|
||
if len(msgs) == 0 {
|
||
return ""
|
||
}
|
||
first, _ := msgs[0].(map[string]any)
|
||
prompt, _ := first["content"].(string)
|
||
matches := testChallengeLineRegex.FindAllStringSubmatch(prompt, -1)
|
||
if len(matches) == 0 {
|
||
return ""
|
||
}
|
||
last := matches[len(matches)-1]
|
||
a, _ := strconv.Atoi(last[1])
|
||
b, _ := strconv.Atoi(last[3])
|
||
if last[2] == "-" {
|
||
return strconv.Itoa(a - b)
|
||
}
|
||
return strconv.Itoa(a + b)
|
||
}
|
||
|
||
func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
|
||
t.Helper()
|
||
swapMonitorHTTPClient(t)
|
||
srv := httptest.NewServer(handler)
|
||
t.Cleanup(srv.Close)
|
||
return srv.URL
|
||
}
|
||
|
||
func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
|
||
h := &captureHandler{respondText: autoChallengeResponse}
|
||
endpoint := setupFakeAnthropic(t, h)
|
||
|
||
// 跑一次 off 模式(opts=nil),确认默认 body 行为未变
|
||
_ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", nil)
|
||
|
||
if h.lastBody["model"] != "claude-x" {
|
||
t.Errorf("default body should contain model=claude-x, got %v", h.lastBody["model"])
|
||
}
|
||
if _, ok := h.lastBody["messages"]; !ok {
|
||
t.Error("default body should contain messages")
|
||
}
|
||
if h.lastHeaders.Get("x-api-key") != "sk-fake" {
|
||
t.Errorf("expected adapter's x-api-key header, got %q", h.lastHeaders.Get("x-api-key"))
|
||
}
|
||
}
|
||
|
||
func TestRunCheckForModel_KiroUsesAnthropicCompatibleMessages(t *testing.T) {
|
||
h := &captureHandler{respondText: autoChallengeResponse}
|
||
endpoint := setupFakeAnthropic(t, h)
|
||
|
||
res := runCheckForModel(context.Background(), MonitorProviderKiro, endpoint, "sk-kiro", "kiro-model", nil)
|
||
|
||
if res.Status != MonitorStatusOperational {
|
||
t.Fatalf("expected operational, got status=%s message=%q", res.Status, res.Message)
|
||
}
|
||
if h.lastBody["model"] != "kiro-model" {
|
||
t.Errorf("kiro body should contain model=kiro-model, got %v", h.lastBody["model"])
|
||
}
|
||
if _, ok := h.lastBody["messages"]; !ok {
|
||
t.Error("kiro body should contain Anthropic-compatible messages")
|
||
}
|
||
if h.lastHeaders.Get("x-api-key") != "sk-kiro" {
|
||
t.Errorf("expected x-api-key header, got %q", h.lastHeaders.Get("x-api-key"))
|
||
}
|
||
if h.lastHeaders.Get("anthropic-version") != monitorAnthropicAPIVersion {
|
||
t.Errorf("expected anthropic-version header, got %q", h.lastHeaders.Get("anthropic-version"))
|
||
}
|
||
}
|
||
|
||
func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
|
||
h := &captureHandler{respondText: "the answer is 42"}
|
||
endpoint := setupFakeAnthropic(t, h)
|
||
|
||
opts := &CheckOptions{
|
||
BodyOverrideMode: MonitorBodyOverrideModeMerge,
|
||
BodyOverride: map[string]any{
|
||
"system": "You are Claude Code...",
|
||
"max_tokens": float64(999), // 应该覆盖默认 50
|
||
"model": "hacked-model", // 应该被黑名单挡住,保留原 model
|
||
"messages": []any{}, // 同上,被挡
|
||
},
|
||
ExtraHeaders: map[string]string{
|
||
"User-Agent": "claude-cli/1.0",
|
||
"Content-Length": "999", // 黑名单
|
||
"x-custom": "ok",
|
||
},
|
||
}
|
||
_ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
|
||
|
||
if h.lastBody["system"] != "You are Claude Code..." {
|
||
t.Errorf("merge mode should inject system, got %v", h.lastBody["system"])
|
||
}
|
||
// max_tokens 覆盖生效
|
||
if mt, ok := h.lastBody["max_tokens"].(float64); !ok || mt != 999 {
|
||
t.Errorf("merge mode should override max_tokens to 999, got %v", h.lastBody["max_tokens"])
|
||
}
|
||
// model 在黑名单 — 应该保留默认值
|
||
if h.lastBody["model"] != "claude-x" {
|
||
t.Errorf("model should be protected by deny list, got %v", h.lastBody["model"])
|
||
}
|
||
// messages 在黑名单 — 应该保留默认值(非空)
|
||
msgs, _ := h.lastBody["messages"].([]any)
|
||
if len(msgs) == 0 {
|
||
t.Error("messages should be protected by deny list (kept default, non-empty)")
|
||
}
|
||
// header 合并
|
||
if h.lastHeaders.Get("User-Agent") != "claude-cli/1.0" {
|
||
t.Errorf("extra User-Agent should override, got %q", h.lastHeaders.Get("User-Agent"))
|
||
}
|
||
if h.lastHeaders.Get("x-custom") != "ok" {
|
||
t.Errorf("extra custom header should be present, got %q", h.lastHeaders.Get("x-custom"))
|
||
}
|
||
// Content-Length 黑名单:会被 net/http 自动重算,但不应由用户的 "999" 决定。
|
||
// 我们无法直接断言丢弃(http.Client 总会填上),只断言请求成功即可。
|
||
}
|
||
|
||
func TestRunCheckForModel_ReplaceMode_FullBodyUsedAndChallengeSkipped(t *testing.T) {
|
||
// replace 模式下我们的 body 完全自定义,challenge 数学题不会出现在请求里,
|
||
// 上游也不会回正确答案 — 但只要 2xx + 响应文本非空,就算 operational
|
||
h := &captureHandler{respondText: "any non-empty text"}
|
||
endpoint := setupFakeAnthropic(t, h)
|
||
|
||
userBody := map[string]any{
|
||
"model": "user-forced-model",
|
||
"messages": []any{map[string]any{"role": "user", "content": "hi"}},
|
||
"max_tokens": float64(10),
|
||
"system": "You are someone else",
|
||
}
|
||
opts := &CheckOptions{
|
||
BodyOverrideMode: MonitorBodyOverrideModeReplace,
|
||
BodyOverride: userBody,
|
||
}
|
||
res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
|
||
|
||
// 请求 body = 用户提供的原样
|
||
if h.lastBody["model"] != "user-forced-model" {
|
||
t.Errorf("replace mode should use user's model, got %v", h.lastBody["model"])
|
||
}
|
||
if h.lastBody["system"] != "You are someone else" {
|
||
t.Errorf("replace mode should use user's system, got %v", h.lastBody["system"])
|
||
}
|
||
// challenge 虽然没命中,但由于 replace 模式跳过 challenge 校验 + 响应非空 → operational
|
||
if res.Status != MonitorStatusOperational {
|
||
t.Errorf("replace mode with 2xx + non-empty text should be operational, got status=%s message=%q",
|
||
res.Status, res.Message)
|
||
}
|
||
}
|
||
|
||
func TestRunCheckForModel_ReplaceMode_EmptyResponseIsFailed(t *testing.T) {
|
||
h := &captureHandler{respondText: ""} // 上游 200 但 content[0].text 为空
|
||
endpoint := setupFakeAnthropic(t, h)
|
||
|
||
opts := &CheckOptions{
|
||
BodyOverrideMode: MonitorBodyOverrideModeReplace,
|
||
BodyOverride: map[string]any{"model": "x", "messages": []any{}},
|
||
}
|
||
res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
|
||
|
||
if res.Status != MonitorStatusFailed {
|
||
t.Errorf("replace mode with empty text should be failed, got status=%s", res.Status)
|
||
}
|
||
if !strings.Contains(res.Message, "replace-mode") {
|
||
t.Errorf("failure message should hint replace-mode, got %q", res.Message)
|
||
}
|
||
}
|