feat: 增加 GitHub 和 Google 邮箱快捷登录
This commit is contained in:
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
var authProviderTypes = map[string]struct{}{
|
||||
"email": {},
|
||||
"github": {},
|
||||
"google": {},
|
||||
"linuxdo": {},
|
||||
"oidc": {},
|
||||
"wechat": {},
|
||||
|
||||
@@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
|
||||
field.String("signup_source").
|
||||
Validate(func(value string) error {
|
||||
switch value {
|
||||
case "email", "linuxdo", "wechat", "oidc":
|
||||
case "email", "linuxdo", "wechat", "oidc", "github", "google":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
|
||||
}
|
||||
}).
|
||||
Default("email"),
|
||||
|
||||
@@ -72,6 +72,8 @@ type Config struct {
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
|
||||
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
||||
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
|
||||
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
@@ -240,6 +242,19 @@ type OIDCConnectConfig struct {
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type EmailOAuthProviderConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
EmailsURL string `mapstructure:"emails_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"`
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultWeChatConnectMode = "open"
|
||||
defaultWeChatConnectScopes = "snsapi_login"
|
||||
|
||||
@@ -169,6 +169,16 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: settings.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecretConfigured: settings.GitHubOAuthClientSecretConfigured,
|
||||
GitHubOAuthRedirectURL: settings.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: settings.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: settings.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecretConfigured: settings.GoogleOAuthClientSecretConfigured,
|
||||
GoogleOAuthRedirectURL: settings.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: settings.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
@@ -368,6 +378,17 @@ type UpdateSettingsRequest struct {
|
||||
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
||||
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
||||
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GitHubOAuthClientID string `json:"github_oauth_client_id"`
|
||||
GitHubOAuthClientSecret string `json:"github_oauth_client_secret"`
|
||||
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
|
||||
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
GoogleOAuthClientID string `json:"google_oauth_client_id"`
|
||||
GoogleOAuthClientSecret string `json:"google_oauth_client_secret"`
|
||||
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
|
||||
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
@@ -413,6 +434,16 @@ type UpdateSettingsRequest struct {
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
|
||||
// Model fallback configuration
|
||||
@@ -1200,6 +1231,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
@@ -1396,6 +1437,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
||||
},
|
||||
GitHub: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultGitHubConcurrency, previousAuthSourceDefaults.GitHub.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
|
||||
},
|
||||
Google: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultGoogleConcurrency, previousAuthSourceDefaults.Google.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
||||
@@ -1538,6 +1593,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: updatedSettings.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: updatedSettings.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecretConfigured: updatedSettings.GitHubOAuthClientSecretConfigured,
|
||||
GitHubOAuthRedirectURL: updatedSettings.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: updatedSettings.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: updatedSettings.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: updatedSettings.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecretConfigured: updatedSettings.GoogleOAuthClientSecretConfigured,
|
||||
GoogleOAuthRedirectURL: updatedSettings.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: updatedSettings.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
@@ -2027,6 +2092,8 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
||||
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
|
||||
{name: "oidc", before: before.OIDC, after: after.OIDC},
|
||||
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
||||
{name: "github", before: before.GitHub, after: after.GitHub},
|
||||
{name: "google", before: before.Google, after: after.Google},
|
||||
}
|
||||
for _, field := range fields {
|
||||
if field.before.Balance != field.after.Balance {
|
||||
@@ -2141,6 +2208,16 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
||||
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
|
||||
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
|
||||
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
|
||||
data["auth_source_default_github_balance"] = authSourceDefaults.GitHub.Balance
|
||||
data["auth_source_default_github_concurrency"] = authSourceDefaults.GitHub.Concurrency
|
||||
data["auth_source_default_github_subscriptions"] = authSourceDefaults.GitHub.Subscriptions
|
||||
data["auth_source_default_github_grant_on_signup"] = authSourceDefaults.GitHub.GrantOnSignup
|
||||
data["auth_source_default_github_grant_on_first_bind"] = authSourceDefaults.GitHub.GrantOnFirstBind
|
||||
data["auth_source_default_google_balance"] = authSourceDefaults.Google.Balance
|
||||
data["auth_source_default_google_concurrency"] = authSourceDefaults.Google.Concurrency
|
||||
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
|
||||
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
|
||||
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
|
||||
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
||||
|
||||
return data
|
||||
|
||||
@@ -0,0 +1,549 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
emailOAuthCookiePath = "/api/v1/auth/oauth"
|
||||
emailOAuthStateCookieName = "email_oauth_state"
|
||||
emailOAuthRedirectCookie = "email_oauth_redirect"
|
||||
emailOAuthProviderCookie = "email_oauth_provider"
|
||||
emailOAuthAffiliateCookie = "email_oauth_affiliate"
|
||||
emailOAuthCookieMaxAgeSec = 10 * 60
|
||||
emailOAuthDefaultRedirect = "/dashboard"
|
||||
)
|
||||
|
||||
type emailOAuthTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type emailOAuthProfile struct {
|
||||
Subject string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
Username string
|
||||
DisplayName string
|
||||
AvatarURL string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
func (h *AuthHandler) GitHubOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "github") }
|
||||
func (h *AuthHandler) GoogleOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "google") }
|
||||
|
||||
func (h *AuthHandler) GitHubOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "github") }
|
||||
func (h *AuthHandler) GoogleOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "google") }
|
||||
func (h *AuthHandler) CompleteGitHubOAuthRegistration(c *gin.Context) {
|
||||
h.completeEmailOAuthRegistration(c, "github")
|
||||
}
|
||||
func (h *AuthHandler) CompleteGoogleOAuthRegistration(c *gin.Context) {
|
||||
h.completeEmailOAuthRegistration(c, "google")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthStart(c *gin.Context, provider string) {
|
||||
cfg, err := h.getEmailOAuthConfig(c.Request.Context(), provider)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||
return
|
||||
}
|
||||
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||
if redirectTo == "" {
|
||||
redirectTo = emailOAuthDefaultRedirect
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
emailOAuthSetCookie(c, emailOAuthStateCookieName, encodeCookieValue(state), secureCookie)
|
||||
emailOAuthSetCookie(c, emailOAuthRedirectCookie, encodeCookieValue(redirectTo), secureCookie)
|
||||
emailOAuthSetCookie(c, emailOAuthProviderCookie, encodeCookieValue(provider), secureCookie)
|
||||
if affCode := strings.TrimSpace(firstNonEmpty(c.Query("aff_code"), c.Query("aff"))); affCode != "" {
|
||||
emailOAuthSetCookie(c, emailOAuthAffiliateCookie, encodeCookieValue(affCode), secureCookie)
|
||||
} else {
|
||||
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
|
||||
}
|
||||
|
||||
authURL, err := buildEmailOAuthAuthorizeURL(cfg, state)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthCallback(c *gin.Context, provider string) {
|
||||
cfg, cfgErr := h.getEmailOAuthConfig(c.Request.Context(), provider)
|
||||
if cfgErr != nil {
|
||||
response.ErrorFrom(c, cfgErr)
|
||||
return
|
||||
}
|
||||
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||
if frontendCallback == "" {
|
||||
frontendCallback = "/auth/oauth/callback"
|
||||
}
|
||||
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if code == "" || state == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
defer func() {
|
||||
emailOAuthClearCookie(c, emailOAuthStateCookieName, secureCookie)
|
||||
emailOAuthClearCookie(c, emailOAuthRedirectCookie, secureCookie)
|
||||
emailOAuthClearCookie(c, emailOAuthProviderCookie, secureCookie)
|
||||
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
|
||||
}()
|
||||
expectedState, err := readCookieDecoded(c, emailOAuthStateCookieName)
|
||||
if err != nil || expectedState == "" || expectedState != state {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||
return
|
||||
}
|
||||
expectedProvider, _ := readCookieDecoded(c, emailOAuthProviderCookie)
|
||||
if !strings.EqualFold(strings.TrimSpace(expectedProvider), provider) {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth provider", "")
|
||||
return
|
||||
}
|
||||
redirectTo, _ := readCookieDecoded(c, emailOAuthRedirectCookie)
|
||||
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||
if redirectTo == "" {
|
||||
redirectTo = emailOAuthDefaultRedirect
|
||||
}
|
||||
|
||||
tokenResp, err := exchangeEmailOAuthCode(c.Request.Context(), cfg, code)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(err.Error()))
|
||||
return
|
||||
}
|
||||
profile, err := fetchEmailOAuthProfile(c.Request.Context(), provider, cfg, tokenResp)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch verified email", singleLine(err.Error()))
|
||||
return
|
||||
}
|
||||
h.emailOAuthCallbackWithProfile(c, provider, cfg, frontendCallback, redirectTo, profile)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthCallbackWithProfile(
|
||||
c *gin.Context,
|
||||
provider string,
|
||||
cfg config.EmailOAuthProviderConfig,
|
||||
frontendCallback string,
|
||||
redirectTo string,
|
||||
profile *emailOAuthProfile,
|
||||
) {
|
||||
input := service.EmailOAuthIdentityInput{
|
||||
ProviderType: provider,
|
||||
ProviderKey: provider,
|
||||
ProviderSubject: profile.Subject,
|
||||
Email: profile.Email,
|
||||
EmailVerified: profile.EmailVerified,
|
||||
Username: profile.Username,
|
||||
DisplayName: profile.DisplayName,
|
||||
AvatarURL: profile.AvatarURL,
|
||||
UpstreamMetadata: profile.Metadata,
|
||||
}
|
||||
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
if pendingErr := h.createEmailOAuthInvitationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
|
||||
return
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", tokenPair.AccessToken)
|
||||
fragment.Set("refresh_token", tokenPair.RefreshToken)
|
||||
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
if code, err := readCookieDecoded(c, emailOAuthAffiliateCookie); err == nil {
|
||||
return strings.TrimSpace(code)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
|
||||
c *gin.Context,
|
||||
provider string,
|
||||
frontendCallback string,
|
||||
redirectTo string,
|
||||
profile *emailOAuthProfile,
|
||||
) error {
|
||||
if h == nil || profile == nil {
|
||||
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
browserSessionKey, err := generateOAuthPendingBrowserSession()
|
||||
if err != nil {
|
||||
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
|
||||
}
|
||||
setOAuthPendingBrowserCookie(c, browserSessionKey, isRequestHTTPS(c))
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
username := strings.TrimSpace(profile.Username)
|
||||
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||
upstreamClaims := map[string]any{
|
||||
"email": email,
|
||||
"email_verified": profile.EmailVerified,
|
||||
"username": username,
|
||||
"provider": provider,
|
||||
"provider_key": provider,
|
||||
"provider_subject": strings.TrimSpace(profile.Subject),
|
||||
}
|
||||
if strings.TrimSpace(profile.DisplayName) != "" {
|
||||
upstreamClaims["suggested_display_name"] = strings.TrimSpace(profile.DisplayName)
|
||||
}
|
||||
if strings.TrimSpace(profile.AvatarURL) != "" {
|
||||
upstreamClaims["suggested_avatar_url"] = strings.TrimSpace(profile.AvatarURL)
|
||||
}
|
||||
if affiliateCode != "" {
|
||||
upstreamClaims["aff_code"] = affiliateCode
|
||||
}
|
||||
for key, value := range profile.Metadata {
|
||||
if _, exists := upstreamClaims[key]; !exists {
|
||||
upstreamClaims[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
completionResponse := map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"error": "invitation_required",
|
||||
"choice_reason": "invitation_required",
|
||||
"adoption_required": false,
|
||||
"create_account_allowed": true,
|
||||
"existing_account_bindable": false,
|
||||
"force_email_on_signup": true,
|
||||
"email": email,
|
||||
"resolved_email": email,
|
||||
"provider": provider,
|
||||
"redirect": redirectTo,
|
||||
}
|
||||
if strings.TrimSpace(frontendCallback) != "" {
|
||||
completionResponse["frontend_callback"] = strings.TrimSpace(frontendCallback)
|
||||
}
|
||||
|
||||
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: service.PendingAuthIdentityKey{ProviderType: provider, ProviderKey: provider, ProviderSubject: strings.TrimSpace(profile.Subject)},
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: completionResponse,
|
||||
})
|
||||
}
|
||||
|
||||
type completeEmailOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
}
|
||||
|
||||
func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider string) {
|
||||
var req completeEmailOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
|
||||
response.BadRequest(c, "Pending oauth session provider mismatch")
|
||||
return
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
affiliateCode := strings.TrimSpace(req.AffCode)
|
||||
if affiliateCode == "" {
|
||||
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(
|
||||
c.Request.Context(),
|
||||
service.EmailOAuthIdentityInput{
|
||||
ProviderType: strings.TrimSpace(session.ProviderType),
|
||||
ProviderKey: strings.TrimSpace(session.ProviderKey),
|
||||
ProviderSubject: strings.TrimSpace(session.ProviderSubject),
|
||||
Email: strings.TrimSpace(session.ResolvedEmail),
|
||||
EmailVerified: true,
|
||||
Username: pendingSessionStringValue(session.UpstreamIdentityClaims, "username"),
|
||||
DisplayName: pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"),
|
||||
AvatarURL: pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url"),
|
||||
UpstreamMetadata: clonePendingMap(session.UpstreamIdentityClaims),
|
||||
},
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
affiliateCode,
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||
return
|
||||
}
|
||||
tx, err := client.Tx(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
|
||||
_ = tx.Rollback()
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||
return
|
||||
}
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getEmailOAuthConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetEmailOAuthProviderConfig(ctx, provider)
|
||||
}
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
|
||||
func buildEmailOAuthAuthorizeURL(cfg config.EmailOAuthProviderConfig, state string) (string, error) {
|
||||
u, err := url.Parse(cfg.AuthorizeURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", cfg.ClientID)
|
||||
q.Set("redirect_uri", cfg.RedirectURL)
|
||||
q.Set("state", state)
|
||||
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func exchangeEmailOAuthCode(ctx context.Context, cfg config.EmailOAuthProviderConfig, code string) (*emailOAuthTokenResponse, error) {
|
||||
resp, err := req.C().
|
||||
R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetFormData(map[string]string{
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": cfg.ClientID,
|
||||
"client_secret": cfg.ClientSecret,
|
||||
"code": code,
|
||||
"redirect_uri": cfg.RedirectURL,
|
||||
}).
|
||||
Post(cfg.TokenURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("token endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||
}
|
||||
var tokenResp emailOAuthTokenResponse
|
||||
if err := json.Unmarshal(resp.Bytes(), &tokenResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return nil, errors.New("missing access_token")
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func fetchEmailOAuthProfile(ctx context.Context, provider string, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse) (*emailOAuthProfile, error) {
|
||||
resp, err := req.C().
|
||||
R().
|
||||
SetContext(ctx).
|
||||
SetBearerAuthToken(token.AccessToken).
|
||||
SetHeader("Accept", "application/json").
|
||||
Get(cfg.UserInfoURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("userinfo endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "github":
|
||||
return parseGitHubOAuthProfile(ctx, cfg, token, resp.String())
|
||||
case "google":
|
||||
return parseGoogleOAuthProfile(resp.String())
|
||||
default:
|
||||
return nil, errors.New("unsupported oauth provider")
|
||||
}
|
||||
}
|
||||
|
||||
func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse, body string) (*emailOAuthProfile, error) {
|
||||
subject := strings.TrimSpace(gjson.Get(body, "id").String())
|
||||
if subject == "" {
|
||||
return nil, errors.New("github user id is missing")
|
||||
}
|
||||
email := strings.TrimSpace(gjson.Get(body, "email").String())
|
||||
emailVerified := false
|
||||
if email != "" {
|
||||
emailVerified = true
|
||||
}
|
||||
if strings.TrimSpace(cfg.EmailsURL) != "" {
|
||||
if verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, cfg.EmailsURL, token.AccessToken); err == nil && verifiedEmail != "" {
|
||||
email = verifiedEmail
|
||||
emailVerified = true
|
||||
} else if email == "" && err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if email == "" || !emailVerified {
|
||||
return nil, errors.New("github verified email is missing")
|
||||
}
|
||||
login := strings.TrimSpace(gjson.Get(body, "login").String())
|
||||
name := strings.TrimSpace(gjson.Get(body, "name").String())
|
||||
return &emailOAuthProfile{
|
||||
Subject: subject,
|
||||
Email: email,
|
||||
EmailVerified: true,
|
||||
Username: firstNonEmpty(login, name, "github_"+subject),
|
||||
DisplayName: firstNonEmpty(name, login),
|
||||
AvatarURL: strings.TrimSpace(gjson.Get(body, "avatar_url").String()),
|
||||
Metadata: map[string]any{
|
||||
"login": login,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryVerifiedEmail(ctx context.Context, emailsURL string, accessToken string) (string, error) {
|
||||
resp, err := req.C().
|
||||
R().
|
||||
SetContext(ctx).
|
||||
SetBearerAuthToken(accessToken).
|
||||
SetHeader("Accept", "application/json").
|
||||
Get(emailsURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("github emails endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||
}
|
||||
items := gjson.Parse(resp.String()).Array()
|
||||
for _, item := range items {
|
||||
if item.Get("primary").Bool() && item.Get("verified").Bool() {
|
||||
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.Get("verified").Bool() {
|
||||
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", errors.New("github verified email is missing")
|
||||
}
|
||||
|
||||
func parseGoogleOAuthProfile(body string) (*emailOAuthProfile, error) {
|
||||
subject := strings.TrimSpace(gjson.Get(body, "sub").String())
|
||||
email := strings.TrimSpace(gjson.Get(body, "email").String())
|
||||
verified := gjson.Get(body, "email_verified").Bool()
|
||||
if subject == "" {
|
||||
return nil, errors.New("google subject is missing")
|
||||
}
|
||||
if email == "" || !verified {
|
||||
return nil, errors.New("google verified email is missing")
|
||||
}
|
||||
name := strings.TrimSpace(gjson.Get(body, "name").String())
|
||||
return &emailOAuthProfile{
|
||||
Subject: subject,
|
||||
Email: email,
|
||||
EmailVerified: true,
|
||||
Username: firstNonEmpty(strings.TrimSpace(gjson.Get(body, "given_name").String()), name, email),
|
||||
DisplayName: name,
|
||||
AvatarURL: strings.TrimSpace(gjson.Get(body, "picture").String()),
|
||||
Metadata: map[string]any{
|
||||
"email_verified": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func emailOAuthSetCookie(c *gin.Context, name, value string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: emailOAuthCookiePath,
|
||||
MaxAge: emailOAuthCookieMaxAgeSec,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func emailOAuthClearCookie(c *gin.Context, name string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: emailOAuthCookiePath,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,333 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
state := "github-oauth-state"
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil)
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)})
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")})
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")})
|
||||
c.Request = req
|
||||
|
||||
profile := &emailOAuthProfile{
|
||||
Subject: "github-123",
|
||||
Email: "fresh@example.com",
|
||||
EmailVerified: true,
|
||||
Username: "fresh",
|
||||
DisplayName: "Fresh User",
|
||||
AvatarURL: "https://cdn.example/fresh.png",
|
||||
Metadata: map[string]any{
|
||||
"login": "fresh",
|
||||
},
|
||||
}
|
||||
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
|
||||
Enabled: true,
|
||||
ClientID: "github-client",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
|
||||
FrontendRedirectURL: "/auth/oauth/callback",
|
||||
}, "/auth/oauth/callback", "/dashboard", profile)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.Contains(t, location, "/auth/oauth/callback")
|
||||
require.NotContains(t, location, "access_token=")
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
session, err := client.PendingAuthSession.Query().Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "github", session.ProviderType)
|
||||
require.Equal(t, "github", session.ProviderKey)
|
||||
require.Equal(t, "github-123", session.ProviderSubject)
|
||||
require.Equal(t, "fresh@example.com", session.ResolvedEmail)
|
||||
require.Equal(t, "/dashboard", session.RedirectTo)
|
||||
require.Nil(t, session.TargetUserID)
|
||||
|
||||
completion, ok := readCompletionResponse(session.LocalFlowState)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||
require.Equal(t, "invitation_required", completion["error"])
|
||||
require.Equal(t, "fresh@example.com", completion["email"])
|
||||
require.Equal(t, "fresh@example.com", completion["resolved_email"])
|
||||
require.Equal(t, true, completion["create_account_allowed"])
|
||||
|
||||
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName))
|
||||
}
|
||||
|
||||
func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("existing@example.com").
|
||||
SetUsername("existing").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil)
|
||||
|
||||
handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{
|
||||
Enabled: true,
|
||||
ClientID: "google-client",
|
||||
ClientSecret: "google-secret",
|
||||
RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback",
|
||||
FrontendRedirectURL: "/auth/oauth/callback",
|
||||
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
|
||||
Subject: "google-123",
|
||||
Email: "existing@example.com",
|
||||
EmailVerified: true,
|
||||
Username: "existing",
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.Contains(t, location, "access_token=")
|
||||
require.Contains(t, location, "redirect=%252Fdashboard")
|
||||
|
||||
sessionCount, err := client.PendingAuthSession.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, sessionCount)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().Where(
|
||||
authidentity.ProviderTypeEQ("google"),
|
||||
authidentity.ProviderSubjectEQ("google-123"),
|
||||
).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, identityCount)
|
||||
_ = user
|
||||
}
|
||||
|
||||
func TestEmailOAuthCallbackAutoRegistrationAppliesAffiliateCode(t *testing.T) {
|
||||
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
settingValues: map[string]string{
|
||||
service.SettingKeyAffiliateEnabled: "true",
|
||||
},
|
||||
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
|
||||
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
|
||||
},
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil)
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")})
|
||||
c.Request = req
|
||||
|
||||
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
|
||||
Enabled: true,
|
||||
ClientID: "github-client",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
|
||||
FrontendRedirectURL: "/auth/oauth/callback",
|
||||
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
|
||||
Subject: "github-aff-user",
|
||||
Email: "aff-user@example.com",
|
||||
EmailVerified: true,
|
||||
Username: "aff-user",
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.Contains(t, recorder.Header().Get("Location"), "access_token=")
|
||||
user, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{user.ID, user.ID}, affiliateRepo.ensureUserIDs)
|
||||
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 1001}}, affiliateRepo.bindCalls)
|
||||
}
|
||||
|
||||
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
|
||||
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002})
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
invitationEnabled: true,
|
||||
settingValues: map[string]string{
|
||||
service.SettingKeyAffiliateEnabled: "true",
|
||||
},
|
||||
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
|
||||
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
|
||||
},
|
||||
})
|
||||
ctx := context.Background()
|
||||
invitation, err := client.RedeemCode.Create().
|
||||
SetCode("INVITE456").
|
||||
SetType(service.RedeemTypeInvitation).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("email-oauth-aff-session-token").
|
||||
SetIntent(oauthIntentLogin).
|
||||
SetProviderType("google").
|
||||
SetProviderKey("google").
|
||||
SetProviderSubject("google-aff-user").
|
||||
SetResolvedEmail("pending-aff@example.com").
|
||||
SetRedirectTo("/dashboard").
|
||||
SetBrowserSessionKey("browser-aff-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"email": "pending-aff@example.com",
|
||||
"email_verified": true,
|
||||
"username": "pending-aff",
|
||||
"provider": "google",
|
||||
"provider_key": "google",
|
||||
"provider_subject": "google-aff-user",
|
||||
"aff_code": "AFF456",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"error": "invitation_required",
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"invitation_code":"INVITE456"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
|
||||
c.Request = req
|
||||
|
||||
handler.completeEmailOAuthRegistration(c, "google")
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
|
||||
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedInvitation.UsedBy)
|
||||
require.Equal(t, user.ID, *storedInvitation.UsedBy)
|
||||
}
|
||||
|
||||
type oauthEmailAffiliateBindCall struct {
|
||||
userID int64
|
||||
inviterID int64
|
||||
}
|
||||
|
||||
type oauthEmailAffiliateRepoStub struct {
|
||||
codeOwners map[string]int64
|
||||
ensureUserIDs []int64
|
||||
bindCalls []oauthEmailAffiliateBindCall
|
||||
}
|
||||
|
||||
func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub {
|
||||
return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners}
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) {
|
||||
r.ensureUserIDs = append(r.ensureUserIDs, userID)
|
||||
return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) {
|
||||
userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))]
|
||||
if !ok {
|
||||
return nil, service.ErrAffiliateProfileNotFound
|
||||
}
|
||||
return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) {
|
||||
r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID})
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) {
|
||||
panic("unexpected AccrueQuota call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
|
||||
panic("unexpected GetAccruedRebateFromInvitee call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) {
|
||||
panic("unexpected ThawFrozenQuota call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) {
|
||||
panic("unexpected TransferQuotaToBalance call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) {
|
||||
panic("unexpected ListInvitees call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error {
|
||||
panic("unexpected UpdateUserAffCode call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) {
|
||||
panic("unexpected ResetUserAffCode call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error {
|
||||
panic("unexpected SetUserRebateRate call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error {
|
||||
panic("unexpected BatchSetUserRebateRate call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
|
||||
panic("unexpected ListUsersWithCustomSettings call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||
panic("unexpected ListAffiliateInviteRecords call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||
panic("unexpected ListAffiliateRebateRecords call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||
panic("unexpected ListAffiliateTransferRecords call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) {
|
||||
panic("unexpected GetAffiliateUserOverview call")
|
||||
}
|
||||
|
||||
func findSetCookieValue(cookies []*http.Cookie, name string) string {
|
||||
for _, cookie := range cookies {
|
||||
if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 {
|
||||
return cookie.Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -2121,6 +2121,8 @@ type oauthPendingFlowTestHandlerOptions struct {
|
||||
emailCache service.EmailCache
|
||||
settingValues map[string]string
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||
affiliateService *service.AffiliateService
|
||||
affiliateFactory func(*dbent.Client, *service.SettingService) *service.AffiliateService
|
||||
totpCache service.TotpCache
|
||||
totpEncryptor service.SecretEncryptor
|
||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||
@@ -2160,6 +2162,21 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_affiliates (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
aff_code TEXT NOT NULL UNIQUE,
|
||||
aff_code_custom BOOLEAN NOT NULL DEFAULT false,
|
||||
aff_rebate_rate_percent REAL NULL,
|
||||
inviter_id INTEGER NULL,
|
||||
aff_count INTEGER NOT NULL DEFAULT 0,
|
||||
aff_quota REAL NOT NULL DEFAULT 0,
|
||||
aff_frozen_quota REAL NOT NULL DEFAULT 0,
|
||||
aff_history_quota REAL NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
@@ -2177,14 +2194,19 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
},
|
||||
}
|
||||
settingValues := map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
}
|
||||
for key, value := range options.settingValues {
|
||||
settingValues[key] = value
|
||||
}
|
||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
||||
affiliateService := options.affiliateService
|
||||
if affiliateService == nil && options.affiliateFactory != nil {
|
||||
affiliateService = options.affiliateFactory(client, settingSvc)
|
||||
}
|
||||
userRepo := &oauthPendingFlowUserRepo{
|
||||
client: client,
|
||||
options: options.userRepoOptions,
|
||||
@@ -2210,7 +2232,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
nil,
|
||||
nil,
|
||||
options.defaultSubAssigner,
|
||||
nil,
|
||||
affiliateService,
|
||||
)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
var totpSvc *service.TotpService
|
||||
|
||||
@@ -91,6 +91,17 @@ type SystemSettings struct {
|
||||
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
||||
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
||||
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GitHubOAuthClientID string `json:"github_oauth_client_id"`
|
||||
GitHubOAuthClientSecretConfigured bool `json:"github_oauth_client_secret_configured"`
|
||||
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
|
||||
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
GoogleOAuthClientID string `json:"google_oauth_client_id"`
|
||||
GoogleOAuthClientSecretConfigured bool `json:"google_oauth_client_secret_configured"`
|
||||
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
|
||||
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
@@ -241,6 +252,8 @@ type PublicSettings struct {
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
|
||||
@@ -63,6 +63,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: h.version,
|
||||
|
||||
@@ -40,6 +40,8 @@ func backendModeAllowsAuthPath(path string) bool {
|
||||
"/auth/oauth/wechat/callback",
|
||||
"/auth/oauth/wechat/payment/callback",
|
||||
"/auth/oauth/oidc/callback",
|
||||
"/auth/oauth/github/callback",
|
||||
"/auth/oauth/google/callback",
|
||||
"/auth/oauth/linuxdo/complete-registration",
|
||||
"/auth/oauth/wechat/complete-registration",
|
||||
"/auth/oauth/oidc/complete-registration",
|
||||
|
||||
@@ -246,6 +246,30 @@ func TestBackendModeAuthGuard(t *testing.T) {
|
||||
path: "/api/v1/auth/oauth/oidc/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_github_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/github/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_github_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/github/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_google_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/google/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_google_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/google/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_exchange",
|
||||
enabled: "true",
|
||||
|
||||
@@ -63,6 +63,22 @@ func RegisterAuthRoutes(
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/github/start", h.Auth.GitHubOAuthStart)
|
||||
auth.GET("/oauth/github/callback", h.Auth.GitHubOAuthCallback)
|
||||
auth.POST("/oauth/github/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-github-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteGitHubOAuthRegistration,
|
||||
)
|
||||
auth.GET("/oauth/google/start", h.Auth.GoogleOAuthStart)
|
||||
auth.GET("/oauth/google/callback", h.Auth.GoogleOAuthCallback)
|
||||
auth.POST("/oauth/google/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-google-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteGoogleOAuthRegistration,
|
||||
)
|
||||
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
|
||||
query := c.Request.URL.Query()
|
||||
query.Set("intent", "bind_current_user")
|
||||
|
||||
@@ -0,0 +1,274 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
type EmailOAuthIdentityInput struct {
|
||||
ProviderType string
|
||||
ProviderKey string
|
||||
ProviderSubject string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
Username string
|
||||
DisplayName string
|
||||
AvatarURL string
|
||||
UpstreamMetadata map[string]any
|
||||
}
|
||||
|
||||
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuth(ctx context.Context, input EmailOAuthIdentityInput) (*TokenPair, *User, error) {
|
||||
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, "", "")
|
||||
}
|
||||
|
||||
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuthWithInvitation(
|
||||
ctx context.Context,
|
||||
input EmailOAuthIdentityInput,
|
||||
invitationCode string,
|
||||
affiliateCode string,
|
||||
) (*TokenPair, *User, error) {
|
||||
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, invitationCode, affiliateCode)
|
||||
}
|
||||
|
||||
func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
|
||||
ctx context.Context,
|
||||
input EmailOAuthIdentityInput,
|
||||
invitationCode string,
|
||||
affiliateCode string,
|
||||
) (*TokenPair, *User, error) {
|
||||
if s == nil || s.userRepo == nil || s.entClient == nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||
if providerType != "github" && providerType != "google" {
|
||||
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
|
||||
}
|
||||
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||
if providerKey == "" {
|
||||
providerKey = providerType
|
||||
}
|
||||
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||
if providerSubject == "" {
|
||||
return nil, nil, infraerrors.BadRequest("OAUTH_SUBJECT_MISSING", "oauth subject is missing")
|
||||
}
|
||||
if !input.EmailVerified {
|
||||
return nil, nil, infraerrors.Forbidden("OAUTH_EMAIL_NOT_VERIFIED", "oauth email is not verified")
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(input.Email))
|
||||
if email == "" || len(email) > 255 {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if isReservedEmail(email) {
|
||||
return nil, nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
identityUser, err := s.findEmailOAuthIdentityOwner(ctx, providerType, providerKey, providerSubject)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if identityUser != nil && !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
|
||||
return nil, nil, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
|
||||
}
|
||||
|
||||
user := identityUser
|
||||
created := false
|
||||
if user == nil {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
user, err = s.createEmailOAuthUser(ctx, email, input.Username, providerType, invitationCode, affiliateCode)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
created = true
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during %s oauth login: %v", providerType, err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
if err := s.ensureEmailOAuthIdentity(ctx, user.ID, EmailOAuthIdentityInput{
|
||||
ProviderType: providerType,
|
||||
ProviderKey: providerKey,
|
||||
ProviderSubject: providerSubject,
|
||||
Email: email,
|
||||
EmailVerified: input.EmailVerified,
|
||||
Username: input.Username,
|
||||
DisplayName: input.DisplayName,
|
||||
AvatarURL: input.AvatarURL,
|
||||
UpstreamMetadata: input.UpstreamMetadata,
|
||||
}); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if user.Username == "" && strings.TrimSpace(input.Username) != "" {
|
||||
user.Username = strings.TrimSpace(input.Username)
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after %s oauth login: %v", providerType, err)
|
||||
}
|
||||
}
|
||||
if !created {
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, providerType); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply %s first bind defaults: %v", providerType, err)
|
||||
}
|
||||
}
|
||||
s.RecordSuccessfulLogin(ctx, user.ID)
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, providerType, invitationCode, affiliateCode string) (*User, error) {
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvitationCodeRequired) {
|
||||
return nil, ErrOAuthInvitationRequired
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, providerType)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
user := &User{
|
||||
Email: email,
|
||||
Username: strings.TrimSpace(username),
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: providerType,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
existing, loadErr := s.userRepo.GetByEmail(ctx, email)
|
||||
if loadErr != nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, providerType, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, invitationCode)
|
||||
return nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) findEmailOAuthIdentityOwner(ctx context.Context, providerType, providerKey, providerSubject string) (*User, error) {
|
||||
identity, err := s.entClient.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderSubjectEQ(providerSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||
}
|
||||
user, err := s.userRepo.GetByID(ctx, identity.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailOAuthIdentity(ctx context.Context, userID int64, input EmailOAuthIdentityInput) error {
|
||||
metadata := map[string]any{
|
||||
"email": strings.TrimSpace(strings.ToLower(input.Email)),
|
||||
"email_verified": input.EmailVerified,
|
||||
}
|
||||
for key, value := range input.UpstreamMetadata {
|
||||
metadata[key] = value
|
||||
}
|
||||
if strings.TrimSpace(input.Username) != "" {
|
||||
metadata["username"] = strings.TrimSpace(input.Username)
|
||||
}
|
||||
if strings.TrimSpace(input.DisplayName) != "" {
|
||||
metadata["display_name"] = strings.TrimSpace(input.DisplayName)
|
||||
}
|
||||
if strings.TrimSpace(input.AvatarURL) != "" {
|
||||
metadata["avatar_url"] = strings.TrimSpace(input.AvatarURL)
|
||||
}
|
||||
|
||||
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||
identity, err := s.entClient.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderSubjectEQ(providerSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||
}
|
||||
if identity != nil {
|
||||
if identity.UserID != userID {
|
||||
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
_, err = s.entClient.AuthIdentity.UpdateOneID(identity.ID).
|
||||
SetMetadata(metadata).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
_, err = s.entClient.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType(providerType).
|
||||
SetProviderKey(providerKey).
|
||||
SetProviderSubject(providerSubject).
|
||||
SetMetadata(metadata).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
@@ -17,7 +17,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
|
||||
switch signupSource {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc":
|
||||
case "linuxdo", "wechat", "oidc", "github", "google":
|
||||
return signupSource
|
||||
default:
|
||||
return "email"
|
||||
|
||||
@@ -775,6 +775,10 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
||||
return defaults.OIDC, true
|
||||
case "wechat":
|
||||
return defaults.WeChat, true
|
||||
case "github":
|
||||
return defaults.GitHub, true
|
||||
case "google":
|
||||
return defaults.Google, true
|
||||
default:
|
||||
return ProviderDefaultGrantSettings{}, false
|
||||
}
|
||||
|
||||
@@ -173,6 +173,18 @@ const (
|
||||
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
|
||||
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
|
||||
|
||||
// GitHub / Google 邮箱快捷登录设置
|
||||
SettingKeyGitHubOAuthEnabled = "github_oauth_enabled"
|
||||
SettingKeyGitHubOAuthClientID = "github_oauth_client_id"
|
||||
SettingKeyGitHubOAuthClientSecret = "github_oauth_client_secret"
|
||||
SettingKeyGitHubOAuthRedirectURL = "github_oauth_redirect_url"
|
||||
SettingKeyGitHubOAuthFrontendRedirectURL = "github_oauth_frontend_redirect_url"
|
||||
SettingKeyGoogleOAuthEnabled = "google_oauth_enabled"
|
||||
SettingKeyGoogleOAuthClientID = "google_oauth_client_id"
|
||||
SettingKeyGoogleOAuthClientSecret = "google_oauth_client_secret"
|
||||
SettingKeyGoogleOAuthRedirectURL = "google_oauth_redirect_url"
|
||||
SettingKeyGoogleOAuthFrontendRedirectURL = "google_oauth_frontend_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
@@ -216,6 +228,16 @@ const (
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
|
||||
SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
|
||||
SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
|
||||
SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
|
||||
SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
|
||||
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
|
||||
|
||||
// 管理员 API Key
|
||||
|
||||
@@ -129,6 +129,8 @@ type AuthSourceDefaultSettings struct {
|
||||
LinuxDo ProviderDefaultGrantSettings
|
||||
OIDC ProviderDefaultGrantSettings
|
||||
WeChat ProviderDefaultGrantSettings
|
||||
GitHub ProviderDefaultGrantSettings
|
||||
Google ProviderDefaultGrantSettings
|
||||
ForceEmailOnThirdPartySignup bool
|
||||
}
|
||||
|
||||
@@ -169,6 +171,20 @@ var (
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
}
|
||||
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultGitHubBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||
}
|
||||
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultGoogleBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -177,6 +193,17 @@ const (
|
||||
defaultWeChatConnectMode = "open"
|
||||
defaultWeChatConnectScopes = "snsapi_login"
|
||||
defaultWeChatConnectFrontend = "/auth/wechat/callback"
|
||||
defaultGitHubOAuthAuthorize = "https://github.com/login/oauth/authorize"
|
||||
defaultGitHubOAuthToken = "https://github.com/login/oauth/access_token"
|
||||
defaultGitHubOAuthUserInfo = "https://api.github.com/user"
|
||||
defaultGitHubOAuthEmails = "https://api.github.com/user/emails"
|
||||
defaultGitHubOAuthScopes = "read:user user:email"
|
||||
defaultGitHubOAuthFrontend = "/auth/oauth/callback"
|
||||
defaultGoogleOAuthAuthorize = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
defaultGoogleOAuthToken = "https://oauth2.googleapis.com/token"
|
||||
defaultGoogleOAuthUserInfo = "https://openidconnect.googleapis.com/v1/userinfo"
|
||||
defaultGoogleOAuthScopes = "openid email profile"
|
||||
defaultGoogleOAuthFrontend = "/auth/oauth/callback"
|
||||
)
|
||||
|
||||
func normalizeWeChatConnectModeSetting(raw string) string {
|
||||
@@ -448,6 +475,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingPaymentEnabled,
|
||||
SettingKeyOIDCConnectEnabled,
|
||||
SettingKeyOIDCConnectProviderName,
|
||||
SettingKeyGitHubOAuthEnabled,
|
||||
SettingKeyGitHubOAuthClientID,
|
||||
SettingKeyGitHubOAuthClientSecret,
|
||||
SettingKeyGoogleOAuthEnabled,
|
||||
SettingKeyGoogleOAuthClientID,
|
||||
SettingKeyGoogleOAuthClientSecret,
|
||||
SettingKeyBalanceLowNotifyEnabled,
|
||||
SettingKeyBalanceLowNotifyThreshold,
|
||||
SettingKeyBalanceLowNotifyRechargeURL,
|
||||
@@ -482,6 +515,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
if oidcProviderName == "" {
|
||||
oidcProviderName = "OIDC"
|
||||
}
|
||||
gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github")
|
||||
googleEnabled := s.emailOAuthPublicEnabled(settings, "google")
|
||||
weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
|
||||
|
||||
// Password reset requires email verification to be enabled
|
||||
@@ -534,6 +569,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
|
||||
OIDCOAuthEnabled: oidcEnabled,
|
||||
OIDCOAuthProviderName: oidcProviderName,
|
||||
GitHubOAuthEnabled: gitHubEnabled,
|
||||
GoogleOAuthEnabled: googleEnabled,
|
||||
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
|
||||
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
|
||||
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
|
||||
@@ -677,6 +714,8 @@ type PublicSettingsInjectionPayload struct {
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
@@ -733,6 +772,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: s.version,
|
||||
@@ -806,6 +847,98 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string
|
||||
return openReady || mpReady, openReady, mpReady, mobileReady
|
||||
}
|
||||
|
||||
func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "github":
|
||||
cfg := config.EmailOAuthProviderConfig{
|
||||
AuthorizeURL: defaultGitHubOAuthAuthorize,
|
||||
TokenURL: defaultGitHubOAuthToken,
|
||||
UserInfoURL: defaultGitHubOAuthUserInfo,
|
||||
EmailsURL: defaultGitHubOAuthEmails,
|
||||
Scopes: defaultGitHubOAuthScopes,
|
||||
FrontendRedirectURL: defaultGitHubOAuthFrontend,
|
||||
}
|
||||
if s != nil && s.cfg != nil {
|
||||
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GitHubOAuth)
|
||||
}
|
||||
return cfg
|
||||
case "google":
|
||||
cfg := config.EmailOAuthProviderConfig{
|
||||
AuthorizeURL: defaultGoogleOAuthAuthorize,
|
||||
TokenURL: defaultGoogleOAuthToken,
|
||||
UserInfoURL: defaultGoogleOAuthUserInfo,
|
||||
Scopes: defaultGoogleOAuthScopes,
|
||||
FrontendRedirectURL: defaultGoogleOAuthFrontend,
|
||||
}
|
||||
if s != nil && s.cfg != nil {
|
||||
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GoogleOAuth)
|
||||
}
|
||||
return cfg
|
||||
default:
|
||||
return config.EmailOAuthProviderConfig{}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) config.EmailOAuthProviderConfig {
|
||||
base.Enabled = override.Enabled
|
||||
if strings.TrimSpace(override.ClientID) != "" {
|
||||
base.ClientID = strings.TrimSpace(override.ClientID)
|
||||
}
|
||||
if strings.TrimSpace(override.ClientSecret) != "" {
|
||||
base.ClientSecret = strings.TrimSpace(override.ClientSecret)
|
||||
}
|
||||
if strings.TrimSpace(override.AuthorizeURL) != "" {
|
||||
base.AuthorizeURL = strings.TrimSpace(override.AuthorizeURL)
|
||||
}
|
||||
if strings.TrimSpace(override.TokenURL) != "" {
|
||||
base.TokenURL = strings.TrimSpace(override.TokenURL)
|
||||
}
|
||||
if strings.TrimSpace(override.UserInfoURL) != "" {
|
||||
base.UserInfoURL = strings.TrimSpace(override.UserInfoURL)
|
||||
}
|
||||
if strings.TrimSpace(override.EmailsURL) != "" {
|
||||
base.EmailsURL = strings.TrimSpace(override.EmailsURL)
|
||||
}
|
||||
if strings.TrimSpace(override.Scopes) != "" {
|
||||
base.Scopes = strings.TrimSpace(override.Scopes)
|
||||
}
|
||||
if strings.TrimSpace(override.RedirectURL) != "" {
|
||||
base.RedirectURL = strings.TrimSpace(override.RedirectURL)
|
||||
}
|
||||
if strings.TrimSpace(override.FrontendRedirectURL) != "" {
|
||||
base.FrontendRedirectURL = strings.TrimSpace(override.FrontendRedirectURL)
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool {
|
||||
cfg := s.effectiveEmailOAuthConfig(settings, provider)
|
||||
return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != ""
|
||||
}
|
||||
|
||||
func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig {
|
||||
cfg := s.emailOAuthBaseConfig(provider)
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "github":
|
||||
if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok {
|
||||
cfg.Enabled = raw == "true"
|
||||
}
|
||||
cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID)
|
||||
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret)
|
||||
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL)
|
||||
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend)
|
||||
case "google":
|
||||
if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok {
|
||||
cfg.Enabled = raw == "true"
|
||||
}
|
||||
cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID)
|
||||
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret)
|
||||
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL)
|
||||
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
|
||||
// array string, returning only items with visibility != "admin".
|
||||
func filterUserVisibleMenuItems(raw string) json.RawMessage {
|
||||
@@ -1052,6 +1185,16 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
if settings.WeChatConnectFrontendRedirectURL == "" {
|
||||
settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
|
||||
}
|
||||
settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL)
|
||||
settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL)
|
||||
if settings.GitHubOAuthFrontendRedirectURL == "" {
|
||||
settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend
|
||||
}
|
||||
settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL)
|
||||
settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL)
|
||||
if settings.GoogleOAuthFrontendRedirectURL == "" {
|
||||
settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend
|
||||
}
|
||||
|
||||
updates := make(map[string]string)
|
||||
|
||||
@@ -1121,6 +1264,22 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
|
||||
}
|
||||
|
||||
// GitHub / Google 邮箱快捷登录
|
||||
updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled)
|
||||
updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID)
|
||||
updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL
|
||||
updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL
|
||||
if settings.GitHubOAuthClientSecret != "" {
|
||||
updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret)
|
||||
}
|
||||
updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled)
|
||||
updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID)
|
||||
updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL
|
||||
updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL
|
||||
if settings.GoogleOAuthClientSecret != "" {
|
||||
updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret)
|
||||
}
|
||||
|
||||
// WeChat Connect OAuth 登录
|
||||
updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
|
||||
updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
|
||||
@@ -1273,17 +1432,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
|
||||
settings.LinuxDo.Subscriptions,
|
||||
settings.OIDC.Subscriptions,
|
||||
settings.WeChat.Subscriptions,
|
||||
settings.GitHub.Subscriptions,
|
||||
settings.Google.Subscriptions,
|
||||
} {
|
||||
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
updates := make(map[string]string, 21)
|
||||
updates := make(map[string]string, 31)
|
||||
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
||||
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
||||
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
||||
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
||||
writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub)
|
||||
writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google)
|
||||
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
||||
return updates, nil
|
||||
}
|
||||
@@ -1362,6 +1525,61 @@ func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider != "github" && provider != "google" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_PROVIDER_NOT_FOUND", "oauth provider not found")
|
||||
}
|
||||
keys := []string{
|
||||
SettingKeyGitHubOAuthEnabled,
|
||||
SettingKeyGitHubOAuthClientID,
|
||||
SettingKeyGitHubOAuthClientSecret,
|
||||
SettingKeyGitHubOAuthRedirectURL,
|
||||
SettingKeyGitHubOAuthFrontendRedirectURL,
|
||||
SettingKeyGoogleOAuthEnabled,
|
||||
SettingKeyGoogleOAuthClientID,
|
||||
SettingKeyGoogleOAuthClientSecret,
|
||||
SettingKeyGoogleOAuthRedirectURL,
|
||||
SettingKeyGoogleOAuthFrontendRedirectURL,
|
||||
}
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, fmt.Errorf("get email oauth settings: %w", err)
|
||||
}
|
||||
cfg := s.effectiveEmailOAuthConfig(settings, provider)
|
||||
if !cfg.Enabled {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientID) == "" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
for label, rawURL := range map[string]string{
|
||||
"authorize": cfg.AuthorizeURL,
|
||||
"token": cfg.TokenURL,
|
||||
"userinfo": cfg.UserInfoURL,
|
||||
"redirect": cfg.RedirectURL,
|
||||
} {
|
||||
if strings.TrimSpace(rawURL) == "" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url not configured")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(rawURL); err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url invalid")
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(cfg.EmailsURL) != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(cfg.EmailsURL); err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth emails url invalid")
|
||||
}
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
@@ -1711,6 +1929,16 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultGitHubBalance,
|
||||
SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||
SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultGoogleBalance,
|
||||
SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||
SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||
SettingKeyForceEmailOnThirdPartySignup,
|
||||
}
|
||||
|
||||
@@ -1724,6 +1952,8 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
||||
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
|
||||
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
|
||||
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
|
||||
GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys),
|
||||
Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys),
|
||||
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
|
||||
}, nil
|
||||
}
|
||||
@@ -1824,6 +2054,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyWeChatConnectScopes: "snsapi_login",
|
||||
SettingKeyWeChatConnectRedirectURL: "",
|
||||
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
|
||||
SettingKeyGitHubOAuthEnabled: "false",
|
||||
SettingKeyGitHubOAuthClientID: "",
|
||||
SettingKeyGitHubOAuthClientSecret: "",
|
||||
SettingKeyGitHubOAuthRedirectURL: "",
|
||||
SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
|
||||
SettingKeyGoogleOAuthEnabled: "false",
|
||||
SettingKeyGoogleOAuthClientID: "",
|
||||
SettingKeyGoogleOAuthClientSecret: "",
|
||||
SettingKeyGoogleOAuthRedirectURL: "",
|
||||
SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
|
||||
SettingKeyOIDCConnectEnabled: "false",
|
||||
SettingKeyOIDCConnectProviderName: "OIDC",
|
||||
SettingKeyOIDCConnectClientID: "",
|
||||
@@ -1874,6 +2114,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultGitHubBalance: "0",
|
||||
SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultGoogleBalance: "0",
|
||||
SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
|
||||
SettingKeyForceEmailOnThirdPartySignup: "false",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
@@ -2173,6 +2423,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
}
|
||||
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
|
||||
|
||||
gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github")
|
||||
result.GitHubOAuthEnabled = gitHubEffective.Enabled
|
||||
result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID)
|
||||
result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret)
|
||||
result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != ""
|
||||
result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL)
|
||||
result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL)
|
||||
|
||||
googleEffective := s.effectiveEmailOAuthConfig(settings, "google")
|
||||
result.GoogleOAuthEnabled = googleEffective.Enabled
|
||||
result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID)
|
||||
result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret)
|
||||
result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != ""
|
||||
result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL)
|
||||
result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL)
|
||||
|
||||
// WeChat Connect 设置:
|
||||
// - 优先读取 DB 系统设置
|
||||
// - 缺失时回退到 config/env,保持升级兼容
|
||||
|
||||
@@ -89,6 +89,20 @@ type SystemSettings struct {
|
||||
OIDCConnectUserInfoIDPath string
|
||||
OIDCConnectUserInfoUsernamePath string
|
||||
|
||||
// GitHub / Google 邮箱快捷登录
|
||||
GitHubOAuthEnabled bool
|
||||
GitHubOAuthClientID string
|
||||
GitHubOAuthClientSecret string
|
||||
GitHubOAuthClientSecretConfigured bool
|
||||
GitHubOAuthRedirectURL string
|
||||
GitHubOAuthFrontendRedirectURL string
|
||||
GoogleOAuthEnabled bool
|
||||
GoogleOAuthClientID string
|
||||
GoogleOAuthClientSecret string
|
||||
GoogleOAuthClientSecretConfigured bool
|
||||
GoogleOAuthRedirectURL string
|
||||
GoogleOAuthFrontendRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
@@ -217,6 +231,8 @@ type PublicSettings struct {
|
||||
PaymentEnabled bool
|
||||
OIDCOAuthEnabled bool
|
||||
OIDCOAuthProviderName string
|
||||
GitHubOAuthEnabled bool
|
||||
GoogleOAuthEnabled bool
|
||||
Version string
|
||||
|
||||
BalanceLowNotifyEnabled bool
|
||||
|
||||
Reference in New Issue
Block a user