Merge pull request #2224 from lyen1688/feat-email-oauth-github-google

feat: 增加 GitHub 和 Google 邮箱快捷登录
This commit is contained in:
Wesley Liddick
2026-05-07 10:07:28 +08:00
committed by GitHub
40 changed files with 3219 additions and 77 deletions
+15
View File
@@ -72,6 +72,8 @@ type Config struct {
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
@@ -240,6 +242,19 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type EmailOAuthProviderConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
EmailsURL string `mapstructure:"emails_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
}
const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
@@ -169,6 +169,16 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
GitHubOAuthClientID: settings.GitHubOAuthClientID,
GitHubOAuthClientSecretConfigured: settings.GitHubOAuthClientSecretConfigured,
GitHubOAuthRedirectURL: settings.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: settings.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
GoogleOAuthClientID: settings.GoogleOAuthClientID,
GoogleOAuthClientSecretConfigured: settings.GoogleOAuthClientSecretConfigured,
GoogleOAuthRedirectURL: settings.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: settings.GoogleOAuthFrontendRedirectURL,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
@@ -369,6 +379,17 @@ type UpdateSettingsRequest struct {
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GitHubOAuthClientID string `json:"github_oauth_client_id"`
GitHubOAuthClientSecret string `json:"github_oauth_client_secret"`
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
GoogleOAuthClientID string `json:"google_oauth_client_id"`
GoogleOAuthClientSecret string `json:"google_oauth_client_secret"`
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
@@ -414,6 +435,16 @@ type UpdateSettingsRequest struct {
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
@@ -1214,6 +1245,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
GitHubOAuthClientID: req.GitHubOAuthClientID,
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
GoogleOAuthClientID: req.GoogleOAuthClientID,
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
@@ -1416,6 +1457,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
},
GitHub: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultGitHubConcurrency, previousAuthSourceDefaults.GitHub.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
},
Google: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultGoogleConcurrency, previousAuthSourceDefaults.Google.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
@@ -1558,6 +1613,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: updatedSettings.GitHubOAuthEnabled,
GitHubOAuthClientID: updatedSettings.GitHubOAuthClientID,
GitHubOAuthClientSecretConfigured: updatedSettings.GitHubOAuthClientSecretConfigured,
GitHubOAuthRedirectURL: updatedSettings.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: updatedSettings.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: updatedSettings.GoogleOAuthEnabled,
GoogleOAuthClientID: updatedSettings.GoogleOAuthClientID,
GoogleOAuthClientSecretConfigured: updatedSettings.GoogleOAuthClientSecretConfigured,
GoogleOAuthRedirectURL: updatedSettings.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: updatedSettings.GoogleOAuthFrontendRedirectURL,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
@@ -2052,6 +2117,8 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
{name: "oidc", before: before.OIDC, after: after.OIDC},
{name: "wechat", before: before.WeChat, after: after.WeChat},
{name: "github", before: before.GitHub, after: after.GitHub},
{name: "google", before: before.Google, after: after.Google},
}
for _, field := range fields {
if field.before.Balance != field.after.Balance {
@@ -2166,6 +2233,16 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
data["auth_source_default_github_balance"] = authSourceDefaults.GitHub.Balance
data["auth_source_default_github_concurrency"] = authSourceDefaults.GitHub.Concurrency
data["auth_source_default_github_subscriptions"] = authSourceDefaults.GitHub.Subscriptions
data["auth_source_default_github_grant_on_signup"] = authSourceDefaults.GitHub.GrantOnSignup
data["auth_source_default_github_grant_on_first_bind"] = authSourceDefaults.GitHub.GrantOnFirstBind
data["auth_source_default_google_balance"] = authSourceDefaults.Google.Balance
data["auth_source_default_google_concurrency"] = authSourceDefaults.Google.Concurrency
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
@@ -0,0 +1,621 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
emailOAuthCookiePath = "/api/v1/auth/oauth"
emailOAuthStateCookieName = "email_oauth_state"
emailOAuthRedirectCookie = "email_oauth_redirect"
emailOAuthProviderCookie = "email_oauth_provider"
emailOAuthAffiliateCookie = "email_oauth_affiliate"
emailOAuthCookieMaxAgeSec = 10 * 60
emailOAuthDefaultRedirect = "/dashboard"
)
type emailOAuthTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type emailOAuthProfile struct {
Subject string
Email string
EmailVerified bool
Username string
DisplayName string
AvatarURL string
Metadata map[string]any
}
func (h *AuthHandler) GitHubOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "github") }
func (h *AuthHandler) GoogleOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "google") }
func (h *AuthHandler) GitHubOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "github") }
func (h *AuthHandler) GoogleOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "google") }
func (h *AuthHandler) CompleteGitHubOAuthRegistration(c *gin.Context) {
h.completeEmailOAuthRegistration(c, "github")
}
func (h *AuthHandler) CompleteGoogleOAuthRegistration(c *gin.Context) {
h.completeEmailOAuthRegistration(c, "google")
}
func (h *AuthHandler) emailOAuthStart(c *gin.Context, provider string) {
cfg, err := h.getEmailOAuthConfig(c.Request.Context(), provider)
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = emailOAuthDefaultRedirect
}
secureCookie := isRequestHTTPS(c)
emailOAuthSetCookie(c, emailOAuthStateCookieName, encodeCookieValue(state), secureCookie)
emailOAuthSetCookie(c, emailOAuthRedirectCookie, encodeCookieValue(redirectTo), secureCookie)
emailOAuthSetCookie(c, emailOAuthProviderCookie, encodeCookieValue(provider), secureCookie)
if affCode := strings.TrimSpace(firstNonEmpty(c.Query("aff_code"), c.Query("aff"))); affCode != "" {
emailOAuthSetCookie(c, emailOAuthAffiliateCookie, encodeCookieValue(affCode), secureCookie)
} else {
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
}
authURL, err := buildEmailOAuthAuthorizeURL(cfg, state)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
func (h *AuthHandler) emailOAuthCallback(c *gin.Context, provider string) {
cfg, cfgErr := h.getEmailOAuthConfig(c.Request.Context(), provider)
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = "/auth/oauth/callback"
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
emailOAuthClearCookie(c, emailOAuthStateCookieName, secureCookie)
emailOAuthClearCookie(c, emailOAuthRedirectCookie, secureCookie)
emailOAuthClearCookie(c, emailOAuthProviderCookie, secureCookie)
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, emailOAuthStateCookieName)
if err != nil || expectedState == "" || expectedState != state {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
expectedProvider, _ := readCookieDecoded(c, emailOAuthProviderCookie)
if !strings.EqualFold(strings.TrimSpace(expectedProvider), provider) {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth provider", "")
return
}
redirectTo, _ := readCookieDecoded(c, emailOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = emailOAuthDefaultRedirect
}
tokenResp, err := exchangeEmailOAuthCode(c.Request.Context(), cfg, code)
if err != nil {
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(err.Error()))
return
}
profile, err := fetchEmailOAuthProfile(c.Request.Context(), provider, cfg, tokenResp)
if err != nil {
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch verified email", singleLine(err.Error()))
return
}
h.emailOAuthCallbackWithProfile(c, provider, cfg, frontendCallback, redirectTo, profile)
}
func (h *AuthHandler) emailOAuthCallbackWithProfile(
c *gin.Context,
provider string,
cfg config.EmailOAuthProviderConfig,
frontendCallback string,
redirectTo string,
profile *emailOAuthProfile,
) {
input := service.EmailOAuthIdentityInput{
ProviderType: provider,
ProviderKey: provider,
ProviderSubject: profile.Subject,
Email: profile.Email,
EmailVerified: profile.EmailVerified,
Username: profile.Username,
DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL,
UpstreamMetadata: profile.Metadata,
}
affiliateCode := h.emailOAuthAffiliateCode(c)
if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
return
} else if shouldCreate {
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) {
client := h.entClient()
if client == nil {
return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{
ProviderType: strings.TrimSpace(input.ProviderType),
ProviderKey: strings.TrimSpace(input.ProviderKey),
ProviderSubject: strings.TrimSpace(input.ProviderSubject),
})
if err != nil {
return false, err
}
email := strings.TrimSpace(strings.ToLower(input.Email))
if identityUser != nil {
if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
}
return false, nil
}
if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil {
if errors.Is(err, service.ErrUserNotFound) {
return true, nil
}
return false, err
}
return false, nil
}
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
if c == nil {
return ""
}
if code, err := readCookieDecoded(c, emailOAuthAffiliateCookie); err == nil {
return strings.TrimSpace(code)
}
return ""
}
func (h *AuthHandler) createEmailOAuthRegistrationPendingSession(
c *gin.Context,
provider string,
frontendCallback string,
redirectTo string,
profile *emailOAuthProfile,
) error {
if h == nil || profile == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
browserSessionKey, err := generateOAuthPendingBrowserSession()
if err != nil {
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
}
setOAuthPendingBrowserCookie(c, browserSessionKey, isRequestHTTPS(c))
email := strings.TrimSpace(strings.ToLower(profile.Email))
username := strings.TrimSpace(profile.Username)
affiliateCode := h.emailOAuthAffiliateCode(c)
upstreamClaims := map[string]any{
"email": email,
"email_verified": profile.EmailVerified,
"username": username,
"provider": provider,
"provider_key": provider,
"provider_subject": strings.TrimSpace(profile.Subject),
}
if strings.TrimSpace(profile.DisplayName) != "" {
upstreamClaims["suggested_display_name"] = strings.TrimSpace(profile.DisplayName)
}
if strings.TrimSpace(profile.AvatarURL) != "" {
upstreamClaims["suggested_avatar_url"] = strings.TrimSpace(profile.AvatarURL)
}
if affiliateCode != "" {
upstreamClaims["aff_code"] = affiliateCode
}
for key, value := range profile.Metadata {
if _, exists := upstreamClaims[key]; !exists {
upstreamClaims[key] = value
}
}
invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context())
pendingError := "registration_completion_required"
choiceReason := "registration_completion_required"
if invitationRequired {
pendingError = "invitation_required"
choiceReason = "invitation_required"
}
completionResponse := map[string]any{
"step": oauthPendingChoiceStep,
"error": pendingError,
"choice_reason": choiceReason,
"adoption_required": false,
"create_account_allowed": true,
"existing_account_bindable": false,
"force_email_on_signup": true,
"invitation_required": invitationRequired,
"email": email,
"resolved_email": email,
"provider": provider,
"redirect": redirectTo,
}
if strings.TrimSpace(frontendCallback) != "" {
completionResponse["frontend_callback"] = strings.TrimSpace(frontendCallback)
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: service.PendingAuthIdentityKey{ProviderType: provider, ProviderKey: provider, ProviderSubject: strings.TrimSpace(profile.Subject)},
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: completionResponse,
})
}
type completeEmailOAuthRequest struct {
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AffCode string `json:"aff_code,omitempty"`
}
func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider string) {
var req completeEmailOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
affiliateCode := strings.TrimSpace(req.AffCode)
if affiliateCode == "" {
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
}
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
c.Request.Context(),
strings.TrimSpace(session.ResolvedEmail),
req.Password,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
tx, err := client.Tx(c.Request.Context())
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
return
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
sessionForBinding := *session
sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims)
if strings.TrimSpace(req.InvitationCode) != "" {
sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode)
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{})
if err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
respondPendingOAuthBindingApplyError(c, err)
return
}
if err := h.authService.FinalizeOAuthEmailAccount(
txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
affiliateCode,
); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, err)
return
}
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := tx.Commit(); err != nil {
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
func (h *AuthHandler) getEmailOAuthConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetEmailOAuthProviderConfig(ctx, provider)
}
return config.EmailOAuthProviderConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
func buildEmailOAuthAuthorizeURL(cfg config.EmailOAuthProviderConfig, state string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", cfg.RedirectURL)
q.Set("state", state)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func exchangeEmailOAuthCode(ctx context.Context, cfg config.EmailOAuthProviderConfig, code string) (*emailOAuthTokenResponse, error) {
resp, err := req.C().
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetFormData(map[string]string{
"grant_type": "authorization_code",
"client_id": cfg.ClientID,
"client_secret": cfg.ClientSecret,
"code": code,
"redirect_uri": cfg.RedirectURL,
}).
Post(cfg.TokenURL)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("token endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
}
var tokenResp emailOAuthTokenResponse
if err := json.Unmarshal(resp.Bytes(), &tokenResp); err != nil {
return nil, err
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, errors.New("missing access_token")
}
return &tokenResp, nil
}
func fetchEmailOAuthProfile(ctx context.Context, provider string, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse) (*emailOAuthProfile, error) {
resp, err := req.C().
R().
SetContext(ctx).
SetBearerAuthToken(token.AccessToken).
SetHeader("Accept", "application/json").
Get(cfg.UserInfoURL)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("userinfo endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
}
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github":
return parseGitHubOAuthProfile(ctx, cfg, token, resp.String())
case "google":
return parseGoogleOAuthProfile(resp.String())
default:
return nil, errors.New("unsupported oauth provider")
}
}
func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse, body string) (*emailOAuthProfile, error) {
subject := strings.TrimSpace(gjson.Get(body, "id").String())
if subject == "" {
return nil, errors.New("github user id is missing")
}
email := ""
emailsURL := strings.TrimSpace(cfg.EmailsURL)
if emailsURL == "" {
return nil, errors.New("github verified email is missing")
}
verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken)
if err != nil {
return nil, err
}
email = verifiedEmail
if email == "" {
return nil, errors.New("github verified email is missing")
}
login := strings.TrimSpace(gjson.Get(body, "login").String())
name := strings.TrimSpace(gjson.Get(body, "name").String())
return &emailOAuthProfile{
Subject: subject,
Email: email,
EmailVerified: true,
Username: firstNonEmpty(login, name, "github_"+subject),
DisplayName: firstNonEmpty(name, login),
AvatarURL: strings.TrimSpace(gjson.Get(body, "avatar_url").String()),
Metadata: map[string]any{
"login": login,
},
}, nil
}
func fetchGitHubPrimaryVerifiedEmail(ctx context.Context, emailsURL string, accessToken string) (string, error) {
resp, err := req.C().
R().
SetContext(ctx).
SetBearerAuthToken(accessToken).
SetHeader("Accept", "application/json").
Get(emailsURL)
if err != nil {
return "", err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("github emails endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
}
items := gjson.Parse(resp.String()).Array()
for _, item := range items {
if item.Get("primary").Bool() && item.Get("verified").Bool() {
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
return email, nil
}
}
}
for _, item := range items {
if item.Get("verified").Bool() {
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
return email, nil
}
}
}
return "", errors.New("github verified email is missing")
}
func parseGoogleOAuthProfile(body string) (*emailOAuthProfile, error) {
subject := strings.TrimSpace(gjson.Get(body, "sub").String())
email := strings.TrimSpace(gjson.Get(body, "email").String())
verified := gjson.Get(body, "email_verified").Bool()
if subject == "" {
return nil, errors.New("google subject is missing")
}
if email == "" || !verified {
return nil, errors.New("google verified email is missing")
}
name := strings.TrimSpace(gjson.Get(body, "name").String())
return &emailOAuthProfile{
Subject: subject,
Email: email,
EmailVerified: true,
Username: firstNonEmpty(strings.TrimSpace(gjson.Get(body, "given_name").String()), name, email),
DisplayName: name,
AvatarURL: strings.TrimSpace(gjson.Get(body, "picture").String()),
Metadata: map[string]any{
"email_verified": true,
},
}, nil
}
func emailOAuthSetCookie(c *gin.Context, name, value string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: emailOAuthCookiePath,
MaxAge: emailOAuthCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func emailOAuthClearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: emailOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
@@ -0,0 +1,414 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
state := "github-oauth-state"
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil)
req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)})
req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")})
req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")})
c.Request = req
profile := &emailOAuthProfile{
Subject: "github-123",
Email: "fresh@example.com",
EmailVerified: true,
Username: "fresh",
DisplayName: "Fresh User",
AvatarURL: "https://cdn.example/fresh.png",
Metadata: map[string]any{
"login": "fresh",
},
}
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
Enabled: true,
ClientID: "github-client",
ClientSecret: "github-secret",
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
FrontendRedirectURL: "/auth/oauth/callback",
}, "/auth/oauth/callback", "/dashboard", profile)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "/auth/oauth/callback")
require.NotContains(t, location, "access_token=")
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
session, err := client.PendingAuthSession.Query().Only(ctx)
require.NoError(t, err)
require.Equal(t, "github", session.ProviderType)
require.Equal(t, "github", session.ProviderKey)
require.Equal(t, "github-123", session.ProviderSubject)
require.Equal(t, "fresh@example.com", session.ResolvedEmail)
require.Equal(t, "/dashboard", session.RedirectTo)
require.Nil(t, session.TargetUserID)
completion, ok := readCompletionResponse(session.LocalFlowState)
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "invitation_required", completion["error"])
require.Equal(t, true, completion["invitation_required"])
require.Equal(t, "fresh@example.com", completion["email"])
require.Equal(t, "fresh@example.com", completion["resolved_email"])
require.Equal(t, true, completion["create_account_allowed"])
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName))
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName))
}
func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("existing@example.com").
SetUsername("existing").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil)
handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{
Enabled: true,
ClientID: "google-client",
ClientSecret: "google-secret",
RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback",
FrontendRedirectURL: "/auth/oauth/callback",
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
Subject: "google-123",
Email: "existing@example.com",
EmailVerified: true,
Username: "existing",
})
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "access_token=")
require.Contains(t, location, "redirect=%252Fdashboard")
sessionCount, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, sessionCount)
identityCount, err := client.AuthIdentity.Query().Where(
authidentity.ProviderTypeEQ("google"),
authidentity.ProviderSubjectEQ("google-123"),
).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
_ = user
}
func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) {
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyAffiliateEnabled: "true",
},
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
},
})
ctx := context.Background()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil)
req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")})
c.Request = req
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
Enabled: true,
ClientID: "github-client",
ClientSecret: "github-secret",
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
FrontendRedirectURL: "/auth/oauth/callback",
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
Subject: "github-aff-user",
Email: "aff-user@example.com",
EmailVerified: true,
Username: "aff-user",
})
require.Equal(t, http.StatusFound, recorder.Code)
require.NotContains(t, recorder.Header().Get("Location"), "access_token=")
userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
require.Empty(t, affiliateRepo.ensureUserIDs)
require.Empty(t, affiliateRepo.bindCalls)
session, err := client.PendingAuthSession.Query().Only(ctx)
require.NoError(t, err)
require.Equal(t, "aff-user@example.com", session.ResolvedEmail)
require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code"))
completion, ok := readCompletionResponse(session.LocalFlowState)
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "registration_completion_required", completion["error"])
require.Equal(t, false, completion["invitation_required"])
require.Equal(t, true, completion["create_account_allowed"])
require.Equal(t, true, completion["force_email_on_signup"])
require.Equal(t, "aff-user@example.com", completion["resolved_email"])
}
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002})
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
invitationEnabled: true,
settingValues: map[string]string{
service.SettingKeyAffiliateEnabled: "true",
},
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
},
})
ctx := context.Background()
invitation, err := client.RedeemCode.Create().
SetCode("INVITE456").
SetType(service.RedeemTypeInvitation).
SetStatus(service.StatusUnused).
SetValue(0).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("email-oauth-aff-session-token").
SetIntent(oauthIntentLogin).
SetProviderType("google").
SetProviderKey("google").
SetProviderSubject("google-aff-user").
SetResolvedEmail("pending-aff@example.com").
SetRedirectTo("/dashboard").
SetBrowserSessionKey("browser-aff-key").
SetUpstreamIdentityClaims(map[string]any{
"email": "pending-aff@example.com",
"email_verified": true,
"username": "pending-aff",
"provider": "google",
"provider_key": "google",
"provider_subject": "google-aff-user",
"aff_code": "AFF456",
}).
SetLocalFlowState(map[string]any{
"step": oauthPendingChoiceStep,
"error": "invitation_required",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
c.Request = req
handler.completeEmailOAuthRegistration(c, "google")
require.Equal(t, http.StatusOK, recorder.Code)
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
require.NoError(t, err)
require.NotEmpty(t, user.PasswordHash)
require.NotEqual(t, "secret-123", user.PasswordHash)
tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, tamperedCount)
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedInvitation.UsedBy)
require.Equal(t, user.ID, *storedInvitation.UsedBy)
}
func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("email-oauth-password-session-token").
SetIntent(oauthIntentLogin).
SetProviderType("github").
SetProviderKey("github").
SetProviderSubject("github-password-user").
SetResolvedEmail("password-required@example.com").
SetRedirectTo("/dashboard").
SetBrowserSessionKey("browser-password-key").
SetUpstreamIdentityClaims(map[string]any{
"email": "password-required@example.com",
"email_verified": true,
"username": "password-required",
"provider": "github",
"provider_key": "github",
"provider_subject": "github-password-user",
}).
SetLocalFlowState(map[string]any{
"step": oauthPendingChoiceStep,
"error": "registration_completion_required",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")})
c.Request = req
handler.completeEmailOAuthRegistration(c, "github")
require.Equal(t, http.StatusBadRequest, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
}
func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) {
emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "missing scope", http.StatusForbidden)
}))
t.Cleanup(emailServer.Close)
profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{
EmailsURL: emailServer.URL,
}, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`)
require.Error(t, err)
require.Nil(t, profile)
require.Contains(t, err.Error(), "github emails endpoint status 403")
}
type oauthEmailAffiliateBindCall struct {
userID int64
inviterID int64
}
type oauthEmailAffiliateRepoStub struct {
codeOwners map[string]int64
ensureUserIDs []int64
bindCalls []oauthEmailAffiliateBindCall
}
func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub {
return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners}
}
func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) {
r.ensureUserIDs = append(r.ensureUserIDs, userID)
return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil
}
func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) {
userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))]
if !ok {
return nil, service.ErrAffiliateProfileNotFound
}
return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil
}
func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) {
r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID})
return true, nil
}
func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) {
panic("unexpected AccrueQuota call")
}
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
panic("unexpected GetAccruedRebateFromInvitee call")
}
func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) {
panic("unexpected ThawFrozenQuota call")
}
func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) {
panic("unexpected TransferQuotaToBalance call")
}
func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) {
panic("unexpected ListInvitees call")
}
func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error {
panic("unexpected UpdateUserAffCode call")
}
func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) {
panic("unexpected ResetUserAffCode call")
}
func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error {
panic("unexpected SetUserRebateRate call")
}
func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error {
panic("unexpected BatchSetUserRebateRate call")
}
func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
panic("unexpected ListUsersWithCustomSettings call")
}
func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
panic("unexpected ListAffiliateInviteRecords call")
}
func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
panic("unexpected ListAffiliateRebateRecords call")
}
func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
panic("unexpected ListAffiliateTransferRecords call")
}
func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) {
panic("unexpected GetAffiliateUserOverview call")
}
func findSetCookieValue(cookies []*http.Cookie, name string) string {
for _, cookie := range cookies {
if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 {
return cookie.Value
}
}
return ""
}
@@ -2121,6 +2121,8 @@ type oauthPendingFlowTestHandlerOptions struct {
emailCache service.EmailCache
settingValues map[string]string
defaultSubAssigner service.DefaultSubscriptionAssigner
affiliateService *service.AffiliateService
affiliateFactory func(*dbent.Client, *service.SettingService) *service.AffiliateService
totpCache service.TotpCache
totpEncryptor service.SecretEncryptor
userRepoOptions oauthPendingFlowUserRepoOptions
@@ -2160,6 +2162,21 @@ CREATE TABLE IF NOT EXISTS user_avatars (
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_affiliates (
user_id INTEGER PRIMARY KEY,
aff_code TEXT NOT NULL UNIQUE,
aff_code_custom BOOLEAN NOT NULL DEFAULT false,
aff_rebate_rate_percent REAL NULL,
inviter_id INTEGER NULL,
aff_count INTEGER NOT NULL DEFAULT 0,
aff_quota REAL NOT NULL DEFAULT 0,
aff_frozen_quota REAL NOT NULL DEFAULT 0,
aff_history_quota REAL NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@@ -2177,14 +2194,19 @@ CREATE TABLE IF NOT EXISTS user_avatars (
},
}
settingValues := map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
}
for key, value := range options.settingValues {
settingValues[key] = value
}
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
affiliateService := options.affiliateService
if affiliateService == nil && options.affiliateFactory != nil {
affiliateService = options.affiliateFactory(client, settingSvc)
}
userRepo := &oauthPendingFlowUserRepo{
client: client,
options: options.userRepoOptions,
@@ -2210,7 +2232,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
nil,
nil,
options.defaultSubAssigner,
nil,
affiliateService,
)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
var totpSvc *service.TotpService
+13
View File
@@ -92,6 +92,17 @@ type SystemSettings struct {
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GitHubOAuthClientID string `json:"github_oauth_client_id"`
GitHubOAuthClientSecretConfigured bool `json:"github_oauth_client_secret_configured"`
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
GoogleOAuthClientID string `json:"google_oauth_client_id"`
GoogleOAuthClientSecretConfigured bool `json:"google_oauth_client_secret_configured"`
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
@@ -245,6 +256,8 @@ type PublicSettings struct {
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
@@ -63,6 +63,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: h.version,
@@ -685,6 +685,16 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"github_oauth_enabled": false,
"github_oauth_client_id": "",
"github_oauth_client_secret_configured": false,
"github_oauth_redirect_url": "",
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
"google_oauth_enabled": false,
"google_oauth_client_id": "",
"google_oauth_client_secret_configured": false,
"google_oauth_redirect_url": "",
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
@@ -700,6 +710,16 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_github_balance": 0,
"auth_source_default_github_concurrency": 5,
"auth_source_default_github_subscriptions": [],
"auth_source_default_github_grant_on_signup": false,
"auth_source_default_github_grant_on_first_bind": false,
"auth_source_default_google_balance": 0,
"auth_source_default_google_concurrency": 5,
"auth_source_default_google_subscriptions": [],
"auth_source_default_google_grant_on_signup": false,
"auth_source_default_google_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
@@ -899,6 +919,16 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"github_oauth_enabled": false,
"github_oauth_client_id": "",
"github_oauth_client_secret_configured": false,
"github_oauth_redirect_url": "",
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
"google_oauth_enabled": false,
"google_oauth_client_id": "",
"google_oauth_client_secret_configured": false,
"google_oauth_redirect_url": "",
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subscription to API Conversion Platform",
@@ -1007,6 +1037,16 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_github_balance": 0,
"auth_source_default_github_concurrency": 5,
"auth_source_default_github_subscriptions": [],
"auth_source_default_github_grant_on_signup": false,
"auth_source_default_github_grant_on_first_bind": false,
"auth_source_default_google_balance": 0,
"auth_source_default_google_concurrency": 5,
"auth_source_default_google_subscriptions": [],
"auth_source_default_google_grant_on_signup": false,
"auth_source_default_google_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
@@ -40,6 +40,8 @@ func backendModeAllowsAuthPath(path string) bool {
"/auth/oauth/wechat/callback",
"/auth/oauth/wechat/payment/callback",
"/auth/oauth/oidc/callback",
"/auth/oauth/github/callback",
"/auth/oauth/google/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",
@@ -246,6 +246,30 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/oauth/oidc/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_github_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/github/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_github_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/github/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_google_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/google/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_google_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/google/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",
+16
View File
@@ -63,6 +63,22 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/github/start", h.Auth.GitHubOAuthStart)
auth.GET("/oauth/github/callback", h.Auth.GitHubOAuthCallback)
auth.POST("/oauth/github/complete-registration",
rateLimiter.LimitWithOptions("oauth-github-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteGitHubOAuthRegistration,
)
auth.GET("/oauth/google/start", h.Auth.GoogleOAuthStart)
auth.GET("/oauth/google/callback", h.Auth.GoogleOAuthCallback)
auth.POST("/oauth/google/complete-registration",
rateLimiter.LimitWithOptions("oauth-google-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteGoogleOAuthRegistration,
)
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
@@ -0,0 +1,274 @@
package service
import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
type EmailOAuthIdentityInput struct {
ProviderType string
ProviderKey string
ProviderSubject string
Email string
EmailVerified bool
Username string
DisplayName string
AvatarURL string
UpstreamMetadata map[string]any
}
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuth(ctx context.Context, input EmailOAuthIdentityInput) (*TokenPair, *User, error) {
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, "", "")
}
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuthWithInvitation(
ctx context.Context,
input EmailOAuthIdentityInput,
invitationCode string,
affiliateCode string,
) (*TokenPair, *User, error) {
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, invitationCode, affiliateCode)
}
func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
ctx context.Context,
input EmailOAuthIdentityInput,
invitationCode string,
affiliateCode string,
) (*TokenPair, *User, error) {
if s == nil || s.userRepo == nil || s.entClient == nil {
return nil, nil, ErrServiceUnavailable
}
providerType := normalizeOAuthSignupSource(input.ProviderType)
if providerType != "github" && providerType != "google" {
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
}
providerKey := strings.TrimSpace(input.ProviderKey)
if providerKey == "" {
providerKey = providerType
}
providerSubject := strings.TrimSpace(input.ProviderSubject)
if providerSubject == "" {
return nil, nil, infraerrors.BadRequest("OAUTH_SUBJECT_MISSING", "oauth subject is missing")
}
if !input.EmailVerified {
return nil, nil, infraerrors.Forbidden("OAUTH_EMAIL_NOT_VERIFIED", "oauth email is not verified")
}
email := strings.TrimSpace(strings.ToLower(input.Email))
if email == "" || len(email) > 255 {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
identityUser, err := s.findEmailOAuthIdentityOwner(ctx, providerType, providerKey, providerSubject)
if err != nil {
return nil, nil, err
}
if identityUser != nil && !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
return nil, nil, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
}
user := identityUser
created := false
if user == nil {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
user, err = s.createEmailOAuthUser(ctx, email, input.Username, providerType, invitationCode, affiliateCode)
if err != nil {
return nil, nil, err
}
created = true
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during %s oauth login: %v", providerType, err)
return nil, nil, ErrServiceUnavailable
}
}
}
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
if err := s.ensureEmailOAuthIdentity(ctx, user.ID, EmailOAuthIdentityInput{
ProviderType: providerType,
ProviderKey: providerKey,
ProviderSubject: providerSubject,
Email: email,
EmailVerified: input.EmailVerified,
Username: input.Username,
DisplayName: input.DisplayName,
AvatarURL: input.AvatarURL,
UpstreamMetadata: input.UpstreamMetadata,
}); err != nil {
return nil, nil, err
}
if user.Username == "" && strings.TrimSpace(input.Username) != "" {
user.Username = strings.TrimSpace(input.Username)
if err := s.userRepo.Update(ctx, user); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after %s oauth login: %v", providerType, err)
}
}
if !created {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, providerType); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply %s first bind defaults: %v", providerType, err)
}
}
s.RecordSuccessfulLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, providerType, invitationCode, affiliateCode string) (*User, error) {
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, ErrRegDisabled
}
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrInvitationCodeRequired) {
return nil, ErrOAuthInvitationRequired
}
return nil, err
}
randomPassword, err := randomHexString(32)
if err != nil {
return nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
grantPlan := s.resolveSignupGrantPlan(ctx, providerType)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
user := &User{
Email: email,
Username: strings.TrimSpace(username),
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: providerType,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
existing, loadErr := s.userRepo.GetByEmail(ctx, email)
if loadErr != nil {
return nil, ErrServiceUnavailable
}
return existing, nil
}
return nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, providerType, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, invitationCode)
return nil, ErrInvitationCodeInvalid
}
}
return user, nil
}
func (s *AuthService) findEmailOAuthIdentityOwner(ctx context.Context, providerType, providerKey, providerSubject string) (*User, error) {
identity, err := s.entClient.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
user, err := s.userRepo.GetByID(ctx, identity.UserID)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, nil
}
return nil, ErrServiceUnavailable
}
return user, nil
}
func (s *AuthService) ensureEmailOAuthIdentity(ctx context.Context, userID int64, input EmailOAuthIdentityInput) error {
metadata := map[string]any{
"email": strings.TrimSpace(strings.ToLower(input.Email)),
"email_verified": input.EmailVerified,
}
for key, value := range input.UpstreamMetadata {
metadata[key] = value
}
if strings.TrimSpace(input.Username) != "" {
metadata["username"] = strings.TrimSpace(input.Username)
}
if strings.TrimSpace(input.DisplayName) != "" {
metadata["display_name"] = strings.TrimSpace(input.DisplayName)
}
if strings.TrimSpace(input.AvatarURL) != "" {
metadata["avatar_url"] = strings.TrimSpace(input.AvatarURL)
}
providerType := normalizeOAuthSignupSource(input.ProviderType)
providerKey := strings.TrimSpace(input.ProviderKey)
providerSubject := strings.TrimSpace(input.ProviderSubject)
identity, err := s.entClient.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil {
if identity.UserID != userID {
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
_, err = s.entClient.AuthIdentity.UpdateOneID(identity.ID).
SetMetadata(metadata).
Save(ctx)
return err
}
_, err = s.entClient.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetMetadata(metadata).
Save(ctx)
return err
}
@@ -10,6 +10,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func normalizeOAuthSignupSource(signupSource string) string {
@@ -17,7 +18,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
switch signupSource {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
case "linuxdo", "wechat", "oidc", "github", "google":
return signupSource
default:
return "email"
@@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return tokenPair, user, nil
}
// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth
// provider that has already returned a verified email address.
func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
ctx context.Context,
email string,
password string,
invitationCode string,
signupSource string,
) (*TokenPair, *User, error) {
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
email = strings.TrimSpace(strings.ToLower(email))
if email == "" || len(email) > 255 {
return nil, nil, ErrEmailVerifyRequired
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, ErrEmailVerifyRequired
}
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
if strings.TrimSpace(password) == "" {
return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
}
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
return nil, nil, ErrEmailExists
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = normalizeOAuthSignupSource(signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
return nil, nil, ErrServiceUnavailable
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
// only after the pending OAuth flow has fully reached its last reversible step.
func (s *AuthService) FinalizeOAuthEmailAccount(
@@ -229,6 +229,67 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
}
func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) {
tests := []struct {
name string
email string
signupSource string
want string
}{
{
name: "github",
email: "github@example.com",
signupSource: " GitHub ",
want: "github",
},
{
name: "google",
email: "google@example.com",
signupSource: " Google ",
want: "google",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
tt.email,
"secret-123",
"246810",
"",
tt.signupSource,
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, tt.want, userRepo.created[0].SignupSource)
})
}
}
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
@@ -256,7 +317,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing
"secret-123",
"246810",
"",
"github",
"unknown-provider",
)
require.NoError(t, err)
+4
View File
@@ -775,6 +775,10 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
return defaults.OIDC, true
case "wechat":
return defaults.WeChat, true
case "github":
return defaults.GitHub, true
case "google":
return defaults.Google, true
default:
return ProviderDefaultGrantSettings{}, false
}
@@ -175,6 +175,18 @@ const (
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
// GitHub / Google 邮箱快捷登录设置
SettingKeyGitHubOAuthEnabled = "github_oauth_enabled"
SettingKeyGitHubOAuthClientID = "github_oauth_client_id"
SettingKeyGitHubOAuthClientSecret = "github_oauth_client_secret"
SettingKeyGitHubOAuthRedirectURL = "github_oauth_redirect_url"
SettingKeyGitHubOAuthFrontendRedirectURL = "github_oauth_frontend_redirect_url"
SettingKeyGoogleOAuthEnabled = "google_oauth_enabled"
SettingKeyGoogleOAuthClientID = "google_oauth_client_id"
SettingKeyGoogleOAuthClientSecret = "google_oauth_client_secret"
SettingKeyGoogleOAuthRedirectURL = "google_oauth_redirect_url"
SettingKeyGoogleOAuthFrontendRedirectURL = "google_oauth_frontend_redirect_url"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
@@ -218,6 +230,16 @@ const (
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key
+267 -1
View File
@@ -129,6 +129,8 @@ type AuthSourceDefaultSettings struct {
LinuxDo ProviderDefaultGrantSettings
OIDC ProviderDefaultGrantSettings
WeChat ProviderDefaultGrantSettings
GitHub ProviderDefaultGrantSettings
Google ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup bool
}
@@ -169,6 +171,20 @@ var (
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
}
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultGitHubBalance,
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
}
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultGoogleBalance,
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
}
)
const (
@@ -177,6 +193,17 @@ const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
defaultWeChatConnectFrontend = "/auth/wechat/callback"
defaultGitHubOAuthAuthorize = "https://github.com/login/oauth/authorize"
defaultGitHubOAuthToken = "https://github.com/login/oauth/access_token"
defaultGitHubOAuthUserInfo = "https://api.github.com/user"
defaultGitHubOAuthEmails = "https://api.github.com/user/emails"
defaultGitHubOAuthScopes = "read:user user:email"
defaultGitHubOAuthFrontend = "/auth/oauth/callback"
defaultGoogleOAuthAuthorize = "https://accounts.google.com/o/oauth2/v2/auth"
defaultGoogleOAuthToken = "https://oauth2.googleapis.com/token"
defaultGoogleOAuthUserInfo = "https://openidconnect.googleapis.com/v1/userinfo"
defaultGoogleOAuthScopes = "openid email profile"
defaultGoogleOAuthFrontend = "/auth/oauth/callback"
)
func normalizeWeChatConnectModeSetting(raw string) string {
@@ -448,6 +475,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingKeyGitHubOAuthEnabled,
SettingKeyGitHubOAuthClientID,
SettingKeyGitHubOAuthClientSecret,
SettingKeyGoogleOAuthEnabled,
SettingKeyGoogleOAuthClientID,
SettingKeyGoogleOAuthClientSecret,
SettingKeyBalanceLowNotifyEnabled,
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
@@ -483,6 +516,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github")
googleEnabled := s.emailOAuthPublicEnabled(settings, "google")
weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
// Password reset requires email verification to be enabled
@@ -535,6 +570,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
GitHubOAuthEnabled: gitHubEnabled,
GoogleOAuthEnabled: googleEnabled,
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
@@ -680,6 +717,8 @@ type PublicSettingsInjectionPayload struct {
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
@@ -737,6 +776,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version,
@@ -811,6 +852,98 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string
return openReady || mpReady, openReady, mpReady, mobileReady
}
func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github":
cfg := config.EmailOAuthProviderConfig{
AuthorizeURL: defaultGitHubOAuthAuthorize,
TokenURL: defaultGitHubOAuthToken,
UserInfoURL: defaultGitHubOAuthUserInfo,
EmailsURL: defaultGitHubOAuthEmails,
Scopes: defaultGitHubOAuthScopes,
FrontendRedirectURL: defaultGitHubOAuthFrontend,
}
if s != nil && s.cfg != nil {
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GitHubOAuth)
}
return cfg
case "google":
cfg := config.EmailOAuthProviderConfig{
AuthorizeURL: defaultGoogleOAuthAuthorize,
TokenURL: defaultGoogleOAuthToken,
UserInfoURL: defaultGoogleOAuthUserInfo,
Scopes: defaultGoogleOAuthScopes,
FrontendRedirectURL: defaultGoogleOAuthFrontend,
}
if s != nil && s.cfg != nil {
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GoogleOAuth)
}
return cfg
default:
return config.EmailOAuthProviderConfig{}
}
}
func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) config.EmailOAuthProviderConfig {
base.Enabled = override.Enabled
if strings.TrimSpace(override.ClientID) != "" {
base.ClientID = strings.TrimSpace(override.ClientID)
}
if strings.TrimSpace(override.ClientSecret) != "" {
base.ClientSecret = strings.TrimSpace(override.ClientSecret)
}
if strings.TrimSpace(override.AuthorizeURL) != "" {
base.AuthorizeURL = strings.TrimSpace(override.AuthorizeURL)
}
if strings.TrimSpace(override.TokenURL) != "" {
base.TokenURL = strings.TrimSpace(override.TokenURL)
}
if strings.TrimSpace(override.UserInfoURL) != "" {
base.UserInfoURL = strings.TrimSpace(override.UserInfoURL)
}
if strings.TrimSpace(override.EmailsURL) != "" {
base.EmailsURL = strings.TrimSpace(override.EmailsURL)
}
if strings.TrimSpace(override.Scopes) != "" {
base.Scopes = strings.TrimSpace(override.Scopes)
}
if strings.TrimSpace(override.RedirectURL) != "" {
base.RedirectURL = strings.TrimSpace(override.RedirectURL)
}
if strings.TrimSpace(override.FrontendRedirectURL) != "" {
base.FrontendRedirectURL = strings.TrimSpace(override.FrontendRedirectURL)
}
return base
}
func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool {
cfg := s.effectiveEmailOAuthConfig(settings, provider)
return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != ""
}
func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig {
cfg := s.emailOAuthBaseConfig(provider)
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github":
if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok {
cfg.Enabled = raw == "true"
}
cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID)
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret)
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL)
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend)
case "google":
if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok {
cfg.Enabled = raw == "true"
}
cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID)
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret)
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL)
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend)
}
return cfg
}
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
// array string, returning only items with visibility != "admin".
func filterUserVisibleMenuItems(raw string) json.RawMessage {
@@ -1057,6 +1190,16 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
if settings.WeChatConnectFrontendRedirectURL == "" {
settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
}
settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL)
settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL)
if settings.GitHubOAuthFrontendRedirectURL == "" {
settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend
}
settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL)
settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL)
if settings.GoogleOAuthFrontendRedirectURL == "" {
settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend
}
updates := make(map[string]string)
@@ -1126,6 +1269,22 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
// GitHub / Google 邮箱快捷登录
updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled)
updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID)
updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL
updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL
if settings.GitHubOAuthClientSecret != "" {
updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret)
}
updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled)
updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID)
updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL
updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL
if settings.GoogleOAuthClientSecret != "" {
updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret)
}
// WeChat Connect OAuth 登录
updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
@@ -1281,17 +1440,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
settings.GitHub.Subscriptions,
settings.Google.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return nil, err
}
}
updates := make(map[string]string, 21)
updates := make(map[string]string, 31)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub)
writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
return updates, nil
}
@@ -1370,6 +1533,61 @@ func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context,
return nil
}
func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider != "github" && provider != "google" {
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_PROVIDER_NOT_FOUND", "oauth provider not found")
}
keys := []string{
SettingKeyGitHubOAuthEnabled,
SettingKeyGitHubOAuthClientID,
SettingKeyGitHubOAuthClientSecret,
SettingKeyGitHubOAuthRedirectURL,
SettingKeyGitHubOAuthFrontendRedirectURL,
SettingKeyGoogleOAuthEnabled,
SettingKeyGoogleOAuthClientID,
SettingKeyGoogleOAuthClientSecret,
SettingKeyGoogleOAuthRedirectURL,
SettingKeyGoogleOAuthFrontendRedirectURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return config.EmailOAuthProviderConfig{}, fmt.Errorf("get email oauth settings: %w", err)
}
cfg := s.effectiveEmailOAuthConfig(settings, provider)
if !cfg.Enabled {
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
if strings.TrimSpace(cfg.ClientID) == "" {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
}
if strings.TrimSpace(cfg.ClientSecret) == "" {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
for label, rawURL := range map[string]string{
"authorize": cfg.AuthorizeURL,
"token": cfg.TokenURL,
"userinfo": cfg.UserInfoURL,
"redirect": cfg.RedirectURL,
} {
if strings.TrimSpace(rawURL) == "" {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(rawURL); err != nil {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url invalid")
}
}
if strings.TrimSpace(cfg.EmailsURL) != "" {
if err := config.ValidateAbsoluteHTTPURL(cfg.EmailsURL); err != nil {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth emails url invalid")
}
}
if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
}
return cfg, nil
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
@@ -1728,6 +1946,16 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
SettingKeyAuthSourceDefaultWeChatSubscriptions,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
SettingKeyAuthSourceDefaultGitHubBalance,
SettingKeyAuthSourceDefaultGitHubConcurrency,
SettingKeyAuthSourceDefaultGitHubSubscriptions,
SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
SettingKeyAuthSourceDefaultGoogleBalance,
SettingKeyAuthSourceDefaultGoogleConcurrency,
SettingKeyAuthSourceDefaultGoogleSubscriptions,
SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
SettingKeyForceEmailOnThirdPartySignup,
}
@@ -1741,6 +1969,8 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys),
Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys),
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
}, nil
}
@@ -1841,6 +2071,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyWeChatConnectScopes: "snsapi_login",
SettingKeyWeChatConnectRedirectURL: "",
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
SettingKeyGitHubOAuthEnabled: "false",
SettingKeyGitHubOAuthClientID: "",
SettingKeyGitHubOAuthClientSecret: "",
SettingKeyGitHubOAuthRedirectURL: "",
SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
SettingKeyGoogleOAuthEnabled: "false",
SettingKeyGoogleOAuthClientID: "",
SettingKeyGoogleOAuthClientSecret: "",
SettingKeyGoogleOAuthRedirectURL: "",
SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyOIDCConnectClientID: "",
@@ -1891,6 +2131,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultGitHubBalance: "0",
SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultGoogleBalance: "0",
SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
SettingKeyForceEmailOnThirdPartySignup: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
@@ -2193,6 +2443,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github")
result.GitHubOAuthEnabled = gitHubEffective.Enabled
result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID)
result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret)
result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != ""
result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL)
result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL)
googleEffective := s.effectiveEmailOAuthConfig(settings, "google")
result.GoogleOAuthEnabled = googleEffective.Enabled
result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID)
result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret)
result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != ""
result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL)
result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL)
// WeChat Connect 设置:
// - 优先读取 DB 系统设置
// - 缺失时回退到 config/env,保持升级兼容
+16
View File
@@ -89,6 +89,20 @@ type SystemSettings struct {
OIDCConnectUserInfoIDPath string
OIDCConnectUserInfoUsernamePath string
// GitHub / Google 邮箱快捷登录
GitHubOAuthEnabled bool
GitHubOAuthClientID string
GitHubOAuthClientSecret string
GitHubOAuthClientSecretConfigured bool
GitHubOAuthRedirectURL string
GitHubOAuthFrontendRedirectURL string
GoogleOAuthEnabled bool
GoogleOAuthClientID string
GoogleOAuthClientSecret string
GoogleOAuthClientSecretConfigured bool
GoogleOAuthRedirectURL string
GoogleOAuthFrontendRedirectURL string
SiteName string
SiteLogo string
SiteSubtitle string
@@ -218,6 +232,8 @@ type PublicSettings struct {
PaymentEnabled bool
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
GitHubOAuthEnabled bool
GoogleOAuthEnabled bool
Version string
BalanceLowNotifyEnabled bool