feat(rpm): RPM 限流模块优化
P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
@@ -61,8 +61,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingCache := repository.NewBillingCache(redisClient)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
|
||||
userRPMCache := repository.NewUserRPMCache(redisClient)
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
@@ -104,7 +105,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -137,7 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
@@ -184,6 +185,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||
registry := payment.ProvideRegistry()
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
@@ -211,16 +221,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||
registry := payment.ProvideRegistry()
|
||||
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
@@ -249,6 +249,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
||||
pricingSvc := service.NewPricingService(cfg, nil)
|
||||
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||
|
||||
+12
-1
@@ -79,6 +79,8 @@ type Group struct {
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
|
||||
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
||||
// 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流
|
||||
RpmLimit int `json:"rpm_limit,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
|
||||
}
|
||||
}
|
||||
case group.FieldRpmLimit:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RpmLimit = int(value.Int64)
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -599,6 +607,9 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("messages_dispatch_model_config=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rpm_limit=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -76,6 +76,8 @@ const (
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
|
||||
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
|
||||
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
|
||||
FieldRpmLimit = "rpm_limit"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -181,6 +183,7 @@ var Columns = []string{
|
||||
FieldRequirePrivacySet,
|
||||
FieldDefaultMappedModel,
|
||||
FieldMessagesDispatchModelConfig,
|
||||
FieldRpmLimit,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -258,6 +261,8 @@ var (
|
||||
DefaultMappedModelValidator func(string) error
|
||||
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
|
||||
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
|
||||
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
|
||||
DefaultRpmLimit int
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRpmLimit orders the results by the rpm_limit field.
|
||||
func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
|
||||
func RpmLimit(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
|
||||
func RpmLimitEQ(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
|
||||
func RpmLimitNEQ(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitIn applies the In predicate on the "rpm_limit" field.
|
||||
func RpmLimitIn(vs ...int) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...))
|
||||
}
|
||||
|
||||
// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
|
||||
func RpmLimitNotIn(vs ...int) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...))
|
||||
}
|
||||
|
||||
// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
|
||||
func RpmLimitGT(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
|
||||
func RpmLimitGTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
|
||||
func RpmLimitLT(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
|
||||
func RpmLimitLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate {
|
||||
_c.mutation.SetRpmLimit(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetRpmLimit(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultMessagesDispatchModelConfig
|
||||
_c.mutation.SetMessagesDispatchModelConfig(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||
v := group.DefaultRpmLimit
|
||||
_c.mutation.SetRpmLimit(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
|
||||
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||
_node.MessagesDispatchModelConfig = value
|
||||
}
|
||||
if value, ok := _c.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
_node.RpmLimit = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert {
|
||||
u.Set(group.FieldRpmLimit, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldRpmLimit)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||
func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert {
|
||||
u.Add(group.FieldRpmLimit, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||
func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateRpmLimit()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||
func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateRpmLimit()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate {
|
||||
_u.mutation.ResetRpmLimit()
|
||||
_u.mutation.SetRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetRpmLimit(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||
func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate {
|
||||
_u.mutation.AddRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||
_spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne {
|
||||
_u.mutation.ResetRpmLimit()
|
||||
_u.mutation.SetRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRpmLimit(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||
func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne {
|
||||
_u.mutation.AddRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||
_spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -491,6 +491,7 @@ var (
|
||||
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "rpm_limit", Type: field.TypeInt, Default: 0},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
@@ -1276,7 +1277,7 @@ var (
|
||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"},
|
||||
{Name: "signup_source", Type: field.TypeString, Default: "email"},
|
||||
{Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
|
||||
@@ -1284,6 +1285,7 @@ var (
|
||||
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "rpm_limit", Type: field.TypeInt, Default: 0},
|
||||
}
|
||||
// UsersTable holds the schema information for the "users" table.
|
||||
UsersTable = &schema.Table{
|
||||
|
||||
+176
-2
@@ -10102,6 +10102,8 @@ type GroupMutation struct {
|
||||
require_privacy_set *bool
|
||||
default_mapped_model *string
|
||||
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
|
||||
rpm_limit *int
|
||||
addrpm_limit *int
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -11690,6 +11692,62 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
|
||||
m.messages_dispatch_model_config = nil
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (m *GroupMutation) SetRpmLimit(i int) {
|
||||
m.rpm_limit = &i
|
||||
m.addrpm_limit = nil
|
||||
}
|
||||
|
||||
// RpmLimit returns the value of the "rpm_limit" field in the mutation.
|
||||
func (m *GroupMutation) RpmLimit() (r int, exists bool) {
|
||||
v := m.rpm_limit
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldRpmLimit requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
|
||||
}
|
||||
return oldValue.RpmLimit, nil
|
||||
}
|
||||
|
||||
// AddRpmLimit adds i to the "rpm_limit" field.
|
||||
func (m *GroupMutation) AddRpmLimit(i int) {
|
||||
if m.addrpm_limit != nil {
|
||||
*m.addrpm_limit += i
|
||||
} else {
|
||||
m.addrpm_limit = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
|
||||
func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) {
|
||||
v := m.addrpm_limit
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetRpmLimit resets all changes to the "rpm_limit" field.
|
||||
func (m *GroupMutation) ResetRpmLimit() {
|
||||
m.rpm_limit = nil
|
||||
m.addrpm_limit = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -12048,7 +12106,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 30)
|
||||
fields := make([]string, 0, 31)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -12139,6 +12197,9 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.messages_dispatch_model_config != nil {
|
||||
fields = append(fields, group.FieldMessagesDispatchModelConfig)
|
||||
}
|
||||
if m.rpm_limit != nil {
|
||||
fields = append(fields, group.FieldRpmLimit)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -12207,6 +12268,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.DefaultMappedModel()
|
||||
case group.FieldMessagesDispatchModelConfig:
|
||||
return m.MessagesDispatchModelConfig()
|
||||
case group.FieldRpmLimit:
|
||||
return m.RpmLimit()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -12276,6 +12339,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
case group.FieldMessagesDispatchModelConfig:
|
||||
return m.OldMessagesDispatchModelConfig(ctx)
|
||||
case group.FieldRpmLimit:
|
||||
return m.OldRpmLimit(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -12495,6 +12560,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetMessagesDispatchModelConfig(v)
|
||||
return nil
|
||||
case group.FieldRpmLimit:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetRpmLimit(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -12536,6 +12608,9 @@ func (m *GroupMutation) AddedFields() []string {
|
||||
if m.addsort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
if m.addrpm_limit != nil {
|
||||
fields = append(fields, group.FieldRpmLimit)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -12566,6 +12641,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedFallbackGroupIDOnInvalidRequest()
|
||||
case group.FieldSortOrder:
|
||||
return m.AddedSortOrder()
|
||||
case group.FieldRpmLimit:
|
||||
return m.AddedRpmLimit()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -12652,6 +12729,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddSortOrder(v)
|
||||
return nil
|
||||
case group.FieldRpmLimit:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddRpmLimit(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||
}
|
||||
@@ -12838,6 +12922,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldMessagesDispatchModelConfig:
|
||||
m.ResetMessagesDispatchModelConfig()
|
||||
return nil
|
||||
case group.FieldRpmLimit:
|
||||
m.ResetRpmLimit()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -32681,6 +32768,8 @@ type UserMutation struct {
|
||||
balance_notify_extra_emails *string
|
||||
total_recharged *float64
|
||||
addtotal_recharged *float64
|
||||
rpm_limit *int
|
||||
addrpm_limit *int
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -33772,6 +33861,62 @@ func (m *UserMutation) ResetTotalRecharged() {
|
||||
m.addtotal_recharged = nil
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (m *UserMutation) SetRpmLimit(i int) {
|
||||
m.rpm_limit = &i
|
||||
m.addrpm_limit = nil
|
||||
}
|
||||
|
||||
// RpmLimit returns the value of the "rpm_limit" field in the mutation.
|
||||
func (m *UserMutation) RpmLimit() (r int, exists bool) {
|
||||
v := m.rpm_limit
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldRpmLimit returns the old "rpm_limit" field's value of the User entity.
|
||||
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldRpmLimit requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
|
||||
}
|
||||
return oldValue.RpmLimit, nil
|
||||
}
|
||||
|
||||
// AddRpmLimit adds i to the "rpm_limit" field.
|
||||
func (m *UserMutation) AddRpmLimit(i int) {
|
||||
if m.addrpm_limit != nil {
|
||||
*m.addrpm_limit += i
|
||||
} else {
|
||||
m.addrpm_limit = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
|
||||
func (m *UserMutation) AddedRpmLimit() (r int, exists bool) {
|
||||
v := m.addrpm_limit
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetRpmLimit resets all changes to the "rpm_limit" field.
|
||||
func (m *UserMutation) ResetRpmLimit() {
|
||||
m.rpm_limit = nil
|
||||
m.addrpm_limit = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -34454,7 +34599,7 @@ func (m *UserMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UserMutation) Fields() []string {
|
||||
fields := make([]string, 0, 22)
|
||||
fields := make([]string, 0, 23)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, user.FieldCreatedAt)
|
||||
}
|
||||
@@ -34521,6 +34666,9 @@ func (m *UserMutation) Fields() []string {
|
||||
if m.total_recharged != nil {
|
||||
fields = append(fields, user.FieldTotalRecharged)
|
||||
}
|
||||
if m.rpm_limit != nil {
|
||||
fields = append(fields, user.FieldRpmLimit)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -34573,6 +34721,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.BalanceNotifyExtraEmails()
|
||||
case user.FieldTotalRecharged:
|
||||
return m.TotalRecharged()
|
||||
case user.FieldRpmLimit:
|
||||
return m.RpmLimit()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -34626,6 +34776,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
|
||||
return m.OldBalanceNotifyExtraEmails(ctx)
|
||||
case user.FieldTotalRecharged:
|
||||
return m.OldTotalRecharged(ctx)
|
||||
case user.FieldRpmLimit:
|
||||
return m.OldRpmLimit(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown User field %s", name)
|
||||
}
|
||||
@@ -34789,6 +34941,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetTotalRecharged(v)
|
||||
return nil
|
||||
case user.FieldRpmLimit:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetRpmLimit(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown User field %s", name)
|
||||
}
|
||||
@@ -34809,6 +34968,9 @@ func (m *UserMutation) AddedFields() []string {
|
||||
if m.addtotal_recharged != nil {
|
||||
fields = append(fields, user.FieldTotalRecharged)
|
||||
}
|
||||
if m.addrpm_limit != nil {
|
||||
fields = append(fields, user.FieldRpmLimit)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -34825,6 +34987,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedBalanceNotifyThreshold()
|
||||
case user.FieldTotalRecharged:
|
||||
return m.AddedTotalRecharged()
|
||||
case user.FieldRpmLimit:
|
||||
return m.AddedRpmLimit()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -34862,6 +35026,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddTotalRecharged(v)
|
||||
return nil
|
||||
case user.FieldRpmLimit:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddRpmLimit(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown User numeric field %s", name)
|
||||
}
|
||||
@@ -34994,6 +35165,9 @@ func (m *UserMutation) ResetField(name string) error {
|
||||
case user.FieldTotalRecharged:
|
||||
m.ResetTotalRecharged()
|
||||
return nil
|
||||
case user.FieldRpmLimit:
|
||||
m.ResetRpmLimit()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown User field %s", name)
|
||||
}
|
||||
|
||||
@@ -595,6 +595,10 @@ func init() {
|
||||
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
|
||||
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
|
||||
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
|
||||
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||
groupDescRpmLimit := groupFields[27].Descriptor()
|
||||
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
@@ -1575,6 +1579,10 @@ func init() {
|
||||
userDescTotalRecharged := userFields[18].Descriptor()
|
||||
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
|
||||
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
|
||||
// userDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||
userDescRpmLimit := userFields[19].Descriptor()
|
||||
// user.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||
user.DefaultRpmLimit = userDescRpmLimit.Default.(int)
|
||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||
_ = userallowedgroupFields
|
||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||
|
||||
@@ -145,6 +145,11 @@ func (Group) Fields() []ent.Field {
|
||||
Default(domain.OpenAIMessagesDispatchModelConfig{}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
|
||||
|
||||
// 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。
|
||||
field.Int("rpm_limit").
|
||||
Default(0).
|
||||
Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -108,6 +108,10 @@ func (User) Fields() []ent.Field {
|
||||
field.Float("total_recharged").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0),
|
||||
|
||||
// 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。
|
||||
field.Int("rpm_limit").
|
||||
Default(0),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+12
-1
@@ -61,6 +61,8 @@ type User struct {
|
||||
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
|
||||
// TotalRecharged holds the value of the "total_recharged" field.
|
||||
TotalRecharged float64 `json:"total_recharged,omitempty"`
|
||||
// RpmLimit holds the value of the "rpm_limit" field.
|
||||
RpmLimit int `json:"rpm_limit,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the UserQuery when eager-loading is set.
|
||||
Edges UserEdges `json:"edges"`
|
||||
@@ -226,7 +228,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case user.FieldID, user.FieldConcurrency:
|
||||
case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -391,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.TotalRecharged = value.Float64
|
||||
}
|
||||
case user.FieldRpmLimit:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RpmLimit = int(value.Int64)
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -569,6 +577,9 @@ func (_m *User) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("total_recharged=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rpm_limit=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -59,6 +59,8 @@ const (
|
||||
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
|
||||
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
|
||||
FieldTotalRecharged = "total_recharged"
|
||||
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
|
||||
FieldRpmLimit = "rpm_limit"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -203,6 +205,7 @@ var Columns = []string{
|
||||
FieldBalanceNotifyThreshold,
|
||||
FieldBalanceNotifyExtraEmails,
|
||||
FieldTotalRecharged,
|
||||
FieldRpmLimit,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -271,6 +274,8 @@ var (
|
||||
DefaultBalanceNotifyExtraEmails string
|
||||
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
|
||||
DefaultTotalRecharged float64
|
||||
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
|
||||
DefaultRpmLimit int
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the User queries.
|
||||
@@ -391,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRpmLimit orders the results by the rpm_limit field.
|
||||
func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -165,6 +165,11 @@ func TotalRecharged(v float64) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
|
||||
}
|
||||
|
||||
// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
|
||||
func RpmLimit(v int) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1295,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
|
||||
}
|
||||
|
||||
// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
|
||||
func RpmLimitEQ(v int) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
|
||||
func RpmLimitNEQ(v int) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitIn applies the In predicate on the "rpm_limit" field.
|
||||
func RpmLimitIn(vs ...int) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldRpmLimit, vs...))
|
||||
}
|
||||
|
||||
// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
|
||||
func RpmLimitNotIn(vs ...int) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...))
|
||||
}
|
||||
|
||||
// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
|
||||
func RpmLimitGT(v int) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
|
||||
func RpmLimitGTE(v int) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
|
||||
func RpmLimitLT(v int) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
|
||||
func RpmLimitLTE(v int) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldRpmLimit, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@@ -325,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_c *UserCreate) SetRpmLimit(v int) *UserCreate {
|
||||
_c.mutation.SetRpmLimit(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetRpmLimit(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -604,6 +618,10 @@ func (_c *UserCreate) defaults() error {
|
||||
v := user.DefaultTotalRecharged
|
||||
_c.mutation.SetTotalRecharged(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||
v := user.DefaultRpmLimit
|
||||
_c.mutation.SetRpmLimit(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -687,6 +705,9 @@ func (_c *UserCreate) check() error {
|
||||
if _, ok := _c.mutation.TotalRecharged(); !ok {
|
||||
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -802,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||
_node.TotalRecharged = value
|
||||
}
|
||||
if value, ok := _c.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
|
||||
_node.RpmLimit = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1362,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert {
|
||||
u.Set(user.FieldRpmLimit, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateRpmLimit() *UserUpsert {
|
||||
u.SetExcluded(user.FieldRpmLimit)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||
func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert {
|
||||
u.Add(user.FieldRpmLimit, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1771,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||
func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.AddRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateRpmLimit()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -2346,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRpmLimit adds v to the "rpm_limit" field.
|
||||
func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.AddRpmLimit(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateRpmLimit()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -389,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate {
|
||||
_u.mutation.ResetRpmLimit()
|
||||
_u.mutation.SetRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetRpmLimit(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||
func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate {
|
||||
_u.mutation.AddRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1008,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
||||
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||
_spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1930,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne {
|
||||
_u.mutation.ResetRpmLimit()
|
||||
_u.mutation.SetRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRpmLimit(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRpmLimit adds value to the "rpm_limit" field.
|
||||
func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne {
|
||||
_u.mutation.AddRpmLimit(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2579,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
||||
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRpmLimit(); ok {
|
||||
_spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -104,6 +104,7 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
|
||||
@@ -162,6 +162,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -216,6 +218,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -249,6 +253,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -278,6 +284,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -310,6 +318,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@@ -183,6 +183,17 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
return map[string]any{"user_id": userID}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
|
||||
user, err := s.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.UserRPMStatus{
|
||||
UserRPMUsed: 0,
|
||||
UserRPMLimit: user.RPMLimit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
|
||||
s.boundAuthIdentityFor = userID
|
||||
copied := input
|
||||
@@ -276,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
|
||||
s.lastListAccounts.platform = platform
|
||||
s.lastListAccounts.accountType = accountType
|
||||
|
||||
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
// 分组 RPM 上限(0 = 不限制)
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
|
||||
RequirePrivacySet *bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
|
||||
RPMLimit *int `json:"rpm_limit"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||
RPMLimit: req.RPMLimit,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||
RPMLimit: req.RPMLimit,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
|
||||
type BatchSetGroupRPMOverridesRequest struct {
|
||||
Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
|
||||
// PUT /api/v1/admin/groups/:id/rpm-overrides
|
||||
func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRPMOverridesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
|
||||
}
|
||||
|
||||
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rpm-overrides
|
||||
func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
|
||||
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
RPMLimit *int `json:"rpm_limit"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
RPMLimit: req.RPMLimit,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
RPMLimit: req.RPMLimit,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
|
||||
"migrated_keys": result.MigratedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
|
||||
// GET /api/v1/admin/users/:id/rpm-status
|
||||
func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, status)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
|
||||
TotalRecharged: u.TotalRecharged,
|
||||
RPMLimit: u.RPMLimit,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: g.RequireOAuthOnly,
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
RPMLimit: g.RPMLimit,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ type SystemSettings struct {
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
|
||||
@@ -26,6 +26,9 @@ type User struct {
|
||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
|
||||
TotalRecharged float64 `json:"total_recharged"`
|
||||
|
||||
// RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
@@ -108,6 +111,9 @@ type Group struct {
|
||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
|
||||
// RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.errorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string) {
|
||||
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
|
||||
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = "Billing service temporarily unavailable. Please retry later."
|
||||
}
|
||||
return http.StatusServiceUnavailable, "billing_service_error", msg
|
||||
return http.StatusServiceUnavailable, "billing_service_error", msg, 0
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||
}
|
||||
// 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
|
||||
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
|
||||
if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
retrySeconds := 60 - int(time.Now().Unix()%60)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
).Warn("gateway.billing_error_missing_message")
|
||||
msg = "Billing error"
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
return http.StatusForbidden, "billing_error", msg, 0
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
|
||||
status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
|
||||
require.Equal(t, http.StatusTooManyRequests, status)
|
||||
require.Equal(t, "rate_limit_exceeded", code)
|
||||
require.NotEmpty(t, msg)
|
||||
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
|
||||
require.LessOrEqual(t, retryAfter, 60)
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
|
||||
status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
|
||||
require.Equal(t, http.StatusTooManyRequests, status)
|
||||
require.Equal(t, "rate_limit_exceeded", code)
|
||||
require.NotEmpty(t, msg)
|
||||
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
|
||||
require.LessOrEqual(t, retryAfter, 60)
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
|
||||
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
|
||||
for _, err := range []error{
|
||||
service.ErrAPIKeyRateLimit5hExceeded,
|
||||
service.ErrAPIKeyRateLimit1dExceeded,
|
||||
service.ErrAPIKeyRateLimit7dExceeded,
|
||||
} {
|
||||
status, code, _, _ := billingErrorDetails(err)
|
||||
require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
|
||||
require.Equal(t, "rate_limit_exceeded", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
|
||||
status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
|
||||
require.Equal(t, http.StatusServiceUnavailable, status)
|
||||
require.Equal(t, "billing_service_error", code)
|
||||
require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
|
||||
status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
|
||||
require.Equal(t, http.StatusForbidden, status)
|
||||
require.Equal(t, "billing_error", code)
|
||||
require.NotEmpty(t, msg)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.responsesErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message := billingErrorDetails(err)
|
||||
status, _, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
googleError(c, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
user.FieldSignupSource,
|
||||
user.FieldLastLoginAt,
|
||||
user.FieldLastActiveAt,
|
||||
user.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldMessagesDispatchModelConfig,
|
||||
group.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||
TotalRecharged: u.TotalRecharged,
|
||||
RPMLimit: u.RpmLimit,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
||||
RPMLimit: g.RpmLimit,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
|
||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||
return &userGroupRateRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserIDs 批量获取多个用户的专属分组倍率。
|
||||
// 返回结构:map[userID]map[groupID]rate
|
||||
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
||||
result := make(map[int64]map[int64]float64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT user_id, group_id, rate_multiplier
|
||||
FROM user_group_rate_multipliers
|
||||
WHERE user_id = ANY($1)
|
||||
WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||
WHERE ugr.group_id = $1
|
||||
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||
var rate sql.NullFloat64
|
||||
var rpm sql.NullInt32
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rate.Valid {
|
||||
v := rate.Float64
|
||||
entry.RateMultiplier = &v
|
||||
}
|
||||
if rpm.Valid {
|
||||
v := int(rpm.Int32)
|
||||
entry.RPMOverride = &v
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rate float64
|
||||
var rate sql.NullFloat64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rate, nil
|
||||
if !rate.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
v := rate.Float64
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
|
||||
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rpm sql.NullInt32
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rpm.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
v := int(rpm.Int32)
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
|
||||
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
|
||||
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
|
||||
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
|
||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||
if len(rates) == 0 {
|
||||
// 如果传入空 map,删除该用户的所有专属倍率
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||
userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
var clearGroupIDs []int64
|
||||
upsertGroupIDs := make([]int64, 0, len(rates))
|
||||
upsertRates := make([]float64, 0, len(rates))
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
clearGroupIDs = append(clearGroupIDs, groupID)
|
||||
} else {
|
||||
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
||||
upsertRates = append(upsertRates, *rate)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
if len(toDelete) > 0 {
|
||||
if len(clearGroupIDs) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1 AND group_id = ANY($2)
|
||||
`, userID, pq.Array(clearGroupIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
|
||||
userID, pq.Array(toDelete)); err != nil {
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||
userID, pq.Array(clearGroupIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
if len(upsertGroupIDs) > 0 {
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
SELECT
|
||||
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
|
||||
// 语义:
|
||||
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
|
||||
// - 出现的用户行:upsert rate_multiplier。
|
||||
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||
keepUserIDs := make([]int64, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||
}
|
||||
|
||||
// 未在 entries 列表中的行:清空 rate_multiplier。
|
||||
if len(keepUserIDs) == 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 清空后若整行 NULL 则删除。
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
userIDs := make([]int64, len(entries))
|
||||
rates := make([]float64, len(entries))
|
||||
for i, e := range entries {
|
||||
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
|
||||
// 语义:
|
||||
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
|
||||
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
|
||||
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
|
||||
keepUserIDs := make([]int64, 0, len(entries))
|
||||
var clearUserIDs []int64
|
||||
upsertUserIDs := make([]int64, 0, len(entries))
|
||||
upsertValues := make([]int32, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||
if e.RPMOverride == nil {
|
||||
clearUserIDs = append(clearUserIDs, e.UserID)
|
||||
} else {
|
||||
upsertUserIDs = append(upsertUserIDs, e.UserID)
|
||||
upsertValues = append(upsertValues, int32(*e.RPMOverride))
|
||||
}
|
||||
}
|
||||
|
||||
// 未在 entries 列表中的行:清空 rpm_override。
|
||||
if len(keepUserIDs) == 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 显式 clear 的行。
|
||||
if len(clearUserIDs) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id = ANY($2)
|
||||
`, groupID, pq.Array(clearUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 清空后若整行 NULL 则删除。
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(upsertUserIDs) > 0 {
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
|
||||
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
|
||||
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
|
||||
ON CONFLICT (user_id, group_id)
|
||||
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
|
||||
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
|
||||
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属条目
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||||
// DeleteByUserID 删除指定用户的所有专属条目
|
||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
|
||||
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||
SetRpmLimit(userIn.RPMLimit).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
||||
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||||
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||||
SetTotalRecharged(userIn.TotalRecharged)
|
||||
SetTotalRecharged(userIn.TotalRecharged).
|
||||
SetRpmLimit(userIn.RPMLimit)
|
||||
if userIn.SignupSource != "" {
|
||||
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 用户/分组级 RPM 计数器 Redis 实现。
|
||||
//
|
||||
// 设计说明:
|
||||
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
|
||||
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
|
||||
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
|
||||
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
|
||||
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
|
||||
const (
|
||||
userGroupRPMKeyPrefix = "rpm:ug:"
|
||||
userRPMKeyPrefix = "rpm:u:"
|
||||
|
||||
userRPMKeyTTL = 120 * time.Second
|
||||
)
|
||||
|
||||
type userRPMCacheImpl struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
|
||||
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
|
||||
return &userRPMCacheImpl{rdb: rdb}
|
||||
}
|
||||
|
||||
// minuteTS 获取当前 Redis 服务端分钟时间戳。
|
||||
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
|
||||
t, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
return t.Unix() / 60, nil
|
||||
}
|
||||
|
||||
// atomicIncr 原子 INCR+EXPIRE。
|
||||
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
|
||||
pipe := c.rdb.TxPipeline()
|
||||
incr := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, userRPMKeyTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("user rpm increment: %w", err)
|
||||
}
|
||||
return int(incr.Val()), nil
|
||||
}
|
||||
|
||||
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
|
||||
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||
return c.atomicIncr(ctx, key)
|
||||
}
|
||||
|
||||
// IncrementUserRPM 递增用户分钟计数。
|
||||
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||
return c.atomicIncr(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
|
||||
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("user group rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
|
||||
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("user rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
NewUserRPMCache,
|
||||
NewUserMsgQueueCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
|
||||
@@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"role": "user",
|
||||
"balance": 12.5,
|
||||
"concurrency": 5,
|
||||
"rpm_limit": 0,
|
||||
"status": "active",
|
||||
"allowed_groups": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
@@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"require_oauth_only": false,
|
||||
"require_privacy_set": false,
|
||||
"rpm_limit": 0,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"force_email_on_third_party_signup": false,
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||
@@ -889,6 +892,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"custom_endpoints": [],
|
||||
"default_concurrency": 0,
|
||||
"default_balance": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||
@@ -1084,7 +1088,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@@ -221,6 +221,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
||||
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
||||
|
||||
// User attribute values
|
||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||
@@ -244,6 +245,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
||||
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
||||
groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
|
||||
groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,7 @@ type AdminService interface {
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
// codeType is optional - pass empty string to return all types.
|
||||
// Also returns totalRecharged (sum of all positive balance top-ups).
|
||||
@@ -50,6 +52,8 @@ type AdminService interface {
|
||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
|
||||
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// API Key management (admin)
|
||||
@@ -114,6 +118,7 @@ type CreateUserInput struct {
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
RPMLimit int
|
||||
AllowedGroups []int64
|
||||
}
|
||||
|
||||
@@ -124,6 +129,7 @@ type UpdateUserInput struct {
|
||||
Notes *string
|
||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
@@ -199,6 +205,8 @@ type CreateGroupInput struct {
|
||||
RequireOAuthOnly bool
|
||||
RequirePrivacySet bool
|
||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
||||
// RPMLimit 分组 RPM 上限(0 = 不限制)
|
||||
RPMLimit int
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
|
||||
RequireOAuthOnly *bool
|
||||
RequirePrivacySet *bool
|
||||
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
|
||||
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
|
||||
RPMLimit *int
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
|
||||
MigratedKeys int64 // 迁移的 Key 数量
|
||||
}
|
||||
|
||||
// UserRPMStatus describes a user's current per-minute RPM usage.
|
||||
type UserRPMStatus struct {
|
||||
UserRPMUsed int `json:"user_rpm_used"`
|
||||
UserRPMLimit int `json:"user_rpm_limit"`
|
||||
PerGroup []UserGroupRPMStatus `json:"per_group"`
|
||||
}
|
||||
|
||||
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
|
||||
type UserGroupRPMStatus struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
Source string `json:"source"` // "group" | "override"
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
@@ -463,6 +489,8 @@ const (
|
||||
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo UserRepository
|
||||
@@ -472,6 +500,7 @@ type adminServiceImpl struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
userRPMCache UserRPMCache
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
@@ -496,6 +525,7 @@ func NewAdminService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
userRPMCache UserRPMCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
@@ -514,6 +544,7 @@ func NewAdminService(
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
userRPMCache: userRPMCache,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
RPMLimit: input.RPMLimit,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
}
|
||||
@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
oldConcurrency := user.Concurrency
|
||||
oldStatus := user.Status
|
||||
oldRole := user.Role
|
||||
oldRPMLimit := user.RPMLimit
|
||||
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
user.Concurrency = *input.Concurrency
|
||||
}
|
||||
|
||||
if input.RPMLimit != nil {
|
||||
user.RPMLimit = *input.RPMLimit
|
||||
}
|
||||
|
||||
if input.AllowedGroups != nil {
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
|
||||
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
|
||||
if s.userRPMCache == nil {
|
||||
return nil, ErrRPMStatusUnavailable
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
|
||||
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDSet := make(map[int64]struct{})
|
||||
for _, key := range keys {
|
||||
if key.GroupID != nil && *key.GroupID > 0 {
|
||||
groupIDSet[*key.GroupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groupIDSet))
|
||||
for groupID := range groupIDSet {
|
||||
groupIDs = append(groupIDs, groupID)
|
||||
}
|
||||
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
|
||||
|
||||
var perGroup []UserGroupRPMStatus
|
||||
for _, groupID := range groupIDs {
|
||||
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
|
||||
if getErr != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
|
||||
}
|
||||
|
||||
entry := UserGroupRPMStatus{
|
||||
GroupID: groupID,
|
||||
Used: used,
|
||||
}
|
||||
|
||||
if s.groupRepo != nil {
|
||||
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
|
||||
entry.GroupName = group.Name
|
||||
entry.Limit = group.RPMLimit
|
||||
entry.Source = "group"
|
||||
} else if groupErr != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
|
||||
}
|
||||
}
|
||||
|
||||
if s.userGroupRateRepo != nil {
|
||||
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
|
||||
if overrideErr != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
|
||||
} else if override != nil {
|
||||
entry.Limit = *override
|
||||
entry.Source = "override"
|
||||
}
|
||||
}
|
||||
|
||||
perGroup = append(perGroup, entry)
|
||||
}
|
||||
|
||||
return &UserRPMStatus{
|
||||
UserRPMUsed: userRPMUsed,
|
||||
UserRPMLimit: user.RPMLimit,
|
||||
PerGroup: perGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
// Return mock data for now
|
||||
return map[string]any{
|
||||
@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
RequirePrivacySet: input.RequirePrivacySet,
|
||||
DefaultMappedModel: input.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
|
||||
RPMLimit: input.RPMLimit,
|
||||
}
|
||||
sanitizeGroupMessagesDispatchFields(group)
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.MessagesDispatchModelConfig != nil {
|
||||
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
|
||||
}
|
||||
if input.RPMLimit != nil {
|
||||
group.RPMLimit = *input.RPMLimit
|
||||
}
|
||||
sanitizeGroupMessagesDispatchFields(group)
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
|
||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.RPMOverride != nil && *e.RPMOverride < 0 {
|
||||
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
|
||||
}
|
||||
}
|
||||
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
|
||||
return err
|
||||
}
|
||||
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
|
||||
syncedGroupID int64
|
||||
syncedEntries []GroupRateMultiplierInput
|
||||
syncGroupErr error
|
||||
|
||||
rpmSyncedGroupID int64
|
||||
rpmSyncedEntries []GroupRPMOverrideInput
|
||||
rpmSyncErr error
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||
panic("unexpected GetRPMOverrideByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.getByGroupIDErr != nil {
|
||||
return nil, s.getByGroupIDErr
|
||||
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
|
||||
return s.syncGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
|
||||
s.rpmSyncedGroupID = groupID
|
||||
s.rpmSyncedEntries = entries
|
||||
return s.rpmSyncErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||
panic("unexpected ClearGroupRPMOverrides call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||
return s.deleteByGroupErr
|
||||
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||
10: {
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, int64(1), entries[0].UserID)
|
||||
require.Equal(t, "alice", entries[0].UserName)
|
||||
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||
require.NotNil(t, entries[0].RateMultiplier)
|
||||
require.Equal(t, 1.5, *entries[0].RateMultiplier)
|
||||
require.Equal(t, int64(2), entries[1].UserID)
|
||||
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||
require.NotNil(t, entries[1].RateMultiplier)
|
||||
require.Equal(t, 0.8, *entries[1].RateMultiplier)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "sync failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
|
||||
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
override := 20
|
||||
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
|
||||
|
||||
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
|
||||
require.Equal(t, entries, repo.rpmSyncedEntries)
|
||||
})
|
||||
|
||||
t.Run("rejects negative override as bad request", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
negative := -1
|
||||
|
||||
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
|
||||
{UserID: 2, RPMOverride: &negative},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
|
||||
require.Zero(t, repo.rpmSyncedGroupID)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
RPMLimit: 10,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
groupRepo: repo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
rpmLimit := 60
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
RPMLimit: &rpmLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.Equal(t, 60, repo.updated.RPMLimit)
|
||||
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||
panic("unexpected GetRPMOverrideByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
|
||||
panic("unexpected SyncGroupRateMultipliers call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
|
||||
panic("unexpected SyncGroupRPMOverrides call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||
panic("unexpected ClearGroupRPMOverrides call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type rpmStatusUserRepoStub struct {
|
||||
UserRepository
|
||||
user *User
|
||||
}
|
||||
|
||||
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
|
||||
return s.user, nil
|
||||
}
|
||||
|
||||
type rpmStatusAPIKeyRepoStub struct {
|
||||
APIKeyRepository
|
||||
keys []APIKey
|
||||
}
|
||||
|
||||
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
|
||||
}
|
||||
|
||||
type rpmStatusGroupRepoStub struct {
|
||||
GroupRepository
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
return s.groups[id], nil
|
||||
}
|
||||
|
||||
type rpmStatusRateRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
overrides map[int64]*int
|
||||
}
|
||||
|
||||
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
|
||||
return s.overrides[groupID], nil
|
||||
}
|
||||
|
||||
type rpmStatusCacheStub struct {
|
||||
UserRPMCache
|
||||
userUsed int
|
||||
groupUsed map[int64]int
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
|
||||
return s.groupUsed[groupID], nil
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
|
||||
return s.userUsed, nil
|
||||
}
|
||||
|
||||
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
|
||||
groupOneID := int64(1)
|
||||
groupTwoID := int64(2)
|
||||
override := 7
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: &rpmStatusUserRepoStub{user: &User{
|
||||
ID: 42,
|
||||
RPMLimit: 20,
|
||||
}},
|
||||
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
|
||||
{ID: 100, UserID: 42, GroupID: &groupTwoID},
|
||||
{ID: 101, UserID: 42, GroupID: &groupOneID},
|
||||
{ID: 102, UserID: 42, GroupID: &groupTwoID},
|
||||
{ID: 103, UserID: 42},
|
||||
}},
|
||||
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
|
||||
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
|
||||
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
|
||||
}},
|
||||
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
|
||||
groupTwoID: &override,
|
||||
}},
|
||||
userRPMCache: &rpmStatusCacheStub{
|
||||
userUsed: 5,
|
||||
groupUsed: map[int64]int{
|
||||
groupOneID: 3,
|
||||
groupTwoID: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
status, err := svc.GetUserRPMStatus(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &UserRPMStatus{
|
||||
UserRPMUsed: 5,
|
||||
UserRPMLimit: 20,
|
||||
PerGroup: []UserGroupRPMStatus{
|
||||
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
|
||||
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
|
||||
},
|
||||
}, status)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
|
||||
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
|
||||
type rpmUserRepoStub struct {
|
||||
*userRepoStub
|
||||
lastUpdated *User
|
||||
}
|
||||
|
||||
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *user
|
||||
s.lastUpdated = &clone
|
||||
if s.userRepoStub != nil {
|
||||
s.userRepoStub.user = &clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
|
||||
repo := &rpmUserRepoStub{userRepoStub: base}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: &redeemRepoStub{},
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
newRPM := 60
|
||||
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
|
||||
RPMLimit: &newRPM,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 60, updated.RPMLimit)
|
||||
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
|
||||
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
|
||||
repo := &rpmUserRepoStub{userRepoStub: base}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: &redeemRepoStub{},
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
newName := "new"
|
||||
sameRPM := 10
|
||||
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
|
||||
Username: &newName,
|
||||
RPMLimit: &sameRPM,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
|
||||
}
|
||||
@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
|
||||
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
||||
TotalRecharged float64 `json:"total_recharged"`
|
||||
|
||||
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
|
||||
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
|
||||
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
|
||||
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
||||
|
||||
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
|
||||
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||
snapshot := s.snapshotFromAPIKey(ctx, apiKey)
|
||||
if snapshot == nil {
|
||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
|
||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
if apiKey == nil || apiKey.User == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
||||
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
||||
TotalRecharged: apiKey.User.TotalRecharged,
|
||||
RPMLimit: apiKey.User.RPMLimit,
|
||||
},
|
||||
}
|
||||
|
||||
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
|
||||
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
|
||||
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
|
||||
if err == nil && override != nil {
|
||||
snapshot.User.UserGroupRPMOverride = override
|
||||
}
|
||||
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
|
||||
RPMLimit: apiKey.Group.RPMLimit,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
||||
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
||||
TotalRecharged: snapshot.User.TotalRecharged,
|
||||
RPMLimit: snapshot.User.RPMLimit,
|
||||
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
|
||||
},
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
|
||||
RPMLimit: snapshot.Group.RPMLimit,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
||||
},
|
||||
}
|
||||
|
||||
snapshot := svc.snapshotFromAPIKey(apiKey)
|
||||
snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
|
||||
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
||||
|
||||
require.NotNil(t, roundTrip)
|
||||
|
||||
@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
|
||||
|
||||
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &User{
|
||||
Email: email,
|
||||
@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
|
||||
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
|
||||
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@@ -87,6 +90,8 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||||
userRPMCache UserRPMCache
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
@@ -104,12 +109,22 @@ type BillingCacheService struct {
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
|
||||
func NewBillingCacheService(
|
||||
cache BillingCache,
|
||||
userRepo UserRepository,
|
||||
subRepo UserSubscriptionRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRPMCache UserRPMCache,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cfg *config.Config,
|
||||
) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
apiKeyRateLimitLoader: apiKeyRepo,
|
||||
userRPMCache: userRPMCache,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
}
|
||||
}
|
||||
|
||||
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
|
||||
if err := s.checkRPM(ctx, user, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
|
||||
//
|
||||
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
|
||||
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
|
||||
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
|
||||
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
|
||||
//
|
||||
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
|
||||
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
|
||||
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
|
||||
if s == nil || s.userRPMCache == nil || user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
|
||||
if group != nil {
|
||||
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
|
||||
var override *int
|
||||
if user.UserGroupRPMOverride != nil {
|
||||
override = user.UserGroupRPMOverride
|
||||
} else if s.userGroupRateRepo != nil {
|
||||
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm override lookup failed for user=%d group=%d: %v",
|
||||
user.ID, group.ID, err,
|
||||
)
|
||||
} else {
|
||||
override = dbOverride
|
||||
}
|
||||
}
|
||||
|
||||
if override != nil {
|
||||
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
|
||||
if *override > 0 {
|
||||
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||||
if incErr != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
|
||||
user.ID, group.ID, incErr,
|
||||
)
|
||||
// fail-open
|
||||
} else if count > *override {
|
||||
return ErrGroupRPMExceeded
|
||||
}
|
||||
}
|
||||
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
|
||||
} else if group.RPMLimit > 0 {
|
||||
// 无 override,检查 group.rpm_limit。
|
||||
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
|
||||
user.ID, group.ID, err,
|
||||
)
|
||||
// fail-open
|
||||
} else if count > group.RPMLimit {
|
||||
return ErrGroupRPMExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 第二层:用户级全局硬上限(始终生效) ──
|
||||
if user.RPMLimit > 0 {
|
||||
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm increment (user) failed for user=%d: %v",
|
||||
user.ID, err,
|
||||
)
|
||||
return nil // fail-open
|
||||
}
|
||||
if count > user.RPMLimit {
|
||||
return ErrUserRPMExceeded
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,253 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
|
||||
type userRPMCacheStub struct {
|
||||
userGroupCalls int32
|
||||
userCalls int32
|
||||
|
||||
userGroupCounts []int // 依次返回的计数值
|
||||
userGroupErr error
|
||||
userCounts []int
|
||||
userErr error
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||||
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
|
||||
if s.userGroupErr != nil {
|
||||
return 0, s.userGroupErr
|
||||
}
|
||||
if idx < len(s.userGroupCounts) {
|
||||
return s.userGroupCounts[idx], nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
|
||||
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
|
||||
if s.userErr != nil {
|
||||
return 0, s.userErr
|
||||
}
|
||||
if idx < len(s.userCounts) {
|
||||
return s.userCounts[idx], nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
|
||||
type rpmOverrideRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
|
||||
override *int
|
||||
err error
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.override, nil
|
||||
}
|
||||
|
||||
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
|
||||
t.Helper()
|
||||
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
|
||||
// 我们只直接测 checkRPM。
|
||||
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
|
||||
override := 2
|
||||
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
|
||||
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{override: &override}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
|
||||
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
|
||||
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
|
||||
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
|
||||
override := 100 // override 很高
|
||||
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{override: &override}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
|
||||
zero := 0
|
||||
// user 计数: 依次返回 1..6
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
|
||||
repo := &rpmOverrideRepoStub{override: &zero}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 5}
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
|
||||
}
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
|
||||
"override=0 跳过分组但 user 全局上限仍应生效")
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
|
||||
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
|
||||
zero := 0
|
||||
cache := &userRPMCacheStub{}
|
||||
repo := &rpmOverrideRepoStub{override: &zero}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0} // user 也不限
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
}
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
|
||||
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
|
||||
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
|
||||
group := &Group{ID: 10, RPMLimit: 5}
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
|
||||
|
||||
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
|
||||
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0}
|
||||
group := &Group{ID: 10, RPMLimit: 10}
|
||||
|
||||
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 2}
|
||||
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
|
||||
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
|
||||
cache := &userRPMCacheStub{}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0}
|
||||
group := &Group{ID: 10, RPMLimit: 0}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
}
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0}
|
||||
group := &Group{ID: 10, RPMLimit: 5}
|
||||
|
||||
// Redis 故障时应 fail-open,不拒绝请求
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 2}
|
||||
|
||||
// 无 group(纯用户级限流场景),不应查询 rpm_override。
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
|
||||
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
|
||||
cache := &userRPMCacheStub{}
|
||||
repo := &rpmOverrideRepoStub{}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
|
||||
@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
|
||||
@@ -170,9 +170,10 @@ const (
|
||||
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
|
||||
|
||||
// 第三方认证来源默认授予配置
|
||||
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
|
||||
|
||||
@@ -59,6 +59,10 @@ type Group struct {
|
||||
DefaultMappedModel string
|
||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
||||
|
||||
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
|
||||
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
|
||||
RPMLimit int
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
|
||||
@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
|
||||
@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
|
||||
func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
|
||||
if err != nil || value == "" {
|
||||
return 0
|
||||
}
|
||||
if v, err := strconv.Atoi(value); err == nil && v >= 0 {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
|
||||
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
|
||||
@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyOIDCConnectUserInfoUsernamePath: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultUserRPMLimit: "0",
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||
@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
|
||||
}
|
||||
|
||||
if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
|
||||
result.DefaultUserRPMLimit = rpm
|
||||
}
|
||||
|
||||
// 解析浮点数类型
|
||||
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
|
||||
result.DefaultBalance = balance
|
||||
|
||||
@@ -106,6 +106,7 @@ type SystemSettings struct {
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
|
||||
// Model fallback configuration
|
||||
|
||||
@@ -49,6 +49,15 @@ type User struct {
|
||||
BalanceNotifyExtraEmails []NotifyEmailEntry
|
||||
TotalRecharged float64
|
||||
|
||||
// RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit
|
||||
// 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。
|
||||
RPMLimit int
|
||||
|
||||
// UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。
|
||||
// nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用,
|
||||
// 避免每请求查 DB。字段不持久化到数据库。
|
||||
UserGroupRPMOverride *int
|
||||
|
||||
APIKeys []APIKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
@@ -2,14 +2,16 @@ package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserGroupRateEntry 分组下用户专属倍率条目
|
||||
// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。
|
||||
// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。
|
||||
type UserGroupRateEntry struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserEmail string `json:"user_email"`
|
||||
UserNotes string `json:"user_notes"`
|
||||
UserStatus string `json:"user_status"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
UserID int64 `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserEmail string `json:"user_email"`
|
||||
UserNotes string `json:"user_notes"`
|
||||
UserStatus string `json:"user_status"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
|
||||
RPMOverride *int `json:"rpm_override,omitempty"`
|
||||
}
|
||||
|
||||
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
|
||||
@@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct {
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
}
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。
|
||||
// RPMOverride 为 *int 以支持清除(nil)语义。
|
||||
type GroupRPMOverrideInput struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
RPMOverride *int `json:"rpm_override"`
|
||||
}
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。
|
||||
type UserGroupRateRepository interface {
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// 如果未设置专属倍率,返回 nil
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||
GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error)
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier
|
||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据)
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分)
|
||||
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
||||
// SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。
|
||||
// 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。
|
||||
SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
|
||||
|
||||
// ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL)
|
||||
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用)
|
||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
|
||||
// DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用)
|
||||
DeleteByUserID(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserRPMCache 用户/分组级 RPM 计数器接口。
|
||||
//
|
||||
// 与账号级 RPMCache 的区别:
|
||||
// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。
|
||||
// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。
|
||||
// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。
|
||||
type UserRPMCache interface {
|
||||
// IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。
|
||||
// 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。
|
||||
IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
|
||||
|
||||
// IncrementUserRPM 原子递增用户级分钟计数并返回最新值。
|
||||
// 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。
|
||||
IncrementUserRPM(ctx context.Context, userID int64) (count int, err error)
|
||||
|
||||
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。
|
||||
GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
|
||||
|
||||
// GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。
|
||||
GetUserRPM(ctx context.Context, userID int64) (count int, err error)
|
||||
}
|
||||
@@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
return NewEmailQueueService(emailService, 3)
|
||||
}
|
||||
|
||||
// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL.
|
||||
func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||||
return NewOAuthRefreshAPI(accountRepo, tokenCache)
|
||||
}
|
||||
|
||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
@@ -383,6 +388,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
|
||||
func ProvideBillingCacheService(
|
||||
cache BillingCache,
|
||||
userRepo UserRepository,
|
||||
subRepo UserSubscriptionRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
rpmCache UserRPMCache,
|
||||
rateRepo UserGroupRateRepository,
|
||||
cfg *config.Config,
|
||||
) *BillingCacheService {
|
||||
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -399,7 +417,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewDashboardService,
|
||||
ProvidePricingService,
|
||||
NewBillingService,
|
||||
NewBillingCacheService,
|
||||
ProvideBillingCacheService,
|
||||
NewAnnouncementService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
@@ -411,7 +429,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewOAuthRefreshAPI,
|
||||
ProvideOAuthRefreshAPI,
|
||||
ProvideGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
ProvideAntigravityTokenProvider,
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Add per-group Requests-Per-Minute limit.
|
||||
-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。
|
||||
-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。
|
||||
-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。
|
||||
ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
|
||||
|
||||
COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。';
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Add per-user Requests-Per-Minute cap.
|
||||
-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。
|
||||
-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。
|
||||
-- 计数键:rpm:u:{user_id}:{minute}。
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
|
||||
|
||||
COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。';
|
||||
@@ -0,0 +1,16 @@
|
||||
-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空,
|
||||
-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。
|
||||
-- 语义:
|
||||
-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值
|
||||
-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率
|
||||
-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值
|
||||
-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制)
|
||||
-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。
|
||||
ALTER TABLE user_group_rate_multipliers
|
||||
ADD COLUMN IF NOT EXISTS rpm_override integer NULL;
|
||||
|
||||
ALTER TABLE user_group_rate_multipliers
|
||||
ALTER COLUMN rate_multiplier DROP NOT NULL;
|
||||
|
||||
COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。';
|
||||
COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。';
|
||||
Reference in New Issue
Block a user