release: prepare v0.1.132

This commit is contained in:
kone
2026-05-15 22:33:43 +08:00
parent 41e60b20d6
commit b430cd4aa9
47 changed files with 1107 additions and 213 deletions
@@ -175,6 +175,11 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
return &user, nil
}
func (s *stubAdminService) RefreshUserRegistrationIPLocation(ctx context.Context, userID int64) (*service.User, error) {
user := service.User{ID: userID, Email: "user@example.com", Status: service.StatusActive, RegisterIPAddress: "8.8.8.8", RegisterIPLocation: "美国"}
return &user, nil
}
func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
return len(userIDs), nil
}
@@ -204,6 +204,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
AffiliateInviteBalanceReward: settings.AffiliateInviteBalanceReward,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
@@ -452,6 +453,7 @@ type UpdateSettingsRequest struct {
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
AffiliateInviteBalanceReward *float64 `json:"affiliate_invite_balance_reward"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
@@ -641,6 +643,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if affiliateRebatePerInviteeCap < 0 {
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
}
affiliateInviteBalanceReward := previousSettings.AffiliateInviteBalanceReward
if req.AffiliateInviteBalanceReward != nil {
affiliateInviteBalanceReward = *req.AffiliateInviteBalanceReward
}
if affiliateInviteBalanceReward < 0 {
affiliateInviteBalanceReward = service.AffiliateInviteBalanceRewardDefault
}
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
@@ -1374,6 +1383,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
AffiliateInviteBalanceReward: affiliateInviteBalanceReward,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
@@ -1758,6 +1768,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
AffiliateInviteBalanceReward: updatedSettings.AffiliateInviteBalanceReward,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -2099,6 +2110,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
changed = append(changed, "affiliate_rebate_per_invitee_cap")
}
if before.AffiliateInviteBalanceReward != after.AffiliateInviteBalanceReward {
changed = append(changed, "affiliate_invite_balance_reward")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
@@ -341,6 +341,24 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
})
}
// RefreshRegistrationIPLocation handles refreshing signup IP location.
// POST /api/v1/admin/users/:id/register-ip-location
func (h *UserHandler) RefreshRegistrationIPLocation(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
user, err := h.adminService.RefreshUserRegistrationIPLocation(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromServiceAdmin(user))
}
// GetUserAPIKeys handles getting user's API keys
// GET /api/v1/admin/users/:id/api-keys
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
+1 -1
View File
@@ -362,7 +362,7 @@ func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider st
}
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
c.Request.Context(),
registrationIPContext(c),
strings.TrimSpace(session.ResolvedEmail),
req.Password,
strings.TrimSpace(req.InvitationCode),
@@ -352,6 +352,10 @@ func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64,
panic("unexpected AccrueQuota call")
}
func (r *oauthEmailAffiliateRepoStub) CreditInviteBalanceReward(context.Context, int64, int64, float64) (float64, error) {
panic("unexpected CreditInviteBalanceReward call")
}
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
panic("unexpected GetAccruedRebateFromInvitee call")
}
+10 -1
View File
@@ -150,6 +150,15 @@ func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
return h.settingSvc.IsBackendModeEnabled(ctx)
}
func registrationIPContext(c *gin.Context) context.Context {
base := c.Request.Context()
clientIP := strings.TrimSpace(ip.GetClientIP(c))
if clientIP == "" {
return base
}
return service.WithRegistrationIPInfo(base, service.RegistrationIPInfo{IPAddress: clientIP})
}
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
@@ -166,7 +175,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
_, user, err := h.authService.RegisterWithVerification(
c.Request.Context(),
registrationIPContext(c),
req.Email,
req.Password,
req.VerifyCode,
@@ -519,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(registrationIPContext(c), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -1673,7 +1673,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(),
registrationIPContext(c),
email,
req.Password,
strings.TrimSpace(req.VerifyCode),
+1 -1
View File
@@ -666,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(registrationIPContext(c), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(registrationIPContext(c), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
+6
View File
@@ -21,6 +21,12 @@ func UserFromServiceShallow(u *service.User) *User {
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
RegisterIPAddress: u.RegisterIPAddress,
RegisterIPCountry: u.RegisterIPCountry,
RegisterIPCountryCode: u.RegisterIPCountryCode,
RegisterIPRegion: u.RegisterIPRegion,
RegisterIPCity: u.RegisterIPCity,
RegisterIPLocation: u.RegisterIPLocation,
LastActiveAt: u.LastActiveAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
+1
View File
@@ -129,6 +129,7 @@ type SystemSettings struct {
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
AffiliateInviteBalanceReward float64 `json:"affiliate_invite_balance_reward"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
+17 -11
View File
@@ -7,17 +7,23 @@ import (
)
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
RegisterIPAddress string `json:"register_ip_address,omitempty"`
RegisterIPCountry string `json:"register_ip_country,omitempty"`
RegisterIPCountryCode string `json:"register_ip_country_code,omitempty"`
RegisterIPRegion string `json:"register_ip_region,omitempty"`
RegisterIPCity string `json:"register_ip_city,omitempty"`
RegisterIPLocation string `json:"register_ip_location,omitempty"`
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+112
View File
@@ -0,0 +1,112 @@
// Package ipgeo provides best-effort IP geolocation lookup for audit display.
package ipgeo
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
)
type Info struct {
IPAddress string
Country string
CountryCode string
Region string
City string
Location string
}
func Lookup(ctx context.Context, rawIP string) (*Info, error) {
ip := strings.TrimSpace(rawIP)
parsed := net.ParseIP(ip)
if parsed == nil {
return nil, fmt.Errorf("invalid ip")
}
if isLocalIP(parsed) {
return &Info{IPAddress: ip}, nil
}
endpoint := "http://ip-api.com/json/" + url.PathEscape(ip) + "?lang=zh-CN&fields=status,message,country,countryCode,regionName,region,city,query"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("lookup failed: status %d", resp.StatusCode)
}
var body struct {
Status string `json:"status"`
Message string `json:"message"`
Query string `json:"query"`
Country string `json:"country"`
CountryCode string `json:"countryCode"`
Region string `json:"region"`
RegionName string `json:"regionName"`
City string `json:"city"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, err
}
if strings.ToLower(body.Status) != "success" {
if body.Message == "" {
body.Message = "ip lookup failed"
}
return nil, errors.New(body.Message)
}
region := strings.TrimSpace(body.RegionName)
if region == "" {
region = strings.TrimSpace(body.Region)
}
info := &Info{
IPAddress: firstNonEmpty(body.Query, ip),
Country: strings.TrimSpace(body.Country),
CountryCode: strings.TrimSpace(body.CountryCode),
Region: region,
City: strings.TrimSpace(body.City),
}
info.Location = formatLocation(info.Country, info.Region, info.City)
return info, nil
}
func isLocalIP(ip net.IP) bool {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified()
}
func formatLocation(parts ...string) string {
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
key := strings.ToLower(part)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, part)
}
return strings.Join(out, " ")
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
@@ -162,6 +162,57 @@ VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUser
return applied, nil
}
func (r *affiliateRepository) CreditInviteBalanceReward(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (float64, error) {
if amount <= 0 {
return 0, nil
}
var newBalance float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
affected, err := txClient.User.Update().
Where(user.IDEQ(inviterID)).
AddBalance(amount).
AddTotalRecharged(amount).
Save(txCtx)
if err != nil {
return fmt.Errorf("credit invite balance reward: %w", err)
}
if affected == 0 {
return service.ErrUserNotFound
}
newBalance, err = queryUserBalance(txCtx, txClient, inviterID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (
user_id,
action,
amount,
source_user_id,
balance_after,
created_at,
updated_at
)
VALUES ($1, 'signup_reward', $2, $3, $4, NOW(), NOW())`,
inviterID,
amount,
inviteeUserID,
newBalance,
); err != nil {
return fmt.Errorf("insert affiliate signup reward ledger: %w", err)
}
return nil
})
if err != nil {
return 0, err
}
return newBalance, nil
}
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx,
+103
View File
@@ -98,6 +98,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if info := service.RegistrationIPInfoFromContext(ctx); strings.TrimSpace(info.IPAddress) != "" {
if err := updateUserRegistrationIPInfo(txCtx, txClient, created.ID, info); err != nil {
return err
}
userIn.RegisterIPAddress = strings.TrimSpace(info.IPAddress)
}
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
@@ -116,6 +122,10 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
return nil
}
func (r *userRepository) UpdateRegistrationIPInfo(ctx context.Context, userID int64, info service.RegistrationIPInfo) error {
return updateUserRegistrationIPInfo(ctx, r.client, userID, info)
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
if err != nil {
@@ -123,6 +133,9 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
out := userEntityToService(m)
if err := r.loadRegistrationIPInfo(ctx, map[int64]*service.User{id: out}); err != nil {
return nil, err
}
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err != nil {
return nil, err
@@ -150,6 +163,9 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service
m := matches[0]
out := userEntityToService(m)
if err := r.loadRegistrationIPInfo(ctx, map[int64]*service.User{m.ID: out}); err != nil {
return nil, err
}
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
@@ -474,6 +490,9 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
if err := r.loadRegistrationIPInfo(ctx, userMap); err != nil {
return nil, nil, err
}
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
if shouldLoadSubscriptions {
@@ -509,6 +528,90 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
func updateUserRegistrationIPInfo(ctx context.Context, client *dbent.Client, userID int64, info service.RegistrationIPInfo) error {
if userID <= 0 || strings.TrimSpace(info.IPAddress) == "" {
return nil
}
_, err := client.ExecContext(ctx, `
UPDATE users
SET register_ip_address = $1,
register_ip_country = $2,
register_ip_country_code = $3,
register_ip_region = $4,
register_ip_city = $5,
register_ip_location = $6,
updated_at = updated_at
WHERE id = $7`,
strings.TrimSpace(info.IPAddress),
strings.TrimSpace(info.Country),
strings.TrimSpace(info.CountryCode),
strings.TrimSpace(info.Region),
strings.TrimSpace(info.City),
strings.TrimSpace(info.Location),
userID,
)
if err != nil {
return fmt.Errorf("update user registration ip info: %w", err)
}
return nil
}
func (r *userRepository) loadRegistrationIPInfo(ctx context.Context, users map[int64]*service.User) error {
if len(users) == 0 {
return nil
}
ids := make([]int64, 0, len(users))
for id := range users {
ids = append(ids, id)
}
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return fmt.Errorf("sql executor is not configured")
}
rows, err := exec.QueryContext(ctx, `
SELECT id,
register_ip_address,
register_ip_country,
register_ip_country_code,
register_ip_region,
register_ip_city,
register_ip_location
FROM users
WHERE id = ANY($1)`, pq.Array(ids))
if err != nil {
if isRegistrationIPInfoSchemaMissing(err) {
return nil
}
return fmt.Errorf("load user registration ip info: %w", err)
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var ipAddress, country, countryCode, region, city, location string
if err := rows.Scan(&id, &ipAddress, &country, &countryCode, &region, &city, &location); err != nil {
return err
}
if u, ok := users[id]; ok && u != nil {
u.RegisterIPAddress = ipAddress
u.RegisterIPCountry = country
u.RegisterIPCountryCode = countryCode
u.RegisterIPRegion = region
u.RegisterIPCity = city
u.RegisterIPLocation = location
}
}
return rows.Err()
}
func isRegistrationIPInfoSchemaMissing(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "no such column: register_ip_")
}
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
@@ -751,6 +751,7 @@ func TestAPIContracts(t *testing.T) {
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"affiliate_invite_balance_reward": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -969,6 +970,7 @@ func TestAPIContracts(t *testing.T) {
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"affiliate_invite_balance_reward": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
+1
View File
@@ -240,6 +240,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.POST("/:id/register-ip-location", h.Admin.User.RefreshRegistrationIPLocation)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
+22 -2
View File
@@ -33,6 +33,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
RefreshUserRegistrationIPLocation(ctx context.Context, userID int64) (*User, error)
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
@@ -916,6 +917,25 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) RefreshUserRegistrationIPLocation(ctx context.Context, userID int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
ipAddress := strings.TrimSpace(user.RegisterIPAddress)
if ipAddress == "" {
return nil, infraerrors.BadRequest("REGISTER_IP_EMPTY", "registration IP is empty")
}
info, err := lookupRegistrationIPInfo(ctx, ipAddress)
if err != nil {
return nil, err
}
if err := updateUserRegistrationIPInfoWithRepo(ctx, s.userRepo, userID, info); err != nil {
return nil, err
}
return s.GetUser(ctx, userID)
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
@@ -1133,7 +1153,7 @@ SELECT id,
created_at
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'
AND action IN ('transfer', 'signup_reward')
ORDER BY created_at DESC, id DESC
OFFSET $2
LIMIT $3`, userID, params.Offset(), params.Limit())
@@ -1179,7 +1199,7 @@ func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, use
SELECT COUNT(*)
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'`, userID)
AND action IN ('transfer', 'signup_reward')`, userID)
if err != nil {
return 0, err
}
@@ -87,6 +87,8 @@ type AffiliateDetail struct {
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
// InviteBalanceReward 是新用户通过邀请注册并绑定成功后,直接进入邀请人余额的固定金额。
InviteBalanceReward float64 `json:"invite_balance_reward"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
@@ -99,6 +101,7 @@ type AffiliateRepository interface {
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
CreditInviteBalanceReward(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (float64, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
@@ -261,11 +264,23 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
AffQuota: summary.AffQuota,
AffFrozenQuota: summary.AffFrozenQuota,
AffHistoryQuota: summary.AffHistoryQuota,
InviteBalanceReward: s.resolveInviteBalanceReward(ctx),
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
Invitees: invitees,
}, nil
}
func (s *AffiliateService) resolveInviteBalanceReward(ctx context.Context) float64 {
if s == nil || s.settingService == nil {
return AffiliateInviteBalanceRewardDefault
}
amount := s.settingService.GetAffiliateInviteBalanceReward(ctx)
if amount <= 0 || math.IsNaN(amount) || math.IsInf(amount, 0) {
return AffiliateInviteBalanceRewardDefault
}
return amount
}
func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
code := strings.ToUpper(strings.TrimSpace(rawCode))
if code == "" {
@@ -308,9 +323,27 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
if !bound {
return ErrAffiliateAlreadyBound
}
s.creditInviteBalanceReward(ctx, inviterSummary.UserID, userID)
return nil
}
func (s *AffiliateService) creditInviteBalanceReward(ctx context.Context, inviterID, inviteeUserID int64) {
if s == nil || s.repo == nil || s.settingService == nil {
return
}
amount := s.settingService.GetAffiliateInviteBalanceReward(ctx)
if amount <= 0 || math.IsNaN(amount) || math.IsInf(amount, 0) {
return
}
newBalance, err := s.repo.CreditInviteBalanceReward(ctx, inviterID, inviteeUserID, amount)
if err != nil {
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to credit invite balance reward: inviter=%d invitee=%d amount=%.8f err=%v", inviterID, inviteeUserID, amount, err)
return
}
s.invalidateAffiliateCaches(ctx, inviterID)
logger.LegacyPrintf("service.affiliate", "[Affiliate] Invite balance reward credited: inviter=%d invitee=%d amount=%.8f balance=%.8f", inviterID, inviteeUserID, amount, newBalance)
}
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
}
+3
View File
@@ -813,6 +813,9 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
if info := RegistrationIPInfoFromContext(ctx); strings.TrimSpace(info.IPAddress) != "" {
s.refreshRegistrationIPLocationInBackground(user.ID, info.IPAddress)
}
}
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
@@ -30,6 +30,7 @@ const (
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
AffiliateRebateDurationDaysMax = 3650 // ~10 年
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
AffiliateInviteBalanceRewardDefault = 0.0 // 邀请注册后直接进入邀请人余额;0 = 关闭
)
// Platform constants
@@ -108,6 +109,7 @@ const (
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
SettingKeyAffiliateInviteBalanceReward = "affiliate_invite_balance_reward" // 邀请注册奖励,直接进入邀请人余额(0=关闭)
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON
SettingKeyLoginAgreementEnabled = "login_agreement_enabled" // 登录前是否要求同意条款
@@ -205,37 +205,23 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
return upstreamReq, nil
})
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
return nil, err
}
defer func() { _ = resp.Body.Close() }()
@@ -114,7 +114,6 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
zap.Bool("stream", clientStream),
)
// 5. Build upstream request
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
@@ -129,53 +128,41 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
}
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
// 6. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
return upstreamReq, nil
})
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
return nil, err
}
defer func() { _ = resp.Body.Close() }()
@@ -12,6 +12,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@@ -31,6 +32,40 @@ func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
return w.ResponseWriter.Write(p)
}
type sequentialHTTPUpstreamRecorder struct {
responses []*http.Response
errs []error
requests []*http.Request
bodies [][]byte
}
func (u *sequentialHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.requests = append(u.requests, req)
if req != nil && req.Body != nil {
b, _ := io.ReadAll(req.Body)
u.bodies = append(u.bodies, append([]byte(nil), b...))
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(b))
}
if len(u.errs) > 0 {
err := u.errs[0]
u.errs = u.errs[1:]
if err != nil {
return nil, err
}
}
if len(u.responses) > 0 {
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
return nil, errors.New("no response configured")
}
func (u *sequentialHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel()
@@ -133,6 +168,112 @@ func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *te
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestForwardAsChatCompletions_RequestErrorRetriesBeforeSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_retry","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &sequentialHTTPUpstreamRecorder{
errs: []error{
errors.New("Post \"https://chatgpt.com/backend-api/codex/responses\": read tcp 172.18.0.4:60076->42.193.179.21:1081: read: connection reset by peer"),
errors.New("connection reset by peer"),
errors.New("unexpected EOF"),
nil,
},
responses: []*http.Response{{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_retry_success"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}},
}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, upstream.requests, 4)
require.Len(t, upstream.bodies, 4)
require.Equal(t, upstream.bodies[0], upstream.bodies[3], "retry must rebuild the same upstream body")
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 3)
require.Equal(t, "request_error", events[0].Kind)
require.Contains(t, events[0].Message, "connection reset by peer")
}
func TestForwardAsChatCompletions_RequestErrorExhaustionReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &sequentialHTTPUpstreamRecorder{
errs: []error{
errors.New("connection reset by peer"),
errors.New("connection reset by peer"),
errors.New("connection reset by peer"),
errors.New("connection reset by peer"),
},
}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
require.Nil(t, result)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.False(t, c.Writer.Written(), "forward should not write a 502 before handler failover")
require.Len(t, upstream.requests, 4)
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 4)
require.Equal(t, "request_error:retry_exhausted", events[3].Kind)
}
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -243,57 +243,44 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// Override session_id with a deterministic UUID derived from the isolated
// session key, ensuring different API keys produce different upstream sessions.
if promptCacheKey != "" {
isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))
upstreamReq.Header.Set("session_id", isolatedSessionID)
if upstreamReq.Header.Get("conversation_id") != "" {
upstreamReq.Header.Set("conversation_id", isolatedSessionID)
}
}
if account.Type == AccountTypeOAuth {
// Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint.
// Match airgate-openai's request shape: the SSE endpoint does not need
// the Responses experimental beta header, and forcing originator can make
// ChatGPT select a different internal continuation path.
upstreamReq.Header.Del("OpenAI-Beta")
upstreamReq.Header.Del("originator")
}
if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" {
upstreamReq.Header.Del("conversation_id")
}
if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" {
upstreamReq.Header.Set("x-codex-turn-state", compatTurnState)
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// Override session_id with a deterministic UUID derived from the isolated
// session key, ensuring different API keys produce different upstream sessions.
if promptCacheKey != "" {
isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))
upstreamReq.Header.Set("session_id", isolatedSessionID)
if upstreamReq.Header.Get("conversation_id") != "" {
upstreamReq.Header.Set("conversation_id", isolatedSessionID)
}
}
if account.Type == AccountTypeOAuth {
// Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint.
// Match airgate-openai's request shape: the SSE endpoint does not need
// the Responses experimental beta header, and forcing originator can make
// ChatGPT select a different internal continuation path.
upstreamReq.Header.Del("OpenAI-Beta")
upstreamReq.Header.Del("originator")
}
if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" {
upstreamReq.Header.Del("conversation_id")
}
if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" {
upstreamReq.Header.Set("x-codex-turn-state", compatTurnState)
}
return upstreamReq, nil
})
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
return nil, err
}
defer func() { _ = resp.Body.Close() }()
@@ -2681,14 +2681,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
httpInvalidEncryptedContentRetryTried := false
for {
// Build upstream request
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
@@ -2696,28 +2688,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
return s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
})
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
return nil, err
}
// Handle error response
@@ -2972,13 +2947,6 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err
}
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
@@ -2989,28 +2957,11 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
c.Set("openai_passthrough", true)
}
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, true, func(upstreamCtx context.Context) (*http.Request, error) {
return s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
})
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Passthrough: true,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
return nil, err
}
defer func() { _ = resp.Body.Close() }()
@@ -0,0 +1,152 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"syscall"
"time"
"github.com/gin-gonic/gin"
)
const (
openAIHTTPRequestRetryMaxRetries = 3
openAIHTTPRequestRetryBaseDelay = 150 * time.Millisecond
)
type openAIUpstreamRequestBuilder func(context.Context) (*http.Request, error)
func (s *OpenAIGatewayService) doOpenAIUpstreamWithRequestRetry(
ctx context.Context,
c *gin.Context,
account *Account,
proxyURL string,
passthrough bool,
buildReq openAIUpstreamRequestBuilder,
) (*http.Response, error) {
if buildReq == nil {
return nil, errors.New("missing upstream request builder")
}
if account == nil {
return nil, errors.New("missing account")
}
attempts := openAIHTTPRequestRetryMaxRetries + 1
var lastErr error
startedAt := time.Now()
for attempt := 1; attempt <= attempts; attempt++ {
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
req, err := buildReq(upstreamCtx)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(startedAt).Milliseconds())
if err == nil {
return resp, nil
}
lastErr = err
s.recordOpenAIHTTPRequestErrorAttempt(c, account, passthrough, attempt, attempts, err)
if !isRetryableOpenAIHTTPRequestError(err) || attempt >= attempts {
break
}
time.Sleep(openAIHTTPRequestRetryDelay(attempt))
}
if isRetryableOpenAIHTTPRequestError(lastErr) {
return nil, newOpenAIHTTPRequestFailoverError(lastErr)
}
return nil, fmt.Errorf("upstream request failed: %s", sanitizeUpstreamErrorMessage(lastErr.Error()))
}
func (s *OpenAIGatewayService) recordOpenAIHTTPRequestErrorAttempt(c *gin.Context, account *Account, passthrough bool, attempt, attempts int, err error) {
if c == nil || err == nil {
return
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
kind := "request_error"
if attempt >= attempts {
kind = "request_error:retry_exhausted"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Passthrough: passthrough,
Kind: kind,
Message: safeErr,
Detail: fmt.Sprintf("attempt=%d max_retries=%d", attempt, openAIHTTPRequestRetryMaxRetries),
})
}
func openAIHTTPRequestRetryDelay(attempt int) time.Duration {
if attempt <= 0 {
return openAIHTTPRequestRetryBaseDelay
}
delay := openAIHTTPRequestRetryBaseDelay << (attempt - 1)
if delay > time.Second {
return time.Second
}
return delay
}
func isRetryableOpenAIHTTPRequestError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
return true
}
if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ETIMEDOUT) {
return true
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
msg := strings.ToLower(err.Error())
retryableMarkers := []string{
"connection reset by peer",
"connection refused",
"unexpected eof",
"server closed idle connection",
"broken pipe",
"connection aborted",
"tls: use of closed connection",
"http2: client connection lost",
}
for _, marker := range retryableMarkers {
if strings.Contains(msg, marker) {
return true
}
}
return false
}
func newOpenAIHTTPRequestFailoverError(err error) *UpstreamFailoverError {
message := "Upstream request failed"
if err != nil {
message = sanitizeUpstreamErrorMessage(err.Error())
}
body, _ := json.Marshal(gin.H{
"error": gin.H{
"type": "upstream_error",
"message": message,
},
})
return &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: body,
}
}
@@ -0,0 +1,73 @@
package service
import (
"context"
"errors"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ipgeo"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
type registrationIPInfoUpdater interface {
UpdateRegistrationIPInfo(ctx context.Context, userID int64, info RegistrationIPInfo) error
}
func (s *AuthService) refreshRegistrationIPLocationInBackground(userID int64, rawIP string) {
if s == nil || s.userRepo == nil || userID <= 0 || strings.TrimSpace(rawIP) == "" {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
info, err := lookupRegistrationIPInfo(ctx, rawIP)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to lookup registration IP location: user_id=%d ip=%s err=%v", userID, rawIP, err)
return
}
if err := updateUserRegistrationIPInfoWithRepo(ctx, s.userRepo, userID, info); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update registration IP location: user_id=%d ip=%s err=%v", userID, rawIP, err)
}
}()
}
func updateUserRegistrationIPInfoWithRepo(ctx context.Context, repo UserRepository, userID int64, info RegistrationIPInfo) error {
updater, ok := repo.(registrationIPInfoUpdater)
if !ok {
return errors.New("registration ip updater is not configured")
}
return updater.UpdateRegistrationIPInfo(ctx, userID, info)
}
func lookupRegistrationIPInfo(ctx context.Context, rawIP string) (RegistrationIPInfo, error) {
ipAddress := strings.TrimSpace(rawIP)
if ipAddress == "" {
return RegistrationIPInfo{}, infraerrors.BadRequest("REGISTER_IP_EMPTY", "registration IP is empty")
}
geo, err := ipgeo.Lookup(ctx, ipAddress)
if err != nil {
return RegistrationIPInfo{}, err
}
if geo == nil {
return RegistrationIPInfo{IPAddress: ipAddress}, nil
}
return RegistrationIPInfo{
IPAddress: firstNonEmptyRegistrationIP(geo.IPAddress, ipAddress),
Country: geo.Country,
CountryCode: geo.CountryCode,
Region: geo.Region,
City: geo.City,
Location: geo.Location,
}, nil
}
func firstNonEmptyRegistrationIP(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
@@ -1617,6 +1617,10 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
}
updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
if settings.AffiliateInviteBalanceReward < 0 {
settings.AffiliateInviteBalanceReward = AffiliateInviteBalanceRewardDefault
}
updates[SettingKeyAffiliateInviteBalanceReward] = strconv.FormatFloat(settings.AffiliateInviteBalanceReward, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
@@ -2136,6 +2140,20 @@ func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) fl
return cap
}
// GetAffiliateInviteBalanceReward returns the fixed reward credited directly to
// the inviter's account balance when an invitee binds successfully.
func (s *SettingService) GetAffiliateInviteBalanceReward(ctx context.Context) float64 {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateInviteBalanceReward)
if err != nil {
return AffiliateInviteBalanceRewardDefault
}
amount, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
if err != nil || amount < 0 || math.IsNaN(amount) || math.IsInf(amount, 0) {
return AffiliateInviteBalanceRewardDefault
}
return amount
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
@@ -2412,6 +2430,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
SettingKeyAffiliateInviteBalanceReward: strconv.FormatFloat(AffiliateInviteBalanceRewardDefault, 'f', 2, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
@@ -2587,6 +2606,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
result.AffiliateRebatePerInviteeCap = perInviteeCap
}
if inviteReward, err := strconv.ParseFloat(settings[SettingKeyAffiliateInviteBalanceReward], 64); err == nil && inviteReward >= 0 {
result.AffiliateInviteBalanceReward = inviteReward
}
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
@@ -130,6 +130,7 @@ type SystemSettings struct {
AffiliateRebateFreezeHours int
AffiliateRebateDurationDays int
AffiliateRebatePerInviteeCap float64
AffiliateInviteBalanceReward float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting
+42 -7
View File
@@ -1,6 +1,7 @@
package service
import (
"context"
"time"
"golang.org/x/crypto/bcrypt"
@@ -25,13 +26,19 @@ type User struct {
TokenVersion int64 // Incremented on password change to invalidate existing tokens
// TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived
// value expected in JWT claims and refresh-token state.
TokenVersionResolved bool
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
TokenVersionResolved bool
SignupSource string
RegisterIPAddress string
RegisterIPCountry string
RegisterIPCountryCode string
RegisterIPRegion string
RegisterIPCity string
RegisterIPLocation string
LastLoginAt *time.Time
LastActiveAt *time.Time
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
@@ -62,6 +69,34 @@ type User struct {
Subscriptions []UserSubscription
}
type registrationIPContextKey struct{}
type RegistrationIPInfo struct {
IPAddress string
Country string
CountryCode string
Region string
City string
Location string
}
func WithRegistrationIPInfo(ctx context.Context, info RegistrationIPInfo) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, registrationIPContextKey{}, info)
}
func RegistrationIPInfoFromContext(ctx context.Context) RegistrationIPInfo {
if ctx == nil {
return RegistrationIPInfo{}
}
if info, ok := ctx.Value(registrationIPContextKey{}).(RegistrationIPInfo); ok {
return info
}
return RegistrationIPInfo{}
}
func (u *User) IsAdmin() bool {
return u.Role == RoleAdmin || u.Role == RoleUserAdmin
}