feat(risk-control): add content moderation audit

This commit is contained in:
shaw
2026-05-07 09:01:48 +08:00
parent a1106e8167
commit fff4a300c6
54 changed files with 6840 additions and 34 deletions
@@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Name string `json:"name"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Name: apiKey.Name,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
@@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Name: snapshot.Name,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
@@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Name: "Audit Key",
Status: StatusActive,
User: &User{
ID: 2,
@@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.Equal(t, apiKey.Name, roundTrip.Name)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,117 @@
package service
import (
"fmt"
"html"
"strings"
"time"
)
func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
if log == nil {
return ""
}
userName := strings.TrimSpace(log.UserEmail)
if userName == "" && log.UserID != nil {
userName = fmt.Sprintf("UID %d", *log.UserID)
}
threshold := cfg.BanThreshold
if threshold <= 0 {
threshold = defaultContentModerationBanThreshold
}
statusBlock := ""
if log.AutoBanned {
statusBlock = `<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>`
}
return fmt.Sprintf(`<!doctype html>
<html>
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 风控提醒</div>
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户触发内容审计规则</h1>
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>,您的 API 请求在内容审计中触发平台风控策略。详情如下。</p>
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">触发详情</h2>
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 次(阈值 %d</td></tr>
</table>
</div>
%s
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送,请勿回复。</p>
</div>
</div>
</body>
</html>`,
html.EscapeString(userName),
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
log.HighestScore,
log.ViolationCount,
threshold,
statusBlock,
html.EscapeString(siteName),
)
}
func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
if log == nil {
return ""
}
userName := strings.TrimSpace(log.UserEmail)
if userName == "" && log.UserID != nil {
userName = fmt.Sprintf("UID %d", *log.UserID)
}
threshold := cfg.BanThreshold
if threshold <= 0 {
threshold = defaultContentModerationBanThreshold
}
return fmt.Sprintf(`<!doctype html>
<html>
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 账户封禁</div>
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户已被自动禁用</h1>
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>,您的账户在计数周期内多次触发平台风控策略,系统已自动禁用该账户。详情如下。</p>
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">封禁详情</h2>
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">封禁时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 次(阈值 %d</td></tr>
</table>
</div>
<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>
<p style="font-size:15px;line-height:1.8;color:#666;margin-top:24px;">如需申诉或恢复账号,请联系平台管理员处理。</p>
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送,请勿回复。</p>
</div>
</div>
</body>
</html>`,
html.EscapeString(userName),
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
log.HighestScore,
log.ViolationCount,
threshold,
html.EscapeString(siteName),
)
}
func defaultContentModerationString(value string, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return strings.TrimSpace(value)
}
@@ -0,0 +1,307 @@
package service
import (
"fmt"
"strings"
"github.com/tidwall/gjson"
)
func ExtractContentModerationText(protocol string, body []byte) string {
return ExtractContentModerationInput(protocol, body).Text
}
func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput {
if len(body) == 0 || !gjson.ValidBytes(body) {
return ContentModerationInput{}
}
var parts []string
var images []string
switch protocol {
case ContentModerationProtocolAnthropicMessages:
collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images)
case ContentModerationProtocolOpenAIChat:
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
case ContentModerationProtocolOpenAIResponses:
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
case ContentModerationProtocolGemini:
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
case ContentModerationProtocolOpenAIImages:
addModerationText(&parts, gjson.GetBytes(body, "prompt").String())
collectContentValue(gjson.GetBytes(body, "images"), &parts, &images)
default:
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
}
out := ContentModerationInput{
Text: normalizeContentModerationText(strings.Join(parts, "\n")),
Images: normalizeModerationImages(images),
}
out.Normalize()
return out
}
func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
var candidate []string
var candidateImages []string
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
var candidate []string
var candidateImages []string
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
switch {
case !value.Exists():
return
case value.Type == gjson.String:
if !isAnthropicSystemReminderText(value.String()) {
addModerationText(parts, value.String())
}
case value.IsArray():
value.ForEach(func(_, item gjson.Result) bool {
collectAnthropicUserContentValue(item, parts, images)
return true
})
case value.IsObject():
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
switch typ {
case "", "text", "input_text", "message":
if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) {
addModerationText(parts, value.Get("text").String())
}
if value.Get("content").Exists() {
collectAnthropicUserContentValue(value.Get("content"), parts, images)
}
case "image_url", "input_image", "image":
collectContentValue(value, parts, images)
}
}
}
func isAnthropicSystemReminderText(text string) bool {
return strings.HasPrefix(strings.TrimSpace(text), "<system-reminder>")
}
func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) {
switch {
case !input.Exists():
return
case input.Type == gjson.String:
addModerationText(parts, input.String())
case input.IsArray():
var last gjson.Result
input.ForEach(func(_, item gjson.Result) bool {
if isResponsesUserTextItem(item) {
last = item
}
return true
})
if last.Exists() {
collectContentValue(last.Get("content"), parts, images)
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
collectContentValue(last, parts, images)
}
}
case input.IsObject():
if isResponsesUserTextItem(input) {
collectContentValue(input.Get("content"), parts, images)
if input.Get("type").String() == "input_text" || input.Get("text").Exists() {
collectContentValue(input, parts, images)
}
}
}
}
func isResponsesUserTextItem(item gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(item.Get("role").String()))
if role == "user" {
return responseItemHasModerationText(item)
}
if role != "" {
return false
}
return responseItemHasModerationText(item)
}
func responseItemHasModerationText(item gjson.Result) bool {
var parts []string
var images []string
collectContentValue(item.Get("content"), &parts, &images)
if item.Get("type").String() == "input_text" || item.Get("text").Exists() {
collectContentValue(item, &parts, &images)
}
return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0
}
func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) {
if !contents.IsArray() {
return
}
var lastParts []string
var lastImages []string
contents.ForEach(func(_, content gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
if role == "" || role == "user" {
var candidate []string
var candidateImages []string
if arr := content.Get("parts"); arr.IsArray() {
arr.ForEach(func(_, part gjson.Result) bool {
addModerationText(&candidate, part.Get("text").String())
addGeminiModerationImage(&candidateImages, part)
return true
})
}
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {
switch {
case !value.Exists():
return
case value.Type == gjson.String:
addModerationText(parts, value.String())
case value.IsArray():
value.ForEach(func(_, item gjson.Result) bool {
collectContentValue(item, parts, images)
return true
})
case value.IsObject():
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
addModerationImage(images, value.Get("image_url.url").String())
addModerationImage(images, value.Get("image_url").String())
addModerationImage(images, value.Get("url").String())
addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String())
addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String())
addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String())
addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String())
addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String())
addModerationImage(images, value.Get("source.data").String())
addModerationImage(images, value.Get("data").String())
addModerationImage(images, value.Get("base64").String())
switch typ {
case "", "text", "input_text", "message":
if value.Get("text").Exists() {
addModerationText(parts, value.Get("text").String())
}
if value.Get("content").Exists() {
collectContentValue(value.Get("content"), parts, images)
}
case "image_url", "input_image", "image":
}
}
}
func addGeminiModerationImage(images *[]string, part gjson.Result) {
if inlineData := part.Get("inline_data"); inlineData.IsObject() {
mimeType := strings.TrimSpace(inlineData.Get("mime_type").String())
data := strings.TrimSpace(inlineData.Get("data").String())
if mimeType != "" && data != "" {
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
}
if inlineData := part.Get("inlineData"); inlineData.IsObject() {
mimeType := strings.TrimSpace(inlineData.Get("mimeType").String())
data := strings.TrimSpace(inlineData.Get("data").String())
if mimeType != "" && data != "" {
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
}
addModerationImage(images, part.Get("file_data.file_uri").String())
addModerationImage(images, part.Get("fileData.fileUri").String())
}
func addModerationImageData(images *[]string, mimeType string, data string) {
mimeType = strings.TrimSpace(mimeType)
data = strings.TrimSpace(data)
if mimeType == "" || data == "" {
return
}
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
func addModerationImage(images *[]string, image string) {
image = strings.TrimSpace(image)
if image == "" {
return
}
if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") {
*images = append(*images, image)
}
}
func normalizeModerationImages(images []string) []string {
out := make([]string, 0, len(images))
seen := make(map[string]struct{}, len(images))
for _, image := range images {
image = strings.TrimSpace(image)
if image == "" {
continue
}
if _, ok := seen[image]; ok {
continue
}
seen[image] = struct{}{}
out = append(out, image)
}
return out
}
func addModerationText(parts *[]string, text string) {
text = strings.TrimSpace(text)
if text == "" {
return
}
if strings.Contains(text, "<system-reminder>") {
return
}
*parts = append(*parts, text)
}
func normalizeContentModerationText(text string) string {
return strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
}
@@ -0,0 +1,36 @@
package service
import (
"regexp"
"strings"
)
var contentModerationSecretPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`),
regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`),
regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`),
regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`),
regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`),
regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`),
regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`),
regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`),
}
func redactContentModerationSecrets(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
out := text
for idx, pattern := range contentModerationSecretPatterns {
switch idx {
case 0:
out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`)
case 1:
out = pattern.ReplaceAllString(out, `${1}[已脱敏]`)
default:
out = pattern.ReplaceAllString(out, `[已脱敏]`)
}
}
return out
}
@@ -0,0 +1,811 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type contentModerationTestSettingRepo struct {
values map[string]string
}
func (r *contentModerationTestSettingRepo) Get(ctx context.Context, key string) (*Setting, error) {
if value, ok := r.values[key]; ok {
return &Setting{Key: key, Value: value}, nil
}
return nil, ErrSettingNotFound
}
func (r *contentModerationTestSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
if value, ok := r.values[key]; ok {
return value, nil
}
return "", ErrSettingNotFound
}
func (r *contentModerationTestSettingRepo) Set(ctx context.Context, key, value string) error {
if r.values == nil {
r.values = map[string]string{}
}
r.values[key] = value
return nil
}
func (r *contentModerationTestSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := map[string]string{}
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *contentModerationTestSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
if r.values == nil {
r.values = map[string]string{}
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *contentModerationTestSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.values, key)
return nil
}
type contentModerationTestRepo struct {
logs []ContentModerationLog
}
func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error {
if log != nil {
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
return 0, nil
}
func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) {
return &ContentModerationCleanupResult{}, nil
}
type contentModerationTestHashCache struct {
hashes map[string]struct{}
recorded []string
checked []string
deleted []string
hasResult bool
hasResultUsed bool
}
type contentModerationTestUserRepo struct {
user *User
updated []User
}
func (r *contentModerationTestUserRepo) Create(ctx context.Context, user *User) error {
panic("unexpected Create call")
}
func (r *contentModerationTestUserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
if r.user == nil {
return nil, ErrUserNotFound
}
clone := *r.user
return &clone, nil
}
func (r *contentModerationTestUserRepo) GetByEmail(ctx context.Context, email string) (*User, error) {
panic("unexpected GetByEmail call")
}
func (r *contentModerationTestUserRepo) GetFirstAdmin(ctx context.Context) (*User, error) {
panic("unexpected GetFirstAdmin call")
}
func (r *contentModerationTestUserRepo) Update(ctx context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
r.updated = append(r.updated, clone)
r.user = &clone
return nil
}
func (r *contentModerationTestUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (r *contentModerationTestUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (r *contentModerationTestUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (r *contentModerationTestUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (r *contentModerationTestUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (r *contentModerationTestUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserIDs call")
}
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserID call")
}
func (r *contentModerationTestUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
panic("unexpected UpdateUserLastActiveAt call")
}
func (r *contentModerationTestUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (r *contentModerationTestUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (r *contentModerationTestUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (r *contentModerationTestUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (r *contentModerationTestUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (r *contentModerationTestUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (r *contentModerationTestUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected RemoveGroupFromUserAllowedGroups call")
}
func (r *contentModerationTestUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (r *contentModerationTestUserRepo) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
panic("unexpected UnbindUserAuthProvider call")
}
func (r *contentModerationTestUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (r *contentModerationTestUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (r *contentModerationTestUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type contentModerationTestAuthCacheInvalidator struct {
userIDs []int64
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByKey(ctx context.Context, key string) {
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
i.userIDs = append(i.userIDs, userID)
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
}
func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
if c.hashes == nil {
c.hashes = map[string]struct{}{}
}
c.hashes[inputHash] = struct{}{}
c.recorded = append(c.recorded, inputHash)
return nil
}
func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
c.checked = append(c.checked, inputHash)
if c.hasResultUsed {
return c.hasResult, nil
}
_, ok := c.hashes[inputHash]
return ok, nil
}
func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
c.deleted = append(c.deleted, inputHash)
if c.hashes == nil {
return false, nil
}
if _, ok := c.hashes[inputHash]; !ok {
return false, nil
}
delete(c.hashes, inputHash)
return true, nil
}
func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
deleted := int64(len(c.hashes))
c.hashes = map[string]struct{}{}
return deleted, nil
}
func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
return int64(len(c.hashes)), nil
}
func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) {
svc := &ContentModerationService{}
cfg := defaultContentModerationConfig()
input := ContentModerationCheckInput{
RequestID: "req-1",
Endpoint: "/v1/chat/completions",
Provider: "openai",
}
log := svc.buildLog(input, cfg, ContentModerationActionAllow, true, "sexual", 0.8, map[string]float64{"sexual": 0.8}, "hello sk-proj-1234567890abcdef", nil, nil, "")
require.NotContains(t, log.InputExcerpt, "sk-proj-1234567890abcdef")
require.Contains(t, log.InputExcerpt, "[已脱敏]")
}
func TestRedactContentModerationSecrets_LongHexAndTokens(t *testing.T) {
input := "你哈市多大事cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554 token=abc123456789xyz Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signaturepart"
out := redactContentModerationSecrets(input)
require.NotContains(t, out, "cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554")
require.NotContains(t, out, "abc123456789xyz")
require.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9")
require.Contains(t, out, "[已脱敏]")
}
func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.NonHitRetentionDays = 30
cfg.normalize()
require.Equal(t, 3, cfg.NonHitRetentionDays)
}
func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"old"},
{"role":"assistant","content":"ok"},
{"role":"user","content":[
{"type":"text","text":"检查这张图"},
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"aGVsbG8="}}
]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "检查这张图", input.Text)
require.Equal(t, []string{"data:image/png;base64,aGVsbG8="}, input.Images)
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
require.Equal(t, "检查这张图", log.InputExcerpt)
require.NotContains(t, log.InputExcerpt, "aGVsbG8=")
}
func TestExtractContentModerationInput_AnthropicKeepsEphemeralUserTextAndSkipsSystemReminders(t *testing.T) {
body := []byte(`{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "<system-reminder>工具说明</system-reminder>"},
{"type": "text", "text": "<system-reminder>Ainder>\n\n"},
{"type": "text", "text": "hid", "cache_control": {"type": "ephemeral"}}
]
}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "hid", input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_OpenAIChatUsesLastUserMessage(t *testing.T) {
body := []byte(`{
"model":"gpt-5.5",
"messages":[
{"role":"system","content":"system prompt"},
{"role":"user","content":"old user"},
{"role":"assistant","content":"ok"},
{"role":"user","content":[{"type":"text","text":"latest user"},{"type":"image_url","image_url":{"url":"https://example.com/a.png"}}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
require.Equal(t, "latest user", input.Text)
require.Equal(t, []string{"https://example.com/a.png"}, input.Images)
require.NotContains(t, input.Text, "old user")
require.NotContains(t, input.Text, "system prompt")
}
func TestExtractContentModerationInput_OpenAIImagesIncludesPromptAndImages(t *testing.T) {
body := []byte(`{
"prompt":"replace background",
"images":[
{"image_url":"https://example.com/source.png"},
{"image_url":"data:image/png;base64,aGVsbG8="}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, body)
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{"https://example.com/source.png", "data:image/png;base64,aGVsbG8="}, input.Images)
}
func TestExtractContentModerationInput_OpenAIResponsesCodexPayloadUsesLastUserMessage(t *testing.T) {
body := []byte(`{
"model":"gpt-5.5",
"instructions":"instructions.....",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer permissions sk-proj-1234567890abcdef"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
],
"prompt_cache_key":"cache-key"
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
require.Equal(t, "last user prompt", input.Text)
require.Empty(t, input.Images)
require.NotContains(t, input.Text, "developer permissions")
require.NotContains(t, input.Text, "first user prompt")
}
func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *testing.T) {
var moderationRequest moderationAPIRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.01},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.RecordNonHits = true
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
body := []byte(`{
"model":"gpt-5.5",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
]
}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
UserID: 1001,
Endpoint: "/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIResponses,
Body: body,
})
require.NoError(t, err)
require.False(t, decision.Blocked)
require.Len(t, repo.logs, 1)
require.False(t, repo.logs[0].Flagged)
require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action)
require.Equal(t, "/responses", repo.logs[0].Endpoint)
require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt)
require.Equal(t, "last user prompt", moderationRequest.Input)
}
func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *testing.T) {
var moderationRequest moderationAPIRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusUnavailableForLegalReasons
cfg.BlockMessage = "内容审计测试阻断"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
body := []byte(`{
"model":"gpt-5.5",
"instructions":"instructions.....",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"environment context"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"latest blocked prompt"}]}
]
}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
UserID: 1001,
Endpoint: "/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIResponses,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionBlock, decision.Action)
require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode)
require.Equal(t, "内容审计测试阻断", decision.Message)
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode)
require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt)
require.Equal(t, "latest blocked prompt", moderationRequest.Input)
}
func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) {
result := buildContentModerationTestAuditResult(&moderationAPIResult{
Flagged: true,
CategoryScores: map[string]float64{
"harassment": 0.65,
},
}, nil)
require.NotNil(t, result)
require.False(t, result.Flagged)
require.Equal(t, "harassment", result.HighestCategory)
require.Equal(t, 0.65, result.HighestScore)
require.Equal(t, 0.65, result.CompositeScore)
require.Equal(t, 0.98, result.Thresholds["harassment"])
}
func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.PreHashCheckEnabled = true
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusConflict
cfg.BlockMessage = "命中历史风险输入"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{}}
content := ContentModerationInput{Text: "blocked prompt"}
content.Normalize()
hashCache.hashes[content.Hash()] = struct{}{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
&contentModerationTestRepo{},
hashCache,
nil,
nil,
nil,
nil,
)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
require.Equal(t, http.StatusConflict, decision.StatusCode)
require.Equal(t, content.Hash(), decision.InputHash)
require.Contains(t, decision.Message, "命中历史风险输入")
require.Contains(t, decision.Message, content.Hash())
require.Len(t, hashCache.checked, 1)
}
func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) {
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.PreHashCheckEnabled = true
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusConflict
cfg.BlockMessage = "命中风险输入"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
hashCache := &contentModerationTestHashCache{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
hashCache,
nil,
nil,
nil,
nil,
)
body := []byte(`{"messages":[{"role":"user","content":"repeat blocked prompt"}]}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionBlock, decision.Action)
require.Equal(t, 1, requestCount)
require.Len(t, hashCache.recorded, 1)
require.Len(t, repo.logs, 1)
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
require.Equal(t, hashCache.recorded[0], decision.InputHash)
require.Equal(t, 1, requestCount)
require.Len(t, repo.logs, 1)
}
func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) {
existingHash := strings.Repeat("a", 64)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
existingHash: {},
}}
svc := &ContentModerationService{hashCache: hashCache}
result, err := svc.DeleteFlaggedInputHash(context.Background(), strings.ToUpper(existingHash))
require.NoError(t, err)
require.Equal(t, existingHash, result.InputHash)
require.True(t, result.Deleted)
require.NotContains(t, hashCache.hashes, existingHash)
require.Equal(t, []string{existingHash}, hashCache.deleted)
result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash)
require.NoError(t, err)
require.Equal(t, existingHash, result.InputHash)
require.False(t, result.Deleted)
}
func TestContentModerationClearFlaggedInputHashesAndStatusCount(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
strings.Repeat("a", 64): {},
strings.Repeat("b", 64): {},
}}
svc := &ContentModerationService{
settingRepo: &contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
hashCache: hashCache,
keyHealth: make(map[string]*contentModerationKeyHealth),
}
status, err := svc.GetStatus(context.Background())
require.NoError(t, err)
require.Equal(t, int64(2), status.FlaggedHashCount)
result, err := svc.ClearFlaggedInputHashes(context.Background())
require.NoError(t, err)
require.Equal(t, int64(2), result.Deleted)
status, err = svc.GetStatus(context.Background())
require.NoError(t, err)
require.Equal(t, int64(0), status.FlaggedHashCount)
}
func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModeObserve
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
hashCache := &contentModerationTestHashCache{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
hashCache,
nil,
nil,
nil,
nil,
)
decision := svc.checkSync(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"bad prompt"}]}`),
}, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false)
require.False(t, decision.Blocked)
require.Len(t, hashCache.recorded, 1)
require.Len(t, repo.logs, 1)
}
func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) {
userID := int64(1001)
cfg := defaultContentModerationConfig()
cfg.BanThreshold = 10
body := buildContentModerationAccountDisabledEmailBody("Sub2API <Admin>", &ContentModerationLog{
UserID: &userID,
UserEmail: "user@example.com",
GroupName: "vip_2",
HighestCategory: "sexual",
HighestScore: 0.926,
ViolationCount: 10,
}, cfg)
require.Contains(t, body, "账户已被自动禁用")
require.Contains(t, body, "封禁详情")
require.Contains(t, body, "账户当前处于封禁状态,所有 API 请求将被拒绝")
require.Contains(t, body, "10 次(阈值 10")
require.Contains(t, body, "sexual / 0.926")
require.Contains(t, body, "Sub2API &lt;Admin&gt;")
}
func TestContentModerationUnbanUser_ActivatesUserAndInvalidatesAuthCache(t *testing.T) {
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusDisabled}}
invalidator := &contentModerationTestAuthCacheInvalidator{}
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
result, err := svc.UnbanUser(context.Background(), 1001)
require.NoError(t, err)
require.Equal(t, int64(1001), result.UserID)
require.Equal(t, StatusActive, result.Status)
require.Len(t, userRepo.updated, 1)
require.Equal(t, StatusActive, userRepo.updated[0].Status)
require.Equal(t, []int64{1001}, invalidator.userIDs)
}
func TestContentModerationUnbanUser_ActiveUserOnlyInvalidatesAuthCache(t *testing.T) {
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusActive}}
invalidator := &contentModerationTestAuthCacheInvalidator{}
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
result, err := svc.UnbanUser(context.Background(), 1001)
require.NoError(t, err)
require.Equal(t, StatusActive, result.Status)
require.Empty(t, userRepo.updated)
require.Equal(t, []int64{1001}, invalidator.userIDs)
}
func contentModerationIntPtr(v int) *int {
return &v
}
@@ -107,6 +107,8 @@ const (
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
+63
View File
@@ -90,6 +90,69 @@ type OpenAIImagesRequest struct {
bodyHash string
}
func (r *OpenAIImagesRequest) ModerationBody() []byte {
if r == nil {
return nil
}
payload := map[string]any{}
if prompt := strings.TrimSpace(r.Prompt); prompt != "" {
payload["prompt"] = prompt
}
images := r.moderationImages()
if len(images) > 0 {
payload["images"] = images
}
if len(payload) == 0 {
return nil
}
body, err := json.Marshal(payload)
if err != nil {
return nil
}
return body
}
func (r *OpenAIImagesRequest) moderationImages() []map[string]string {
if r == nil {
return nil
}
images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1)
for _, imageURL := range r.InputImageURLs {
imageURL = strings.TrimSpace(imageURL)
if imageURL != "" {
images = append(images, map[string]string{"image_url": imageURL})
}
}
for _, upload := range r.Uploads {
if dataURL := upload.ModerationDataURL(); dataURL != "" {
images = append(images, map[string]string{"image_url": dataURL})
}
}
if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" {
images = append(images, map[string]string{"image_url": maskURL})
}
if r.MaskUpload != nil {
if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" {
images = append(images, map[string]string{"image_url": dataURL})
}
}
return images
}
func (u OpenAIImagesUpload) ModerationDataURL() string {
if len(u.Data) == 0 {
return ""
}
contentType := strings.TrimSpace(u.ContentType)
if contentType == "" {
contentType = http.DetectContentType(u.Data)
}
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return ""
}
return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data))
}
func (r *OpenAIImagesRequest) IsEdits() bool {
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
}
@@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Prompt: "replace background",
InputImageURLs: []string{"https://example.com/source.png"},
MaskImageURL: "https://example.com/mask.png",
}
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images)
}
func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Prompt: "replace background",
Uploads: []OpenAIImagesUpload{{
FieldName: "image",
FileName: "source.png",
ContentType: "image/png",
Data: []byte("fake-image-bytes"),
}},
MaskUpload: &OpenAIImagesUpload{
FieldName: "mask",
FileName: "mask.png",
ContentType: "image/png",
Data: []byte("fake-mask-bytes"),
},
}
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{
"data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==",
"data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz",
}, input.Images)
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
require.Equal(t, "replace background", log.InputExcerpt)
require.NotContains(t, log.InputExcerpt, "ZmFrZS")
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct {
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
InitialRequestModel string
BeforeTurn func(turn int) error
BeforeRequest func(turn int, payload []byte, originalModel string) error
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
}
@@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return true
}
for {
if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil {
if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil {
return err
}
}
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
if err := hooks.BeforeTurn(turn); err != nil {
return err
@@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
if msgType != coderws.MessageText {
return payload, nil, nil
}
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil {
turnNo := int(completedTurns.Load()) + 1
if turnNo < 2 {
turnNo = 2
}
requestModel := usageMeta.requestModelForFrame(payload)
if requestModel == "" {
requestModel = capturedSessionModel
}
if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil {
return payload, nil, err
}
}
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现
@@ -456,6 +456,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyChannelMonitorDefaultIntervalSeconds,
SettingKeyAvailableChannelsEnabled,
SettingKeyAffiliateEnabled,
SettingKeyRiskControlEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -545,6 +546,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true",
}, nil
}
@@ -692,6 +695,7 @@ type PublicSettingsInjectionPayload struct {
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
RiskControlEnabled bool `json:"risk_control_enabled"`
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
@@ -745,6 +749,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
RiskControlEnabled: settings.RiskControlEnabled,
}, nil
}
@@ -1232,6 +1237,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// Affiliate (邀请返利) feature switch
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
// 风控中心功能开关
updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled)
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@@ -1903,6 +1911,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Affiliate (邀请返利) feature (default disabled; opt-in)
SettingKeyAffiliateEnabled: "false",
// 风控中心功能(默认关闭,显式启用)
SettingKeyRiskControlEnabled: "false",
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
@@ -2242,6 +2253,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Affiliate (邀请返利) feature (default: disabled; strict true)
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
// 风控中心功能(默认关闭,严格 true 才启用)
result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true"
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
@@ -106,6 +106,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
RiskControlEnabled bool
AffiliateEnabled bool
AffiliateRebateRate float64
AffiliateRebateFreezeHours int
@@ -233,6 +234,9 @@ type PublicSettings struct {
// Affiliate (邀请返利) feature toggle
AffiliateEnabled bool `json:"affiliate_enabled"`
// 风控中心功能开关
RiskControlEnabled bool `json:"risk_control_enabled"`
}
type WeChatConnectOAuthConfig struct {
+1
View File
@@ -509,6 +509,7 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
NewContentModerationService,
NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,