feat(channels): add custom account stats pricing rules
Allow channels to configure independent model pricing for account statistics cost calculation, decoupled from user billing. Backend: - Migration 101: channels.apply_pricing_to_account_stats toggle, channel_account_stats_pricing_rules/model_pricing tables, usage_logs.account_stats_cost column - resolveAccountStatsCost: match rules by group/account, then channel pricing, fallback to original formula when unconfigured - Integrate into both GatewayService.recordUsageCore and OpenAIGatewayService.RecordUsage - Update 8 account stats SQL queries to use COALESCE(account_stats_cost, total_cost) * account_rate_multiplier - 23 unit tests for matching, pricing lookup, and cost calculation Frontend: - Channel edit dialog: toggle + custom rules UI with group/account multi-select and pricing entry cards - API types and i18n (zh/en)
This commit is contained in:
@@ -26,28 +26,30 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Features *string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Features *string `json:"features"`
|
||||
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
|
||||
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
@@ -75,20 +77,28 @@ type pricingIntervalRequest struct {
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type accountStatsPricingRuleRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Pricing []channelModelPricingRequest `json:"pricing"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
@@ -118,6 +128,14 @@ type pricingIntervalResponse struct {
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type accountStatsPricingRuleResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Pricing []channelModelPricingResponse `json:"pricing"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
if ch == nil {
|
||||
return nil
|
||||
@@ -129,7 +147,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
Status: ch.Status,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
Features: ch.Features,
|
||||
FeaturesConfig: ch.FeaturesConfig,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
@@ -150,6 +167,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
for _, p := range ch.ModelPricing {
|
||||
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
||||
}
|
||||
|
||||
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
|
||||
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
|
||||
for _, rule := range ch.AccountStatsPricingRules {
|
||||
ruleResp := accountStatsPricingRuleResponse{
|
||||
ID: rule.ID,
|
||||
Name: rule.Name,
|
||||
GroupIDs: rule.GroupIDs,
|
||||
AccountIDs: rule.AccountIDs,
|
||||
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
|
||||
}
|
||||
if ruleResp.GroupIDs == nil {
|
||||
ruleResp.GroupIDs = []int64{}
|
||||
}
|
||||
if ruleResp.AccountIDs == nil {
|
||||
ruleResp.AccountIDs = []int64{}
|
||||
}
|
||||
for i := range rule.Pricing {
|
||||
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
|
||||
}
|
||||
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -241,6 +281,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
||||
return result
|
||||
}
|
||||
|
||||
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
|
||||
return service.AccountStatsPricingRule{
|
||||
Name: r.Name,
|
||||
GroupIDs: r.GroupIDs,
|
||||
AccountIDs: r.AccountIDs,
|
||||
Pricing: pricingRequestToService(r.Pricing),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// List handles listing channels with pagination
|
||||
@@ -300,16 +349,24 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
|
||||
pricing := pricingRequestToService(req.ModelPricing)
|
||||
|
||||
var statsRules []service.AccountStatsPricingRule
|
||||
for i, r := range req.AccountStatsPricingRules {
|
||||
rule := accountStatsPricingRuleRequestToService(r)
|
||||
rule.SortOrder = i
|
||||
statsRules = append(statsRules, rule)
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
FeaturesConfig: req.FeaturesConfig,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: statsRules,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -335,20 +392,29 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.UpdateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
FeaturesConfig: req.FeaturesConfig,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
input.ModelPricing = &pricing
|
||||
}
|
||||
if req.AccountStatsPricingRules != nil {
|
||||
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
|
||||
for i, r := range *req.AccountStatsPricingRules {
|
||||
rule := accountStatsPricingRuleRequestToService(r)
|
||||
rule.SortOrder = i
|
||||
statsRules = append(statsRules, rule)
|
||||
}
|
||||
input.AccountStatsPricingRules = &statsRules
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
|
||||
@@ -41,14 +41,10 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -71,17 +67,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
}
|
||||
}
|
||||
|
||||
// 设置账号统计定价规则
|
||||
if len(channel.AccountStatsPricingRules) > 0 {
|
||||
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
var modelMappingJSON []byte
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
@@ -89,7 +92,6 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
@@ -103,6 +105,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
}
|
||||
ch.ModelPricing = pricing
|
||||
|
||||
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.AccountStatsPricingRules = statsPricingRules
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
@@ -112,14 +120,10 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -146,6 +150,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
}
|
||||
}
|
||||
|
||||
// 更新账号统计定价规则
|
||||
if channel.AccountStatsPricingRules != nil {
|
||||
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -196,7 +207,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||
)
|
||||
@@ -212,12 +223,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -235,9 +245,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,7 +298,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
@@ -294,12 +309,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -323,9 +337,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 批量加载账号统计定价规则
|
||||
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
@@ -467,28 +488,6 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
return m
|
||||
}
|
||||
|
||||
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// --- 账号统计定价规则 ---
|
||||
|
||||
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
|
||||
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
|
||||
// 1. 查询规则
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
|
||||
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var allRules []service.AccountStatsPricingRule
|
||||
var ruleIDs []int64
|
||||
for rows.Next() {
|
||||
var rule service.AccountStatsPricingRule
|
||||
if err := rows.Scan(
|
||||
&rule.ID, &rule.ChannelID, &rule.Name,
|
||||
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
|
||||
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
|
||||
}
|
||||
ruleIDs = append(ruleIDs, rule.ID)
|
||||
allRules = append(allRules, rule)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
|
||||
}
|
||||
|
||||
// 2. 批量加载规则的模型定价
|
||||
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. 按 channelID 分组并关联定价
|
||||
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
|
||||
for i := range allRules {
|
||||
allRules[i].Pricing = pricingMap[allRules[i].ID]
|
||||
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
|
||||
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
if len(ruleIDs) == 0 {
|
||||
return make(map[int64][]service.ChannelModelPricing), nil
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
|
||||
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
|
||||
pq.Array(ruleIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
|
||||
for rows.Next() {
|
||||
var p service.ChannelModelPricing
|
||||
var ruleID int64
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||
p.Models = []string{}
|
||||
}
|
||||
pricingMap[ruleID] = append(pricingMap[ruleID], p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
|
||||
}
|
||||
return pricingMap, nil
|
||||
}
|
||||
|
||||
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
|
||||
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
|
||||
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result[channelID], nil
|
||||
}
|
||||
|
||||
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
|
||||
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
|
||||
// CASCADE 会自动删除关联的 model_pricing
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
|
||||
); err != nil {
|
||||
return fmt.Errorf("delete old account stats pricing rules: %w", err)
|
||||
}
|
||||
|
||||
for i := range rules {
|
||||
rules[i].ChannelID = channelID
|
||||
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
|
||||
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
|
||||
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
|
||||
err := tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
|
||||
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
|
||||
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||
}
|
||||
|
||||
for j := range rule.Pricing {
|
||||
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
|
||||
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := pricing.Platform
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
ruleID, platform, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert account stats model pricing: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
|
||||
|
||||
// usageLogInsertArgTypes must stay in the same order as:
|
||||
// 1. prepareUsageLogInsert().args
|
||||
@@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text", // model_mapping_chain
|
||||
"text", // billing_tier
|
||||
"text", // billing_mode
|
||||
"numeric", // account_stats_cost
|
||||
"timestamptz", // created_at
|
||||
}
|
||||
|
||||
@@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
@@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
@@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
@@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
@@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*45)
|
||||
args := make([]any, 0, len(preparedList)*46)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
@@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
@@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
@@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$10, $11, $12, $13,
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
modelMappingChain,
|
||||
billingTier,
|
||||
billingMode,
|
||||
log.AccountStatsCost, // account_stats_cost
|
||||
createdAt,
|
||||
},
|
||||
}
|
||||
@@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
SELECT
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
@@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
SELECT
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
@@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
|
||||
account_id,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
@@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
modelExpr := resolveModelDimensionExpression(source)
|
||||
|
||||
@@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
%s
|
||||
@@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
|
||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
|
||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
@@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
modelMappingChain sql.NullString
|
||||
billingTier sql.NullString
|
||||
billingMode sql.NullString
|
||||
accountStatsCost sql.NullFloat64
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&modelMappingChain,
|
||||
&billingTier,
|
||||
&billingMode,
|
||||
&accountStatsCost,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if billingMode.Valid {
|
||||
log.BillingMode = &billingMode.String
|
||||
}
|
||||
if accountStatsCost.Valid {
|
||||
log.AccountStatsCost = &accountStatsCost.Float64
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
@@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
sqlmock.AnyArg(), // account_stats_cost
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||
@@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
sqlmock.AnyArg(), // account_stats_cost
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
@@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullFloat64{}, // account_stats_cost
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullFloat64{}, // account_stats_cost
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullFloat64{}, // account_stats_cost
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||||
//
|
||||
// 匹配优先级(先命中为准):
|
||||
// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
|
||||
// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
|
||||
// 3. nil → 走默认公式
|
||||
func resolveAccountStatsCost(
|
||||
ctx context.Context,
|
||||
channelService *ChannelService,
|
||||
billingService *BillingService,
|
||||
accountID int64,
|
||||
groupID int64,
|
||||
billingModel string,
|
||||
tokens UsageTokens,
|
||||
requestCount int,
|
||||
serviceTier string,
|
||||
) *float64 {
|
||||
if channelService == nil || billingService == nil {
|
||||
return nil
|
||||
}
|
||||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||||
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||||
modelLower := strings.ToLower(billingModel)
|
||||
|
||||
// 优先级 1:自定义规则
|
||||
if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
|
||||
return cost
|
||||
}
|
||||
|
||||
// 优先级 2:渠道已有模型定价
|
||||
return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
|
||||
}
|
||||
|
||||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||||
func tryCustomRules(
|
||||
channel *Channel, accountID, groupID int64,
|
||||
platform, modelLower string, tokens UsageTokens, requestCount int,
|
||||
) *float64 {
|
||||
for _, rule := range channel.AccountStatsPricingRules {
|
||||
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||||
continue
|
||||
}
|
||||
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||||
if pricing == nil {
|
||||
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||||
}
|
||||
return calculateStatsCost(pricing, tokens, requestCount)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
|
||||
func tryChannelPricing(
|
||||
ctx context.Context, channelService *ChannelService,
|
||||
groupID int64, billingModel string, tokens UsageTokens, requestCount int,
|
||||
) *float64 {
|
||||
pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
return calculateStatsCost(pricing, tokens, requestCount)
|
||||
}
|
||||
|
||||
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||||
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||||
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||||
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||||
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, id := range rule.AccountIDs {
|
||||
if id == accountID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, id := range rule.GroupIDs {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wildcardMatch 通配符匹配候选项(用于排序)
|
||||
type wildcardMatch struct {
|
||||
prefixLen int
|
||||
pricing *ChannelModelPricing
|
||||
}
|
||||
|
||||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||||
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||||
// 精确匹配优先
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
continue
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||||
var matches []wildcardMatch
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
continue
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
ml := strings.ToLower(m)
|
||||
if !strings.HasSuffix(ml, "*") {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSuffix(ml, "*")
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].prefixLen > matches[j].prefixLen
|
||||
})
|
||||
return matches[0].pricing
|
||||
}
|
||||
|
||||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||||
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||||
if queryPlatform == "" || pricingPlatform == "" {
|
||||
return true
|
||||
}
|
||||
return queryPlatform == pricingPlatform
|
||||
}
|
||||
|
||||
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||||
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
switch pricing.BillingMode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return calculatePerRequestStatsCost(pricing, requestCount)
|
||||
default:
|
||||
return calculateTokenStatsCost(pricing, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePerRequestStatsCost 按次/图片计费。
|
||||
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||||
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||||
return nil
|
||||
}
|
||||
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||||
return &cost
|
||||
}
|
||||
|
||||
// calculateTokenStatsCost Token 计费。
|
||||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||||
deref := func(p *float64) float64 {
|
||||
if p == nil {
|
||||
return 0
|
||||
}
|
||||
return *p
|
||||
}
|
||||
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
||||
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
||||
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
||||
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
||||
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
||||
if cost == 0 {
|
||||
return nil
|
||||
}
|
||||
return &cost
|
||||
}
|
||||
@@ -0,0 +1,430 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// matchAccountStatsRule
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{}
|
||||
require.False(t, matchAccountStatsRule(rule, 1, 10))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
|
||||
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
|
||||
require.True(t, matchAccountStatsRule(rule, 999, 20))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.True(t, matchAccountStatsRule(rule, 999, 10))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.False(t, matchAccountStatsRule(rule, 999, 999))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// findPricingForModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestFindPricingForModel(t *testing.T) {
|
||||
exactPricing := ChannelModelPricing{
|
||||
ID: 1,
|
||||
Models: []string{"claude-opus-4"},
|
||||
}
|
||||
wildcardPricing := ChannelModelPricing{
|
||||
ID: 2,
|
||||
Models: []string{"claude-*"},
|
||||
}
|
||||
platformPricing := ChannelModelPricing{
|
||||
ID: 3,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4o"},
|
||||
}
|
||||
emptyPlatformPricing := ChannelModelPricing{
|
||||
ID: 4,
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
list []ChannelModelPricing
|
||||
platform string
|
||||
model string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
list: []ChannelModelPricing{exactPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 1,
|
||||
},
|
||||
{
|
||||
name: "exact match case insensitive",
|
||||
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 5,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
list: []ChannelModelPricing{wildcardPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 2,
|
||||
},
|
||||
{
|
||||
name: "exact match takes priority over wildcard",
|
||||
list: []ChannelModelPricing{wildcardPricing, exactPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 1,
|
||||
},
|
||||
{
|
||||
name: "platform mismatch skipped",
|
||||
list: []ChannelModelPricing{platformPricing},
|
||||
platform: "anthropic",
|
||||
model: "gpt-4o",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty platform in pricing matches any",
|
||||
list: []ChannelModelPricing{emptyPlatformPricing},
|
||||
platform: "gemini",
|
||||
model: "gemini-2.5-pro",
|
||||
wantID: 4,
|
||||
},
|
||||
{
|
||||
name: "empty platform in query matches any pricing platform",
|
||||
list: []ChannelModelPricing{platformPricing},
|
||||
platform: "",
|
||||
model: "gpt-4o",
|
||||
wantID: 3,
|
||||
},
|
||||
{
|
||||
name: "no match at all",
|
||||
list: []ChannelModelPricing{exactPricing, wildcardPricing},
|
||||
platform: "anthropic",
|
||||
model: "gpt-4o",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty list returns nil",
|
||||
list: nil,
|
||||
model: "claude-opus-4",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "longer wildcard prefix wins over shorter",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
|
||||
},
|
||||
{
|
||||
name: "shorter wildcard used when longer does not match",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-sonnet-4",
|
||||
wantID: 10, // only "claude-*" matches
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findPricingForModel(tt.list, tt.platform, tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantID, result.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// calculateStatsCost
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCalculateStatsCost_NilPricing(t *testing.T) {
|
||||
result := calculateStatsCost(nil, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
CacheWritePrice: testPtrFloat64(0.003),
|
||||
CacheReadPrice: testPtrFloat64(0.0005),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheCreationTokens: 200,
|
||||
CacheReadTokens: 300,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
|
||||
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||
require.InDelta(t, 0.95, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
ImageOutputPrice: testPtrFloat64(0.01),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
ImageOutputTokens: 10,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
|
||||
require.InDelta(t, 0.3, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheCreationTokens: 200,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// Only input contributes: 100*0.001 = 0.1
|
||||
require.InDelta(t, 0.1, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{} // all zeros
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
// totalCost == 0 → returns nil (does not override, falls back to default formula)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
|
||||
result := calculateStatsCost(pricing, tokens, 3)
|
||||
require.NotNil(t, result)
|
||||
// 0.05 * 3 = 0.15
|
||||
require.InDelta(t, 0.15, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
// PerRequestPrice is nil
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0),
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.10),
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 2)
|
||||
require.NotNil(t, result)
|
||||
// 0.10 * 2 = 0.20
|
||||
require.InDelta(t, 0.20, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
// PerRequestPrice is nil
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
|
||||
// BillingMode is empty string (default) → falls into token billing
|
||||
pricing := &ChannelModelPricing{
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tryCustomRules — 多规则顺序测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
|
||||
require.InDelta(t, 2.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
AccountIDs: []int64{888}, // 不匹配
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1}, // 匹配
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
|
||||
require.InDelta(t, 5.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
AccountIDs: []int64{888},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
|
||||
require.Nil(t, result) // 账号和分组都不匹配
|
||||
}
|
||||
|
||||
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
|
||||
}
|
||||
@@ -49,21 +49,25 @@ type Channel struct {
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
||||
FeaturesConfig map[string]any
|
||||
|
||||
// 账号统计定价
|
||||
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
|
||||
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
||||
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||
if c == nil || c.FeaturesConfig == nil {
|
||||
return false
|
||||
}
|
||||
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := wse[platform].(bool)
|
||||
return ok && enabled
|
||||
// AccountStatsPricingRule 账号统计定价规则
|
||||
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
|
||||
// 多条规则按 SortOrder 排序,先命中为准。
|
||||
type AccountStatsPricingRule struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Name string
|
||||
GroupIDs []int64
|
||||
AccountIDs []int64
|
||||
SortOrder int
|
||||
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
@@ -192,6 +196,26 @@ func (c *Channel) Clone() *Channel {
|
||||
cp.ModelMapping[platform] = inner
|
||||
}
|
||||
}
|
||||
if c.AccountStatsPricingRules != nil {
|
||||
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
|
||||
for i, rule := range c.AccountStatsPricingRules {
|
||||
cp.AccountStatsPricingRules[i] = rule
|
||||
if rule.GroupIDs != nil {
|
||||
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
|
||||
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
|
||||
}
|
||||
if rule.AccountIDs != nil {
|
||||
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
|
||||
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
|
||||
}
|
||||
if rule.Pricing != nil {
|
||||
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
|
||||
for j := range rule.Pricing {
|
||||
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
|
||||
@@ -416,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// GetGroupPlatform 获取分组的平台标识(从缓存)
|
||||
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cache.groupPlatform[groupID]
|
||||
}
|
||||
|
||||
// channelLookup 热路径公共查找结果
|
||||
type channelLookup struct {
|
||||
cache *channelCache
|
||||
@@ -656,16 +665,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
FeaturesConfig: input.FeaturesConfig,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: input.AccountStatsPricingRules,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
@@ -754,8 +764,11 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
if input.FeaturesConfig != nil {
|
||||
channel.FeaturesConfig = input.FeaturesConfig
|
||||
if input.ApplyPricingToAccountStats != nil {
|
||||
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
|
||||
}
|
||||
if input.AccountStatsPricingRules != nil {
|
||||
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -922,27 +935,29 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
FeaturesConfig map[string]any
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
ApplyPricingToAccountStats bool
|
||||
AccountStatsPricingRules []AccountStatsPricingRule
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
FeaturesConfig map[string]any
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
ApplyPricingToAccountStats *bool
|
||||
AccountStatsPricingRules *[]AccountStatsPricingRule
|
||||
}
|
||||
|
||||
@@ -7559,6 +7559,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
// 计算账号统计定价费用
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, billingModel,
|
||||
UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
},
|
||||
1, // requestCount
|
||||
"", // serviceTier: Anthropic 平台不使用 service tier
|
||||
)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
|
||||
@@ -4569,6 +4569,15 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
// 计算账号统计定价费用
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, billingModel,
|
||||
tokens, 1, serviceTier,
|
||||
)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
|
||||
@@ -146,6 +146,8 @@ type UsageLog struct {
|
||||
RateMultiplier float64
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
|
||||
AccountRateMultiplier *float64
|
||||
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
|
||||
AccountStatsCost *float64
|
||||
|
||||
BillingType int8
|
||||
RequestType RequestType
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
-- Account statistics pricing: allow channels to configure custom pricing for account cost tracking.
|
||||
|
||||
-- 1. Channel-level toggle
|
||||
ALTER TABLE channels ADD COLUMN IF NOT EXISTS apply_pricing_to_account_stats BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- 2. Account stats pricing rules (ordered list per channel)
|
||||
CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_rules (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
name VARCHAR(100) NOT NULL DEFAULT '',
|
||||
group_ids BIGINT[] NOT NULL DEFAULT '{}',
|
||||
account_ids BIGINT[] NOT NULL DEFAULT '{}',
|
||||
sort_order INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cas_pricing_rules_channel_id ON channel_account_stats_pricing_rules(channel_id);
|
||||
|
||||
-- 3. Model pricing for each rule (same structure as channel_model_pricing)
|
||||
CREATE TABLE IF NOT EXISTS channel_account_stats_model_pricing (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
rule_id BIGINT NOT NULL REFERENCES channel_account_stats_pricing_rules(id) ON DELETE CASCADE,
|
||||
platform VARCHAR(50) NOT NULL DEFAULT '',
|
||||
models JSONB NOT NULL DEFAULT '[]',
|
||||
billing_mode VARCHAR(20) NOT NULL DEFAULT 'token',
|
||||
input_price NUMERIC(20,10),
|
||||
output_price NUMERIC(20,10),
|
||||
cache_write_price NUMERIC(20,10),
|
||||
cache_read_price NUMERIC(20,10),
|
||||
image_output_price NUMERIC(20,10),
|
||||
per_request_price NUMERIC(20,10),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cas_model_pricing_rule_id ON channel_account_stats_model_pricing(rule_id);
|
||||
|
||||
-- 4. Usage logs: pre-computed account stats cost (NULL = use default formula)
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS account_stats_cost NUMERIC(20,10);
|
||||
Reference in New Issue
Block a user