530 lines
15 KiB
Go
530 lines
15 KiB
Go
package kiro
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const (
|
|
socialAuthPortalURL = "https://app.kiro.dev"
|
|
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
|
defaultIDCRegion = "us-east-1"
|
|
BuilderIDStartURL = "https://view.awsapps.com/start"
|
|
sessionTTL = 10 * time.Minute
|
|
sessionCleanupEvery = 32
|
|
sessionCleanupMin = 32
|
|
)
|
|
|
|
var (
|
|
socialAuthEndpointURL = socialAuthEndpoint
|
|
oidcEndpointOverride = ""
|
|
)
|
|
|
|
type SocialProvider string
|
|
|
|
const (
|
|
SocialProviderGoogle SocialProvider = "Google"
|
|
SocialProviderGitHub SocialProvider = "Github"
|
|
)
|
|
|
|
type AuthSession struct {
|
|
State string
|
|
CodeVerifier string
|
|
ProxyURL string
|
|
CreatedAt time.Time
|
|
AuthType string
|
|
Provider string
|
|
RedirectURI string
|
|
ClientID string
|
|
ClientSecret string
|
|
Region string
|
|
StartURL string
|
|
}
|
|
|
|
type SessionStore struct {
|
|
mu sync.RWMutex
|
|
data map[string]*AuthSession
|
|
setCount uint64
|
|
}
|
|
|
|
func NewSessionStore() *SessionStore {
|
|
return &SessionStore{data: make(map[string]*AuthSession)}
|
|
}
|
|
|
|
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
|
|
now := time.Now()
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
session, ok := s.data[id]
|
|
if ok && sessionExpired(session, now) {
|
|
delete(s.data, id)
|
|
return nil, false
|
|
}
|
|
return session, ok
|
|
}
|
|
|
|
func (s *SessionStore) Set(id string, session *AuthSession) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.setCount++
|
|
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
|
|
s.pruneExpiredLocked(time.Now())
|
|
}
|
|
s.data[id] = session
|
|
}
|
|
|
|
func (s *SessionStore) Delete(id string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.data, id)
|
|
}
|
|
|
|
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
|
|
for id, session := range s.data {
|
|
if sessionExpired(session, now) {
|
|
delete(s.data, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func sessionExpired(session *AuthSession, now time.Time) bool {
|
|
if session == nil {
|
|
return true
|
|
}
|
|
if session.CreatedAt.IsZero() {
|
|
return true
|
|
}
|
|
return now.After(session.CreatedAt.Add(sessionTTL))
|
|
}
|
|
|
|
type TokenData struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ProfileArn string `json:"profileArn,omitempty"`
|
|
ExpiresAt string `json:"expiresAt,omitempty"`
|
|
AuthMethod string `json:"authMethod,omitempty"`
|
|
Provider string `json:"provider,omitempty"`
|
|
ClientID string `json:"clientId,omitempty"`
|
|
ClientSecret string `json:"clientSecret,omitempty"`
|
|
ClientIDHash string `json:"clientIdHash,omitempty"`
|
|
Email string `json:"email,omitempty"`
|
|
StartURL string `json:"startUrl,omitempty"`
|
|
Region string `json:"region,omitempty"`
|
|
}
|
|
|
|
type socialTokenResponse struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ProfileArn string `json:"profileArn"`
|
|
ExpiresIn int `json:"expiresIn"`
|
|
}
|
|
|
|
type registerClientResponse struct {
|
|
ClientID string `json:"clientId"`
|
|
ClientSecret string `json:"clientSecret"`
|
|
}
|
|
|
|
type createTokenResponse struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ProfileArn string `json:"profileArn"`
|
|
ExpiresIn int `json:"expiresIn"`
|
|
}
|
|
|
|
type userInfoResponse struct {
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
type deviceRegistration struct {
|
|
ClientID string `json:"clientId"`
|
|
ClientSecret string `json:"clientSecret"`
|
|
}
|
|
|
|
type RefreshTokenInvalidError struct {
|
|
StatusCode int
|
|
Body string
|
|
}
|
|
|
|
func (e *RefreshTokenInvalidError) Error() string {
|
|
if e == nil {
|
|
return ""
|
|
}
|
|
body := strings.TrimSpace(e.Body)
|
|
if body == "" {
|
|
return "kiro refresh token invalid (invalid_grant)"
|
|
}
|
|
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
|
|
}
|
|
|
|
func GenerateSessionID() string {
|
|
return uuid.NewString()
|
|
}
|
|
|
|
func GenerateState() (string, error) {
|
|
return randomURLSafe(16)
|
|
}
|
|
|
|
func GenerateCodeVerifier() (string, error) {
|
|
return randomURLSafe(32)
|
|
}
|
|
|
|
func randomURLSafe(n int) (string, error) {
|
|
buf := make([]byte, n)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
|
}
|
|
|
|
func GenerateCodeChallenge(verifier string) string {
|
|
sum := sha256.Sum256([]byte(verifier))
|
|
return base64.RawURLEncoding.EncodeToString(sum[:])
|
|
}
|
|
|
|
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
|
|
params := url.Values{}
|
|
params.Set("state", state)
|
|
params.Set("code_challenge", codeChallenge)
|
|
params.Set("code_challenge_method", "S256")
|
|
params.Set("redirect_uri", redirectURI)
|
|
params.Set("redirect_from", "KiroIDE")
|
|
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
|
|
}
|
|
|
|
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
|
|
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
|
|
if redirectURI == "" {
|
|
return ""
|
|
}
|
|
path := strings.TrimSpace(callbackPath)
|
|
if path == "" {
|
|
path = "/oauth/callback"
|
|
} else if !strings.HasPrefix(path, "/") {
|
|
path = "/" + path
|
|
}
|
|
fullRedirectURI := redirectURI + path
|
|
if option := strings.TrimSpace(loginOption); option != "" {
|
|
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
|
|
}
|
|
return fullRedirectURI
|
|
}
|
|
|
|
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
|
|
payload := map[string]string{
|
|
"code": code,
|
|
"code_verifier": codeVerifier,
|
|
"redirect_uri": redirectURI,
|
|
}
|
|
var resp socialTokenResponse
|
|
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
|
|
return nil, err
|
|
}
|
|
expiresIn := resp.ExpiresIn
|
|
if expiresIn <= 0 {
|
|
expiresIn = 3600
|
|
}
|
|
return &TokenData{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ProfileArn: resp.ProfileArn,
|
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
|
AuthMethod: "social",
|
|
Region: defaultIDCRegion,
|
|
}, nil
|
|
}
|
|
|
|
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
|
|
payload := map[string]string{
|
|
"refreshToken": refreshToken,
|
|
}
|
|
var resp socialTokenResponse
|
|
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
|
|
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
|
|
return nil, err
|
|
}
|
|
expiresIn := resp.ExpiresIn
|
|
if expiresIn <= 0 {
|
|
expiresIn = 3600
|
|
}
|
|
return &TokenData{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ProfileArn: resp.ProfileArn,
|
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
|
AuthMethod: "social",
|
|
Provider: provider,
|
|
Region: defaultIDCRegion,
|
|
}, nil
|
|
}
|
|
|
|
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
|
|
if region == "" {
|
|
region = defaultIDCRegion
|
|
}
|
|
payload := map[string]any{
|
|
"clientName": "Kiro IDE",
|
|
"clientType": "public",
|
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
|
"grantTypes": []string{"authorization_code", "refresh_token"},
|
|
"redirectUris": []string{redirectURI},
|
|
"issuerUrl": issuerURL,
|
|
}
|
|
var resp registerClientResponse
|
|
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
|
|
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
|
|
return nil, err
|
|
}
|
|
return &resp, nil
|
|
}
|
|
|
|
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
|
|
if region == "" {
|
|
region = defaultIDCRegion
|
|
}
|
|
params := url.Values{}
|
|
params.Set("response_type", "code")
|
|
params.Set("client_id", clientID)
|
|
params.Set("redirect_uri", redirectURI)
|
|
params.Set("scopes", strings.Join([]string{
|
|
"codewhisperer:completions",
|
|
"codewhisperer:analysis",
|
|
"codewhisperer:conversations",
|
|
"codewhisperer:transformations",
|
|
"codewhisperer:taskassist",
|
|
}, " "))
|
|
params.Set("state", state)
|
|
params.Set("code_challenge", codeChallenge)
|
|
params.Set("code_challenge_method", "S256")
|
|
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
|
|
}
|
|
|
|
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
|
|
if region == "" {
|
|
region = defaultIDCRegion
|
|
}
|
|
payload := map[string]string{
|
|
"clientId": clientID,
|
|
"clientSecret": clientSecret,
|
|
"code": code,
|
|
"codeVerifier": codeVerifier,
|
|
"redirectUri": redirectURI,
|
|
"grantType": "authorization_code",
|
|
}
|
|
var resp createTokenResponse
|
|
accountKey := BuildAccountKey(clientID, "", "", "", 0)
|
|
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
|
|
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
|
return nil, err
|
|
}
|
|
expiresIn := resp.ExpiresIn
|
|
if expiresIn <= 0 {
|
|
expiresIn = 3600
|
|
}
|
|
token := &TokenData{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ProfileArn: resp.ProfileArn,
|
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
|
AuthMethod: "idc",
|
|
Provider: "AWS",
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
StartURL: startURL,
|
|
Region: region,
|
|
}
|
|
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
|
return token, nil
|
|
}
|
|
|
|
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
|
|
if region == "" {
|
|
region = defaultIDCRegion
|
|
}
|
|
payload := map[string]string{
|
|
"clientId": clientID,
|
|
"clientSecret": clientSecret,
|
|
"refreshToken": refreshToken,
|
|
"grantType": "refresh_token",
|
|
}
|
|
var resp createTokenResponse
|
|
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
|
|
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
|
|
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
|
return nil, err
|
|
}
|
|
expiresIn := resp.ExpiresIn
|
|
if expiresIn <= 0 {
|
|
expiresIn = 3600
|
|
}
|
|
token := &TokenData{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ProfileArn: resp.ProfileArn,
|
|
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
|
AuthMethod: "idc",
|
|
Provider: "AWS",
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
StartURL: startURL,
|
|
Region: region,
|
|
}
|
|
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
|
return token, nil
|
|
}
|
|
|
|
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
|
|
if strings.TrimSpace(accessToken) == "" {
|
|
return ""
|
|
}
|
|
var resp userInfoResponse
|
|
headers := map[string]string{
|
|
"Authorization": "Bearer " + accessToken,
|
|
}
|
|
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(resp.Email)
|
|
}
|
|
|
|
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
|
|
var token TokenData
|
|
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
|
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
|
|
}
|
|
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
|
if strings.TrimSpace(token.AccessToken) == "" {
|
|
return nil, fmt.Errorf("access token is empty")
|
|
}
|
|
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
|
|
var reg deviceRegistration
|
|
if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil {
|
|
return nil, fmt.Errorf("failed to parse device registration: %w", err)
|
|
}
|
|
if reg.ClientID != "" {
|
|
token.ClientID = reg.ClientID
|
|
}
|
|
if reg.ClientSecret != "" {
|
|
token.ClientSecret = reg.ClientSecret
|
|
}
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
func ParseImportedRefreshToken(tokenJSON string) (*TokenData, bool, error) {
|
|
var token TokenData
|
|
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
|
return nil, false, fmt.Errorf("failed to parse kiro token: %w", err)
|
|
}
|
|
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
|
if token.AuthMethod == "" {
|
|
token.AuthMethod = "social"
|
|
}
|
|
if token.Provider == "" && token.AuthMethod == "social" {
|
|
token.Provider = string(SocialProviderGoogle)
|
|
}
|
|
if strings.TrimSpace(token.RefreshToken) == "" || strings.TrimSpace(token.AccessToken) != "" {
|
|
return &token, false, nil
|
|
}
|
|
return &token, true, nil
|
|
}
|
|
|
|
func getOIDCEndpoint(region string) string {
|
|
if strings.TrimSpace(oidcEndpointOverride) != "" {
|
|
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
|
|
}
|
|
if region == "" {
|
|
region = defaultIDCRegion
|
|
}
|
|
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
|
}
|
|
|
|
func oidcHeaders(accountKey, machineID string) map[string]string {
|
|
headers := BuildOIDCHeaders(accountKey, machineID)
|
|
if headers["amz-sdk-invocation-id"] == "" {
|
|
headers["amz-sdk-invocation-id"] = uuid.NewString()
|
|
}
|
|
if headers["amz-sdk-request"] == "" {
|
|
headers["amz-sdk-request"] = "attempt=1; max=4"
|
|
}
|
|
return headers
|
|
}
|
|
|
|
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
|
|
client, err := newHTTPClient(proxyURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var body io.Reader
|
|
if payload != nil {
|
|
encoded, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
body = bytes.NewReader(encoded)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if payload != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
for key, value := range extraHeaders {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
bodyText := strings.TrimSpace(string(respBody))
|
|
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
|
|
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
|
|
}
|
|
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
|
|
}
|
|
if out == nil || len(respBody) == 0 {
|
|
return nil
|
|
}
|
|
return json.Unmarshal(respBody, out)
|
|
}
|
|
|
|
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
|
|
_, parsed, err := proxyurl.Parse(rawProxyURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
transport := &http.Transport{}
|
|
if parsed != nil {
|
|
transport.Proxy = http.ProxyURL(parsed)
|
|
}
|
|
return &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
Transport: transport,
|
|
}, nil
|
|
}
|