release: prepare v0.1.132
This commit is contained in:
@@ -175,6 +175,11 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) RefreshUserRegistrationIPLocation(ctx context.Context, userID int64) (*service.User, error) {
|
||||
user := service.User{ID: userID, Email: "user@example.com", Status: service.StatusActive, RegisterIPAddress: "8.8.8.8", RegisterIPLocation: "美国"}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||
return len(userIDs), nil
|
||||
}
|
||||
|
||||
@@ -204,6 +204,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
|
||||
AffiliateInviteBalanceReward: settings.AffiliateInviteBalanceReward,
|
||||
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -452,6 +453,7 @@ type UpdateSettingsRequest struct {
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
AffiliateInviteBalanceReward *float64 `json:"affiliate_invite_balance_reward"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
@@ -641,6 +643,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if affiliateRebatePerInviteeCap < 0 {
|
||||
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
affiliateInviteBalanceReward := previousSettings.AffiliateInviteBalanceReward
|
||||
if req.AffiliateInviteBalanceReward != nil {
|
||||
affiliateInviteBalanceReward = *req.AffiliateInviteBalanceReward
|
||||
}
|
||||
if affiliateInviteBalanceReward < 0 {
|
||||
affiliateInviteBalanceReward = service.AffiliateInviteBalanceRewardDefault
|
||||
}
|
||||
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
|
||||
if req.TableDefaultPageSize <= 0 {
|
||||
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
|
||||
@@ -1374,6 +1383,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
AffiliateInviteBalanceReward: affiliateInviteBalanceReward,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -1758,6 +1768,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
|
||||
AffiliateInviteBalanceReward: updatedSettings.AffiliateInviteBalanceReward,
|
||||
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -2099,6 +2110,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
|
||||
changed = append(changed, "affiliate_rebate_per_invitee_cap")
|
||||
}
|
||||
if before.AffiliateInviteBalanceReward != after.AffiliateInviteBalanceReward {
|
||||
changed = append(changed, "affiliate_invite_balance_reward")
|
||||
}
|
||||
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
|
||||
changed = append(changed, "default_subscriptions")
|
||||
}
|
||||
|
||||
@@ -341,6 +341,24 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshRegistrationIPLocation handles refreshing signup IP location.
|
||||
// POST /api/v1/admin/users/:id/register-ip-location
|
||||
func (h *UserHandler) RefreshRegistrationIPLocation(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.RefreshUserRegistrationIPLocation(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
// GET /api/v1/admin/users/:id/api-keys
|
||||
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
|
||||
@@ -362,7 +362,7 @@ func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider st
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
|
||||
c.Request.Context(),
|
||||
registrationIPContext(c),
|
||||
strings.TrimSpace(session.ResolvedEmail),
|
||||
req.Password,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
|
||||
@@ -352,6 +352,10 @@ func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64,
|
||||
panic("unexpected AccrueQuota call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) CreditInviteBalanceReward(context.Context, int64, int64, float64) (float64, error) {
|
||||
panic("unexpected CreditInviteBalanceReward call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
|
||||
panic("unexpected GetAccruedRebateFromInvitee call")
|
||||
}
|
||||
|
||||
@@ -150,6 +150,15 @@ func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
|
||||
return h.settingSvc.IsBackendModeEnabled(ctx)
|
||||
}
|
||||
|
||||
func registrationIPContext(c *gin.Context) context.Context {
|
||||
base := c.Request.Context()
|
||||
clientIP := strings.TrimSpace(ip.GetClientIP(c))
|
||||
if clientIP == "" {
|
||||
return base
|
||||
}
|
||||
return service.WithRegistrationIPInfo(base, service.RegistrationIPInfo{IPAddress: clientIP})
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
// POST /api/v1/auth/register
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
@@ -166,7 +175,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
|
||||
_, user, err := h.authService.RegisterWithVerification(
|
||||
c.Request.Context(),
|
||||
registrationIPContext(c),
|
||||
req.Email,
|
||||
req.Password,
|
||||
req.VerifyCode,
|
||||
|
||||
@@ -519,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(registrationIPContext(c), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1673,7 +1673,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
|
||||
c.Request.Context(),
|
||||
registrationIPContext(c),
|
||||
email,
|
||||
req.Password,
|
||||
strings.TrimSpace(req.VerifyCode),
|
||||
|
||||
@@ -666,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(registrationIPContext(c), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(registrationIPContext(c), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -21,6 +21,12 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
AllowedGroups: u.AllowedGroups,
|
||||
RegisterIPAddress: u.RegisterIPAddress,
|
||||
RegisterIPCountry: u.RegisterIPCountry,
|
||||
RegisterIPCountryCode: u.RegisterIPCountryCode,
|
||||
RegisterIPRegion: u.RegisterIPRegion,
|
||||
RegisterIPCity: u.RegisterIPCity,
|
||||
RegisterIPLocation: u.RegisterIPLocation,
|
||||
LastActiveAt: u.LastActiveAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
|
||||
@@ -129,6 +129,7 @@ type SystemSettings struct {
|
||||
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
AffiliateInviteBalanceReward float64 `json:"affiliate_invite_balance_reward"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
|
||||
@@ -7,17 +7,23 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Status string `json:"status"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Status string `json:"status"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
RegisterIPAddress string `json:"register_ip_address,omitempty"`
|
||||
RegisterIPCountry string `json:"register_ip_country,omitempty"`
|
||||
RegisterIPCountryCode string `json:"register_ip_country_code,omitempty"`
|
||||
RegisterIPRegion string `json:"register_ip_region,omitempty"`
|
||||
RegisterIPCity string `json:"register_ip_city,omitempty"`
|
||||
RegisterIPLocation string `json:"register_ip_location,omitempty"`
|
||||
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// 余额不足通知
|
||||
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
// Package ipgeo provides best-effort IP geolocation lookup for audit display.
|
||||
package ipgeo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Info struct {
|
||||
IPAddress string
|
||||
Country string
|
||||
CountryCode string
|
||||
Region string
|
||||
City string
|
||||
Location string
|
||||
}
|
||||
|
||||
func Lookup(ctx context.Context, rawIP string) (*Info, error) {
|
||||
ip := strings.TrimSpace(rawIP)
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return nil, fmt.Errorf("invalid ip")
|
||||
}
|
||||
if isLocalIP(parsed) {
|
||||
return &Info{IPAddress: ip}, nil
|
||||
}
|
||||
|
||||
endpoint := "http://ip-api.com/json/" + url.PathEscape(ip) + "?lang=zh-CN&fields=status,message,country,countryCode,regionName,region,city,query"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("lookup failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Query string `json:"query"`
|
||||
Country string `json:"country"`
|
||||
CountryCode string `json:"countryCode"`
|
||||
Region string `json:"region"`
|
||||
RegionName string `json:"regionName"`
|
||||
City string `json:"city"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ToLower(body.Status) != "success" {
|
||||
if body.Message == "" {
|
||||
body.Message = "ip lookup failed"
|
||||
}
|
||||
return nil, errors.New(body.Message)
|
||||
}
|
||||
|
||||
region := strings.TrimSpace(body.RegionName)
|
||||
if region == "" {
|
||||
region = strings.TrimSpace(body.Region)
|
||||
}
|
||||
info := &Info{
|
||||
IPAddress: firstNonEmpty(body.Query, ip),
|
||||
Country: strings.TrimSpace(body.Country),
|
||||
CountryCode: strings.TrimSpace(body.CountryCode),
|
||||
Region: region,
|
||||
City: strings.TrimSpace(body.City),
|
||||
}
|
||||
info.Location = formatLocation(info.Country, info.Region, info.City)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func isLocalIP(ip net.IP) bool {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func formatLocation(parts ...string) string {
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(part)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, part)
|
||||
}
|
||||
return strings.Join(out, " ")
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -162,6 +162,57 @@ VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUser
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) CreditInviteBalanceReward(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (float64, error) {
|
||||
if amount <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var newBalance float64
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
affected, err := txClient.User.Update().
|
||||
Where(user.IDEQ(inviterID)).
|
||||
AddBalance(amount).
|
||||
AddTotalRecharged(amount).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credit invite balance reward: %w", err)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
newBalance, err = queryUserBalance(txCtx, txClient, inviterID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (
|
||||
user_id,
|
||||
action,
|
||||
amount,
|
||||
source_user_id,
|
||||
balance_after,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES ($1, 'signup_reward', $2, $3, $4, NOW(), NOW())`,
|
||||
inviterID,
|
||||
amount,
|
||||
inviteeUserID,
|
||||
newBalance,
|
||||
); err != nil {
|
||||
return fmt.Errorf("insert affiliate signup reward ledger: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return newBalance, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx,
|
||||
|
||||
@@ -98,6 +98,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
if info := service.RegistrationIPInfoFromContext(ctx); strings.TrimSpace(info.IPAddress) != "" {
|
||||
if err := updateUserRegistrationIPInfo(txCtx, txClient, created.ID, info); err != nil {
|
||||
return err
|
||||
}
|
||||
userIn.RegisterIPAddress = strings.TrimSpace(info.IPAddress)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
@@ -116,6 +122,10 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateRegistrationIPInfo(ctx context.Context, userID int64, info service.RegistrationIPInfo) error {
|
||||
return updateUserRegistrationIPInfo(ctx, r.client, userID, info)
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
@@ -123,6 +133,9 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
if err := r.loadRegistrationIPInfo(ctx, map[int64]*service.User{id: out}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,6 +163,9 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service
|
||||
m := matches[0]
|
||||
|
||||
out := userEntityToService(m)
|
||||
if err := r.loadRegistrationIPInfo(ctx, map[int64]*service.User{m.ID: out}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -474,6 +490,9 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
if err := r.loadRegistrationIPInfo(ctx, userMap); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
|
||||
if shouldLoadSubscriptions {
|
||||
@@ -509,6 +528,90 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func updateUserRegistrationIPInfo(ctx context.Context, client *dbent.Client, userID int64, info service.RegistrationIPInfo) error {
|
||||
if userID <= 0 || strings.TrimSpace(info.IPAddress) == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := client.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET register_ip_address = $1,
|
||||
register_ip_country = $2,
|
||||
register_ip_country_code = $3,
|
||||
register_ip_region = $4,
|
||||
register_ip_city = $5,
|
||||
register_ip_location = $6,
|
||||
updated_at = updated_at
|
||||
WHERE id = $7`,
|
||||
strings.TrimSpace(info.IPAddress),
|
||||
strings.TrimSpace(info.Country),
|
||||
strings.TrimSpace(info.CountryCode),
|
||||
strings.TrimSpace(info.Region),
|
||||
strings.TrimSpace(info.City),
|
||||
strings.TrimSpace(info.Location),
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update user registration ip info: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) loadRegistrationIPInfo(ctx context.Context, users map[int64]*service.User) error {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]int64, 0, len(users))
|
||||
for id := range users {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
return fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
rows, err := exec.QueryContext(ctx, `
|
||||
SELECT id,
|
||||
register_ip_address,
|
||||
register_ip_country,
|
||||
register_ip_country_code,
|
||||
register_ip_region,
|
||||
register_ip_city,
|
||||
register_ip_location
|
||||
FROM users
|
||||
WHERE id = ANY($1)`, pq.Array(ids))
|
||||
if err != nil {
|
||||
if isRegistrationIPInfoSchemaMissing(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("load user registration ip info: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var ipAddress, country, countryCode, region, city, location string
|
||||
if err := rows.Scan(&id, &ipAddress, &country, &countryCode, ®ion, &city, &location); err != nil {
|
||||
return err
|
||||
}
|
||||
if u, ok := users[id]; ok && u != nil {
|
||||
u.RegisterIPAddress = ipAddress
|
||||
u.RegisterIPCountry = country
|
||||
u.RegisterIPCountryCode = countryCode
|
||||
u.RegisterIPRegion = region
|
||||
u.RegisterIPCity = city
|
||||
u.RegisterIPLocation = location
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func isRegistrationIPInfoSchemaMissing(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "no such column: register_ip_")
|
||||
}
|
||||
|
||||
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
@@ -751,6 +751,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"affiliate_invite_balance_reward": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@@ -969,6 +970,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"affiliate_invite_balance_reward": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
|
||||
@@ -240,6 +240,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users.PUT("/:id", h.Admin.User.Update)
|
||||
users.DELETE("/:id", h.Admin.User.Delete)
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.POST("/:id/register-ip-location", h.Admin.User.RefreshRegistrationIPLocation)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||
|
||||
@@ -33,6 +33,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
RefreshUserRegistrationIPLocation(ctx context.Context, userID int64) (*User, error)
|
||||
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
@@ -916,6 +917,25 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) RefreshUserRegistrationIPLocation(ctx context.Context, userID int64) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipAddress := strings.TrimSpace(user.RegisterIPAddress)
|
||||
if ipAddress == "" {
|
||||
return nil, infraerrors.BadRequest("REGISTER_IP_EMPTY", "registration IP is empty")
|
||||
}
|
||||
info, err := lookupRegistrationIPInfo(ctx, ipAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := updateUserRegistrationIPInfoWithRepo(ctx, s.userRepo, userID, info); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.GetUser(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
|
||||
@@ -1133,7 +1153,7 @@ SELECT id,
|
||||
created_at
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'
|
||||
AND action IN ('transfer', 'signup_reward')
|
||||
ORDER BY created_at DESC, id DESC
|
||||
OFFSET $2
|
||||
LIMIT $3`, userID, params.Offset(), params.Limit())
|
||||
@@ -1179,7 +1199,7 @@ func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, use
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'`, userID)
|
||||
AND action IN ('transfer', 'signup_reward')`, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -87,6 +87,8 @@ type AffiliateDetail struct {
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
// InviteBalanceReward 是新用户通过邀请注册并绑定成功后,直接进入邀请人余额的固定金额。
|
||||
InviteBalanceReward float64 `json:"invite_balance_reward"`
|
||||
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
|
||||
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
|
||||
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
|
||||
@@ -99,6 +101,7 @@ type AffiliateRepository interface {
|
||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
|
||||
CreditInviteBalanceReward(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (float64, error)
|
||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||
@@ -261,11 +264,23 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
|
||||
AffQuota: summary.AffQuota,
|
||||
AffFrozenQuota: summary.AffFrozenQuota,
|
||||
AffHistoryQuota: summary.AffHistoryQuota,
|
||||
InviteBalanceReward: s.resolveInviteBalanceReward(ctx),
|
||||
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
|
||||
Invitees: invitees,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AffiliateService) resolveInviteBalanceReward(ctx context.Context) float64 {
|
||||
if s == nil || s.settingService == nil {
|
||||
return AffiliateInviteBalanceRewardDefault
|
||||
}
|
||||
amount := s.settingService.GetAffiliateInviteBalanceReward(ctx)
|
||||
if amount <= 0 || math.IsNaN(amount) || math.IsInf(amount, 0) {
|
||||
return AffiliateInviteBalanceRewardDefault
|
||||
}
|
||||
return amount
|
||||
}
|
||||
|
||||
func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
|
||||
code := strings.ToUpper(strings.TrimSpace(rawCode))
|
||||
if code == "" {
|
||||
@@ -308,9 +323,27 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
||||
if !bound {
|
||||
return ErrAffiliateAlreadyBound
|
||||
}
|
||||
s.creditInviteBalanceReward(ctx, inviterSummary.UserID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AffiliateService) creditInviteBalanceReward(ctx context.Context, inviterID, inviteeUserID int64) {
|
||||
if s == nil || s.repo == nil || s.settingService == nil {
|
||||
return
|
||||
}
|
||||
amount := s.settingService.GetAffiliateInviteBalanceReward(ctx)
|
||||
if amount <= 0 || math.IsNaN(amount) || math.IsInf(amount, 0) {
|
||||
return
|
||||
}
|
||||
newBalance, err := s.repo.CreditInviteBalanceReward(ctx, inviterID, inviteeUserID, amount)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to credit invite balance reward: inviter=%d invitee=%d amount=%.8f err=%v", inviterID, inviteeUserID, amount, err)
|
||||
return
|
||||
}
|
||||
s.invalidateAffiliateCaches(ctx, inviterID)
|
||||
logger.LegacyPrintf("service.affiliate", "[Affiliate] Invite balance reward credited: inviter=%d invitee=%d amount=%.8f balance=%.8f", inviterID, inviteeUserID, amount, newBalance)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
||||
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
|
||||
}
|
||||
|
||||
@@ -813,6 +813,9 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
|
||||
if touchLogin {
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
}
|
||||
if info := RegistrationIPInfoFromContext(ctx); strings.TrimSpace(info.IPAddress) != "" {
|
||||
s.refreshRegistrationIPLocationInBackground(user.ID, info.IPAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
|
||||
|
||||
@@ -30,6 +30,7 @@ const (
|
||||
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
|
||||
AffiliateRebateDurationDaysMax = 3650 // ~10 年
|
||||
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
|
||||
AffiliateInviteBalanceRewardDefault = 0.0 // 邀请注册后直接进入邀请人余额;0 = 关闭
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
@@ -108,6 +109,7 @@ const (
|
||||
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
||||
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
||||
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
||||
SettingKeyAffiliateInviteBalanceReward = "affiliate_invite_balance_reward" // 邀请注册奖励,直接进入邀请人余额(0=关闭)
|
||||
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
|
||||
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON)
|
||||
SettingKeyLoginAgreementEnabled = "login_agreement_enabled" // 登录前是否要求同意条款
|
||||
|
||||
@@ -205,37 +205,23 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
if promptCacheKey != "" {
|
||||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||
}
|
||||
|
||||
// 7. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
if promptCacheKey != "" {
|
||||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||
}
|
||||
return upstreamReq, nil
|
||||
})
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
zap.Bool("stream", clientStream),
|
||||
)
|
||||
|
||||
// 5. Build upstream request
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||
@@ -129,53 +128,41 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
}
|
||||
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if clientStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
upstreamReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
upstreamReq.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
upstreamReq.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 6. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if clientStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
upstreamReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
upstreamReq.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
upstreamReq.Header.Set("user-agent", customUA)
|
||||
}
|
||||
return upstreamReq, nil
|
||||
})
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -31,6 +32,40 @@ func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
type sequentialHTTPUpstreamRecorder struct {
|
||||
responses []*http.Response
|
||||
errs []error
|
||||
requests []*http.Request
|
||||
bodies [][]byte
|
||||
}
|
||||
|
||||
func (u *sequentialHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.requests = append(u.requests, req)
|
||||
if req != nil && req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
u.bodies = append(u.bodies, append([]byte(nil), b...))
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(b))
|
||||
}
|
||||
if len(u.errs) > 0 {
|
||||
err := u.errs[0]
|
||||
u.errs = u.errs[1:]
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(u.responses) > 0 {
|
||||
resp := u.responses[0]
|
||||
u.responses = u.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
return nil, errors.New("no response configured")
|
||||
}
|
||||
|
||||
func (u *sequentialHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -133,6 +168,112 @@ func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *te
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_RequestErrorRetriesBeforeSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_retry","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &sequentialHTTPUpstreamRecorder{
|
||||
errs: []error{
|
||||
errors.New("Post \"https://chatgpt.com/backend-api/codex/responses\": read tcp 172.18.0.4:60076->42.193.179.21:1081: read: connection reset by peer"),
|
||||
errors.New("connection reset by peer"),
|
||||
errors.New("unexpected EOF"),
|
||||
nil,
|
||||
},
|
||||
responses: []*http.Response{{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_retry_success"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
require.Len(t, upstream.bodies, 4)
|
||||
require.Equal(t, upstream.bodies[0], upstream.bodies[3], "retry must rebuild the same upstream body")
|
||||
|
||||
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 3)
|
||||
require.Equal(t, "request_error", events[0].Kind)
|
||||
require.Contains(t, events[0].Message, "connection reset by peer")
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_RequestErrorExhaustionReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &sequentialHTTPUpstreamRecorder{
|
||||
errs: []error{
|
||||
errors.New("connection reset by peer"),
|
||||
errors.New("connection reset by peer"),
|
||||
errors.New("connection reset by peer"),
|
||||
errors.New("connection reset by peer"),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
|
||||
require.Nil(t, result)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.False(t, c.Writer.Written(), "forward should not write a 502 before handler failover")
|
||||
require.Len(t, upstream.requests, 4)
|
||||
|
||||
rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 4)
|
||||
require.Equal(t, "request_error:retry_exhausted", events[3].Kind)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -243,57 +243,44 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// Override session_id with a deterministic UUID derived from the isolated
|
||||
// session key, ensuring different API keys produce different upstream sessions.
|
||||
if promptCacheKey != "" {
|
||||
isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))
|
||||
upstreamReq.Header.Set("session_id", isolatedSessionID)
|
||||
if upstreamReq.Header.Get("conversation_id") != "" {
|
||||
upstreamReq.Header.Set("conversation_id", isolatedSessionID)
|
||||
}
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint.
|
||||
// Match airgate-openai's request shape: the SSE endpoint does not need
|
||||
// the Responses experimental beta header, and forcing originator can make
|
||||
// ChatGPT select a different internal continuation path.
|
||||
upstreamReq.Header.Del("OpenAI-Beta")
|
||||
upstreamReq.Header.Del("originator")
|
||||
}
|
||||
if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" {
|
||||
upstreamReq.Header.Del("conversation_id")
|
||||
}
|
||||
if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" {
|
||||
upstreamReq.Header.Set("x-codex-turn-state", compatTurnState)
|
||||
}
|
||||
|
||||
// 7. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// Override session_id with a deterministic UUID derived from the isolated
|
||||
// session key, ensuring different API keys produce different upstream sessions.
|
||||
if promptCacheKey != "" {
|
||||
isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))
|
||||
upstreamReq.Header.Set("session_id", isolatedSessionID)
|
||||
if upstreamReq.Header.Get("conversation_id") != "" {
|
||||
upstreamReq.Header.Set("conversation_id", isolatedSessionID)
|
||||
}
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint.
|
||||
// Match airgate-openai's request shape: the SSE endpoint does not need
|
||||
// the Responses experimental beta header, and forcing originator can make
|
||||
// ChatGPT select a different internal continuation path.
|
||||
upstreamReq.Header.Del("OpenAI-Beta")
|
||||
upstreamReq.Header.Del("originator")
|
||||
}
|
||||
if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" {
|
||||
upstreamReq.Header.Del("conversation_id")
|
||||
}
|
||||
if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" {
|
||||
upstreamReq.Header.Set("x-codex-turn-state", compatTurnState)
|
||||
}
|
||||
return upstreamReq, nil
|
||||
})
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -2681,14 +2681,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
@@ -2696,28 +2688,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, false, func(upstreamCtx context.Context) (*http.Request, error) {
|
||||
return s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
})
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
@@ -2972,13 +2947,6 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
@@ -2989,28 +2957,11 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
c.Set("openai_passthrough", true)
|
||||
}
|
||||
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
resp, err := s.doOpenAIUpstreamWithRequestRetry(ctx, c, account, proxyURL, true, func(upstreamCtx context.Context) (*http.Request, error) {
|
||||
return s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
})
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIHTTPRequestRetryMaxRetries = 3
|
||||
openAIHTTPRequestRetryBaseDelay = 150 * time.Millisecond
|
||||
)
|
||||
|
||||
type openAIUpstreamRequestBuilder func(context.Context) (*http.Request, error)
|
||||
|
||||
func (s *OpenAIGatewayService) doOpenAIUpstreamWithRequestRetry(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
proxyURL string,
|
||||
passthrough bool,
|
||||
buildReq openAIUpstreamRequestBuilder,
|
||||
) (*http.Response, error) {
|
||||
if buildReq == nil {
|
||||
return nil, errors.New("missing upstream request builder")
|
||||
}
|
||||
if account == nil {
|
||||
return nil, errors.New("missing account")
|
||||
}
|
||||
|
||||
attempts := openAIHTTPRequestRetryMaxRetries + 1
|
||||
var lastErr error
|
||||
startedAt := time.Now()
|
||||
for attempt := 1; attempt <= attempts; attempt++ {
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
req, err := buildReq(upstreamCtx)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(startedAt).Milliseconds())
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
s.recordOpenAIHTTPRequestErrorAttempt(c, account, passthrough, attempt, attempts, err)
|
||||
if !isRetryableOpenAIHTTPRequestError(err) || attempt >= attempts {
|
||||
break
|
||||
}
|
||||
time.Sleep(openAIHTTPRequestRetryDelay(attempt))
|
||||
}
|
||||
|
||||
if isRetryableOpenAIHTTPRequestError(lastErr) {
|
||||
return nil, newOpenAIHTTPRequestFailoverError(lastErr)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream request failed: %s", sanitizeUpstreamErrorMessage(lastErr.Error()))
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) recordOpenAIHTTPRequestErrorAttempt(c *gin.Context, account *Account, passthrough bool, attempt, attempts int, err error) {
|
||||
if c == nil || err == nil {
|
||||
return
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
kind := "request_error"
|
||||
if attempt >= attempts {
|
||||
kind = "request_error:retry_exhausted"
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Passthrough: passthrough,
|
||||
Kind: kind,
|
||||
Message: safeErr,
|
||||
Detail: fmt.Sprintf("attempt=%d max_retries=%d", attempt, openAIHTTPRequestRetryMaxRetries),
|
||||
})
|
||||
}
|
||||
|
||||
func openAIHTTPRequestRetryDelay(attempt int) time.Duration {
|
||||
if attempt <= 0 {
|
||||
return openAIHTTPRequestRetryBaseDelay
|
||||
}
|
||||
delay := openAIHTTPRequestRetryBaseDelay << (attempt - 1)
|
||||
if delay > time.Second {
|
||||
return time.Second
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func isRetryableOpenAIHTTPRequestError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ETIMEDOUT) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
retryableMarkers := []string{
|
||||
"connection reset by peer",
|
||||
"connection refused",
|
||||
"unexpected eof",
|
||||
"server closed idle connection",
|
||||
"broken pipe",
|
||||
"connection aborted",
|
||||
"tls: use of closed connection",
|
||||
"http2: client connection lost",
|
||||
}
|
||||
for _, marker := range retryableMarkers {
|
||||
if strings.Contains(msg, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newOpenAIHTTPRequestFailoverError(err error) *UpstreamFailoverError {
|
||||
message := "Upstream request failed"
|
||||
if err != nil {
|
||||
message = sanitizeUpstreamErrorMessage(err.Error())
|
||||
}
|
||||
body, _ := json.Marshal(gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: body,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ipgeo"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
type registrationIPInfoUpdater interface {
|
||||
UpdateRegistrationIPInfo(ctx context.Context, userID int64, info RegistrationIPInfo) error
|
||||
}
|
||||
|
||||
func (s *AuthService) refreshRegistrationIPLocationInBackground(userID int64, rawIP string) {
|
||||
if s == nil || s.userRepo == nil || userID <= 0 || strings.TrimSpace(rawIP) == "" {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
info, err := lookupRegistrationIPInfo(ctx, rawIP)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to lookup registration IP location: user_id=%d ip=%s err=%v", userID, rawIP, err)
|
||||
return
|
||||
}
|
||||
if err := updateUserRegistrationIPInfoWithRepo(ctx, s.userRepo, userID, info); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update registration IP location: user_id=%d ip=%s err=%v", userID, rawIP, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func updateUserRegistrationIPInfoWithRepo(ctx context.Context, repo UserRepository, userID int64, info RegistrationIPInfo) error {
|
||||
updater, ok := repo.(registrationIPInfoUpdater)
|
||||
if !ok {
|
||||
return errors.New("registration ip updater is not configured")
|
||||
}
|
||||
return updater.UpdateRegistrationIPInfo(ctx, userID, info)
|
||||
}
|
||||
|
||||
func lookupRegistrationIPInfo(ctx context.Context, rawIP string) (RegistrationIPInfo, error) {
|
||||
ipAddress := strings.TrimSpace(rawIP)
|
||||
if ipAddress == "" {
|
||||
return RegistrationIPInfo{}, infraerrors.BadRequest("REGISTER_IP_EMPTY", "registration IP is empty")
|
||||
}
|
||||
geo, err := ipgeo.Lookup(ctx, ipAddress)
|
||||
if err != nil {
|
||||
return RegistrationIPInfo{}, err
|
||||
}
|
||||
if geo == nil {
|
||||
return RegistrationIPInfo{IPAddress: ipAddress}, nil
|
||||
}
|
||||
return RegistrationIPInfo{
|
||||
IPAddress: firstNonEmptyRegistrationIP(geo.IPAddress, ipAddress),
|
||||
Country: geo.Country,
|
||||
CountryCode: geo.CountryCode,
|
||||
Region: geo.Region,
|
||||
City: geo.City,
|
||||
Location: geo.Location,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func firstNonEmptyRegistrationIP(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1617,6 +1617,10 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
|
||||
if settings.AffiliateInviteBalanceReward < 0 {
|
||||
settings.AffiliateInviteBalanceReward = AffiliateInviteBalanceRewardDefault
|
||||
}
|
||||
updates[SettingKeyAffiliateInviteBalanceReward] = strconv.FormatFloat(settings.AffiliateInviteBalanceReward, 'f', 8, 64)
|
||||
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
@@ -2136,6 +2140,20 @@ func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) fl
|
||||
return cap
|
||||
}
|
||||
|
||||
// GetAffiliateInviteBalanceReward returns the fixed reward credited directly to
|
||||
// the inviter's account balance when an invitee binds successfully.
|
||||
func (s *SettingService) GetAffiliateInviteBalanceReward(ctx context.Context) float64 {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateInviteBalanceReward)
|
||||
if err != nil {
|
||||
return AffiliateInviteBalanceRewardDefault
|
||||
}
|
||||
amount, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
|
||||
if err != nil || amount < 0 || math.IsNaN(amount) || math.IsInf(amount, 0) {
|
||||
return AffiliateInviteBalanceRewardDefault
|
||||
}
|
||||
return amount
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证
|
||||
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
@@ -2412,6 +2430,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
|
||||
SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
|
||||
SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
|
||||
SettingKeyAffiliateInviteBalanceReward: strconv.FormatFloat(AffiliateInviteBalanceRewardDefault, 'f', 2, 64),
|
||||
SettingKeyDefaultUserRPMLimit: "0",
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||
@@ -2587,6 +2606,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
|
||||
result.AffiliateRebatePerInviteeCap = perInviteeCap
|
||||
}
|
||||
if inviteReward, err := strconv.ParseFloat(settings[SettingKeyAffiliateInviteBalanceReward], 64); err == nil && inviteReward >= 0 {
|
||||
result.AffiliateInviteBalanceReward = inviteReward
|
||||
}
|
||||
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
|
||||
@@ -130,6 +130,7 @@ type SystemSettings struct {
|
||||
AffiliateRebateFreezeHours int
|
||||
AffiliateRebateDurationDays int
|
||||
AffiliateRebatePerInviteeCap float64
|
||||
AffiliateInviteBalanceReward float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -25,13 +26,19 @@ type User struct {
|
||||
TokenVersion int64 // Incremented on password change to invalidate existing tokens
|
||||
// TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived
|
||||
// value expected in JWT claims and refresh-token state.
|
||||
TokenVersionResolved bool
|
||||
SignupSource string
|
||||
LastLoginAt *time.Time
|
||||
LastActiveAt *time.Time
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
TokenVersionResolved bool
|
||||
SignupSource string
|
||||
RegisterIPAddress string
|
||||
RegisterIPCountry string
|
||||
RegisterIPCountryCode string
|
||||
RegisterIPRegion string
|
||||
RegisterIPCity string
|
||||
RegisterIPLocation string
|
||||
LastLoginAt *time.Time
|
||||
LastActiveAt *time.Time
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
@@ -62,6 +69,34 @@ type User struct {
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
type registrationIPContextKey struct{}
|
||||
|
||||
type RegistrationIPInfo struct {
|
||||
IPAddress string
|
||||
Country string
|
||||
CountryCode string
|
||||
Region string
|
||||
City string
|
||||
Location string
|
||||
}
|
||||
|
||||
func WithRegistrationIPInfo(ctx context.Context, info RegistrationIPInfo) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, registrationIPContextKey{}, info)
|
||||
}
|
||||
|
||||
func RegistrationIPInfoFromContext(ctx context.Context) RegistrationIPInfo {
|
||||
if ctx == nil {
|
||||
return RegistrationIPInfo{}
|
||||
}
|
||||
if info, ok := ctx.Value(registrationIPContextKey{}).(RegistrationIPInfo); ok {
|
||||
return info
|
||||
}
|
||||
return RegistrationIPInfo{}
|
||||
}
|
||||
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.Role == RoleAdmin || u.Role == RoleUserAdmin
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user