Files
sub2api/backend/internal/service/kiro_oauth_service.go
T
2026-05-16 15:31:08 +08:00

393 lines
11 KiB
Go

package service
import (
"context"
"fmt"
"strings"
"time"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
)
const (
// Kiro desktop social auth uses localhost loopback callbacks from a fixed
// allowlist. Use one of the bundled ports from the official client.
kiroSocialRedirectURI = "http://localhost:49153"
// AWS IAM Identity Center native/public clients require an explicit loopback IP redirect URI.
kiroIDCRedirectURI = "http://127.0.0.1:9876/oauth/callback"
)
type KiroOAuthService struct {
sessionStore *kiropkg.SessionStore
proxyRepo ProxyRepository
}
func NewKiroOAuthService(proxyRepo ProxyRepository) *KiroOAuthService {
return &KiroOAuthService{
sessionStore: kiropkg.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
func (s *KiroOAuthService) Stop() {}
type KiroAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
type KiroIDCAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
ClientID string `json:"client_id"`
Region string `json:"region"`
StartURL string `json:"start_url"`
}
type KiroTokenInfo struct {
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ProfileArn string `json:"profile_arn,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
AuthMethod string `json:"auth_method,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ClientIDHash string `json:"client_id_hash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"start_url,omitempty"`
Region string `json:"region,omitempty"`
}
type KiroGenerateAuthURLInput struct {
ProxyID *int64
Provider string
}
type KiroExchangeCodeInput struct {
SessionID string
State string
Code string
CallbackPath string
LoginOption string
ProxyID *int64
}
type KiroGenerateIDCAuthURLInput struct {
ProxyID *int64
StartURL string
Region string
}
type KiroRefreshTokenInput struct {
RefreshToken string
AuthMethod string
Provider string
ClientID string
ClientSecret string
StartURL string
Region string
ProfileArn string
ProxyID *int64
}
type KiroImportTokenInput struct {
TokenJSON string
DeviceRegistrationJSON string
ProxyID *int64
}
func (s *KiroOAuthService) GenerateAuthURL(ctx context.Context, input *KiroGenerateAuthURLInput) (*KiroAuthURLResult, error) {
provider := strings.TrimSpace(input.Provider)
if provider == "" {
provider = string(kiropkg.SocialProviderGoogle)
}
if provider != string(kiropkg.SocialProviderGoogle) && provider != string(kiropkg.SocialProviderGitHub) {
return nil, fmt.Errorf("unsupported kiro social provider: %s", provider)
}
state, err := kiropkg.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
codeVerifier, err := kiropkg.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
sessionID := kiropkg.GenerateSessionID()
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
AuthType: "social",
Provider: provider,
RedirectURI: kiroSocialRedirectURI,
})
return &KiroAuthURLResult{
AuthURL: kiropkg.BuildSocialSignInURL(kiroSocialRedirectURI, kiropkg.GenerateCodeChallenge(codeVerifier), state),
SessionID: sessionID,
State: state,
}, nil
}
func (s *KiroOAuthService) ExchangeCode(ctx context.Context, input *KiroExchangeCodeInput) (*KiroTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state invalid")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxyURL, _ = s.resolveProxyURL(ctx, input.ProxyID)
}
switch session.AuthType {
case "social":
token, err := kiropkg.CreateSocialToken(
ctx,
proxyURL,
input.Code,
session.CodeVerifier,
buildKiroSocialExchangeRedirectURI(session.RedirectURI, session.Provider, input.CallbackPath, input.LoginOption),
)
if err != nil {
return nil, err
}
token.Provider = session.Provider
s.sessionStore.Delete(input.SessionID)
return toKiroTokenInfo(token), nil
case "idc":
token, err := kiropkg.ExchangeIDCAuthCode(ctx, proxyURL, session.ClientID, session.ClientSecret, input.Code, session.CodeVerifier, session.RedirectURI, session.Region, session.StartURL)
if err != nil {
return nil, err
}
s.sessionStore.Delete(input.SessionID)
return toKiroTokenInfo(token), nil
default:
return nil, fmt.Errorf("unsupported auth session type: %s", session.AuthType)
}
}
func buildKiroSocialExchangeRedirectURI(baseRedirectURI, provider, callbackPath, loginOption string) string {
option := strings.ToLower(strings.TrimSpace(loginOption))
if option == "" {
switch provider {
case string(kiropkg.SocialProviderGitHub):
option = "github"
case string(kiropkg.SocialProviderGoogle):
option = "google"
}
}
return kiropkg.BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, option)
}
func (s *KiroOAuthService) GenerateIDCAuthURL(ctx context.Context, input *KiroGenerateIDCAuthURLInput) (*KiroIDCAuthURLResult, error) {
startURL := strings.TrimSpace(input.StartURL)
if startURL == "" {
startURL = kiropkg.BuilderIDStartURL
}
region := strings.TrimSpace(input.Region)
if region == "" {
region = "us-east-1"
}
state, err := kiropkg.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
codeVerifier, err := kiropkg.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
reg, err := kiropkg.RegisterIDCClient(ctx, proxyURL, kiroIDCRedirectURI, startURL, region)
if err != nil {
return nil, err
}
sessionID := kiropkg.GenerateSessionID()
s.sessionStore.Set(sessionID, &kiropkg.AuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
AuthType: "idc",
RedirectURI: kiroIDCRedirectURI,
ClientID: reg.ClientID,
ClientSecret: reg.ClientSecret,
Region: region,
StartURL: startURL,
})
return &KiroIDCAuthURLResult{
AuthURL: kiropkg.BuildIDCAuthURL(reg.ClientID, kiroIDCRedirectURI, state, kiropkg.GenerateCodeChallenge(codeVerifier), region),
SessionID: sessionID,
State: state,
ClientID: reg.ClientID,
Region: region,
StartURL: startURL,
}, nil
}
func (s *KiroOAuthService) RefreshToken(ctx context.Context, input *KiroRefreshTokenInput) (*KiroTokenInfo, error) {
proxyURL, _ := s.resolveProxyURL(ctx, input.ProxyID)
authMethod := strings.ToLower(strings.TrimSpace(input.AuthMethod))
if authMethod == "" {
authMethod = "social"
}
var token *kiropkg.TokenData
var err error
switch authMethod {
case "idc":
token, err = kiropkg.RefreshIDCToken(ctx, proxyURL, input.ClientID, input.ClientSecret, input.RefreshToken, input.Region, input.StartURL)
default:
token, err = kiropkg.RefreshSocialToken(ctx, proxyURL, input.RefreshToken, input.Provider)
}
if err != nil {
return nil, err
}
if token.ProfileArn == "" {
token.ProfileArn = input.ProfileArn
}
if token.ClientID == "" {
token.ClientID = input.ClientID
}
if token.ClientSecret == "" {
token.ClientSecret = input.ClientSecret
}
if token.StartURL == "" {
token.StartURL = input.StartURL
}
if token.Region == "" {
token.Region = input.Region
}
return toKiroTokenInfo(token), nil
}
func (s *KiroOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*KiroTokenInfo, error) {
if account.Platform != PlatformKiro || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("not a kiro oauth account")
}
return s.RefreshToken(ctx, &KiroRefreshTokenInput{
RefreshToken: account.GetCredential("refresh_token"),
AuthMethod: account.GetCredential("auth_method"),
Provider: account.GetCredential("provider"),
ClientID: account.GetCredential("client_id"),
ClientSecret: account.GetCredential("client_secret"),
StartURL: account.GetCredential("start_url"),
Region: account.GetCredential("region"),
ProfileArn: account.GetCredential("profile_arn"),
ProxyID: account.ProxyID,
})
}
func (s *KiroOAuthService) ImportToken(input *KiroImportTokenInput) (*KiroTokenInfo, error) {
tokenFromRefresh, refreshOnly, err := kiropkg.ParseImportedRefreshToken(input.TokenJSON)
if err != nil {
return nil, err
}
if refreshOnly {
token, err := s.RefreshToken(context.Background(), &KiroRefreshTokenInput{
RefreshToken: tokenFromRefresh.RefreshToken,
AuthMethod: tokenFromRefresh.AuthMethod,
Provider: tokenFromRefresh.Provider,
ClientID: tokenFromRefresh.ClientID,
ClientSecret: tokenFromRefresh.ClientSecret,
StartURL: tokenFromRefresh.StartURL,
Region: tokenFromRefresh.Region,
ProfileArn: tokenFromRefresh.ProfileArn,
ProxyID: input.ProxyID,
})
if err != nil {
return nil, err
}
return token, nil
}
token, err := kiropkg.ParseImportedToken(input.TokenJSON, input.DeviceRegistrationJSON)
if err != nil {
return nil, err
}
return toKiroTokenInfo(token), nil
}
func (s *KiroOAuthService) BuildAccountCredentials(tokenInfo *KiroTokenInfo) map[string]any {
if tokenInfo == nil {
return map[string]any{}
}
creds := map[string]any{}
if tokenInfo.AccessToken != "" {
creds["access_token"] = tokenInfo.AccessToken
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.ProfileArn != "" {
creds["profile_arn"] = tokenInfo.ProfileArn
}
if tokenInfo.ExpiresAt != "" {
creds["expires_at"] = tokenInfo.ExpiresAt
}
if tokenInfo.AuthMethod != "" {
creds["auth_method"] = tokenInfo.AuthMethod
}
if tokenInfo.Provider != "" {
creds["provider"] = tokenInfo.Provider
}
if tokenInfo.ClientID != "" {
creds["client_id"] = tokenInfo.ClientID
}
if tokenInfo.ClientSecret != "" {
creds["client_secret"] = tokenInfo.ClientSecret
}
if tokenInfo.ClientIDHash != "" {
creds["client_id_hash"] = tokenInfo.ClientIDHash
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.StartURL != "" {
creds["start_url"] = tokenInfo.StartURL
}
if tokenInfo.Region != "" {
creds["region"] = tokenInfo.Region
}
return creds
}
func toKiroTokenInfo(token *kiropkg.TokenData) *KiroTokenInfo {
if token == nil {
return nil
}
return &KiroTokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ProfileArn: token.ProfileArn,
ExpiresAt: token.ExpiresAt,
AuthMethod: token.AuthMethod,
Provider: token.Provider,
ClientID: token.ClientID,
ClientSecret: token.ClientSecret,
ClientIDHash: token.ClientIDHash,
Email: token.Email,
StartURL: token.StartURL,
Region: token.Region,
}
}
func (s *KiroOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil || s.proxyRepo == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil || proxy == nil {
return "", err
}
return proxy.URL(), nil
}