feat(risk-control): add content moderation audit

This commit is contained in:
shaw
2026-05-07 09:01:48 +08:00
parent a1106e8167
commit fff4a300c6
54 changed files with 6840 additions and 34 deletions
@@ -0,0 +1,234 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type ContentModerationHandler struct {
service *service.ContentModerationService
}
func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler {
return &ContentModerationHandler{service: svc}
}
type contentModerationConfigRequest struct {
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
}
type contentModerationAPIKeyTestRequest struct {
APIKeys []string `json:"api_keys"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
TimeoutMS int `json:"timeout_ms"`
Prompt string `json:"prompt"`
Images []string `json:"images"`
}
type contentModerationHashRequest struct {
InputHash string `json:"input_hash"`
}
func (h *ContentModerationHandler) GetConfig(c *gin.Context) {
cfg, err := h.service.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
var req contentModerationConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{
Enabled: req.Enabled,
Mode: req.Mode,
BaseURL: req.BaseURL,
Model: req.Model,
APIKey: req.APIKey,
APIKeys: req.APIKeys,
ClearAPIKey: req.ClearAPIKey,
TimeoutMS: req.TimeoutMS,
SampleRate: req.SampleRate,
AllGroups: req.AllGroups,
GroupIDs: req.GroupIDs,
RecordNonHits: req.RecordNonHits,
WorkerCount: req.WorkerCount,
QueueSize: req.QueueSize,
BlockStatus: req.BlockStatus,
BlockMessage: req.BlockMessage,
EmailOnHit: req.EmailOnHit,
AutoBanEnabled: req.AutoBanEnabled,
BanThreshold: req.BanThreshold,
ViolationWindowHours: req.ViolationWindowHours,
RetryCount: req.RetryCount,
HitRetentionDays: req.HitRetentionDays,
NonHitRetentionDays: req.NonHitRetentionDays,
PreHashCheckEnabled: req.PreHashCheckEnabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) {
var req contentModerationAPIKeyTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{
APIKeys: req.APIKeys,
BaseURL: req.BaseURL,
Model: req.Model,
TimeoutMS: req.TimeoutMS,
Prompt: req.Prompt,
Images: req.Images,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *ContentModerationHandler) GetStatus(c *gin.Context) {
status, err := h.service.GetStatus(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, status)
}
func (h *ContentModerationHandler) ListLogs(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := service.ContentModerationLogFilter{
Pagination: pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortOrder: pagination.SortOrderDesc,
},
Result: c.Query("result"),
Endpoint: c.Query("endpoint"),
Search: c.Query("search"),
}
if raw := strings.TrimSpace(c.Query("group_id")); raw != "" {
groupID, err := strconv.ParseInt(raw, 10, 64)
if err != nil || groupID <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &groupID
}
if raw := strings.TrimSpace(c.Query("from")); raw != "" {
t, _, err := parseContentModerationDate(raw)
if err != nil {
response.BadRequest(c, "Invalid from")
return
}
filter.From = &t
}
if raw := strings.TrimSpace(c.Query("to")); raw != "" {
t, dateOnly, err := parseContentModerationDate(raw)
if err != nil {
response.BadRequest(c, "Invalid to")
return
}
if dateOnly {
t = t.Add(24*time.Hour - time.Nanosecond)
}
filter.To = &t
}
items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize)
}
func (h *ContentModerationHandler) UnbanUser(c *gin.Context) {
userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
result, err := h.service.UnbanUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) {
var req contentModerationHashRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) {
result, err := h.service.ClearFlaggedInputHashes(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func parseContentModerationDate(raw string) (time.Time, bool, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return time.Time{}, false, nil
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return t, false, nil
}
t, err := time.Parse("2006-01-02", raw)
return t, err == nil, err
}
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
RiskControlEnabled: settings.RiskControlEnabled,
AffiliateRebateRate: settings.AffiliateRebateRate,
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
@@ -497,6 +498,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
// 风控中心功能开关
RiskControlEnabled *bool `json:"risk_control_enabled"`
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
@@ -1365,6 +1369,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AffiliateEnabled
}(),
RiskControlEnabled: func() bool {
if req.RiskControlEnabled != nil {
return *req.RiskControlEnabled
}
return previousSettings.RiskControlEnabled
}(),
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
@@ -1616,6 +1626,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
AffiliateEnabled: updatedSettings.AffiliateEnabled,
RiskControlEnabled: updatedSettings.RiskControlEnabled,
}
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
@@ -2004,6 +2016,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AffiliateEnabled != after.AffiliateEnabled {
changed = append(changed, "affiliate_enabled")
}
if before.RiskControlEnabled != after.RiskControlEnabled {
changed = append(changed, "risk_control_enabled")
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
@@ -0,0 +1,130 @@
package handler
import (
"context"
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
if h == nil || h.contentModerationService == nil {
return nil
}
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
}
func contentModerationStatus(decision *service.ContentModerationDecision) int {
if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 {
return http.StatusForbidden
}
return decision.StatusCode
}
func contentModerationErrorCode(decision *service.ContentModerationDecision) string {
return "content_policy_violation"
}
func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
if h == nil || h.contentModerationService == nil {
return nil
}
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
}
func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
if svc == nil || c == nil || c.Request == nil {
return nil
}
input := buildContentModerationInput(c, apiKey, subject, protocol, model, body)
if reqLog != nil {
reqLog.Info("content_moderation.gateway_check_start",
zap.String("request_id", input.RequestID),
zap.Int64("user_id", input.UserID),
zap.Int64("api_key_id", input.APIKeyID),
zap.String("api_key_name", input.APIKeyName),
zap.Int64p("group_id", input.GroupID),
zap.String("group_name", input.GroupName),
zap.String("endpoint", input.Endpoint),
zap.String("provider", input.Provider),
zap.String("protocol", input.Protocol),
zap.String("model", input.Model),
zap.Int("body_bytes", len(body)),
)
}
decision, err := svc.Check(c.Request.Context(), input)
if err != nil {
if reqLog != nil {
reqLog.Warn("content_moderation.check_failed", zap.Error(err))
}
return nil
}
if reqLog != nil && decision != nil {
reqLog.Info("content_moderation.gateway_check_done",
zap.String("request_id", input.RequestID),
zap.Bool("allowed", decision.Allowed),
zap.Bool("blocked", decision.Blocked),
zap.Bool("flagged", decision.Flagged),
zap.String("action", decision.Action),
zap.Int("status_code", decision.StatusCode),
zap.String("highest_category", decision.HighestCategory),
zap.Float64("highest_score", decision.HighestScore),
)
}
return decision
}
func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput {
input := service.ContentModerationCheckInput{
RequestID: contentModerationRequestID(c.Request.Context()),
UserID: subject.UserID,
Endpoint: GetInboundEndpoint(c),
Provider: contentModerationProvider(apiKey),
Model: strings.TrimSpace(model),
Protocol: protocol,
Body: body,
}
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
input.Provider = strings.TrimSpace(forcedPlatform)
}
if apiKey != nil {
input.APIKeyID = apiKey.ID
input.APIKeyName = apiKey.Name
if apiKey.User != nil {
input.UserEmail = apiKey.User.Email
}
if apiKey.GroupID != nil {
groupID := *apiKey.GroupID
input.GroupID = &groupID
}
if apiKey.Group != nil {
input.GroupName = apiKey.Group.Name
}
}
if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil {
input.Endpoint = c.Request.URL.Path
}
return input
}
func contentModerationProvider(apiKey *service.APIKey) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.Platform)
}
func contentModerationRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok {
return strings.TrimSpace(requestID)
}
return ""
}
+5
View File
@@ -197,6 +197,9 @@ type SystemSettings struct {
// Available Channels feature switch (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
// 风控中心功能开关
RiskControlEnabled bool `json:"risk_control_enabled"`
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
@@ -256,6 +259,8 @@ type PublicSettings struct {
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
RiskControlEnabled bool `json:"risk_control_enabled"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
@@ -45,6 +45,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
contentModerationService *service.ContentModerationService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int
@@ -65,6 +66,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
contentModerationService *service.ContentModerationService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
@@ -98,6 +100,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
contentModerationService: contentModerationService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches,
@@ -189,6 +192,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
@@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
@@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
@@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
googleError(c, contentModerationStatus(decision), decision.Message)
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
+1
View File
@@ -33,6 +33,7 @@ type AdminHandlers struct {
Channel *admin.ChannelHandler
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
ContentModeration *admin.ContentModerationHandler
Payment *admin.PaymentHandler
Affiliate *admin.AffiliateHandler
}
@@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
@@ -27,15 +27,16 @@ import (
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int
cfg *config.Config
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
contentModerationService *service.ContentModerationService
concurrencyHelper *ConcurrencyHelper
imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int
cfg *config.Config
}
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
@@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
contentModerationService *service.ContentModerationService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
@@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler(
}
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
contentModerationService: contentModerationService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
}
@@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
@@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
@@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message)
return
}
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
return
@@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
@@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
_ = conn.CloseNow()
}
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
if conn == nil || decision == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
message := strings.TrimSpace(decision.Message)
if message == "" {
message = "content moderation blocked this request"
}
payload, err := json.Marshal(gin.H{
"event_id": "evt_content_moderation_blocked",
"type": "error",
"error": gin.H{
"type": "invalid_request_error",
"code": contentModerationErrorCode(decision),
"message": message,
},
})
if err != nil {
payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`)
}
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
_ = conn.Write(writeCtx, coderws.MessageText, payload)
}
func summarizeWSCloseErrorForLog(err error) (string, string) {
if err == nil {
return "-", "-"
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
@@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
}
type contentModerationHandlerSettingRepo struct {
values map[string]string
}
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
if value, ok := r.values[key]; ok {
return &service.Setting{Key: key, Value: value}, nil
}
return nil, service.ErrSettingNotFound
}
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
if value, ok := r.values[key]; ok {
return value, nil
}
return "", service.ErrSettingNotFound
}
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
if r.values == nil {
r.values = map[string]string{}
}
r.values[key] = value
return nil
}
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := map[string]string{}
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
if r.values == nil {
r.values = map[string]string{}
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.values, key)
return nil
}
type contentModerationHandlerTestRepo struct {
logs []service.ContentModerationLog
}
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log != nil {
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
return 0, nil
}
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
return &service.ContentModerationCleanupResult{}, nil
}
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
gin.SetMode(gin.TestMode)
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
}))
defer moderationServer.Close()
cfg := &service.ContentModerationConfig{
Enabled: true,
Mode: service.ContentModerationModePreBlock,
BaseURL: moderationServer.URL,
Model: "omni-moderation-latest",
APIKeys: []string{"sk-test"},
SampleRate: 100,
AllGroups: true,
BlockMessage: "内容审计测试阻断",
}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationHandlerTestRepo{}
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
service.SettingKeyRiskControlEnabled: "true",
service.SettingKeyContentModerationConfig: string(rawCfg),
}}
moderationSvc := service.NewContentModerationService(
settingRepo,
repo,
nil,
nil,
nil,
nil,
nil,
)
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
UserID: 1,
Endpoint: "/v1/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: service.ContentModerationProtocolOpenAIResponses,
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
repo.logs = nil
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
contentModerationService: moderationSvc,
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
}
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
"type":"response.create",
"model":"gpt-5.5",
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, payload, readErr := clientConn.Read(readCtx)
cancelRead()
if readErr == nil {
require.Contains(t, string(payload), "content_policy_violation")
require.Contains(t, string(payload), "内容审计测试阻断")
} else {
var closeErr coderws.CloseError
require.ErrorAs(t, readErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
}
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
@@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
if !acquired {
return
@@ -77,5 +77,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
RiskControlEnabled: settings.RiskControlEnabled,
})
}
+3
View File
@@ -36,6 +36,7 @@ func ProvideAdminHandlers(
channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
contentModerationHandler *admin.ContentModerationHandler,
paymentHandler *admin.PaymentHandler,
affiliateHandler *admin.AffiliateHandler,
) *AdminHandlers {
@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
Channel: channelHandler,
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
ContentModeration: contentModerationHandler,
Payment: paymentHandler,
Affiliate: affiliateHandler,
}
@@ -170,6 +172,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelHandler,
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
admin.NewContentModerationHandler,
admin.NewPaymentHandler,
admin.NewAffiliateHandler,
@@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldName,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
@@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S
got, err := repo.GetByKeyForAuth(ctx, key.Key)
require.NoError(t, err)
require.Equal(t, key.Name, got.Name)
require.NotNil(t, got.Group)
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
}
@@ -0,0 +1,71 @@
package repository
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes"
type contentModerationHashCache struct {
rdb *redis.Client
}
func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache {
return &contentModerationHashCache{rdb: rdb}
}
func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
inputHash = strings.TrimSpace(inputHash)
if c == nil || c.rdb == nil || inputHash == "" {
return nil
}
return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err()
}
func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
inputHash = strings.TrimSpace(inputHash)
if c == nil || c.rdb == nil || inputHash == "" {
return false, nil
}
return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
}
func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
inputHash = strings.TrimSpace(inputHash)
if c == nil || c.rdb == nil || inputHash == "" {
return false, nil
}
deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
if err != nil {
return false, err
}
return deleted > 0, nil
}
func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
if c == nil || c.rdb == nil {
return 0, nil
}
deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
if err != nil {
return 0, err
}
if deleted == 0 {
return 0, nil
}
if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil {
return 0, err
}
return deleted, nil
}
func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
if c == nil || c.rdb == nil {
return 0, nil
}
return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
}
@@ -0,0 +1,274 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type contentModerationRepository struct {
db *sql.DB
}
func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository {
return &contentModerationRepository{db: db}
}
func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log == nil {
return nil
}
categoryScores, err := json.Marshal(log.CategoryScores)
if err != nil {
return fmt.Errorf("marshal moderation category scores: %w", err)
}
thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot)
if err != nil {
return fmt.Errorf("marshal moderation thresholds: %w", err)
}
var userID any
if log.UserID != nil {
userID = *log.UserID
}
var apiKeyID any
if log.APIKeyID != nil {
apiKeyID = *log.APIKeyID
}
var groupID any
if log.GroupID != nil {
groupID = *log.GroupID
}
var latency any
if log.UpstreamLatencyMS != nil {
latency = *log.UpstreamLatencyMS
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO content_moderation_logs (
request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name,
endpoint, provider, model, mode, action, flagged, highest_category, highest_score,
category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error,
violation_count, auto_banned, email_sent, queue_delay_ms
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9, $10, $11, $12, $13, $14, $15,
$16::jsonb, $17::jsonb, $18, $19, $20,
$21, $22, $23, $24
) RETURNING id, created_at`,
log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName,
log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore,
string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error,
log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS),
).Scan(&log.ID, &log.CreatedAt)
if err != nil {
return fmt.Errorf("insert content moderation log: %w", err)
}
return nil
}
func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
where, args := buildContentModerationLogWhere(filter)
whereSQL := "WHERE " + strings.Join(where, " AND ")
var total int64
if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count content moderation logs: %w", err)
}
params := filter.Pagination
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
queryArgs := append([]any{}, args...)
queryArgs = append(queryArgs, params.Limit(), params.Offset())
rows, err := r.db.QueryContext(ctx, `
SELECT
l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name,
l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score,
l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error,
l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at
FROM content_moderation_logs l
LEFT JOIN users u ON u.id = l.user_id `+whereSQL+`
ORDER BY l.created_at DESC, l.id DESC
LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)),
queryArgs...,
)
if err != nil {
return nil, nil, fmt.Errorf("list content moderation logs: %w", err)
}
defer func() { _ = rows.Close() }()
items := make([]service.ContentModerationLog, 0)
for rows.Next() {
var item service.ContentModerationLog
var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64
var scoresRaw, thresholdsRaw []byte
if err := rows.Scan(
&item.ID,
&item.RequestID,
&userID,
&item.UserEmail,
&apiKeyID,
&item.APIKeyName,
&groupID,
&item.GroupName,
&item.Endpoint,
&item.Provider,
&item.Model,
&item.Mode,
&item.Action,
&item.Flagged,
&item.HighestCategory,
&item.HighestScore,
&scoresRaw,
&thresholdsRaw,
&item.InputExcerpt,
&latency,
&item.Error,
&item.ViolationCount,
&item.AutoBanned,
&item.EmailSent,
&item.UserStatus,
&queueDelay,
&item.CreatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan content moderation log: %w", err)
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if apiKeyID.Valid {
v := apiKeyID.Int64
item.APIKeyID = &v
}
if groupID.Valid {
v := groupID.Int64
item.GroupID = &v
}
if latency.Valid {
v := int(latency.Int64)
item.UpstreamLatencyMS = &v
}
if queueDelay.Valid {
v := int(queueDelay.Int64)
item.QueueDelayMS = &v
}
item.CategoryScores = map[string]float64{}
_ = json.Unmarshal(scoresRaw, &item.CategoryScores)
item.ThresholdSnapshot = map[string]float64{}
_ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot)
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err)
}
return items, paginationResultFromTotal(total, params), nil
}
func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
if userID <= 0 {
return 0, nil
}
var count int
err := r.db.QueryRowContext(ctx, `
WITH last_auto_ban AS (
SELECT MAX(created_at) AS at
FROM content_moderation_logs
WHERE user_id = $1 AND auto_banned = TRUE
)
SELECT COUNT(*)
FROM content_moderation_logs
WHERE user_id = $1
AND flagged = TRUE
AND created_at >= $2
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
`, userID, since).Scan(&count)
if err != nil {
return 0, fmt.Errorf("count user content moderation flagged logs: %w", err)
}
return count, nil
}
func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()}
if r == nil || r.db == nil {
return result, nil
}
hitExec, err := r.db.ExecContext(ctx, `
DELETE FROM content_moderation_logs
WHERE flagged = TRUE AND created_at < $1
`, hitBefore)
if err != nil {
return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err)
}
result.DeletedHit, _ = hitExec.RowsAffected()
nonHitExec, err := r.db.ExecContext(ctx, `
DELETE FROM content_moderation_logs
WHERE flagged = FALSE AND created_at < $1
`, nonHitBefore)
if err != nil {
return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err)
}
result.DeletedNonHit, _ = nonHitExec.RowsAffected()
result.FinishedAt = time.Now()
return result, nil
}
func nullableIntPtr(value *int) any {
if value == nil {
return nil
}
return *value
}
func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) {
where := []string{"l.id IS NOT NULL"}
args := make([]any, 0)
add := func(expr string, value any) {
args = append(args, value)
where = append(where, fmt.Sprintf(expr, len(args)))
}
switch strings.ToLower(strings.TrimSpace(filter.Result)) {
case "hit", "flagged":
where = append(where, "l.flagged = TRUE")
case "blocked", "block":
where = append(where, "l.action = 'block'")
case "pass", "allow":
where = append(where, "l.flagged = FALSE AND l.error = ''")
case "error":
where = append(where, "l.error <> ''")
}
if filter.GroupID != nil {
add("l.group_id = $%d", *filter.GroupID)
}
if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" {
add("l.endpoint = $%d", endpoint)
}
if search := strings.TrimSpace(filter.Search); search != "" {
like := "%" + search + "%"
args = append(args, like, like, like, like, like)
idx := len(args) - 4
where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4))
}
if filter.From != nil && !filter.From.IsZero() {
add("l.created_at >= $%d", *filter.From)
}
if filter.To != nil && !filter.To.IsZero() {
add("l.created_at <= $%d", *filter.To)
}
return where, args
}
+2
View File
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewContentModerationRepository,
NewAffiliateRepository,
// Cache implementations
@@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet(
NewRefreshTokenCache,
NewErrorPassthroughCache,
NewTLSFingerprintProfileCache,
NewContentModerationHashCache,
// Encryptors
NewAESEncryptor,
@@ -792,6 +792,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"risk_control_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
@@ -983,6 +984,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"risk_control_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
+17
View File
@@ -92,11 +92,28 @@ func RegisterAdminRoutes(
// 渠道监控
registerChannelMonitorRoutes(admin, h)
// 风控中心
registerContentModerationRoutes(admin, h)
// 邀请返利(专属用户管理)
registerAffiliateRoutes(admin, h)
}
}
func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
risk := admin.Group("/risk-control")
{
risk.GET("/config", h.Admin.ContentModeration.GetConfig)
risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig)
risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys)
risk.GET("/status", h.Admin.ContentModeration.GetStatus)
risk.GET("/logs", h.Admin.ContentModeration.ListLogs)
risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser)
risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash)
risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes)
}
}
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys")
{
@@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Name string `json:"name"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Name: apiKey.Name,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
@@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Name: snapshot.Name,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
@@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Name: "Audit Key",
Status: StatusActive,
User: &User{
ID: 2,
@@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.Equal(t, apiKey.Name, roundTrip.Name)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,117 @@
package service
import (
"fmt"
"html"
"strings"
"time"
)
func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
if log == nil {
return ""
}
userName := strings.TrimSpace(log.UserEmail)
if userName == "" && log.UserID != nil {
userName = fmt.Sprintf("UID %d", *log.UserID)
}
threshold := cfg.BanThreshold
if threshold <= 0 {
threshold = defaultContentModerationBanThreshold
}
statusBlock := ""
if log.AutoBanned {
statusBlock = `<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>`
}
return fmt.Sprintf(`<!doctype html>
<html>
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 风控提醒</div>
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户触发内容审计规则</h1>
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>您的 API 请求在内容审计中触发平台风控策略详情如下</p>
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">触发详情</h2>
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 阈值 %d</td></tr>
</table>
</div>
%s
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送请勿回复</p>
</div>
</div>
</body>
</html>`,
html.EscapeString(userName),
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
log.HighestScore,
log.ViolationCount,
threshold,
statusBlock,
html.EscapeString(siteName),
)
}
func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
if log == nil {
return ""
}
userName := strings.TrimSpace(log.UserEmail)
if userName == "" && log.UserID != nil {
userName = fmt.Sprintf("UID %d", *log.UserID)
}
threshold := cfg.BanThreshold
if threshold <= 0 {
threshold = defaultContentModerationBanThreshold
}
return fmt.Sprintf(`<!doctype html>
<html>
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 账户封禁</div>
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户已被自动禁用</h1>
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>您的账户在计数周期内多次触发平台风控策略系统已自动禁用该账户详情如下</p>
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">封禁详情</h2>
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">封禁时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 阈值 %d</td></tr>
</table>
</div>
<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态所有 API 请求将被拒绝</div>
<p style="font-size:15px;line-height:1.8;color:#666;margin-top:24px;">如需申诉或恢复账号请联系平台管理员处理</p>
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送请勿回复</p>
</div>
</div>
</body>
</html>`,
html.EscapeString(userName),
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
log.HighestScore,
log.ViolationCount,
threshold,
html.EscapeString(siteName),
)
}
func defaultContentModerationString(value string, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return strings.TrimSpace(value)
}
@@ -0,0 +1,307 @@
package service
import (
"fmt"
"strings"
"github.com/tidwall/gjson"
)
func ExtractContentModerationText(protocol string, body []byte) string {
return ExtractContentModerationInput(protocol, body).Text
}
func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput {
if len(body) == 0 || !gjson.ValidBytes(body) {
return ContentModerationInput{}
}
var parts []string
var images []string
switch protocol {
case ContentModerationProtocolAnthropicMessages:
collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images)
case ContentModerationProtocolOpenAIChat:
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
case ContentModerationProtocolOpenAIResponses:
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
case ContentModerationProtocolGemini:
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
case ContentModerationProtocolOpenAIImages:
addModerationText(&parts, gjson.GetBytes(body, "prompt").String())
collectContentValue(gjson.GetBytes(body, "images"), &parts, &images)
default:
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
}
out := ContentModerationInput{
Text: normalizeContentModerationText(strings.Join(parts, "\n")),
Images: normalizeModerationImages(images),
}
out.Normalize()
return out
}
func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
var candidate []string
var candidateImages []string
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
var candidate []string
var candidateImages []string
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
switch {
case !value.Exists():
return
case value.Type == gjson.String:
if !isAnthropicSystemReminderText(value.String()) {
addModerationText(parts, value.String())
}
case value.IsArray():
value.ForEach(func(_, item gjson.Result) bool {
collectAnthropicUserContentValue(item, parts, images)
return true
})
case value.IsObject():
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
switch typ {
case "", "text", "input_text", "message":
if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) {
addModerationText(parts, value.Get("text").String())
}
if value.Get("content").Exists() {
collectAnthropicUserContentValue(value.Get("content"), parts, images)
}
case "image_url", "input_image", "image":
collectContentValue(value, parts, images)
}
}
}
func isAnthropicSystemReminderText(text string) bool {
return strings.HasPrefix(strings.TrimSpace(text), "<system-reminder>")
}
func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) {
switch {
case !input.Exists():
return
case input.Type == gjson.String:
addModerationText(parts, input.String())
case input.IsArray():
var last gjson.Result
input.ForEach(func(_, item gjson.Result) bool {
if isResponsesUserTextItem(item) {
last = item
}
return true
})
if last.Exists() {
collectContentValue(last.Get("content"), parts, images)
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
collectContentValue(last, parts, images)
}
}
case input.IsObject():
if isResponsesUserTextItem(input) {
collectContentValue(input.Get("content"), parts, images)
if input.Get("type").String() == "input_text" || input.Get("text").Exists() {
collectContentValue(input, parts, images)
}
}
}
}
func isResponsesUserTextItem(item gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(item.Get("role").String()))
if role == "user" {
return responseItemHasModerationText(item)
}
if role != "" {
return false
}
return responseItemHasModerationText(item)
}
func responseItemHasModerationText(item gjson.Result) bool {
var parts []string
var images []string
collectContentValue(item.Get("content"), &parts, &images)
if item.Get("type").String() == "input_text" || item.Get("text").Exists() {
collectContentValue(item, &parts, &images)
}
return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0
}
func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) {
if !contents.IsArray() {
return
}
var lastParts []string
var lastImages []string
contents.ForEach(func(_, content gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
if role == "" || role == "user" {
var candidate []string
var candidateImages []string
if arr := content.Get("parts"); arr.IsArray() {
arr.ForEach(func(_, part gjson.Result) bool {
addModerationText(&candidate, part.Get("text").String())
addGeminiModerationImage(&candidateImages, part)
return true
})
}
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {
switch {
case !value.Exists():
return
case value.Type == gjson.String:
addModerationText(parts, value.String())
case value.IsArray():
value.ForEach(func(_, item gjson.Result) bool {
collectContentValue(item, parts, images)
return true
})
case value.IsObject():
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
addModerationImage(images, value.Get("image_url.url").String())
addModerationImage(images, value.Get("image_url").String())
addModerationImage(images, value.Get("url").String())
addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String())
addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String())
addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String())
addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String())
addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String())
addModerationImage(images, value.Get("source.data").String())
addModerationImage(images, value.Get("data").String())
addModerationImage(images, value.Get("base64").String())
switch typ {
case "", "text", "input_text", "message":
if value.Get("text").Exists() {
addModerationText(parts, value.Get("text").String())
}
if value.Get("content").Exists() {
collectContentValue(value.Get("content"), parts, images)
}
case "image_url", "input_image", "image":
}
}
}
func addGeminiModerationImage(images *[]string, part gjson.Result) {
if inlineData := part.Get("inline_data"); inlineData.IsObject() {
mimeType := strings.TrimSpace(inlineData.Get("mime_type").String())
data := strings.TrimSpace(inlineData.Get("data").String())
if mimeType != "" && data != "" {
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
}
if inlineData := part.Get("inlineData"); inlineData.IsObject() {
mimeType := strings.TrimSpace(inlineData.Get("mimeType").String())
data := strings.TrimSpace(inlineData.Get("data").String())
if mimeType != "" && data != "" {
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
}
addModerationImage(images, part.Get("file_data.file_uri").String())
addModerationImage(images, part.Get("fileData.fileUri").String())
}
func addModerationImageData(images *[]string, mimeType string, data string) {
mimeType = strings.TrimSpace(mimeType)
data = strings.TrimSpace(data)
if mimeType == "" || data == "" {
return
}
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
func addModerationImage(images *[]string, image string) {
image = strings.TrimSpace(image)
if image == "" {
return
}
if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") {
*images = append(*images, image)
}
}
func normalizeModerationImages(images []string) []string {
out := make([]string, 0, len(images))
seen := make(map[string]struct{}, len(images))
for _, image := range images {
image = strings.TrimSpace(image)
if image == "" {
continue
}
if _, ok := seen[image]; ok {
continue
}
seen[image] = struct{}{}
out = append(out, image)
}
return out
}
func addModerationText(parts *[]string, text string) {
text = strings.TrimSpace(text)
if text == "" {
return
}
if strings.Contains(text, "<system-reminder>") {
return
}
*parts = append(*parts, text)
}
func normalizeContentModerationText(text string) string {
return strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
}
@@ -0,0 +1,36 @@
package service
import (
"regexp"
"strings"
)
var contentModerationSecretPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`),
regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`),
regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`),
regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`),
regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`),
regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`),
regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`),
regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`),
}
func redactContentModerationSecrets(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
out := text
for idx, pattern := range contentModerationSecretPatterns {
switch idx {
case 0:
out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`)
case 1:
out = pattern.ReplaceAllString(out, `${1}[已脱敏]`)
default:
out = pattern.ReplaceAllString(out, `[已脱敏]`)
}
}
return out
}
@@ -0,0 +1,811 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type contentModerationTestSettingRepo struct {
values map[string]string
}
func (r *contentModerationTestSettingRepo) Get(ctx context.Context, key string) (*Setting, error) {
if value, ok := r.values[key]; ok {
return &Setting{Key: key, Value: value}, nil
}
return nil, ErrSettingNotFound
}
func (r *contentModerationTestSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
if value, ok := r.values[key]; ok {
return value, nil
}
return "", ErrSettingNotFound
}
func (r *contentModerationTestSettingRepo) Set(ctx context.Context, key, value string) error {
if r.values == nil {
r.values = map[string]string{}
}
r.values[key] = value
return nil
}
func (r *contentModerationTestSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := map[string]string{}
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *contentModerationTestSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
if r.values == nil {
r.values = map[string]string{}
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *contentModerationTestSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.values, key)
return nil
}
type contentModerationTestRepo struct {
logs []ContentModerationLog
}
func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error {
if log != nil {
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
return 0, nil
}
func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) {
return &ContentModerationCleanupResult{}, nil
}
type contentModerationTestHashCache struct {
hashes map[string]struct{}
recorded []string
checked []string
deleted []string
hasResult bool
hasResultUsed bool
}
type contentModerationTestUserRepo struct {
user *User
updated []User
}
func (r *contentModerationTestUserRepo) Create(ctx context.Context, user *User) error {
panic("unexpected Create call")
}
func (r *contentModerationTestUserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
if r.user == nil {
return nil, ErrUserNotFound
}
clone := *r.user
return &clone, nil
}
func (r *contentModerationTestUserRepo) GetByEmail(ctx context.Context, email string) (*User, error) {
panic("unexpected GetByEmail call")
}
func (r *contentModerationTestUserRepo) GetFirstAdmin(ctx context.Context) (*User, error) {
panic("unexpected GetFirstAdmin call")
}
func (r *contentModerationTestUserRepo) Update(ctx context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
r.updated = append(r.updated, clone)
r.user = &clone
return nil
}
func (r *contentModerationTestUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (r *contentModerationTestUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (r *contentModerationTestUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (r *contentModerationTestUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (r *contentModerationTestUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (r *contentModerationTestUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserIDs call")
}
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserID call")
}
func (r *contentModerationTestUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
panic("unexpected UpdateUserLastActiveAt call")
}
func (r *contentModerationTestUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (r *contentModerationTestUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (r *contentModerationTestUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (r *contentModerationTestUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (r *contentModerationTestUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (r *contentModerationTestUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (r *contentModerationTestUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected RemoveGroupFromUserAllowedGroups call")
}
func (r *contentModerationTestUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (r *contentModerationTestUserRepo) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
panic("unexpected UnbindUserAuthProvider call")
}
func (r *contentModerationTestUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (r *contentModerationTestUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (r *contentModerationTestUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type contentModerationTestAuthCacheInvalidator struct {
userIDs []int64
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByKey(ctx context.Context, key string) {
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
i.userIDs = append(i.userIDs, userID)
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
}
func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
if c.hashes == nil {
c.hashes = map[string]struct{}{}
}
c.hashes[inputHash] = struct{}{}
c.recorded = append(c.recorded, inputHash)
return nil
}
func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
c.checked = append(c.checked, inputHash)
if c.hasResultUsed {
return c.hasResult, nil
}
_, ok := c.hashes[inputHash]
return ok, nil
}
func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
c.deleted = append(c.deleted, inputHash)
if c.hashes == nil {
return false, nil
}
if _, ok := c.hashes[inputHash]; !ok {
return false, nil
}
delete(c.hashes, inputHash)
return true, nil
}
func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
deleted := int64(len(c.hashes))
c.hashes = map[string]struct{}{}
return deleted, nil
}
func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
return int64(len(c.hashes)), nil
}
func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) {
svc := &ContentModerationService{}
cfg := defaultContentModerationConfig()
input := ContentModerationCheckInput{
RequestID: "req-1",
Endpoint: "/v1/chat/completions",
Provider: "openai",
}
log := svc.buildLog(input, cfg, ContentModerationActionAllow, true, "sexual", 0.8, map[string]float64{"sexual": 0.8}, "hello sk-proj-1234567890abcdef", nil, nil, "")
require.NotContains(t, log.InputExcerpt, "sk-proj-1234567890abcdef")
require.Contains(t, log.InputExcerpt, "[已脱敏]")
}
func TestRedactContentModerationSecrets_LongHexAndTokens(t *testing.T) {
input := "你哈市多大事cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554 token=abc123456789xyz Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signaturepart"
out := redactContentModerationSecrets(input)
require.NotContains(t, out, "cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554")
require.NotContains(t, out, "abc123456789xyz")
require.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9")
require.Contains(t, out, "[已脱敏]")
}
func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.NonHitRetentionDays = 30
cfg.normalize()
require.Equal(t, 3, cfg.NonHitRetentionDays)
}
func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"old"},
{"role":"assistant","content":"ok"},
{"role":"user","content":[
{"type":"text","text":"检查这张图"},
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"aGVsbG8="}}
]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "检查这张图", input.Text)
require.Equal(t, []string{"data:image/png;base64,aGVsbG8="}, input.Images)
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
require.Equal(t, "检查这张图", log.InputExcerpt)
require.NotContains(t, log.InputExcerpt, "aGVsbG8=")
}
func TestExtractContentModerationInput_AnthropicKeepsEphemeralUserTextAndSkipsSystemReminders(t *testing.T) {
body := []byte(`{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "<system-reminder>工具说明</system-reminder>"},
{"type": "text", "text": "<system-reminder>Ainder>\n\n"},
{"type": "text", "text": "hid", "cache_control": {"type": "ephemeral"}}
]
}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "hid", input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_OpenAIChatUsesLastUserMessage(t *testing.T) {
body := []byte(`{
"model":"gpt-5.5",
"messages":[
{"role":"system","content":"system prompt"},
{"role":"user","content":"old user"},
{"role":"assistant","content":"ok"},
{"role":"user","content":[{"type":"text","text":"latest user"},{"type":"image_url","image_url":{"url":"https://example.com/a.png"}}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
require.Equal(t, "latest user", input.Text)
require.Equal(t, []string{"https://example.com/a.png"}, input.Images)
require.NotContains(t, input.Text, "old user")
require.NotContains(t, input.Text, "system prompt")
}
func TestExtractContentModerationInput_OpenAIImagesIncludesPromptAndImages(t *testing.T) {
body := []byte(`{
"prompt":"replace background",
"images":[
{"image_url":"https://example.com/source.png"},
{"image_url":"data:image/png;base64,aGVsbG8="}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, body)
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{"https://example.com/source.png", "data:image/png;base64,aGVsbG8="}, input.Images)
}
func TestExtractContentModerationInput_OpenAIResponsesCodexPayloadUsesLastUserMessage(t *testing.T) {
body := []byte(`{
"model":"gpt-5.5",
"instructions":"instructions.....",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer permissions sk-proj-1234567890abcdef"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
],
"prompt_cache_key":"cache-key"
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
require.Equal(t, "last user prompt", input.Text)
require.Empty(t, input.Images)
require.NotContains(t, input.Text, "developer permissions")
require.NotContains(t, input.Text, "first user prompt")
}
func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *testing.T) {
var moderationRequest moderationAPIRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.01},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.RecordNonHits = true
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
body := []byte(`{
"model":"gpt-5.5",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
]
}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
UserID: 1001,
Endpoint: "/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIResponses,
Body: body,
})
require.NoError(t, err)
require.False(t, decision.Blocked)
require.Len(t, repo.logs, 1)
require.False(t, repo.logs[0].Flagged)
require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action)
require.Equal(t, "/responses", repo.logs[0].Endpoint)
require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt)
require.Equal(t, "last user prompt", moderationRequest.Input)
}
func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *testing.T) {
var moderationRequest moderationAPIRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusUnavailableForLegalReasons
cfg.BlockMessage = "内容审计测试阻断"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
body := []byte(`{
"model":"gpt-5.5",
"instructions":"instructions.....",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"environment context"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"latest blocked prompt"}]}
]
}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
UserID: 1001,
Endpoint: "/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIResponses,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionBlock, decision.Action)
require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode)
require.Equal(t, "内容审计测试阻断", decision.Message)
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode)
require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt)
require.Equal(t, "latest blocked prompt", moderationRequest.Input)
}
func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) {
result := buildContentModerationTestAuditResult(&moderationAPIResult{
Flagged: true,
CategoryScores: map[string]float64{
"harassment": 0.65,
},
}, nil)
require.NotNil(t, result)
require.False(t, result.Flagged)
require.Equal(t, "harassment", result.HighestCategory)
require.Equal(t, 0.65, result.HighestScore)
require.Equal(t, 0.65, result.CompositeScore)
require.Equal(t, 0.98, result.Thresholds["harassment"])
}
func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.PreHashCheckEnabled = true
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusConflict
cfg.BlockMessage = "命中历史风险输入"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{}}
content := ContentModerationInput{Text: "blocked prompt"}
content.Normalize()
hashCache.hashes[content.Hash()] = struct{}{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
&contentModerationTestRepo{},
hashCache,
nil,
nil,
nil,
nil,
)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
require.Equal(t, http.StatusConflict, decision.StatusCode)
require.Equal(t, content.Hash(), decision.InputHash)
require.Contains(t, decision.Message, "命中历史风险输入")
require.Contains(t, decision.Message, content.Hash())
require.Len(t, hashCache.checked, 1)
}
func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) {
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.PreHashCheckEnabled = true
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusConflict
cfg.BlockMessage = "命中风险输入"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
hashCache := &contentModerationTestHashCache{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
hashCache,
nil,
nil,
nil,
nil,
)
body := []byte(`{"messages":[{"role":"user","content":"repeat blocked prompt"}]}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionBlock, decision.Action)
require.Equal(t, 1, requestCount)
require.Len(t, hashCache.recorded, 1)
require.Len(t, repo.logs, 1)
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
require.Equal(t, hashCache.recorded[0], decision.InputHash)
require.Equal(t, 1, requestCount)
require.Len(t, repo.logs, 1)
}
func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) {
existingHash := strings.Repeat("a", 64)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
existingHash: {},
}}
svc := &ContentModerationService{hashCache: hashCache}
result, err := svc.DeleteFlaggedInputHash(context.Background(), strings.ToUpper(existingHash))
require.NoError(t, err)
require.Equal(t, existingHash, result.InputHash)
require.True(t, result.Deleted)
require.NotContains(t, hashCache.hashes, existingHash)
require.Equal(t, []string{existingHash}, hashCache.deleted)
result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash)
require.NoError(t, err)
require.Equal(t, existingHash, result.InputHash)
require.False(t, result.Deleted)
}
func TestContentModerationClearFlaggedInputHashesAndStatusCount(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
strings.Repeat("a", 64): {},
strings.Repeat("b", 64): {},
}}
svc := &ContentModerationService{
settingRepo: &contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
hashCache: hashCache,
keyHealth: make(map[string]*contentModerationKeyHealth),
}
status, err := svc.GetStatus(context.Background())
require.NoError(t, err)
require.Equal(t, int64(2), status.FlaggedHashCount)
result, err := svc.ClearFlaggedInputHashes(context.Background())
require.NoError(t, err)
require.Equal(t, int64(2), result.Deleted)
status, err = svc.GetStatus(context.Background())
require.NoError(t, err)
require.Equal(t, int64(0), status.FlaggedHashCount)
}
func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModeObserve
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
hashCache := &contentModerationTestHashCache{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
hashCache,
nil,
nil,
nil,
nil,
)
decision := svc.checkSync(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"bad prompt"}]}`),
}, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false)
require.False(t, decision.Blocked)
require.Len(t, hashCache.recorded, 1)
require.Len(t, repo.logs, 1)
}
func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) {
userID := int64(1001)
cfg := defaultContentModerationConfig()
cfg.BanThreshold = 10
body := buildContentModerationAccountDisabledEmailBody("Sub2API <Admin>", &ContentModerationLog{
UserID: &userID,
UserEmail: "user@example.com",
GroupName: "vip_2",
HighestCategory: "sexual",
HighestScore: 0.926,
ViolationCount: 10,
}, cfg)
require.Contains(t, body, "账户已被自动禁用")
require.Contains(t, body, "封禁详情")
require.Contains(t, body, "账户当前处于封禁状态,所有 API 请求将被拒绝")
require.Contains(t, body, "10 次(阈值 10")
require.Contains(t, body, "sexual / 0.926")
require.Contains(t, body, "Sub2API &lt;Admin&gt;")
}
func TestContentModerationUnbanUser_ActivatesUserAndInvalidatesAuthCache(t *testing.T) {
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusDisabled}}
invalidator := &contentModerationTestAuthCacheInvalidator{}
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
result, err := svc.UnbanUser(context.Background(), 1001)
require.NoError(t, err)
require.Equal(t, int64(1001), result.UserID)
require.Equal(t, StatusActive, result.Status)
require.Len(t, userRepo.updated, 1)
require.Equal(t, StatusActive, userRepo.updated[0].Status)
require.Equal(t, []int64{1001}, invalidator.userIDs)
}
func TestContentModerationUnbanUser_ActiveUserOnlyInvalidatesAuthCache(t *testing.T) {
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusActive}}
invalidator := &contentModerationTestAuthCacheInvalidator{}
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
result, err := svc.UnbanUser(context.Background(), 1001)
require.NoError(t, err)
require.Equal(t, StatusActive, result.Status)
require.Empty(t, userRepo.updated)
require.Equal(t, []int64{1001}, invalidator.userIDs)
}
func contentModerationIntPtr(v int) *int {
return &v
}
@@ -107,6 +107,8 @@ const (
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
+63
View File
@@ -90,6 +90,69 @@ type OpenAIImagesRequest struct {
bodyHash string
}
func (r *OpenAIImagesRequest) ModerationBody() []byte {
if r == nil {
return nil
}
payload := map[string]any{}
if prompt := strings.TrimSpace(r.Prompt); prompt != "" {
payload["prompt"] = prompt
}
images := r.moderationImages()
if len(images) > 0 {
payload["images"] = images
}
if len(payload) == 0 {
return nil
}
body, err := json.Marshal(payload)
if err != nil {
return nil
}
return body
}
func (r *OpenAIImagesRequest) moderationImages() []map[string]string {
if r == nil {
return nil
}
images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1)
for _, imageURL := range r.InputImageURLs {
imageURL = strings.TrimSpace(imageURL)
if imageURL != "" {
images = append(images, map[string]string{"image_url": imageURL})
}
}
for _, upload := range r.Uploads {
if dataURL := upload.ModerationDataURL(); dataURL != "" {
images = append(images, map[string]string{"image_url": dataURL})
}
}
if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" {
images = append(images, map[string]string{"image_url": maskURL})
}
if r.MaskUpload != nil {
if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" {
images = append(images, map[string]string{"image_url": dataURL})
}
}
return images
}
func (u OpenAIImagesUpload) ModerationDataURL() string {
if len(u.Data) == 0 {
return ""
}
contentType := strings.TrimSpace(u.ContentType)
if contentType == "" {
contentType = http.DetectContentType(u.Data)
}
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return ""
}
return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data))
}
func (r *OpenAIImagesRequest) IsEdits() bool {
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
}
@@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Prompt: "replace background",
InputImageURLs: []string{"https://example.com/source.png"},
MaskImageURL: "https://example.com/mask.png",
}
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images)
}
func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Prompt: "replace background",
Uploads: []OpenAIImagesUpload{{
FieldName: "image",
FileName: "source.png",
ContentType: "image/png",
Data: []byte("fake-image-bytes"),
}},
MaskUpload: &OpenAIImagesUpload{
FieldName: "mask",
FileName: "mask.png",
ContentType: "image/png",
Data: []byte("fake-mask-bytes"),
},
}
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{
"data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==",
"data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz",
}, input.Images)
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
require.Equal(t, "replace background", log.InputExcerpt)
require.NotContains(t, log.InputExcerpt, "ZmFrZS")
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct {
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
InitialRequestModel string
BeforeTurn func(turn int) error
BeforeRequest func(turn int, payload []byte, originalModel string) error
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
}
@@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return true
}
for {
if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil {
if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil {
return err
}
}
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
if err := hooks.BeforeTurn(turn); err != nil {
return err
@@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
if msgType != coderws.MessageText {
return payload, nil, nil
}
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil {
turnNo := int(completedTurns.Load()) + 1
if turnNo < 2 {
turnNo = 2
}
requestModel := usageMeta.requestModelForFrame(payload)
if requestModel == "" {
requestModel = capturedSessionModel
}
if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil {
return payload, nil, err
}
}
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现
@@ -456,6 +456,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyChannelMonitorDefaultIntervalSeconds,
SettingKeyAvailableChannelsEnabled,
SettingKeyAffiliateEnabled,
SettingKeyRiskControlEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -545,6 +546,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true",
}, nil
}
@@ -692,6 +695,7 @@ type PublicSettingsInjectionPayload struct {
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
RiskControlEnabled bool `json:"risk_control_enabled"`
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
@@ -745,6 +749,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
RiskControlEnabled: settings.RiskControlEnabled,
}, nil
}
@@ -1232,6 +1237,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// Affiliate (邀请返利) feature switch
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
// 风控中心功能开关
updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled)
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@@ -1903,6 +1911,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Affiliate (邀请返利) feature (default disabled; opt-in)
SettingKeyAffiliateEnabled: "false",
// 风控中心功能(默认关闭,显式启用)
SettingKeyRiskControlEnabled: "false",
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
@@ -2242,6 +2253,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Affiliate (邀请返利) feature (default: disabled; strict true)
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
// 风控中心功能(默认关闭,严格 true 才启用)
result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true"
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
@@ -106,6 +106,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
RiskControlEnabled bool
AffiliateEnabled bool
AffiliateRebateRate float64
AffiliateRebateFreezeHours int
@@ -233,6 +234,9 @@ type PublicSettings struct {
// Affiliate (邀请返利) feature toggle
AffiliateEnabled bool `json:"affiliate_enabled"`
// 风控中心功能开关
RiskControlEnabled bool `json:"risk_control_enabled"`
}
type WeChatConnectOAuthConfig struct {
+1
View File
@@ -509,6 +509,7 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
NewContentModerationService,
NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,