Files
2026-05-16 15:31:08 +08:00

530 lines
15 KiB
Go

package kiro
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/google/uuid"
)
const (
socialAuthPortalURL = "https://app.kiro.dev"
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
defaultIDCRegion = "us-east-1"
BuilderIDStartURL = "https://view.awsapps.com/start"
sessionTTL = 10 * time.Minute
sessionCleanupEvery = 32
sessionCleanupMin = 32
)
var (
socialAuthEndpointURL = socialAuthEndpoint
oidcEndpointOverride = ""
)
type SocialProvider string
const (
SocialProviderGoogle SocialProvider = "Google"
SocialProviderGitHub SocialProvider = "Github"
)
type AuthSession struct {
State string
CodeVerifier string
ProxyURL string
CreatedAt time.Time
AuthType string
Provider string
RedirectURI string
ClientID string
ClientSecret string
Region string
StartURL string
}
type SessionStore struct {
mu sync.RWMutex
data map[string]*AuthSession
setCount uint64
}
func NewSessionStore() *SessionStore {
return &SessionStore{data: make(map[string]*AuthSession)}
}
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.data[id]
if ok && sessionExpired(session, now) {
delete(s.data, id)
return nil, false
}
return session, ok
}
func (s *SessionStore) Set(id string, session *AuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.setCount++
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
s.pruneExpiredLocked(time.Now())
}
s.data[id] = session
}
func (s *SessionStore) Delete(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, id)
}
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
for id, session := range s.data {
if sessionExpired(session, now) {
delete(s.data, id)
}
}
}
func sessionExpired(session *AuthSession, now time.Time) bool {
if session == nil {
return true
}
if session.CreatedAt.IsZero() {
return true
}
return now.After(session.CreatedAt.Add(sessionTTL))
}
type TokenData struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn,omitempty"`
ExpiresAt string `json:"expiresAt,omitempty"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
ClientIDHash string `json:"clientIdHash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"startUrl,omitempty"`
Region string `json:"region,omitempty"`
}
type socialTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type registerClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type createTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type userInfoResponse struct {
Email string `json:"email"`
}
type deviceRegistration struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type RefreshTokenInvalidError struct {
StatusCode int
Body string
}
func (e *RefreshTokenInvalidError) Error() string {
if e == nil {
return ""
}
body := strings.TrimSpace(e.Body)
if body == "" {
return "kiro refresh token invalid (invalid_grant)"
}
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
}
func GenerateSessionID() string {
return uuid.NewString()
}
func GenerateState() (string, error) {
return randomURLSafe(16)
}
func GenerateCodeVerifier() (string, error) {
return randomURLSafe(32)
}
func randomURLSafe(n int) (string, error) {
buf := make([]byte, n)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func GenerateCodeChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
params := url.Values{}
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("redirect_uri", redirectURI)
params.Set("redirect_from", "KiroIDE")
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
}
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
if redirectURI == "" {
return ""
}
path := strings.TrimSpace(callbackPath)
if path == "" {
path = "/oauth/callback"
} else if !strings.HasPrefix(path, "/") {
path = "/" + path
}
fullRedirectURI := redirectURI + path
if option := strings.TrimSpace(loginOption); option != "" {
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
}
return fullRedirectURI
}
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
payload := map[string]string{
"code": code,
"code_verifier": codeVerifier,
"redirect_uri": redirectURI,
}
var resp socialTokenResponse
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Region: defaultIDCRegion,
}, nil
}
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
var resp socialTokenResponse
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Provider: provider,
Region: defaultIDCRegion,
}, nil
}
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]any{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"authorization_code", "refresh_token"},
"redirectUris": []string{redirectURI},
"issuerUrl": issuerURL,
}
var resp registerClientResponse
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
return nil, err
}
return &resp, nil
}
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
if region == "" {
region = defaultIDCRegion
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scopes", strings.Join([]string{
"codewhisperer:completions",
"codewhisperer:analysis",
"codewhisperer:conversations",
"codewhisperer:transformations",
"codewhisperer:taskassist",
}, " "))
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
}
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"code": code,
"codeVerifier": codeVerifier,
"redirectUri": redirectURI,
"grantType": "authorization_code",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", "", "", 0)
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
if strings.TrimSpace(accessToken) == "" {
return ""
}
var resp userInfoResponse
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
}
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
return ""
}
return strings.TrimSpace(resp.Email)
}
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
var token TokenData
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
}
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
if strings.TrimSpace(token.AccessToken) == "" {
return nil, fmt.Errorf("access token is empty")
}
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
var reg deviceRegistration
if err := json.Unmarshal([]byte(deviceRegistrationJSON), &reg); err != nil {
return nil, fmt.Errorf("failed to parse device registration: %w", err)
}
if reg.ClientID != "" {
token.ClientID = reg.ClientID
}
if reg.ClientSecret != "" {
token.ClientSecret = reg.ClientSecret
}
}
return &token, nil
}
func ParseImportedRefreshToken(tokenJSON string) (*TokenData, bool, error) {
var token TokenData
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, false, fmt.Errorf("failed to parse kiro token: %w", err)
}
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
if token.AuthMethod == "" {
token.AuthMethod = "social"
}
if token.Provider == "" && token.AuthMethod == "social" {
token.Provider = string(SocialProviderGoogle)
}
if strings.TrimSpace(token.RefreshToken) == "" || strings.TrimSpace(token.AccessToken) != "" {
return &token, false, nil
}
return &token, true, nil
}
func getOIDCEndpoint(region string) string {
if strings.TrimSpace(oidcEndpointOverride) != "" {
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
}
if region == "" {
region = defaultIDCRegion
}
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
}
func oidcHeaders(accountKey, machineID string) map[string]string {
headers := BuildOIDCHeaders(accountKey, machineID)
if headers["amz-sdk-invocation-id"] == "" {
headers["amz-sdk-invocation-id"] = uuid.NewString()
}
if headers["amz-sdk-request"] == "" {
headers["amz-sdk-request"] = "attempt=1; max=4"
}
return headers
}
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
client, err := newHTTPClient(proxyURL)
if err != nil {
return err
}
var body io.Reader
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(encoded)
}
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
if err != nil {
return err
}
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
for key, value := range extraHeaders {
req.Header.Set(key, value)
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyText := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
}
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
}
if out == nil || len(respBody) == 0 {
return nil
}
return json.Unmarshal(respBody, out)
}
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
_, parsed, err := proxyurl.Parse(rawProxyURL)
if err != nil {
return nil, err
}
transport := &http.Transport{}
if parsed != nil {
transport.Proxy = http.ProxyURL(parsed)
}
return &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}, nil
}