feat(backend): add kiro account support

This commit is contained in:
nianzs
2026-04-29 16:29:21 +08:00
parent 9d801595c9
commit 05bc424c9a
60 changed files with 11916 additions and 38 deletions
+258
View File
@@ -0,0 +1,258 @@
package kiro
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"runtime"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type RuntimeFingerprint struct {
OIDCSDKVersion string
RuntimeSDKVersion string
StreamingSDKVersion string
OSType string
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string
}
type runtimeFingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*RuntimeFingerprint
}
var (
globalRuntimeFingerprintManager *runtimeFingerprintManager
globalRuntimeFingerprintManagerOnce sync.Once
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
runtimeSDKVersions = []string{"1.0.0"}
streamingSDKVersions = []string{"1.0.34"}
osTypes = []string{"darwin", "win32"}
osVersions = map[string][]string{
"darwin": {"24.6.0"},
"win32": {"10.0.22631"},
}
nodeVersions = []string{"22.22.0"}
kiroVersions = []string{
"0.11.132", "0.11.131", "0.11.130",
}
)
func globalRuntimeFingerprints() *runtimeFingerprintManager {
globalRuntimeFingerprintManagerOnce.Do(func() {
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
fingerprints: make(map[string]*RuntimeFingerprint),
}
})
return globalRuntimeFingerprintManager
}
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
lookupKey := fingerprintLookupKey(accountKey, "runtime")
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
m.mu.RLock()
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
m.mu.RUnlock()
return fp
}
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
return fp
}
fp := generateRuntimeFingerprint(lookupKey, machineID)
m.fingerprints[lookupKey] = fp
return fp
}
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
hash := sha256.Sum256([]byte(accountKey))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
osType := goOSToNodePlatform(runtime.GOOS)
if !containsString(osTypes, osType) {
osType = osTypes[rng.Intn(len(osTypes))]
}
osVersionPool := osVersions[osType]
if len(osVersionPool) == 0 {
osVersionPool = osVersions["darwin"]
}
return &RuntimeFingerprint{
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
OSType: osType,
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
KiroHash: machineID,
}
}
func goOSToNodePlatform(goos string) string {
switch strings.TrimSpace(goos) {
case "windows":
return "win32"
default:
return strings.TrimSpace(goos)
}
}
func containsString(items []string, target string) bool {
for _, item := range items {
if item == target {
return true
}
}
return false
}
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
switch {
case strings.TrimSpace(clientIDHash) != "":
return clientIDHash
case strings.TrimSpace(clientID) != "":
return shortSHA(clientID)
case strings.TrimSpace(refreshToken) != "":
return shortSHA(refreshToken)
case strings.TrimSpace(profileArn) != "":
return shortSHA(profileArn)
case accountID > 0:
return shortSHA(fmt.Sprintf("account:%d", accountID))
default:
return shortSHA(uuid.NewString())
}
}
func NormalizeMachineID(machineID string) (string, bool) {
trimmed := strings.TrimSpace(machineID)
if len(trimmed) == 64 && isHexString(trimmed) {
return strings.ToLower(trimmed), true
}
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
normalized := strings.ToLower(withoutDashes)
return normalized + normalized, true
}
return "", false
}
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
return sha256Hex("KotlinNativeAPI/" + refreshToken)
}
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
return sha256Hex("KiroAPIKey/" + apiKey)
}
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
return sha256Hex("KiroFallback/" + fallbackKey)
}
return sha256Hex("KiroFallback/default")
}
func shortSHA(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
func sha256Hex(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:])
}
func isHexString(value string) bool {
for _, c := range value {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
if normalized, ok := NormalizeMachineID(machineID); ok {
return normalized
}
return BuildMachineID("", "", fallbackKey)
}
func fingerprintLookupKey(accountKey, fallback string) string {
key := strings.TrimSpace(accountKey)
if key != "" {
return key
}
return fallback
}
func BuildRuntimeUserAgent(accountKey, machineID string) string {
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
return map[string]string{
"Content-Type": "application/json",
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
"amz-sdk-invocation-id": uuid.NewString(),
"amz-sdk-request": "attempt=1; max=4",
}
}
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
return map[string]string{
"Content-Type": "application/json",
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
"Accept": "application/json, text/plain, */*",
}
}
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := baseDelay << attempt
if delay > maxDelay {
delay = maxDelay
}
const jitterFactor = 0.3
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
return time.Duration(float64(delay) * jitter)
}
@@ -0,0 +1,91 @@
package kiro
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildLoginHeadersStable(t *testing.T) {
headers1 := BuildLoginHeaders("", "")
headers2 := BuildLoginHeaders("", "")
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
require.Equal(t, "application/json", headers1["Content-Type"])
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
}
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
machineIDA := BuildMachineID("refresh-a", "", "")
machineIDB := BuildMachineID("refresh-b", "", "")
headers1 := BuildLoginHeaders("account-a", machineIDA)
headers2 := BuildLoginHeaders("account-a", machineIDA)
headers3 := BuildLoginHeaders("account-a", machineIDB)
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
require.Contains(t, headers1["User-Agent"], machineIDA)
}
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
machineID := BuildMachineID("", "", "oidc-machine")
headers1 := BuildOIDCHeaders("account-a", machineID)
headers2 := BuildOIDCHeaders("account-a", machineID)
headers3 := BuildOIDCHeaders("account-b", machineID)
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
}
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
key1 := BuildAccountKey("", "", "", "", 42)
key2 := BuildAccountKey("", "", "", "", 42)
key3 := BuildAccountKey("", "", "", "", 43)
require.Equal(t, key1, key2)
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
require.NotEqual(t, key1, key3)
}
func TestBuildMachineID(t *testing.T) {
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
fallback1 := BuildMachineID("", "", "account:1")
fallback2 := BuildMachineID("", "", "account:1")
fallback3 := BuildMachineID("", "", "account:2")
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
require.Equal(t, fallback1, fallback2)
require.NotEqual(t, fallback1, fallback3)
require.Len(t, fallback1, 64)
}
func TestNormalizeMachineID(t *testing.T) {
hex64 := strings.Repeat("A", 64)
normalized, ok := NormalizeMachineID(hex64)
require.True(t, ok)
require.Equal(t, strings.ToLower(hex64), normalized)
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
require.True(t, ok)
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
_, ok = NormalizeMachineID("not-a-machine-id")
require.False(t, ok)
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
require.False(t, ok)
}
func expectedKiroMachineID(seed string) string {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:])
}
+21
View File
@@ -0,0 +1,21 @@
package kiro
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
var DefaultModels = []Model{
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
}
+43
View File
@@ -0,0 +1,43 @@
package kiro
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
ids := make([]string, 0, len(DefaultModels))
for _, model := range DefaultModels {
ids = append(ids, model.ID)
}
require.Equal(t, []string{
"claude-opus-4-6",
"claude-opus-4-6-thinking",
"claude-sonnet-4-6",
"claude-sonnet-4-6-thinking",
"claude-opus-4-5-20251101",
"claude-opus-4-5-20251101-thinking",
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-5-20250929-thinking",
"claude-haiku-4-5-20251001",
"claude-haiku-4-5-20251001-thinking",
}, ids)
require.Contains(t, ids, "claude-sonnet-4-6")
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
require.NotContains(t, ids, "auto")
require.NotContains(t, ids, "claude-sonnet-4")
require.NotContains(t, ids, "gpt-4o")
require.NotContains(t, ids, "deepseek-3-2")
require.NotContains(t, ids, "minimax-m2-1")
require.NotContains(t, ids, "qwen3-coder-next")
require.NotContains(t, ids, "claude-opus-4-7")
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
for _, id := range ids {
require.NotContains(t, id, "kiro-")
require.NotContains(t, id, "-agentic")
require.NotContains(t, id, "-chat")
}
}
+511
View File
@@ -0,0 +1,511 @@
package kiro
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/google/uuid"
)
const (
socialAuthPortalURL = "https://app.kiro.dev"
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
defaultIDCRegion = "us-east-1"
BuilderIDStartURL = "https://view.awsapps.com/start"
sessionTTL = 10 * time.Minute
sessionCleanupEvery = 32
sessionCleanupMin = 32
)
var (
socialAuthEndpointURL = socialAuthEndpoint
oidcEndpointOverride = ""
)
type SocialProvider string
const (
SocialProviderGoogle SocialProvider = "Google"
SocialProviderGitHub SocialProvider = "Github"
)
type AuthSession struct {
State string
CodeVerifier string
ProxyURL string
CreatedAt time.Time
AuthType string
Provider string
RedirectURI string
ClientID string
ClientSecret string
Region string
StartURL string
}
type SessionStore struct {
mu sync.RWMutex
data map[string]*AuthSession
setCount uint64
}
func NewSessionStore() *SessionStore {
return &SessionStore{data: make(map[string]*AuthSession)}
}
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.data[id]
if ok && sessionExpired(session, now) {
delete(s.data, id)
return nil, false
}
return session, ok
}
func (s *SessionStore) Set(id string, session *AuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.setCount++
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
s.pruneExpiredLocked(time.Now())
}
s.data[id] = session
}
func (s *SessionStore) Delete(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, id)
}
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
for id, session := range s.data {
if sessionExpired(session, now) {
delete(s.data, id)
}
}
}
func sessionExpired(session *AuthSession, now time.Time) bool {
if session == nil {
return true
}
if session.CreatedAt.IsZero() {
return true
}
return now.After(session.CreatedAt.Add(sessionTTL))
}
type TokenData struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn,omitempty"`
ExpiresAt string `json:"expiresAt,omitempty"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
ClientIDHash string `json:"clientIdHash,omitempty"`
Email string `json:"email,omitempty"`
StartURL string `json:"startUrl,omitempty"`
Region string `json:"region,omitempty"`
}
type socialTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type registerClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type createTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
type userInfoResponse struct {
Email string `json:"email"`
}
type deviceRegistration struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
}
type RefreshTokenInvalidError struct {
StatusCode int
Body string
}
func (e *RefreshTokenInvalidError) Error() string {
if e == nil {
return ""
}
body := strings.TrimSpace(e.Body)
if body == "" {
return "kiro refresh token invalid (invalid_grant)"
}
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
}
func GenerateSessionID() string {
return uuid.NewString()
}
func GenerateState() (string, error) {
return randomURLSafe(16)
}
func GenerateCodeVerifier() (string, error) {
return randomURLSafe(32)
}
func randomURLSafe(n int) (string, error) {
buf := make([]byte, n)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func GenerateCodeChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
params := url.Values{}
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("redirect_uri", redirectURI)
params.Set("redirect_from", "KiroIDE")
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
}
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
if redirectURI == "" {
return ""
}
path := strings.TrimSpace(callbackPath)
if path == "" {
path = "/oauth/callback"
} else if !strings.HasPrefix(path, "/") {
path = "/" + path
}
fullRedirectURI := redirectURI + path
if option := strings.TrimSpace(loginOption); option != "" {
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
}
return fullRedirectURI
}
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
payload := map[string]string{
"code": code,
"code_verifier": codeVerifier,
"redirect_uri": redirectURI,
}
var resp socialTokenResponse
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Region: defaultIDCRegion,
}, nil
}
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
var resp socialTokenResponse
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
return &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "social",
Provider: provider,
Region: defaultIDCRegion,
}, nil
}
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]any{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"authorization_code", "refresh_token"},
"redirectUris": []string{redirectURI},
"issuerUrl": issuerURL,
}
var resp registerClientResponse
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
return nil, err
}
return &resp, nil
}
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
if region == "" {
region = defaultIDCRegion
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scopes", strings.Join([]string{
"codewhisperer:completions",
"codewhisperer:analysis",
"codewhisperer:conversations",
"codewhisperer:transformations",
"codewhisperer:taskassist",
}, " "))
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
}
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"code": code,
"codeVerifier": codeVerifier,
"redirectUri": redirectURI,
"grantType": "authorization_code",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", "", "", 0)
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
if region == "" {
region = defaultIDCRegion
}
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
var resp createTokenResponse
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
return nil, err
}
expiresIn := resp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
token := &TokenData{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ProfileArn: resp.ProfileArn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
return token, nil
}
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
if strings.TrimSpace(accessToken) == "" {
return ""
}
var resp userInfoResponse
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
}
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
return ""
}
return strings.TrimSpace(resp.Email)
}
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
var token TokenData
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
}
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
if strings.TrimSpace(token.AccessToken) == "" {
return nil, fmt.Errorf("access token is empty")
}
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
var reg deviceRegistration
if err := json.Unmarshal([]byte(deviceRegistrationJSON), &reg); err != nil {
return nil, fmt.Errorf("failed to parse device registration: %w", err)
}
if reg.ClientID != "" {
token.ClientID = reg.ClientID
}
if reg.ClientSecret != "" {
token.ClientSecret = reg.ClientSecret
}
}
return &token, nil
}
func getOIDCEndpoint(region string) string {
if strings.TrimSpace(oidcEndpointOverride) != "" {
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
}
if region == "" {
region = defaultIDCRegion
}
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
}
func oidcHeaders(accountKey, machineID string) map[string]string {
headers := BuildOIDCHeaders(accountKey, machineID)
if headers["amz-sdk-invocation-id"] == "" {
headers["amz-sdk-invocation-id"] = uuid.NewString()
}
if headers["amz-sdk-request"] == "" {
headers["amz-sdk-request"] = "attempt=1; max=4"
}
return headers
}
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
client, err := newHTTPClient(proxyURL)
if err != nil {
return err
}
var body io.Reader
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(encoded)
}
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
if err != nil {
return err
}
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
for key, value := range extraHeaders {
req.Header.Set(key, value)
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyText := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
}
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
}
if out == nil || len(respBody) == 0 {
return nil
}
return json.Unmarshal(respBody, out)
}
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
_, parsed, err := proxyurl.Parse(rawProxyURL)
if err != nil {
return nil, err
}
transport := &http.Transport{}
if parsed != nil {
transport.Proxy = http.ProxyURL(parsed)
}
return &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}, nil
}
@@ -0,0 +1,105 @@
package kiro
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/refreshToken", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
}))
defer server.Close()
previous := socialAuthEndpointURL
socialAuthEndpointURL = server.URL
t.Cleanup(func() { socialAuthEndpointURL = previous })
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
require.Error(t, err)
var invalid *RefreshTokenInvalidError
require.True(t, errors.As(err, &invalid))
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
require.Contains(t, invalid.Body, "invalid_grant")
}
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/token", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
require.Error(t, err)
var invalid *RefreshTokenInvalidError
require.True(t, errors.As(err, &invalid))
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
require.Contains(t, invalid.Body, "invalid_grant")
}
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
require.NoError(t, err)
require.Equal(t, profileArn, token.ProfileArn)
require.Equal(t, "kiro@example.com", token.Email)
}
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
previous := oidcEndpointOverride
oidcEndpointOverride = server.URL
t.Cleanup(func() { oidcEndpointOverride = previous })
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
require.NoError(t, err)
require.Equal(t, profileArn, token.ProfileArn)
require.Equal(t, "kiro@example.com", token.Email)
}
+56
View File
@@ -0,0 +1,56 @@
//go:build unit
package kiro
import (
"fmt"
"testing"
"time"
)
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
if got != want {
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
}
}
func TestBuildSocialTokenRedirectURI(t *testing.T) {
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
want := "http://localhost:49153/oauth/callback?login_option=github"
if got != want {
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
}
}
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
store := NewSessionStore()
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
session, ok := store.Get("expired")
if ok || session != nil {
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
}
if _, exists := store.data["expired"]; exists {
t.Fatalf("expired session should be deleted from the store")
}
}
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
store := NewSessionStore()
now := time.Now()
for i := 0; i < sessionCleanupMin; i++ {
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
}
store.setCount = sessionCleanupEvery - 1
store.Set("fresh", &AuthSession{CreatedAt: now})
if len(store.data) != 1 {
t.Fatalf("store size = %d, want 1", len(store.data))
}
if _, ok := store.data["fresh"]; !ok {
t.Fatalf("fresh session should remain after pruning")
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+368
View File
@@ -0,0 +1,368 @@
package kiro
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
var cachedWebSearchDescription atomic.Value // stores string
type MCPRequest struct {
ID string `json:"id"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}
type MCPResponse struct {
Result *struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
} `json:"result,omitempty"`
Error *struct {
Code *int `json:"code,omitempty"`
Message *string `json:"message,omitempty"`
} `json:"error,omitempty"`
}
type WebSearchResults struct {
Results []WebSearchResult `json:"results"`
}
type WebSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Snippet *string `json:"snippet,omitempty"`
PublishedDate *int64 `json:"publishedDate,omitempty"`
ID *string `json:"id,omitempty"`
Domain *string `json:"domain,omitempty"`
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
PublicDomain *bool `json:"publicDomain,omitempty"`
}
type SearchIndicator struct {
ToolUseID string
Query string
Results *WebSearchResults
}
func GetCachedWebSearchDescription() string {
if v := cachedWebSearchDescription.Load(); v != nil {
return strings.TrimSpace(v.(string))
}
return ""
}
func SetCachedWebSearchDescription(desc string) {
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
}
func BuildMcpEndpoint(region string) string {
if strings.TrimSpace(region) == "" {
region = "us-east-1"
}
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
}
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
return nil
}
for _, item := range resp.Result.Content {
if item.Type != "" && item.Type != "text" {
continue
}
var results WebSearchResults
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
return &results
}
}
return nil
}
func ExtractSearchQuery(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
for i := len(arr) - 1; i >= 0; i-- {
msg := arr[i]
if msg.Get("role").String() != "user" {
continue
}
text := extractSearchText(msg.Get("content"))
const prefix = "Perform a web search for the query: "
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
if text != "" {
return text
}
}
return ""
}
func extractSearchText(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if !content.IsArray() {
return ""
}
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
return text
}
}
}
return ""
}
func GenerateToolUseID() string {
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
}
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err != nil {
return body, err
}
rawTools, ok := payload["tools"].([]interface{})
if !ok {
return body, nil
}
replaced := make([]interface{}, 0, len(rawTools))
for _, rawTool := range rawTools {
tool, ok := rawTool.(map[string]interface{})
if !ok {
replaced = append(replaced, rawTool)
continue
}
name := getInterfaceString(tool["name"])
toolType := getInterfaceString(tool["type"])
if !isWebSearchToolName(name, toolType) {
replaced = append(replaced, rawTool)
continue
}
replaced = append(replaced, map[string]interface{}{
"name": "web_search",
"description": minimalWebSearchDescription,
"input_schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "The search query to execute",
},
},
"required": []string{"query"},
"additionalProperties": false,
},
})
}
payload["tools"] = replaced
updated, err := json.Marshal(payload)
if err != nil {
return body, err
}
return updated, nil
}
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(claudePayload, &payload); err != nil {
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
}
rawMessages, ok := payload["messages"].([]interface{})
if !ok {
return claudePayload, fmt.Errorf("claude payload missing messages array")
}
assistantMsg := map[string]interface{}{
"role": "assistant",
"content": []interface{}{
map[string]interface{}{
"type": "tool_use",
"id": toolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": query},
},
},
}
userContent := []interface{}{
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolUseID,
"content": formatToolResultText(results),
},
}
if guidance := searchGuidanceText(); guidance != "" {
userContent = append(userContent, map[string]interface{}{
"type": "text",
"text": guidance,
})
}
userMsg := map[string]interface{}{
"role": "user",
"content": userContent,
}
rawMessages = append(rawMessages, assistantMsg, userMsg)
payload["messages"] = rawMessages
updated, err := json.Marshal(payload)
if err != nil {
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
}
return updated, nil
}
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
if len(searches) == 0 {
return responsePayload, nil
}
var response map[string]interface{}
if err := json.Unmarshal(responsePayload, &response); err != nil {
return responsePayload, err
}
content, _ := response["content"].([]interface{})
updated := make([]interface{}, 0, len(searches)*2+len(content))
for _, search := range searches {
updated = append(updated, map[string]interface{}{
"type": "server_tool_use",
"id": search.ToolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": search.Query},
})
updated = append(updated, map[string]interface{}{
"type": "web_search_tool_result",
"content": buildSearchResultContent(search.Results),
})
}
updated = append(updated, content...)
response["content"] = updated
encoded, err := json.Marshal(response)
if err != nil {
return responsePayload, err
}
return encoded, nil
}
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
content := make([]map[string]interface{}, 0)
if results == nil {
return content
}
for _, result := range results.Results {
snippet := ""
if result.Snippet != nil {
snippet = strings.TrimSpace(*result.Snippet)
}
content = append(content, map[string]interface{}{
"type": "web_search_result",
"title": result.Title,
"url": result.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
return content
}
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
content := gjson.GetBytes(responsePayload, "content")
if !content.IsArray() {
return "", "", false
}
for _, block := range content.Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if !isWebSearchToolName(name, "") {
continue
}
query = strings.TrimSpace(block.Get("input.query").String())
if query == "" {
continue
}
return block.Get("id").String(), query, true
}
return "", "", false
}
func isWebSearchToolName(name, toolType string) bool {
name = strings.ToLower(strings.TrimSpace(name))
toolType = strings.ToLower(strings.TrimSpace(toolType))
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
return true
}
switch name {
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
return true
default:
return false
}
}
func getInterfaceString(v interface{}) string {
if v == nil {
return ""
}
switch val := v.(type) {
case string:
return strings.TrimSpace(val)
default:
return strings.TrimSpace(fmt.Sprint(val))
}
}
func formatToolResultText(results *WebSearchResults) string {
if results == nil || len(results.Results) == 0 {
return "No search results found."
}
payload, err := json.MarshalIndent(results.Results, "", " ")
if err != nil {
return "Found search results, but failed to format them."
}
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
}
func searchGuidanceText() string {
now := time.Now()
return fmt.Sprintf(`<search_guidance>
Current date: %s (%s)
IMPORTANT: Evaluate the search results above carefully. If the results are:
- Mostly spam, SEO junk, or unrelated websites
- Missing actual information about the query topic
- Outdated or not matching the requested time frame
Then you MUST use the web_search tool again with a refined query. Try:
- Rephrasing in English for better coverage
- Using more specific keywords
- Adding date context
Do NOT apologize for bad results without first attempting a re-search.
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
}
@@ -0,0 +1,297 @@
package kiro
import (
"encoding/json"
"strings"
)
type BufferedStreamResult struct {
StopReason string
WebSearchQuery string
WebSearchToolUseID string
HasWebSearchToolUse bool
WebSearchToolUseIndex int
}
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
searchContent := make([]map[string]interface{}, 0)
if results != nil {
for _, result := range results.Results {
snippet := ""
if result.Snippet != nil {
snippet = strings.TrimSpace(*result.Snippet)
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": result.Title,
"url": result.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
inputJSON, _ := json.Marshal(map[string]string{"query": query})
events := []map[string]interface{}{
{
"type": "content_block_start",
"index": startIndex,
"content_block": map[string]interface{}{
"type": "server_tool_use",
"id": toolUseID,
"name": "web_search",
"input": map[string]interface{}{},
},
},
{
"type": "content_block_delta",
"index": startIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
},
{
"type": "content_block_stop",
"index": startIndex,
},
{
"type": "content_block_start",
"index": startIndex + 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"content": searchContent,
},
},
{
"type": "content_block_stop",
"index": startIndex + 1,
},
}
result := make([][]byte, 0, len(events))
for _, event := range events {
eventType, _ := event["type"].(string)
payload, _ := json.Marshal(event)
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
}
return result
}
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
var currentToolName string
currentToolIndex := -1
var toolInputBuilder strings.Builder
for _, chunk := range chunks {
lines := strings.Split(string(chunk), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
switch eventType, _ := event["type"].(string); eventType {
case "message_delta":
if delta, ok := event["delta"].(map[string]interface{}); ok {
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
result.StopReason = stopReason
}
}
case "content_block_start":
contentBlock, ok := event["content_block"].(map[string]interface{})
if !ok {
continue
}
blockType, _ := contentBlock["type"].(string)
if blockType != "tool_use" {
continue
}
currentToolName, _ = contentBlock["name"].(string)
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
if idx, ok := event["index"].(float64); ok {
currentToolIndex = int(idx)
}
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
}
toolInputBuilder.Reset()
case "content_block_delta":
if currentToolName == "" {
continue
}
delta, ok := event["delta"].(map[string]interface{})
if !ok {
continue
}
deltaType, _ := delta["type"].(string)
if deltaType != "input_json_delta" {
continue
}
if partialJSON, ok := delta["partial_json"].(string); ok {
toolInputBuilder.WriteString(partialJSON)
}
case "content_block_stop":
if !isWebSearchToolName(currentToolName, "") {
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
continue
}
result.HasWebSearchToolUse = true
result.WebSearchToolUseIndex = currentToolIndex
var input map[string]string
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
result.WebSearchQuery = strings.TrimSpace(input["query"])
}
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
}
}
}
return result
}
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
filtered := make([][]byte, 0, len(chunks))
for _, chunk := range chunks {
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
if shouldForward {
filtered = append(filtered, adjusted)
}
}
return filtered
}
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
return filterSSEChunk(chunk, -1, offset)
}
func MaxContentBlockIndex(chunks [][]byte) int {
maxIndex := -1
for _, chunk := range chunks {
lines := strings.Split(string(chunk), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
switch eventType, _ := event["type"].(string); eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
maxIndex = int(idx)
}
}
}
}
return maxIndex
}
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
lines := strings.Split(string(chunk), "\n")
var builder strings.Builder
hasContent := false
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "event: ") {
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
i++
continue
}
}
builder.WriteString(line + "\n")
hasContent = true
continue
}
if strings.HasPrefix(line, "data: ") {
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if payload == "[DONE]" {
continue
}
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
continue
}
adjusted := adjustEventPayload(payload, indexOffset)
if adjusted == "" {
continue
}
builder.WriteString("data: " + adjusted + "\n")
hasContent = true
continue
}
builder.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
if !hasContent {
return nil, false
}
return []byte(builder.String()), true
}
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
if payload == "" {
return false
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
return false
}
eventType, _ := event["type"].(string)
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
return true
}
if webSearchToolUseIndex < 0 {
return false
}
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
return true
}
return false
}
func adjustEventPayload(payload string, indexOffset int) string {
if payload == "" || indexOffset == 0 {
return payload
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(payload), &event); err != nil {
return payload
}
switch eventType, _ := event["type"].(string); eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + indexOffset
if adjusted, err := json.Marshal(event); err == nil {
return string(adjusted)
}
}
}
return payload
}
@@ -0,0 +1,73 @@
package kiro
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
snippet := "result snippet"
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
},
}, 0)
require.Len(t, events, 5)
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
require.Contains(t, string(events[0]), `"input":{}`)
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
require.NotContains(t, string(events[3]), `"tool_use_id"`)
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
}
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
chunks := [][]byte{
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
}
result := AnalyzeBufferedStream(chunks)
require.True(t, result.HasWebSearchToolUse)
require.Equal(t, "golang concurrency", result.WebSearchQuery)
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
require.Equal(t, 1, result.WebSearchToolUseIndex)
require.Equal(t, "tool_use", result.StopReason)
}
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
chunks := [][]byte{
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
}
filtered := FilterChunksForClient(chunks, 1, 2)
require.NotEmpty(t, filtered)
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
require.NotContains(t, joined, `"type":"message_start"`)
require.NotContains(t, joined, `"type":"message_delta"`)
require.NotContains(t, joined, `"name":"web_search"`)
require.Contains(t, joined, `"index":2`)
require.Equal(t, 2, MaxContentBlockIndex(filtered))
}
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
require.False(t, shouldForward)
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
require.True(t, shouldForward)
require.Contains(t, string(adjusted), `"index":3`)
}
+138
View File
@@ -0,0 +1,138 @@
package kiro
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
body := []byte(`{
"tools":[{"type":"web_search_20250305","description":"old"}],
"messages":[{"role":"user","content":"golang"}]
}`)
updated, err := ReplaceWebSearchToolDescription(body)
require.NoError(t, err)
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
}
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
body := []byte(`{
"messages":[{"role":"user","content":"what is golang"}]
}`)
results := &WebSearchResults{
Results: []WebSearchResult{
{Title: "Go", URL: "https://go.dev"},
},
}
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
require.NoError(t, err)
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
}
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
response := []byte(`{
"content":[
{"type":"text","text":"let me search"},
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
]
}`)
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
require.True(t, ok)
require.Equal(t, "srvtoolu_next", toolUseID)
require.Equal(t, "golang concurrency", query)
}
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
response := []byte(`{
"id":"msg_1",
"type":"message",
"role":"assistant",
"model":"kiro",
"content":[{"type":"text","text":"final"}],
"stop_reason":"end_turn",
"usage":{"input_tokens":1,"output_tokens":1}
}`)
snippet := "result snippet"
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
{
ToolUseID: "srvtoolu_test",
Query: "golang",
Results: &WebSearchResults{
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
},
},
})
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(updated, &decoded))
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
}
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
resp := &MCPResponse{
Result: &struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
}{
Content: []struct {
Type string `json:"type"`
Text string `json:"text"`
}{
{
Type: "text",
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
},
},
},
}
results := ParseSearchResults(resp)
require.NotNil(t, results)
require.Len(t, results.Results, 1)
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
require.Equal(t, "doc-1", *results.Results[0].ID)
require.Equal(t, "go.dev", *results.Results[0].Domain)
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
require.True(t, *results.Results[0].PublicDomain)
}
func TestSearchGuidanceText_IsStructured(t *testing.T) {
guidance := searchGuidanceText()
require.Contains(t, guidance, "<search_guidance>")
require.Contains(t, guidance, "Current date:")
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
require.Contains(t, guidance, "Rephrasing in English for better coverage")
}
+479
View File
@@ -0,0 +1,479 @@
package kirocooldown
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
const (
MinRequestInterval = time.Second
MaxRequestInterval = 2 * time.Second
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
ShortCooldown = time.Minute
MaxCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
redisTimeout = 3 * time.Second
activeTTL = 10 * time.Second
stateTTL = 25 * time.Hour
keyPrefix = "kiro:cooldown:"
)
var (
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
reserveRequestScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
local interval_ms = tonumber(ARGV[1])
local active_ttl_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
if cooldown_until_ms > now_ms then
return {1, cooldown_until_ms - now_ms, cooldown_reason}
end
if cooldown_until_ms > 0 then
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
end
local next_slot_ms = now_ms
if last_request_ms > 0 then
local candidate_ms = last_request_ms + interval_ms
if candidate_ms > now_ms then
next_slot_ms = candidate_ms
end
end
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
if fail_count > 0 or cooldown_until_ms > now_ms then
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
else
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
end
return {0, next_slot_ms - now_ms, ''}
`)
mark429Script = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
local short_cooldown_ms = tonumber(ARGV[1])
local max_cooldown_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
if cooldown_ms > max_cooldown_ms then
cooldown_ms = max_cooldown_ms
end
redis.call('HSET', KEYS[1],
'fail_count', fail_count,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[4]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
markSuccessScript = redis.NewScript(`
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', 0,
'cooldown_reason', ''
)
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
return 1
`)
markSuspendedScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local cooldown_ms = tonumber(ARGV[1])
local state_ttl_ms = tonumber(ARGV[2])
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[3]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
)
type Error struct {
remaining time.Duration
reason string
}
type State struct {
Active bool
Reason string
CooldownUntil time.Time
Remaining time.Duration
FailCount int
}
func NewError(remaining time.Duration, reason string) error {
return &Error{remaining: remaining, reason: reason}
}
func (e *Error) Error() string {
if e == nil {
return ""
}
if e.reason == "" {
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
}
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
}
func Calculate429Cooldown(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
cooldown := ShortCooldown * time.Duration(1<<retryCount)
if cooldown > MaxCooldown {
return MaxCooldown
}
return cooldown
}
type Store struct {
client *redis.Client
rngMu sync.Mutex
rng *rand.Rand
}
func NewStore(client *redis.Client) *Store {
return &Store{
client: client,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := reserveRequestScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
s.nextInterval().Milliseconds(),
activeTTL.Milliseconds(),
stateTTL.Milliseconds(),
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
}
parts, ok := values.([]interface{})
if !ok || len(parts) != 3 {
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
}
state, err := luaInt64(parts[0])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
}
waitMS, err := luaInt64(parts[1])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
}
reason, err := luaString(parts[2])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
}
if state == 1 {
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
}
if waitMS <= 0 {
return 0, nil
}
return time.Duration(waitMS) * time.Millisecond, nil
}
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
if err := s.validate(); err != nil {
return err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
if err := markSuccessScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
activeTTL.Milliseconds(),
).Err(); err != nil {
return fmt.Errorf("kiro cooldown mark success: %w", err)
}
return nil
}
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := mark429Script.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
ShortCooldown.Milliseconds(),
MaxCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReason429,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := markSuspendedScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
LongCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReasonSuspended,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
if err := s.validate(); err != nil {
return nil, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := s.client.HMGet(
cacheCtx,
RedisKey(tokenKey),
"cooldown_until_ms",
"cooldown_reason",
"fail_count",
).Result()
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
}
if len(values) != 3 {
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
}
if cooldownUntilMS <= 0 {
return nil, nil
}
cooldownUntil := time.UnixMilli(cooldownUntilMS)
remaining := time.Until(cooldownUntil)
if remaining <= 0 {
return nil, nil
}
return &State{
Active: true,
Reason: reason,
CooldownUntil: cooldownUntil,
Remaining: remaining,
FailCount: int(failCount),
}, nil
}
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
if err := s.validate(); err != nil {
return false, err
}
uniqueKeys := make([]string, 0, len(tokenKeys))
seen := make(map[string]struct{}, len(tokenKeys))
for _, tokenKey := range tokenKeys {
tokenKey = strings.TrimSpace(tokenKey)
if tokenKey == "" {
continue
}
redisKey := RedisKey(tokenKey)
if _, ok := seen[redisKey]; ok {
continue
}
seen[redisKey] = struct{}{}
uniqueKeys = append(uniqueKeys, redisKey)
}
if len(uniqueKeys) == 0 {
return false, nil
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
type candidate struct {
redisKey string
cooldownUntilMS int64
failCount int64
}
now := time.Now().UnixMilli()
var best *candidate
pipe := s.client.Pipeline()
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
for _, redisKey := range uniqueKeys {
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
}
if _, err := pipe.Exec(cacheCtx); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
}
for i, cmd := range cmds {
values, err := cmd.Result()
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
}
if len(values) != 3 {
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
}
if cooldownUntilMS <= now || reason != CooldownReason429 {
continue
}
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
if best == nil ||
current.cooldownUntilMS < best.cooldownUntilMS ||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
best = current
}
}
if best == nil {
return false, nil
}
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
}
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
}
return true, nil
}
func RedisKey(tokenKey string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
digest := hex.EncodeToString(sum[:])
return keyPrefix + "{" + digest + "}"
}
func ActiveTTL() time.Duration {
return activeTTL
}
func StateTTL() time.Duration {
return stateTTL
}
func (s *Store) validate() error {
if s == nil || s.client == nil {
return ErrStoreUnavailable
}
return nil
}
func (s *Store) nextInterval() time.Duration {
s.rngMu.Lock()
defer s.rngMu.Unlock()
if MaxRequestInterval <= MinRequestInterval {
return MinRequestInterval
}
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
}
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, redisTimeout)
}
func luaInt64(v any) (int64, error) {
switch n := v.(type) {
case int64:
return n, nil
case int:
return int64(n), nil
case string:
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
case []byte:
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
default:
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
}
}
func luaString(v any) (string, error) {
switch s := v.(type) {
case string:
return s, nil
case []byte:
return string(s), nil
case nil:
return "", nil
default:
return "", fmt.Errorf("unsupported lua string type %T", v)
}
}
@@ -0,0 +1,32 @@
package kirocooldown
import (
"context"
"testing"
"github.com/redis/go-redis/v9"
)
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
if err != nil {
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
}
if cleared {
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
}
}
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
store := NewStore(nil)
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
if err == nil {
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
}
if cleared {
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
}
}