Merge remote-tracking branch 'upstream/main' into feat/payment-system-v2

# Conflicts:
#	frontend/src/api/admin/settings.ts
#	frontend/src/stores/app.ts
#	frontend/src/types/index.ts
#	frontend/src/views/admin/SettingsView.vue
This commit is contained in:
erio
2026-04-11 18:24:49 +08:00
115 changed files with 3422 additions and 396 deletions
+22 -4
View File
@@ -10,6 +10,7 @@ import (
"log/slog"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -359,7 +360,7 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "")
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "", "created_at", "desc")
if err != nil {
return nil, err
}
@@ -372,12 +373,12 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e
return out, nil
}
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) {
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string, groupID int64, privacyMode, sortBy, sortOrder string) ([]service.Account, error) {
page := 1
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil {
return nil, err
}
@@ -409,11 +410,28 @@ func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64,
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "name")
sortOrder := c.DefaultQuery("sort_order", "asc")
if len(search) > 100 {
search = search[:100]
}
return h.listAccountsFiltered(ctx, platform, accountType, status, search)
groupID := int64(0)
if groupIDStr := c.Query("group"); groupIDStr != "" {
if groupIDStr == accountListGroupUngroupedQueryValue {
groupID = service.AccountListGroupUngrouped
} else {
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
if parseErr != nil || parsedGroupID <= 0 {
return nil, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter")
}
groupID = parsedGroupID
}
}
return h.listAccountsFiltered(ctx, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
}
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {
@@ -172,6 +172,51 @@ func TestExportDataWithoutProxies(t *testing.T) {
require.Nil(t, resp.Data.Accounts[0].ProxyKey)
}
func TestExportDataPassesAccountFiltersAndSort(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
adminSvc.accounts = []service.Account{
{ID: 1, Name: "acc-1", Status: service.StatusActive},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/api/v1/admin/accounts/data?platform=openai&type=oauth&status=active&group=12&privacy_mode=blocked&search=keyword&sort_by=priority&sort_order=desc",
nil,
)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListAccounts.calls)
require.Equal(t, "openai", adminSvc.lastListAccounts.platform)
require.Equal(t, "oauth", adminSvc.lastListAccounts.accountType)
require.Equal(t, "active", adminSvc.lastListAccounts.status)
require.Equal(t, int64(12), adminSvc.lastListAccounts.groupID)
require.Equal(t, "blocked", adminSvc.lastListAccounts.privacyMode)
require.Equal(t, "keyword", adminSvc.lastListAccounts.search)
require.Equal(t, "priority", adminSvc.lastListAccounts.sortBy)
require.Equal(t, "desc", adminSvc.lastListAccounts.sortOrder)
}
func TestExportDataSelectedIDsOverrideFilters(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/api/v1/admin/accounts/data?ids=1,2&platform=openai&search=keyword&sort_by=priority&sort_order=desc",
nil,
)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Accounts, 2)
require.Equal(t, 0, adminSvc.lastListAccounts.calls)
}
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
@@ -221,6 +221,8 @@ func (h *AccountHandler) List(c *gin.Context) {
status := c.Query("status")
search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
sortBy := c.DefaultQuery("sort_by", "name")
sortOrder := c.DefaultQuery("sort_order", "asc")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
@@ -246,7 +248,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -2029,7 +2031,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "name", "asc")
if err != nil {
response.ErrorFrom(c, err)
return
@@ -31,6 +31,33 @@ type stubAdminService struct {
platform string
groupIDs []int64
}
lastListAccounts struct {
platform string
accountType string
status string
search string
groupID int64
privacyMode string
sortBy string
sortOrder string
calls int
}
lastListProxies struct {
protocol string
status string
search string
sortBy string
sortOrder string
calls int
}
lastListRedeemCodes struct {
codeType string
status string
search string
sortBy string
sortOrder string
calls int
}
mu sync.Mutex
}
@@ -99,7 +126,7 @@ func newStubAdminService() *stubAdminService {
}
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil
}
@@ -132,7 +159,7 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
return &user, nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
@@ -140,7 +167,7 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
@@ -187,7 +214,16 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
s.lastListAccounts.platform = platform
s.lastListAccounts.accountType = accountType
s.lastListAccounts.status = status
s.lastListAccounts.search = search
s.lastListAccounts.groupID = groupID
s.lastListAccounts.privacyMode = privacyMode
s.lastListAccounts.sortBy = sortBy
s.lastListAccounts.sortOrder = sortOrder
s.lastListAccounts.calls++
return s.accounts, int64(len(s.accounts)), nil
}
@@ -261,7 +297,13 @@ func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAcc
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.Proxy, int64, error) {
s.lastListProxies.protocol = protocol
s.lastListProxies.status = status
s.lastListProxies.search = search
s.lastListProxies.sortBy = sortBy
s.lastListProxies.sortOrder = sortOrder
s.lastListProxies.calls++
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies {
@@ -283,7 +325,7 @@ func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int,
return filtered, int64(len(filtered)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
@@ -384,7 +426,13 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]service.RedeemCode, int64, error) {
s.lastListRedeemCodes.codeType = codeType
s.lastListRedeemCodes.status = status
s.lastListRedeemCodes.search = search
s.lastListRedeemCodes.sortBy = sortBy
s.lastListRedeemCodes.sortOrder = sortOrder
s.lastListRedeemCodes.calls++
return s.redeems, int64(len(s.redeems)), nil
}
@@ -52,13 +52,17 @@ func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 200 {
search = search[:200]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
Page: page,
PageSize: pageSize,
SortBy: sortBy,
SortOrder: sortOrder,
}
items, paginationResult, err := h.announcementService.List(
@@ -227,8 +231,10 @@ func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "email"),
SortOrder: c.DefaultQuery("sort_order", "asc"),
}
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
@@ -0,0 +1,138 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type announcementRepoCapture struct {
service.AnnouncementRepository
listParams pagination.PaginationParams
}
func (r *announcementRepoCapture) List(ctx context.Context, params pagination.PaginationParams, filters service.AnnouncementListFilters) ([]service.Announcement, *pagination.PaginationResult, error) {
r.listParams = params
return []service.Announcement{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func (r *announcementRepoCapture) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
return &service.Announcement{
ID: id,
Title: "announcement",
Content: "content",
Status: service.AnnouncementStatusActive,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
}
type announcementUserRepoCapture struct {
service.UserRepository
listParams pagination.PaginationParams
}
func (r *announcementUserRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
r.listParams = params
return []service.User{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
type announcementReadRepoCapture struct {
service.AnnouncementReadRepository
}
func (r *announcementReadRepoCapture) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
return map[int64]time.Time{}, nil
}
type announcementUserSubRepoCapture struct {
service.UserSubscriptionRepository
}
func newAnnouncementSortTestRouter(announcementRepo *announcementRepoCapture, userRepo *announcementUserRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
svc := service.NewAnnouncementService(
announcementRepo,
&announcementReadRepoCapture{},
userRepo,
&announcementUserSubRepoCapture{},
)
handler := NewAnnouncementHandler(svc)
router := gin.New()
router.GET("/admin/announcements", handler.List)
router.GET("/admin/announcements/:id/read-status", handler.ListReadStatus)
return router
}
func TestAdminAnnouncementListSortParams(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements?sort_by=title&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "title", announcementRepo.listParams.SortBy)
require.Equal(t, "ASC", announcementRepo.listParams.SortOrder)
}
func TestAdminAnnouncementListSortDefaults(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", announcementRepo.listParams.SortBy)
require.Equal(t, "desc", announcementRepo.listParams.SortOrder)
}
func TestAdminAnnouncementReadStatusSortParams(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status?sort_by=balance&sort_order=DESC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "balance", userRepo.listParams.SortBy)
require.Equal(t, "DESC", userRepo.listParams.SortOrder)
}
func TestAdminAnnouncementReadStatusSortDefaults(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "email", userRepo.listParams.SortBy)
require.Equal(t, "asc", userRepo.listParams.SortOrder)
}
@@ -245,7 +245,12 @@ func (h *ChannelHandler) List(c *gin.Context) {
search = search[:100]
}
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -162,6 +162,8 @@ func (h *GroupHandler) List(c *gin.Context) {
search = search[:100]
}
isExclusiveStr := c.Query("is_exclusive")
sortBy := c.DefaultQuery("sort_by", "sort_order")
sortOrder := c.DefaultQuery("sort_order", "asc")
var isExclusive *bool
if isExclusiveStr != "" {
@@ -169,7 +171,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive = &val
}
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -55,8 +55,10 @@ func (h *PromoHandler) List(c *gin.Context) {
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
+27 -10
View File
@@ -33,11 +33,13 @@ func (h *ProxyHandler) ExportData(c *gin.Context) {
protocol := c.Query("protocol")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 100 {
search = search[:100]
}
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search)
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -89,7 +91,7 @@ func (h *ProxyHandler) ImportData(c *gin.Context) {
ctx := c.Request.Context()
result := DataImportResult{}
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "")
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "", "id", "desc")
if err != nil {
response.ErrorFrom(c, err)
return
@@ -220,18 +222,33 @@ func parseProxyIDs(c *gin.Context) ([]int64, error) {
return ids, nil
}
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) {
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search, sortBy, sortOrder string) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
sortBy = strings.TrimSpace(sortBy)
useAccountCountSort := strings.EqualFold(sortBy, "account_count")
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
if useAccountCountSort {
items, total, err := h.adminService.ListProxiesWithAccountCount(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil {
return nil, err
}
for i := range items {
out = append(out, items[i].Proxy)
}
if len(out) >= int(total) || len(items) == 0 {
break
}
} else {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
}
page++
}
@@ -74,6 +74,10 @@ func TestProxyExportDataRespectsFilters(t *testing.T) {
require.Len(t, resp.Data.Proxies, 1)
require.Len(t, resp.Data.Accounts, 0)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, 1, adminSvc.lastListProxies.calls)
require.Equal(t, "https", adminSvc.lastListProxies.protocol)
require.Equal(t, "id", adminSvc.lastListProxies.sortBy)
require.Equal(t, "desc", adminSvc.lastListProxies.sortOrder)
}
func TestProxyExportDataWithSelectedIDs(t *testing.T) {
@@ -113,6 +117,96 @@ func TestProxyExportDataWithSelectedIDs(t *testing.T) {
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
require.Equal(t, 0, adminSvc.lastListProxies.calls)
}
func TestProxyExportDataPassesSortParams(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=http&status=active&search=proxy&sort_by=name&sort_order=asc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListProxies.calls)
require.Equal(t, "http", adminSvc.lastListProxies.protocol)
require.Equal(t, "active", adminSvc.lastListProxies.status)
require.Equal(t, "proxy", adminSvc.lastListProxies.search)
require.Equal(t, "name", adminSvc.lastListProxies.sortBy)
require.Equal(t, "asc", adminSvc.lastListProxies.sortOrder)
}
func TestProxyExportDataSortByAccountCountUsesAccountCountListing(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-id-1",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-id-2",
Protocol: "http",
Host: "127.0.0.2",
Port: 8081,
Status: service.StatusActive,
},
}
adminSvc.proxyCounts = []service.ProxyWithAccountCount{
{
Proxy: service.Proxy{
ID: 2,
Name: "proxy-count-high",
Protocol: "http",
Host: "127.0.0.2",
Port: 8081,
Status: service.StatusActive,
},
AccountCount: 9,
},
{
Proxy: service.Proxy{
ID: 1,
Name: "proxy-count-low",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
},
AccountCount: 1,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?sort_by=account_count&sort_order=desc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 2)
require.Equal(t, "proxy-count-high", resp.Data.Proxies[0].Name)
require.Equal(t, "proxy-count-low", resp.Data.Proxies[1].Name)
require.Equal(t, 0, adminSvc.lastListProxies.calls)
}
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {
@@ -52,13 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol := c.Query("protocol")
status := c.Query("status")
search := c.Query("search")
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -0,0 +1,49 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupRedeemExportRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/redeem-codes/export", h.Export)
return router, adminSvc
}
func TestRedeemExportPassesSearchAndSort(t *testing.T) {
router, adminSvc := setupRedeemExportRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export?type=balance&status=unused&search=ABC&sort_by=value&sort_order=asc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls)
require.Equal(t, "balance", adminSvc.lastListRedeemCodes.codeType)
require.Equal(t, "unused", adminSvc.lastListRedeemCodes.status)
require.Equal(t, "ABC", adminSvc.lastListRedeemCodes.search)
require.Equal(t, "value", adminSvc.lastListRedeemCodes.sortBy)
require.Equal(t, "asc", adminSvc.lastListRedeemCodes.sortOrder)
}
func TestRedeemExportSortDefaults(t *testing.T) {
router, adminSvc := setupRedeemExportRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls)
require.Equal(t, "id", adminSvc.lastListRedeemCodes.sortBy)
require.Equal(t, "desc", adminSvc.lastListRedeemCodes.sortOrder)
}
@@ -59,13 +59,15 @@ func (h *RedeemHandler) List(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -300,9 +302,15 @@ func (h *RedeemHandler) GetStats(c *gin.Context) {
func (h *RedeemHandler) Export(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 100 {
search = search[:100]
}
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -150,6 +150,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
@@ -261,6 +263,8 @@ type UpdateSettingsRequest struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
@@ -345,6 +349,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
}
if req.TablePageSizeOptions == nil {
req.TablePageSizeOptions = previousSettings.TablePageSizeOptions
}
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
@@ -810,6 +821,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
TableDefaultPageSize: req.TableDefaultPageSize,
TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
@@ -989,6 +1002,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
@@ -1278,6 +1293,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
changed = append(changed, "purchase_subscription_url")
}
if before.TableDefaultPageSize != after.TableDefaultPageSize {
changed = append(changed, "table_default_page_size")
}
if !equalIntSlice(before.TablePageSizeOptions, after.TablePageSizeOptions) {
changed = append(changed, "table_page_size_options")
}
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
@@ -1334,6 +1355,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
return true
}
func equalIntSlice(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"`
@@ -165,7 +165,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
filters := usagestats.UsageLogFilters{
UserID: userID,
APIKeyID: apiKeyID,
@@ -339,7 +344,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}, "email", "asc")
if err != nil {
response.ErrorFrom(c, err)
return
@@ -15,11 +15,13 @@ import (
type adminUsageRepoCapture struct {
service.UsageLogRepository
listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters
statsFilters usagestats.UsageLogFilters
}
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listParams = params
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
@@ -0,0 +1,35 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestAdminUsageListSortParams(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?sort_by=model&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "model", repo.listParams.SortBy)
require.Equal(t, "ASC", repo.listParams.SortOrder)
}
func TestAdminUsageListSortDefaults(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", repo.listParams.SortBy)
require.Equal(t, "desc", repo.listParams.SortOrder)
}
@@ -91,12 +91,14 @@ func (h *UserHandler) List(c *gin.Context) {
GroupName: strings.TrimSpace(c.Query("group_name")),
Attributes: parseAttributeFilters(c),
}
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -290,8 +292,10 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
}
page, pageSize := response.ParsePagination(c)
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
+6 -1
View File
@@ -72,7 +72,12 @@ func (h *APIKeyHandler) List(c *gin.Context) {
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
// Parse filter parameters
var filters service.APIKeyListFilters
+4
View File
@@ -84,6 +84,8 @@ type SystemSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
@@ -170,6 +172,8 @@ type PublicSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
@@ -51,6 +51,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+6 -1
View File
@@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID,
@@ -16,10 +16,12 @@ import (
type userUsageRepoCapture struct {
service.UsageLogRepository
listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters
}
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listParams = params
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
@@ -0,0 +1,35 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestUserUsageListSortParams(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?sort_by=model&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "model", repo.listParams.SortBy)
require.Equal(t, "ASC", repo.listParams.SortOrder)
}
func TestUserUsageListSortDefaults(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", repo.listParams.SortBy)
require.Equal(t, "desc", repo.listParams.SortOrder)
}
+40 -6
View File
@@ -1,10 +1,19 @@
// Package pagination provides types and helpers for paginated responses.
package pagination
import "strings"
const (
SortOrderAsc = "asc"
SortOrderDesc = "desc"
)
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
Page int
PageSize int
SortBy string
SortOrder string
}
// PaginationResult 分页结果
@@ -18,8 +27,9 @@ type PaginationResult struct {
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
Page: 1,
PageSize: 20,
SortOrder: SortOrderDesc,
}
}
@@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
if p.PageSize > 1000 {
return 1000
}
return p.PageSize
}
// NormalizeSortOrder normalizes sort order to asc/desc and falls back to defaultOrder.
func NormalizeSortOrder(order string, defaultOrder string) string {
switch strings.ToLower(strings.TrimSpace(defaultOrder)) {
case SortOrderAsc:
defaultOrder = SortOrderAsc
default:
defaultOrder = SortOrderDesc
}
switch strings.ToLower(strings.TrimSpace(order)) {
case SortOrderAsc:
return SortOrderAsc
case SortOrderDesc:
return SortOrderDesc
default:
return defaultOrder
}
}
// NormalizedSortOrder returns the normalized sort order using defaultOrder as fallback.
func (p PaginationParams) NormalizedSortOrder(defaultOrder string) string {
return NormalizeSortOrder(p.SortOrder, defaultOrder)
}
@@ -0,0 +1,71 @@
package pagination
import "testing"
func TestNormalizeSortOrder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultOrder string
want string
}{
{name: "asc", input: "asc", defaultOrder: "desc", want: "asc"},
{name: "uppercase asc", input: "ASC", defaultOrder: "desc", want: "asc"},
{name: "desc", input: "desc", defaultOrder: "asc", want: "desc"},
{name: "trim spaces", input: " desc ", defaultOrder: "asc", want: "desc"},
{name: "invalid falls back", input: "sideways", defaultOrder: "asc", want: "asc"},
{name: "empty falls back", input: "", defaultOrder: "desc", want: "desc"},
{name: "invalid default falls back to desc", input: "", defaultOrder: "wat", want: "desc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeSortOrder(tt.input, tt.defaultOrder); got != tt.want {
t.Fatalf("NormalizeSortOrder(%q, %q) = %q, want %q", tt.input, tt.defaultOrder, got, tt.want)
}
})
}
}
func TestPaginationParamsNormalizedSortOrder(t *testing.T) {
t.Parallel()
params := PaginationParams{SortOrder: "ASC"}
if got := params.NormalizedSortOrder("desc"); got != "asc" {
t.Fatalf("NormalizedSortOrder = %q, want asc", got)
}
params = PaginationParams{SortOrder: "bad"}
if got := params.NormalizedSortOrder("asc"); got != "asc" {
t.Fatalf("NormalizedSortOrder invalid fallback = %q, want asc", got)
}
}
func TestPaginationParamsLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pageSize int
want int
}{
{name: "non-positive falls back to default", pageSize: 0, want: 20},
{name: "negative falls back to default", pageSize: -1, want: 20},
{name: "normal value keeps", pageSize: 50, want: 50},
{name: "max value keeps", pageSize: 1000, want: 1000},
{name: "beyond max clamps to 1000", pageSize: 1500, want: 1000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := PaginationParams{PageSize: tt.pageSize}
if got := p.Limit(); got != tt.want {
t.Fatalf("Limit() for PageSize=%d = %d, want %d", tt.pageSize, got, tt.want)
}
})
}
}
+96 -12
View File
@@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
case service.StatusActive:
q = q.Where(
dbaccount.StatusEQ(status),
dbaccount.SchedulableEQ(true),
dbaccount.Or(
dbaccount.RateLimitResetAtIsNil(),
dbaccount.RateLimitResetAtLTE(time.Now()),
),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.RateLimitResetAtGT(time.Now()),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
case "temp_unschedulable":
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.And(
entsql.Not(entsql.IsNull(col)),
entsql.GT(col, entsql.Expr("NOW()")),
))
}))
q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.And(
entsql.Not(entsql.IsNull(col)),
entsql.GT(col, entsql.Expr("NOW()")),
))
}),
)
case "unschedulable":
q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(false),
dbaccount.Or(
dbaccount.RateLimitResetAtIsNil(),
dbaccount.RateLimitResetAtLTE(time.Now()),
),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
default:
q = q.Where(dbaccount.StatusEQ(status))
}
@@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err
}
accounts, err := q.
accountsQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbaccount.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range accountListOrder(params) {
accountsQuery = accountsQuery.Order(order)
}
accounts, err := accountsQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return outAccounts, paginationResultFromTotal(int64(total), params), nil
}
func accountListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
field := dbaccount.FieldName
defaultOrder := true
switch sortBy {
case "", "name":
field = dbaccount.FieldName
case "id":
field = dbaccount.FieldID
defaultOrder = false
case "status":
field = dbaccount.FieldStatus
defaultOrder = false
case "schedulable":
field = dbaccount.FieldSchedulable
defaultOrder = false
case "priority":
field = dbaccount.FieldPriority
defaultOrder = false
case "rate_multiplier":
field = dbaccount.FieldRateMultiplier
defaultOrder = false
case "last_used_at":
field = dbaccount.FieldLastUsedAt
defaultOrder = false
case "expires_at":
field = dbaccount.FieldExpiresAt
defaultOrder = false
case "created_at":
field = dbaccount.FieldCreatedAt
defaultOrder = false
}
if sortOrder == pagination.SortOrderDesc {
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbaccount.FieldID)}
}
if defaultOrder {
return []func(*entsql.Selector){dbent.Asc(dbaccount.FieldName), dbent.Asc(dbaccount.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbaccount.FieldID)}
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
@@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
},
{
name: "filter_by_status_active_excludes_rate_limited",
name: "filter_by_status_active_excludes_runtime_blocked_accounts",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive})
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
@@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() {
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
},
status: service.StatusActive,
wantCount: 1,
@@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Equal("active-normal", accounts[0].Name)
},
},
{
name: "filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive, Schedulable: true})
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err := client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
err = client.Account.UpdateOneID(rateLimited.ID).
SetSchedulable(false).
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetSchedulable(false).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
},
status: "unschedulable",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-unsched", accounts[0].Name)
},
},
{
name: "filter_by_status_rate_limited_excludes_temp_unschedulable",
setup: func(client *dbent.Client) {
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
err := client.Account.UpdateOneID(rateLimited.ID).
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetRateLimitResetAt(time.Now().Add(20 * time.Minute)).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
},
status: "rate_limited",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-rate-limited", accounts[0].Name)
},
},
{
name: "filter_by_status_temp_unschedulable_excludes_manually_unschedulable",
setup: func(client *dbent.Client) {
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive, Schedulable: true})
err := client.Account.UpdateOneID(tempUnsched.ID).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
},
status: "temp_unschedulable",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-temp-unsched", accounts[0].Name)
},
},
{
name: "filter_by_search",
setup: func(client *dbent.Client) {
@@ -0,0 +1,35 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *AccountRepoSuite) TestList_DefaultSortByNameAsc() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "z-account"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-account"})
accounts, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err)
s.Require().Len(accounts, 2)
s.Require().Equal("a-account", accounts[0].Name)
s.Require().Equal("z-account", accounts[1].Name)
}
func (s *AccountRepoSuite) TestListWithFilters_SortByPriorityDesc() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "low-priority", Priority: 10})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "high-priority", Priority: 90})
accounts, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "priority",
SortOrder: "desc",
}, "", "", "", "", 0, "")
s.Require().NoError(err)
s.Require().Len(accounts, 2)
s.Require().Equal("high-priority", accounts[0].Name)
s.Require().Equal("low-priority", accounts[1].Name)
}
@@ -2,12 +2,15 @@ package repository
import (
"context"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
)
type announcementRepository struct {
@@ -128,11 +131,14 @@ func (r *announcementRepository) List(
return nil, nil, err
}
items, err := q.
itemsQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(announcement.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range announcementListOrders(params) {
itemsQuery = itemsQuery.Order(order)
}
items, err := itemsQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -141,6 +147,56 @@ func (r *announcementRepository) List(
return out, paginationResultFromTotal(int64(total), params), nil
}
func announcementListOrder(params pagination.PaginationParams) (string, string) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
switch sortBy {
case "title":
return announcement.FieldTitle, sortOrder
case "status":
return announcement.FieldStatus, sortOrder
case "notify_mode":
return announcement.FieldNotifyMode, sortOrder
case "starts_at":
return announcement.FieldStartsAt, sortOrder
case "ends_at":
return announcement.FieldEndsAt, sortOrder
case "id":
return announcement.FieldID, sortOrder
case "", "created_at":
return announcement.FieldCreatedAt, sortOrder
default:
return announcement.FieldCreatedAt, pagination.SortOrderDesc
}
}
func announcementListOrders(params pagination.PaginationParams) []func(*entsql.Selector) {
field, sortOrder := announcementListOrder(params)
if sortOrder == pagination.SortOrderAsc {
if field == announcement.FieldID {
return []func(*entsql.Selector){
dbent.Asc(field),
}
}
return []func(*entsql.Selector){
dbent.Asc(field),
dbent.Asc(announcement.FieldID),
}
}
if field == announcement.FieldID {
return []func(*entsql.Selector){
dbent.Desc(field),
}
}
return []func(*entsql.Selector){
dbent.Desc(field),
dbent.Desc(announcement.FieldID),
}
}
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
q := r.client.Announcement.Query().
Where(
@@ -0,0 +1,63 @@
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
func TestAnnouncementListOrder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
params pagination.PaginationParams
wantBy string
want string
}{
{
name: "default created_at desc",
params: pagination.PaginationParams{},
wantBy: "created_at",
want: "desc",
},
{
name: "title asc",
params: pagination.PaginationParams{
SortBy: "title",
SortOrder: "ASC",
},
wantBy: "title",
want: "asc",
},
{
name: "status desc",
params: pagination.PaginationParams{
SortBy: "status",
SortOrder: "desc",
},
wantBy: "status",
want: "desc",
},
{
name: "invalid falls back",
params: pagination.PaginationParams{
SortBy: "sideways",
SortOrder: "wat",
},
wantBy: "created_at",
want: "desc",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotBy, gotOrder := announcementListOrder(tt.params)
if gotBy != tt.wantBy || gotOrder != tt.want {
t.Fatalf("announcementListOrder(%+v) = (%q, %q), want (%q, %q)", tt.params, gotBy, gotOrder, tt.wantBy, tt.want)
}
})
}
}
+45 -8
View File
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql "entgo.io/ent/dialect/sql"
)
type apiKeyRepository struct {
@@ -164,6 +167,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldSupportedModelScopes,
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig,
)
}).
Only(ctx)
@@ -309,12 +313,15 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
keys, err := q.
keysQuery := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range apiKeyListOrder(params) {
keysQuery = keysQuery.Order(order)
}
keys, err := keysQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -359,12 +366,15 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
keys, err := q.
keysQuery := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range apiKeyListOrder(params) {
keysQuery = keysQuery.Order(order)
}
keys, err := keysQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -377,6 +387,32 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
func apiKeyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "name":
field = apikey.FieldName
case "status":
field = apikey.FieldStatus
case "expires_at":
field = apikey.FieldExpiresAt
case "last_used_at":
field = apikey.FieldLastUsedAt
case "created_at":
field = apikey.FieldCreatedAt
default:
field = apikey.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(apikey.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(apikey.FieldID)}
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
@@ -654,6 +690,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequireOAuthOnly: g.RequireOauthOnly,
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
@@ -86,6 +86,45 @@ func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
s.Require().Error(err, "expected error for non-existent key")
}
func (s *APIKeyRepoSuite) TestGetByKeyForAuth_PreservesMessagesDispatchModelConfig() {
user := s.mustCreateUser("getbykey-auth-dispatch@test.com")
group, err := s.client.Group.Create().
SetName("g-auth-dispatch").
SetPlatform(service.PlatformOpenAI).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1).
SetAllowMessagesDispatch(true).
SetDefaultMappedModel("gpt-5.4").
SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
}).
Save(s.ctx)
s.Require().NoError(err)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey-auth-dispatch",
Name: "Dispatch Key",
GroupID: &group.ID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
got, err := s.repo.GetByKeyForAuth(s.ctx, key.Key)
s.Require().NoError(err)
s.Require().NotNil(got.Group)
s.Require().True(got.Group.AllowMessagesDispatch)
s.Require().Equal("gpt-5.4", got.Group.DefaultMappedModel)
s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.OpusMappedModel)
s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.ExactModelMappings["claude-sonnet-4.5"])
}
// --- Update ---
func (s *APIKeyRepoSuite) TestUpdate() {
@@ -0,0 +1,74 @@
package repository
import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestGroupEntityToService_PreservesMessagesDispatchModelConfig(t *testing.T) {
group := &dbent.Group{
ID: 1,
Name: "openai-dispatch",
Platform: service.PlatformOpenAI,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
}
got := groupEntityToService(group)
require.NotNil(t, got)
require.Equal(t, group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
}
func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_SQLite(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "getbykey-auth-dispatch-unit@test.com")
group, err := client.Group.Create().
SetName("g-auth-dispatch-unit").
SetPlatform(service.PlatformOpenAI).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1).
SetAllowMessagesDispatch(true).
SetDefaultMappedModel("gpt-5.4").
SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
}).
Save(ctx)
require.NoError(t, err)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey-auth-dispatch-unit",
Name: "Dispatch Key Unit",
GroupID: &group.ID,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
got, err := repo.GetByKeyForAuth(ctx, key.Key)
require.NoError(t, err)
require.NotNil(t, got.Group)
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
}
@@ -0,0 +1,25 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *APIKeyRepoSuite) TestListByUserID_SortByNameAsc() {
user := s.mustCreateUser("sort-name@example.com")
s.mustCreateApiKey(user.ID, "sk-z", "z-key", nil)
s.mustCreateApiKey(user.ID, "sk-a", "a-key", nil)
keys, _, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "name",
SortOrder: "asc",
}, service.APIKeyListFilters{})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal("a-key", keys[0].Name)
s.Require().Equal("z-key", keys[1].Name)
}
+27 -2
View File
@@ -188,8 +188,8 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
@@ -246,6 +246,31 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
return channels, paginationResult, nil
}
func channelListOrderBy(params pagination.PaginationParams) string {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
var column string
switch sortBy {
case "":
column = "c.id"
sortOrder = "ASC"
case "id":
column = "c.id"
case "name":
column = "c.name"
case "status":
column = "c.status"
case "created_at":
column = "c.created_at"
default:
column = "c.id"
sortOrder = "ASC"
}
return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
}
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
@@ -8,6 +8,7 @@ import (
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
@@ -225,3 +226,12 @@ func TestIsUniqueViolation(t *testing.T) {
})
}
}
func TestChannelListOrderBy_AllowsDescendingIDSort(t *testing.T) {
params := pagination.PaginationParams{
SortBy: "id",
SortOrder: "desc",
}
require.Equal(t, "c.id DESC, c.id DESC", channelListOrderBy(params))
}
+113 -4
View File
@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
type sqlExecutor interface {
@@ -40,6 +43,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier).
SetSortOrder(groupIn.SortOrder).
SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType).
@@ -233,11 +237,18 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
groups, err := q.
if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
return r.listWithAccountCountSort(ctx, q, params, total)
}
groupsQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range groupListOrder(params) {
groupsQuery = groupsQuery.Order(order)
}
groups, err := groupsQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -263,6 +274,104 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return outGroups, paginationResultFromTotal(int64(total), params), nil
}
func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
groups, err := q.
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err != nil {
return nil, nil, err
}
for i := range outGroups {
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(outGroups, func(i, j int) bool {
if outGroups[i].AccountCount == outGroups[j].AccountCount {
if outGroups[i].SortOrder == outGroups[j].SortOrder {
return outGroups[i].ID < outGroups[j].ID
}
return outGroups[i].SortOrder < outGroups[j].SortOrder
}
if sortOrder == pagination.SortOrderAsc {
return outGroups[i].AccountCount < outGroups[j].AccountCount
}
return outGroups[i].AccountCount > outGroups[j].AccountCount
})
return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
}
func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
var field string
tieField := group.FieldID
defaultOrder := true
switch sortBy {
case "", "sort_order":
field = group.FieldSortOrder
case "name":
field = group.FieldName
defaultOrder = false
case "platform":
field = group.FieldPlatform
defaultOrder = false
case "billing_type", "subscription_type":
field = group.FieldSubscriptionType
defaultOrder = false
case "rate_multiplier":
field = group.FieldRateMultiplier
defaultOrder = false
case "is_exclusive":
field = group.FieldIsExclusive
defaultOrder = false
case "status":
field = group.FieldStatus
defaultOrder = false
case "created_at":
field = group.FieldCreatedAt
defaultOrder = false
case "id":
field = group.FieldID
defaultOrder = false
tieField = ""
default:
field = group.FieldSortOrder
}
if sortOrder == pagination.SortOrderDesc && sortBy != "" {
if tieField == "" {
return []func(*entsql.Selector){dbent.Desc(field)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(tieField)}
}
if defaultOrder {
return []func(*entsql.Selector){dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)}
}
if tieField == "" {
return []func(*entsql.Selector){dbent.Asc(field)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(tieField)}
}
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive)).
@@ -113,6 +113,33 @@ func (s *GroupRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_PreservesMessagesDispatchModelConfig() {
group := &service.Group{
Name: "openai-dispatch",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
}
s.Require().NoError(s.repo.Create(s.ctx, group))
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
}
func (s *GroupRepoSuite) TestDelete() {
group := &service.Group{
Name: "to-delete",
@@ -0,0 +1,50 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, g := range groups {
indexByID[g.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
// g2 has SortOrder=10, g1 has SortOrder=20; ascending means g2 comes first
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "sort_order",
SortOrder: "desc",
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, group := range groups {
indexByID[group.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
+19
View File
@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams)
Pages: pages,
}
}
func paginateSlice[T any](items []T, params pagination.PaginationParams) []T {
if len(items) == 0 {
return []T{}
}
offset := params.Offset()
if offset >= len(items) {
return []T{}
}
limit := params.Limit()
end := offset + limit
if end > len(items) {
end = len(items)
}
return items[offset:end]
}
+36 -4
View File
@@ -2,12 +2,15 @@ package repository
import (
"context"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
)
type promoCodeRepository struct {
@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return nil, nil, err
}
codes, err := q.
codesQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocode.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range promoCodeListOrder(params) {
codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "bonus_amount":
field = promocode.FieldBonusAmount
case "status":
field = promocode.FieldStatus
case "expires_at":
field = promocode.FieldExpiresAt
case "created_at":
field = promocode.FieldCreatedAt
case "code":
field = promocode.FieldCode
default:
field = promocode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)}
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
+76 -9
View File
@@ -3,12 +3,16 @@ package repository
import (
"context"
"database/sql"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql "entgo.io/ent/dialect/sql"
)
type sqlQuerier interface {
@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
proxies, err := q.
proxiesQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range proxyListOrder(params) {
proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
return nil, nil, err
}
proxies, err := q.
if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
return r.listWithAccountCountSort(ctx, q, params, total)
}
proxiesQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Limit(params.Limit())
for _, order := range proxyListOrder(params) {
proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
}
func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
proxies, err := q.
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
if err != nil {
return nil, nil, err
}
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(result, func(i, j int) bool {
if result[i].AccountCount == result[j].AccountCount {
return result[i].ID > result[j].ID
}
if sortOrder == pagination.SortOrderAsc {
return result[i].AccountCount < result[j].AccountCount
}
return result[i].AccountCount > result[j].AccountCount
})
return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
})
}
return result, paginationResultFromTotal(int64(total), params), nil
return result, paginationResultFromTotal(total, params), nil
}
func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "name":
field = proxy.FieldName
case "protocol":
field = proxy.FieldProtocol
case "status":
field = proxy.FieldStatus
case "created_at":
field = proxy.FieldCreatedAt
default:
field = proxy.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)}
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
@@ -0,0 +1,28 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() {
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "account_count",
SortOrder: "desc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 2)
s.Require().Equal(p1.ID, proxies[0].ID)
s.Require().Equal(int64(2), proxies[0].AccountCount)
s.Require().Equal(p2.ID, proxies[1].ID)
}
@@ -2,6 +2,7 @@ package repository
import (
"context"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -9,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
)
type redeemCodeRepository struct {
@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err
}
codes, err := q.
codesQuery := q.
WithUser().
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(redeemcode.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range redeemCodeListOrder(params) {
codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "type":
field = redeemcode.FieldType
case "value":
field = redeemcode.FieldValue
case "status":
field = redeemcode.FieldStatus
case "used_at":
field = redeemcode.FieldUsedAt
case "created_at":
field = redeemcode.FieldCreatedAt
case "code":
field = redeemcode.FieldCode
default:
field = redeemcode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)}
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code).
@@ -0,0 +1,24 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "value",
SortOrder: "asc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 2)
s.Require().Equal("VALUE-10", codes[0].Code)
s.Require().Equal("VALUE-20", codes[1].Code)
}
+22 -2
View File
@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
return nil, nil, err
@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
@@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
return logs, paginationResultFromTotal(total, params), nil
}
func usageLogOrderBy(params pagination.PaginationParams) string {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc))
var column string
switch sortBy {
case "model":
column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
case "created_at":
column = "created_at"
default:
column = "id"
}
if column == "id" {
return fmt.Sprintf("id %s", sortOrder)
}
return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder)
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
"total_account_cost",
"avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT CONCAT\\(").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err)
@@ -0,0 +1,61 @@
//go:build integration
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
)
func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"})
first := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "z-model",
RequestedModel: "z-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now(),
}
_, err := s.repo.Create(s.ctx, first)
s.Require().NoError(err)
second := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "a-model",
RequestedModel: "a-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().Add(time.Second),
}
_, err = s.repo.Create(s.ctx, second)
s.Require().NoError(err)
logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "model",
SortOrder: "asc",
}, usagestats.UsageLogFilters{UserID: user.ID})
s.Require().NoError(err)
s.Require().Len(logs, 2)
s.Require().Equal("a-model", logs[0].RequestedModel)
s.Require().Equal("z-model", logs[1].RequestedModel)
}
+53 -4
View File
@@ -17,6 +17,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
)
type userRepository struct {
@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return nil, nil, err
}
users, err := q.
usersQuery := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbuser.FieldID)).
All(ctx)
Limit(params.Limit())
for _, order := range userListOrder(params) {
usersQuery = usersQuery.Order(order)
}
users, err := usersQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
defaultField := true
switch sortBy {
case "email":
field = dbuser.FieldEmail
defaultField = false
case "username":
field = dbuser.FieldUsername
defaultField = false
case "role":
field = dbuser.FieldRole
defaultField = false
case "balance":
field = dbuser.FieldBalance
defaultField = false
case "concurrency":
field = dbuser.FieldConcurrency
defaultField = false
case "status":
field = dbuser.FieldStatus
defaultField = false
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
default:
field = dbuser.FieldID
}
if sortOrder == pagination.SortOrderAsc {
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
@@ -0,0 +1,39 @@
//go:build integration
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "email",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal("a-first@example.com", users[0].Email)
s.Require().Equal("z-last@example.com", users[1].Email)
}
func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
first := s.mustCreateUser(&service.User{Email: "first@example.com"})
second := s.mustCreateUser(&service.User{Email: "second@example.com"})
users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal(second.ID, users[0].ID)
s.Require().Equal(first.ID, users[1].ID)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
+6 -2
View File
@@ -491,8 +491,10 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
@@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) {
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50, 100],
"min_claude_code_version": "",
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
+21 -21
View File
@@ -21,13 +21,13 @@ import (
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
@@ -35,7 +35,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
@@ -55,7 +55,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error)
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@@ -77,8 +77,8 @@ type AdminService interface {
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
@@ -93,7 +93,7 @@ type AdminService interface {
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
@@ -485,8 +485,8 @@ func NewAdminService(
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, 0, err
@@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil {
return nil, 0, err
@@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil {
return nil, 0, err
@@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil {
return nil, 0, err
@@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
@@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil
}
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
@@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
@@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return nil
}
func TestAdminService_ListGroups_PassesSortParams(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}},
}
svc := &adminServiceImpl{groupRepo: repo}
_, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 3,
PageSize: 25,
SortBy: "account_count",
SortOrder: "ASC",
}, repo.listWithFiltersParams)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
@@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
@@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "")
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
@@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "")
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
@@ -13,11 +13,13 @@ import (
type userRepoStubForListUsers struct {
userRepoStub
users []User
err error
users []User
err error
listWithFiltersParams pagination.PaginationParams
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
s.listWithFiltersParams = params
if s.err != nil {
return nil, nil, s.err
}
@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo: rateRepo,
}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}
func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{{ID: 1, Email: "a@example.com"}},
}
svc := &adminServiceImpl{userRepo: userRepo}
_, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 2,
PageSize: 50,
SortBy: "email",
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "")
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked)
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
Version int `json:"version"`
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct {
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
@@ -13,6 +13,8 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 3
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
if entry.Snapshot == nil {
return nil, false, nil
}
if entry.Snapshot.Version != apiKeyAuthSnapshotVersion {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
return nil
}
snapshot := &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
}
}
return snapshot
@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
}
}
s.compileAPIKeyIPRules(apiKey)
@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
}
func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t *testing.T) {
svc := NewAPIKeyService(nil, nil, nil, nil, nil, nil, &config.Config{})
groupID := int64(9)
apiKey := &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}
func TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig(t *testing.T) {
cache := &authCacheStub{}
var repoCalls int32
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&repoCalls, 1)
groupID := int64(9)
return &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
Hydrated: true,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
},
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
},
},
}, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k-legacy")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&repoCalls))
require.NotNil(t, apiKey.Group)
require.Equal(t, "gpt-5.4-nano", apiKey.Group.MessagesDispatchModelConfig.OpusMappedModel)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
@@ -143,6 +143,8 @@ const (
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src
SettingKeyTableDefaultPageSize = "table_default_page_size" // 表格默认每页条数
SettingKeyTablePageSizeOptions = "table_page_size_options" // 表格可选每页条数(JSON 数组)
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
@@ -492,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "")
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "", "", "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Len(t, accounts, 1)
@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/url"
"sort"
"strconv"
"strings"
"sync/atomic"
@@ -161,6 +162,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeyTableDefaultPageSize,
SettingKeyTablePageSizeOptions,
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
@@ -201,6 +204,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist(
settings[SettingKeyRegistrationEmailSuffixWhitelist],
)
tableDefaultPageSize, tablePageSizeOptions := parseTablePreferences(
settings[SettingKeyTableDefaultPageSize],
settings[SettingKeyTablePageSizeOptions],
)
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
@@ -222,6 +229,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
TableDefaultPageSize: tableDefaultPageSize,
TablePageSizeOptions: tablePageSizeOptions,
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
@@ -272,6 +281,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
@@ -300,6 +311,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
@@ -526,6 +539,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
tableDefaultPageSize, tablePageSizeOptions := normalizeTablePreferences(
settings.TableDefaultPageSize,
settings.TablePageSizeOptions,
)
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
if err != nil {
return fmt.Errorf("marshal table page size options: %w", err)
}
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
@@ -879,6 +902,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
@@ -950,6 +975,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}
result.TableDefaultPageSize, result.TablePageSizeOptions = parseTablePreferences(
settings[SettingKeyTableDefaultPageSize],
settings[SettingKeyTablePageSizeOptions],
)
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
@@ -1225,6 +1254,50 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
defaultPageSize = v
}
var options []int
if strings.TrimSpace(optionsRaw) != "" {
_ = json.Unmarshal([]byte(optionsRaw), &options)
}
return normalizeTablePreferences(defaultPageSize, options)
}
func normalizeTablePreferences(defaultPageSize int, options []int) (int, []int) {
const minPageSize = 5
const maxPageSize = 1000
const fallbackPageSize = 20
seen := make(map[int]struct{}, len(options))
normalizedOptions := make([]int, 0, len(options))
for _, option := range options {
if option < minPageSize || option > maxPageSize {
continue
}
if _, ok := seen[option]; ok {
continue
}
seen[option] = struct{}{}
normalizedOptions = append(normalizedOptions, option)
}
sort.Ints(normalizedOptions)
if defaultPageSize < minPageSize || defaultPageSize > maxPageSize {
defaultPageSize = fallbackPageSize
}
if len(normalizedOptions) == 0 {
normalizedOptions = []int{10, 20, 50}
}
return defaultPageSize, normalizedOptions
}
// getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" {
@@ -62,3 +62,18 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
require.NoError(t, err)
require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist)
}
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {
repo := &settingPublicRepoStub{
values: map[string]string{
SettingKeyTableDefaultPageSize: "50",
SettingKeyTablePageSizeOptions: "[20,50,100]",
},
}
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}
@@ -202,3 +202,24 @@ func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
{GroupID: 12, ValidityDays: MaxValidityDays},
}, got)
}
func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
repo := &settingUpdateRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 50,
TablePageSizeOptions: []int{20, 50, 100},
})
require.NoError(t, err)
require.Equal(t, "50", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions])
err = svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 1000,
TablePageSizeOptions: []int{20, 100},
})
require.NoError(t, err)
require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions])
}
@@ -66,6 +66,8 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
TableDefaultPageSize int
TablePageSizeOptions []int
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
@@ -132,6 +134,8 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
TableDefaultPageSize int
TablePageSizeOptions []int
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints