package kiro import ( "bytes" "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/google/uuid" ) const ( socialAuthPortalURL = "https://app.kiro.dev" socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" defaultIDCRegion = "us-east-1" BuilderIDStartURL = "https://view.awsapps.com/start" sessionTTL = 10 * time.Minute sessionCleanupEvery = 32 sessionCleanupMin = 32 ) var ( socialAuthEndpointURL = socialAuthEndpoint oidcEndpointOverride = "" ) type SocialProvider string const ( SocialProviderGoogle SocialProvider = "Google" SocialProviderGitHub SocialProvider = "Github" ) type AuthSession struct { State string CodeVerifier string ProxyURL string CreatedAt time.Time AuthType string Provider string RedirectURI string ClientID string ClientSecret string Region string StartURL string } type SessionStore struct { mu sync.RWMutex data map[string]*AuthSession setCount uint64 } func NewSessionStore() *SessionStore { return &SessionStore{data: make(map[string]*AuthSession)} } func (s *SessionStore) Get(id string) (*AuthSession, bool) { now := time.Now() s.mu.Lock() defer s.mu.Unlock() session, ok := s.data[id] if ok && sessionExpired(session, now) { delete(s.data, id) return nil, false } return session, ok } func (s *SessionStore) Set(id string, session *AuthSession) { s.mu.Lock() defer s.mu.Unlock() s.setCount++ if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 { s.pruneExpiredLocked(time.Now()) } s.data[id] = session } func (s *SessionStore) Delete(id string) { s.mu.Lock() defer s.mu.Unlock() delete(s.data, id) } func (s *SessionStore) pruneExpiredLocked(now time.Time) { for id, session := range s.data { if sessionExpired(session, now) { delete(s.data, id) } } } func sessionExpired(session *AuthSession, now time.Time) bool { if session == nil { return true } if session.CreatedAt.IsZero() { return true } return now.After(session.CreatedAt.Add(sessionTTL)) } type TokenData struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ProfileArn string `json:"profileArn,omitempty"` ExpiresAt string `json:"expiresAt,omitempty"` AuthMethod string `json:"authMethod,omitempty"` Provider string `json:"provider,omitempty"` ClientID string `json:"clientId,omitempty"` ClientSecret string `json:"clientSecret,omitempty"` ClientIDHash string `json:"clientIdHash,omitempty"` Email string `json:"email,omitempty"` StartURL string `json:"startUrl,omitempty"` Region string `json:"region,omitempty"` } type socialTokenResponse struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ProfileArn string `json:"profileArn"` ExpiresIn int `json:"expiresIn"` } type registerClientResponse struct { ClientID string `json:"clientId"` ClientSecret string `json:"clientSecret"` } type createTokenResponse struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ProfileArn string `json:"profileArn"` ExpiresIn int `json:"expiresIn"` } type userInfoResponse struct { Email string `json:"email"` } type deviceRegistration struct { ClientID string `json:"clientId"` ClientSecret string `json:"clientSecret"` } type RefreshTokenInvalidError struct { StatusCode int Body string } func (e *RefreshTokenInvalidError) Error() string { if e == nil { return "" } body := strings.TrimSpace(e.Body) if body == "" { return "kiro refresh token invalid (invalid_grant)" } return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body) } func GenerateSessionID() string { return uuid.NewString() } func GenerateState() (string, error) { return randomURLSafe(16) } func GenerateCodeVerifier() (string, error) { return randomURLSafe(32) } func randomURLSafe(n int) (string, error) { buf := make([]byte, n) if _, err := rand.Read(buf); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(buf), nil } func GenerateCodeChallenge(verifier string) string { sum := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(sum[:]) } func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string { params := url.Values{} params.Set("state", state) params.Set("code_challenge", codeChallenge) params.Set("code_challenge_method", "S256") params.Set("redirect_uri", redirectURI) params.Set("redirect_from", "KiroIDE") return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode()) } func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string { redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/") if redirectURI == "" { return "" } path := strings.TrimSpace(callbackPath) if path == "" { path = "/oauth/callback" } else if !strings.HasPrefix(path, "/") { path = "/" + path } fullRedirectURI := redirectURI + path if option := strings.TrimSpace(loginOption); option != "" { return fullRedirectURI + "?login_option=" + url.QueryEscape(option) } return fullRedirectURI } func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) { payload := map[string]string{ "code": code, "code_verifier": codeVerifier, "redirect_uri": redirectURI, } var resp socialTokenResponse if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil { return nil, err } expiresIn := resp.ExpiresIn if expiresIn <= 0 { expiresIn = 3600 } return &TokenData{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ProfileArn: resp.ProfileArn, ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339), AuthMethod: "social", Region: defaultIDCRegion, }, nil } func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) { payload := map[string]string{ "refreshToken": refreshToken, } var resp socialTokenResponse accountKey := BuildAccountKey("", "", refreshToken, "", 0) if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil { return nil, err } expiresIn := resp.ExpiresIn if expiresIn <= 0 { expiresIn = 3600 } return &TokenData{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ProfileArn: resp.ProfileArn, ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339), AuthMethod: "social", Provider: provider, Region: defaultIDCRegion, }, nil } func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) { if region == "" { region = defaultIDCRegion } payload := map[string]any{ "clientName": "Kiro IDE", "clientType": "public", "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, "grantTypes": []string{"authorization_code", "refresh_token"}, "redirectUris": []string{redirectURI}, "issuerUrl": issuerURL, } var resp registerClientResponse headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client")) if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil { return nil, err } return &resp, nil } func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string { if region == "" { region = defaultIDCRegion } params := url.Values{} params.Set("response_type", "code") params.Set("client_id", clientID) params.Set("redirect_uri", redirectURI) params.Set("scopes", strings.Join([]string{ "codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist", }, " ")) params.Set("state", state) params.Set("code_challenge", codeChallenge) params.Set("code_challenge_method", "S256") return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode()) } func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) { if region == "" { region = defaultIDCRegion } payload := map[string]string{ "clientId": clientID, "clientSecret": clientSecret, "code": code, "codeVerifier": codeVerifier, "redirectUri": redirectURI, "grantType": "authorization_code", } var resp createTokenResponse accountKey := BuildAccountKey(clientID, "", "", "", 0) headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID)) if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil { return nil, err } expiresIn := resp.ExpiresIn if expiresIn <= 0 { expiresIn = 3600 } token := &TokenData{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ProfileArn: resp.ProfileArn, ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339), AuthMethod: "idc", Provider: "AWS", ClientID: clientID, ClientSecret: clientSecret, StartURL: startURL, Region: region, } token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region) return token, nil } func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) { if region == "" { region = defaultIDCRegion } payload := map[string]string{ "clientId": clientID, "clientSecret": clientSecret, "refreshToken": refreshToken, "grantType": "refresh_token", } var resp createTokenResponse accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0) headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey)) if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil { return nil, err } expiresIn := resp.ExpiresIn if expiresIn <= 0 { expiresIn = 3600 } token := &TokenData{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ProfileArn: resp.ProfileArn, ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339), AuthMethod: "idc", Provider: "AWS", ClientID: clientID, ClientSecret: clientSecret, StartURL: startURL, Region: region, } token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region) return token, nil } func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string { if strings.TrimSpace(accessToken) == "" { return "" } var resp userInfoResponse headers := map[string]string{ "Authorization": "Bearer " + accessToken, } if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil { return "" } return strings.TrimSpace(resp.Email) } func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) { var token TokenData if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil { return nil, fmt.Errorf("failed to parse kiro token: %w", err) } token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod)) if strings.TrimSpace(token.AccessToken) == "" { return nil, fmt.Errorf("access token is empty") } if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" { var reg deviceRegistration if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil { return nil, fmt.Errorf("failed to parse device registration: %w", err) } if reg.ClientID != "" { token.ClientID = reg.ClientID } if reg.ClientSecret != "" { token.ClientSecret = reg.ClientSecret } } return &token, nil } func getOIDCEndpoint(region string) string { if strings.TrimSpace(oidcEndpointOverride) != "" { return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/") } if region == "" { region = defaultIDCRegion } return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) } func oidcHeaders(accountKey, machineID string) map[string]string { headers := BuildOIDCHeaders(accountKey, machineID) if headers["amz-sdk-invocation-id"] == "" { headers["amz-sdk-invocation-id"] = uuid.NewString() } if headers["amz-sdk-request"] == "" { headers["amz-sdk-request"] = "attempt=1; max=4" } return headers } func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error { client, err := newHTTPClient(proxyURL) if err != nil { return err } var body io.Reader if payload != nil { encoded, err := json.Marshal(payload) if err != nil { return err } body = bytes.NewReader(encoded) } req, err := http.NewRequestWithContext(ctx, method, rawURL, body) if err != nil { return err } if payload != nil { req.Header.Set("Content-Type", "application/json") } for key, value := range extraHeaders { req.Header.Set(key, value) } resp, err := client.Do(req) if err != nil { return err } defer func() { _ = resp.Body.Close() }() respBody, err := io.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { bodyText := strings.TrimSpace(string(respBody)) if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") { return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText} } return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText) } if out == nil || len(respBody) == 0 { return nil } return json.Unmarshal(respBody, out) } func newHTTPClient(rawProxyURL string) (*http.Client, error) { _, parsed, err := proxyurl.Parse(rawProxyURL) if err != nil { return nil, err } transport := &http.Transport{} if parsed != nil { transport.Proxy = http.ProxyURL(parsed) } return &http.Client{ Timeout: 30 * time.Second, Transport: transport, }, nil }