Merge pull request #2202 from Michael-Jetson/main
新增三大功能:兑换码邀请返利、批量修改用户并发数、Markdown页面渲染
This commit is contained in:
@@ -175,6 +175,10 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||
return len(userIDs), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
@@ -998,17 +998,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(item.URL) == "" {
|
||||
response.BadRequest(c, "Custom menu item URL is required")
|
||||
return
|
||||
}
|
||||
if len(item.URL) > maxMenuItemURLLen {
|
||||
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
|
||||
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
|
||||
return
|
||||
urlTrimmed := strings.TrimSpace(item.URL)
|
||||
if strings.HasPrefix(urlTrimmed, "md:") {
|
||||
// Markdown page mode: URL = "md:<slug>"
|
||||
slug := strings.TrimPrefix(urlTrimmed, "md:")
|
||||
if slug == "" {
|
||||
response.BadRequest(c, "Custom menu item markdown slug cannot be empty (use md:slug format)")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if urlTrimmed == "" {
|
||||
response.BadRequest(c, "Custom menu item URL is required (use md:slug for markdown pages)")
|
||||
return
|
||||
}
|
||||
if len(item.URL) > maxMenuItemURLLen {
|
||||
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(urlTrimmed); err != nil {
|
||||
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL or md:<slug>")
|
||||
return
|
||||
}
|
||||
}
|
||||
if item.Visibility != "user" && item.Visibility != "admin" {
|
||||
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
|
||||
|
||||
@@ -477,3 +477,63 @@ func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
|
||||
|
||||
response.Success(c, status)
|
||||
}
|
||||
|
||||
// BatchUpdateConcurrency 批量修改用户并发数
|
||||
// POST /api/v1/admin/users/batch-concurrency
|
||||
type BatchUpdateConcurrencyRequest struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
All bool `json:"all"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Mode string `json:"mode" binding:"required,oneof=set add"`
|
||||
}
|
||||
|
||||
func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
|
||||
var req BatchUpdateConcurrencyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.All && len(req.UserIDs) == 0 {
|
||||
response.BadRequest(c, "user_ids is required unless all=true")
|
||||
return
|
||||
}
|
||||
if len(req.UserIDs) > 500 {
|
||||
response.BadRequest(c, "user_ids cannot exceed 500")
|
||||
return
|
||||
}
|
||||
|
||||
var userIDs []int64
|
||||
if req.All {
|
||||
// Fetch all user IDs via pagination
|
||||
page := 1
|
||||
const pageSize = 500
|
||||
for {
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, service.UserListFilters{}, "id", "asc")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
for _, u := range users {
|
||||
userIDs = append(userIDs, u.ID)
|
||||
}
|
||||
if len(users) < pageSize {
|
||||
break
|
||||
}
|
||||
page++
|
||||
}
|
||||
} else {
|
||||
userIDs = req.UserIDs
|
||||
}
|
||||
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, gin.H{"affected": 0})
|
||||
return
|
||||
}
|
||||
|
||||
affected, err := h.adminService.BatchUpdateConcurrency(c.Request.Context(), userIDs, req.Concurrency, req.Mode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"affected": affected})
|
||||
}
|
||||
|
||||
@@ -2798,6 +2798,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
|
||||
panic("unexpected BatchSetConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
|
||||
panic("unexpected BatchAddConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||
return map[int64]*time.Time{}, nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type CustomMenuItem struct {
|
||||
Label string `json:"label"`
|
||||
IconSVG string `json:"icon_svg"`
|
||||
URL string `json:"url"`
|
||||
PageSlug string `json:"page_slug,omitempty"`
|
||||
Visibility string `json:"visibility"` // "user" or "admin"
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var validSlugPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
|
||||
|
||||
const maxPageFileSize = 1 << 20 // 1MB
|
||||
|
||||
type PageHandler struct {
|
||||
pagesDir string
|
||||
settingService *service.SettingService
|
||||
}
|
||||
|
||||
func NewPageHandler(dataDir string, settingService *service.SettingService) *PageHandler {
|
||||
pagesDir := filepath.Join(dataDir, "pages")
|
||||
_ = os.MkdirAll(pagesDir, 0755)
|
||||
return &PageHandler{pagesDir: pagesDir, settingService: settingService}
|
||||
}
|
||||
|
||||
// GetPageContent serves raw markdown content for a given slug.
|
||||
// GET /api/v1/pages/:slug
|
||||
func (h *PageHandler) GetPageContent(c *gin.Context) {
|
||||
slug := c.Param("slug")
|
||||
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
|
||||
response.BadRequest(c, "Invalid page slug")
|
||||
return
|
||||
}
|
||||
|
||||
// Visibility check: slug must be configured in custom_menu_items
|
||||
// and the user must have permission based on visibility setting
|
||||
if !h.checkSlugVisibility(c, slug) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
|
||||
return
|
||||
}
|
||||
|
||||
filePath := filepath.Join(h.pagesDir, slug+".md")
|
||||
cleaned := filepath.Clean(filePath)
|
||||
if !strings.HasPrefix(cleaned, filepath.Clean(h.pagesDir)) {
|
||||
response.BadRequest(c, "Invalid page slug")
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(cleaned)
|
||||
if err != nil || info.IsDir() {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
|
||||
return
|
||||
}
|
||||
if info.Size() > maxPageFileSize {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "page too large"})
|
||||
return
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(cleaned)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read page"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/markdown; charset=utf-8", content)
|
||||
}
|
||||
|
||||
// ListPages returns available page slugs.
|
||||
// GET /api/v1/pages
|
||||
func (h *PageHandler) ListPages(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.pagesDir)
|
||||
if err != nil {
|
||||
response.Success(c, []string{})
|
||||
return
|
||||
}
|
||||
|
||||
slugs := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if strings.HasSuffix(name, ".md") {
|
||||
slugs = append(slugs, strings.TrimSuffix(name, ".md"))
|
||||
}
|
||||
}
|
||||
response.Success(c, slugs)
|
||||
}
|
||||
|
||||
// ServePageImage serves images from data/pages/{slug}/ directory.
|
||||
// GET /api/v1/pages/:slug/images/*filename
|
||||
// No JWT required (browser img tags can't carry tokens), but visibility is checked.
|
||||
func (h *PageHandler) ServePageImage(c *gin.Context) {
|
||||
slug := c.Param("slug")
|
||||
filename := c.Param("filename")
|
||||
filename = strings.TrimPrefix(filename, "/")
|
||||
|
||||
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if !h.checkImageSlugVisibility(c, slug) {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if filename == "" || strings.Contains(filename, "..") || strings.Contains(filename, "/") || strings.Contains(filename, "\\") {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
imagesDir := filepath.Join(h.pagesDir, slug)
|
||||
filePath := filepath.Join(imagesDir, filename)
|
||||
cleaned := filepath.Clean(filePath)
|
||||
if !strings.HasPrefix(cleaned, filepath.Clean(imagesDir)) {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(cleaned)
|
||||
if err != nil || info.IsDir() {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(cleaned)
|
||||
}
|
||||
|
||||
// findSlugVisibility looks up the slug in custom_menu_items and returns (visibility, found).
|
||||
func (h *PageHandler) findSlugVisibility(c *gin.Context, slug string) (string, bool) {
|
||||
if h.settingService == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw := h.settingService.GetCustomMenuItemsRaw(c.Request.Context())
|
||||
if raw == "" || raw == "[]" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var items []struct {
|
||||
URL string `json:"url"`
|
||||
PageSlug string `json:"page_slug"`
|
||||
Visibility string `json:"visibility"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
itemSlug := item.PageSlug
|
||||
if itemSlug == "" && strings.HasPrefix(item.URL, "md:") {
|
||||
itemSlug = strings.TrimPrefix(item.URL, "md:")
|
||||
}
|
||||
if itemSlug == slug {
|
||||
return item.Visibility, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// checkSlugVisibility verifies the slug is configured in custom_menu_items
|
||||
// and the authenticated user has permission to view it.
|
||||
func (h *PageHandler) checkSlugVisibility(c *gin.Context, slug string) bool {
|
||||
visibility, found := h.findSlugVisibility(c, slug)
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
if visibility == "admin" {
|
||||
role, _ := middleware2.GetUserRoleFromContext(c)
|
||||
return role == "admin"
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkImageSlugVisibility checks visibility for image requests (no JWT available).
|
||||
// Only allows user-visible pages; admin-only pages are blocked.
|
||||
func (h *PageHandler) checkImageSlugVisibility(c *gin.Context, slug string) bool {
|
||||
visibility, found := h.findSlugVisibility(c, slug)
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
return visibility != "admin"
|
||||
}
|
||||
|
||||
// RegisterPageRoutes registers page routes on a router group.
|
||||
func RegisterPageRoutes(v1 *gin.RouterGroup, dataDir string, jwtAuth gin.HandlerFunc, adminAuth gin.HandlerFunc, settingService *service.SettingService) {
|
||||
h := NewPageHandler(dataDir, settingService)
|
||||
|
||||
// Authenticated page content (JWT required + visibility check)
|
||||
pages := v1.Group("/pages")
|
||||
pages.Use(jwtAuth)
|
||||
{
|
||||
pages.GET("/:slug", h.GetPageContent)
|
||||
}
|
||||
|
||||
// Images: no JWT (browser img tags can't carry tokens), visibility check in handler
|
||||
pageImages := v1.Group("/pages")
|
||||
{
|
||||
pageImages.GET("/:slug/images/*filename", h.ServePageImage)
|
||||
}
|
||||
|
||||
// Admin-only: list all available pages
|
||||
adminPages := v1.Group("/pages")
|
||||
adminPages.Use(adminAuth)
|
||||
{
|
||||
adminPages.GET("", h.ListPages)
|
||||
}
|
||||
}
|
||||
@@ -87,6 +87,8 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
|
||||
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
|
||||
@@ -737,6 +737,37 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if value < 0 {
|
||||
value = 0
|
||||
}
|
||||
res, err := r.sql.ExecContext(ctx,
|
||||
"UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
|
||||
value, pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("batch set concurrency: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
return int(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
res, err := r.sql.ExecContext(ctx,
|
||||
"UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
|
||||
delta, pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("batch add concurrency: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
return int(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
@@ -1125,7 +1125,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil, nil)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
|
||||
settingRepo := newStubSettingRepo()
|
||||
@@ -1296,6 +1296,9 @@ func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (r *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -198,6 +198,9 @@ func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
panic("unexpected ExistsByEmail call")
|
||||
}
|
||||
|
||||
@@ -112,4 +112,6 @@ func registerRoutes(
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
|
||||
|
||||
handler.RegisterPageRoutes(v1, cfg.Pricing.DataDir, gin.HandlerFunc(jwtAuth), gin.HandlerFunc(adminAuth), settingService)
|
||||
}
|
||||
|
||||
@@ -245,6 +245,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
||||
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
||||
users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency)
|
||||
|
||||
// User attribute values
|
||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||
|
||||
@@ -33,6 +33,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
||||
@@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||
cleaned := make([]int64, 0, len(userIDs))
|
||||
for _, uid := range userIDs {
|
||||
if uid > 0 {
|
||||
cleaned = append(cleaned, uid)
|
||||
}
|
||||
}
|
||||
if len(cleaned) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var affected int
|
||||
var err error
|
||||
switch mode {
|
||||
case "set":
|
||||
affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value)
|
||||
case "add":
|
||||
affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value)
|
||||
default:
|
||||
return 0, errors.New("invalid mode: must be 'set' or 'add'")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, uid := range cleaned {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid)
|
||||
}
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
|
||||
@@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
@@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
if s.existsErr != nil {
|
||||
return false, s.existsErr
|
||||
|
||||
@@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||
|
||||
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
|
||||
@@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
||||
case redeemActionRedeem:
|
||||
// Code exists but unused — skip creation, proceed to redeem
|
||||
}
|
||||
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
|
||||
if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil {
|
||||
return fmt.Errorf("redeem balance: %w", err)
|
||||
}
|
||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||
|
||||
@@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
@@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
@@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
@@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -28,6 +29,15 @@ const (
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
)
|
||||
|
||||
type ctxKeySkipRedeemAffiliate struct{}
|
||||
|
||||
// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level
|
||||
// affiliate rebate. Used by payment fulfillment which handles rebate separately
|
||||
// via applyAffiliateRebateForOrder (with audit-log deduplication).
|
||||
func ContextSkipRedeemAffiliate(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true)
|
||||
}
|
||||
|
||||
// RedeemCache defines cache operations for redeem service
|
||||
type RedeemCache interface {
|
||||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
@@ -80,6 +90,7 @@ type RedeemService struct {
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
affiliateService *AffiliateService
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
@@ -91,6 +102,7 @@ func NewRedeemService(
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
affiliateService *AffiliateService,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
@@ -100,6 +112,7 @@ func NewRedeemService(
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
affiliateService: affiliateService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
// 事务提交成功后失效缓存
|
||||
s.invalidateRedeemCaches(ctx, userID, redeemCode)
|
||||
|
||||
// 余额类正数兑换码触发邀请返利(best-effort,失败不影响兑换结果)
|
||||
if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 {
|
||||
s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value)
|
||||
}
|
||||
|
||||
// 重新获取更新后的兑换码
|
||||
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
||||
if err != nil {
|
||||
@@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) {
|
||||
if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil {
|
||||
return
|
||||
}
|
||||
if s.affiliateService == nil {
|
||||
return
|
||||
}
|
||||
if !s.affiliateService.IsEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err)
|
||||
return
|
||||
}
|
||||
if rebate > 0 {
|
||||
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取兑换码
|
||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -1542,6 +1542,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting.
|
||||
func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
|
||||
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
|
||||
|
||||
@@ -96,6 +96,8 @@ type UserRepository interface {
|
||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||
BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error)
|
||||
BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error)
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
|
||||
|
||||
@@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
|
||||
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||
out := make([]UserAuthIdentityRecord, len(m.identities))
|
||||
|
||||
Reference in New Issue
Block a user