Merge pull request #2202 from Michael-Jetson/main

新增三大功能:兑换码邀请返利、批量修改用户并发数、Markdown页面渲染
This commit is contained in:
Wesley Liddick
2026-05-07 09:35:14 +08:00
committed by GitHub
26 changed files with 765 additions and 18 deletions
+34
View File
@@ -33,6 +33,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
@@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
cleaned := make([]int64, 0, len(userIDs))
for _, uid := range userIDs {
if uid > 0 {
cleaned = append(cleaned, uid)
}
}
if len(cleaned) == 0 {
return 0, nil
}
var affected int
var err error
switch mode {
case "set":
affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value)
case "add":
affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value)
default:
return 0, errors.New("invalid mode: must be 'set' or 'add'")
}
if err != nil {
return 0, err
}
if s.authCacheInvalidator != nil {
for _, uid := range cleaned {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid)
}
}
return affected, nil
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
panic("unexpected")
}
@@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i
panic("unexpected UpdateConcurrency call")
}
func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
if s.existsErr != nil {
return false, s.existsErr
@@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64)
return 0, nil
}
func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
@@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
return ok, nil
}
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
@@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
case redeemActionRedeem:
// Code exists but unused — skip creation, proceed to redeem
}
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
@@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -28,6 +29,15 @@ const (
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
type ctxKeySkipRedeemAffiliate struct{}
// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level
// affiliate rebate. Used by payment fulfillment which handles rebate separately
// via applyAffiliateRebateForOrder (with audit-log deduplication).
func ContextSkipRedeemAffiliate(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true)
}
// RedeemCache defines cache operations for redeem service
type RedeemCache interface {
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
@@ -80,6 +90,7 @@ type RedeemService struct {
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
affiliateService *AffiliateService
}
// NewRedeemService 创建兑换码服务实例
@@ -91,6 +102,7 @@ func NewRedeemService(
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
affiliateService *AffiliateService,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
@@ -100,6 +112,7 @@ func NewRedeemService(
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
affiliateService: affiliateService,
}
}
@@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 事务提交成功后失效缓存
s.invalidateRedeemCaches(ctx, userID, redeemCode)
// 余额类正数兑换码触发邀请返利(best-effort,失败不影响兑换结果)
if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 {
s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value)
}
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
@@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64
}
}
func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) {
if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil {
return
}
if s.affiliateService == nil {
return
}
if !s.affiliateService.IsEnabled(ctx) {
return
}
rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount)
if err != nil {
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err)
return
}
if rebate > 0 {
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID)
}
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
@@ -1542,6 +1542,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
return value == "true"
}
// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting.
func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems)
if err != nil {
return "[]"
}
return value
}
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
+2
View File
@@ -96,6 +96,8 @@ type UserRepository interface {
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
UpdateConcurrency(ctx context.Context, id int64, amount int) error
BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error)
BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error)
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
@@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
out := make([]UserAuthIdentityRecord, len(m.identities))