feat(risk-control): add content moderation audit
This commit is contained in:
@@ -27,15 +27,16 @@ import (
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
contentModerationService *service.ContentModerationService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
||||
@@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler(
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
contentModerationService *service.ContentModerationService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler(
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
contentModerationService: contentModerationService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
@@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
@@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
@@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
@@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||
if conn == nil || decision == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
message := strings.TrimSpace(decision.Message)
|
||||
if message == "" {
|
||||
message = "content moderation blocked this request"
|
||||
}
|
||||
payload, err := json.Marshal(gin.H{
|
||||
"event_id": "evt_content_moderation_blocked",
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"code": contentModerationErrorCode(decision),
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`)
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, payload)
|
||||
}
|
||||
|
||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||
if err == nil {
|
||||
return "-", "-"
|
||||
|
||||
Reference in New Issue
Block a user