feat(backend): add kiro account support
This commit is contained in:
@@ -0,0 +1,258 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RuntimeFingerprint struct {
|
||||
OIDCSDKVersion string
|
||||
RuntimeSDKVersion string
|
||||
StreamingSDKVersion string
|
||||
OSType string
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string
|
||||
}
|
||||
|
||||
type runtimeFingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*RuntimeFingerprint
|
||||
}
|
||||
|
||||
var (
|
||||
globalRuntimeFingerprintManager *runtimeFingerprintManager
|
||||
globalRuntimeFingerprintManagerOnce sync.Once
|
||||
|
||||
oidcSDKVersions = []string{"3.980.0", "3.975.0", "3.972.0", "3.808.0", "3.738.0", "3.737.0", "3.736.0", "3.735.0"}
|
||||
runtimeSDKVersions = []string{"1.0.0"}
|
||||
streamingSDKVersions = []string{"1.0.34"}
|
||||
osTypes = []string{"darwin", "win32"}
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"24.6.0"},
|
||||
"win32": {"10.0.22631"},
|
||||
}
|
||||
nodeVersions = []string{"22.22.0"}
|
||||
kiroVersions = []string{
|
||||
"0.11.132", "0.11.131", "0.11.130",
|
||||
}
|
||||
)
|
||||
|
||||
func globalRuntimeFingerprints() *runtimeFingerprintManager {
|
||||
globalRuntimeFingerprintManagerOnce.Do(func() {
|
||||
globalRuntimeFingerprintManager = &runtimeFingerprintManager{
|
||||
fingerprints: make(map[string]*RuntimeFingerprint),
|
||||
}
|
||||
})
|
||||
return globalRuntimeFingerprintManager
|
||||
}
|
||||
|
||||
func (m *runtimeFingerprintManager) Get(accountKey, machineID string) *RuntimeFingerprint {
|
||||
lookupKey := fingerprintLookupKey(accountKey, "runtime")
|
||||
machineID = normalizeMachineIDOrFallback(machineID, lookupKey)
|
||||
|
||||
m.mu.RLock()
|
||||
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||
m.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if fp, ok := m.fingerprints[lookupKey]; ok && fp.KiroHash == machineID {
|
||||
return fp
|
||||
}
|
||||
fp := generateRuntimeFingerprint(lookupKey, machineID)
|
||||
m.fingerprints[lookupKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
func generateRuntimeFingerprint(accountKey, machineID string) *RuntimeFingerprint {
|
||||
hash := sha256.Sum256([]byte(accountKey))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
osType := goOSToNodePlatform(runtime.GOOS)
|
||||
if !containsString(osTypes, osType) {
|
||||
osType = osTypes[rng.Intn(len(osTypes))]
|
||||
}
|
||||
osVersionPool := osVersions[osType]
|
||||
if len(osVersionPool) == 0 {
|
||||
osVersionPool = osVersions["darwin"]
|
||||
}
|
||||
|
||||
return &RuntimeFingerprint{
|
||||
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||
OSType: osType,
|
||||
OSVersion: osVersionPool[rng.Intn(len(osVersionPool))],
|
||||
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||
KiroHash: machineID,
|
||||
}
|
||||
}
|
||||
|
||||
func goOSToNodePlatform(goos string) string {
|
||||
switch strings.TrimSpace(goos) {
|
||||
case "windows":
|
||||
return "win32"
|
||||
default:
|
||||
return strings.TrimSpace(goos)
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(items []string, target string) bool {
|
||||
for _, item := range items {
|
||||
if item == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func BuildAccountKey(clientID, clientIDHash, refreshToken, profileArn string, accountID int64) string {
|
||||
switch {
|
||||
case strings.TrimSpace(clientIDHash) != "":
|
||||
return clientIDHash
|
||||
case strings.TrimSpace(clientID) != "":
|
||||
return shortSHA(clientID)
|
||||
case strings.TrimSpace(refreshToken) != "":
|
||||
return shortSHA(refreshToken)
|
||||
case strings.TrimSpace(profileArn) != "":
|
||||
return shortSHA(profileArn)
|
||||
case accountID > 0:
|
||||
return shortSHA(fmt.Sprintf("account:%d", accountID))
|
||||
default:
|
||||
return shortSHA(uuid.NewString())
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeMachineID(machineID string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(machineID)
|
||||
if len(trimmed) == 64 && isHexString(trimmed) {
|
||||
return strings.ToLower(trimmed), true
|
||||
}
|
||||
withoutDashes := strings.ReplaceAll(trimmed, "-", "")
|
||||
if len(withoutDashes) == 32 && isHexString(withoutDashes) {
|
||||
normalized := strings.ToLower(withoutDashes)
|
||||
return normalized + normalized, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func BuildMachineID(refreshToken, apiKey, fallbackKey string) string {
|
||||
if refreshToken = strings.TrimSpace(refreshToken); refreshToken != "" {
|
||||
return sha256Hex("KotlinNativeAPI/" + refreshToken)
|
||||
}
|
||||
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
|
||||
return sha256Hex("KiroAPIKey/" + apiKey)
|
||||
}
|
||||
if fallbackKey = strings.TrimSpace(fallbackKey); fallbackKey != "" {
|
||||
return sha256Hex("KiroFallback/" + fallbackKey)
|
||||
}
|
||||
return sha256Hex("KiroFallback/default")
|
||||
}
|
||||
|
||||
func shortSHA(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func sha256Hex(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func isHexString(value string) bool {
|
||||
for _, c := range value {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeMachineIDOrFallback(machineID, fallbackKey string) string {
|
||||
if normalized, ok := NormalizeMachineID(machineID); ok {
|
||||
return normalized
|
||||
}
|
||||
return BuildMachineID("", "", fallbackKey)
|
||||
}
|
||||
|
||||
func fingerprintLookupKey(accountKey, fallback string) string {
|
||||
key := strings.TrimSpace(accountKey)
|
||||
if key != "" {
|
||||
return key
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func BuildRuntimeUserAgent(accountKey, machineID string) string {
|
||||
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildRuntimeAmzUserAgent(accountKey, machineID string) string {
|
||||
fp := globalRuntimeFingerprints().Get(accountKey, machineID)
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildOIDCHeaders(accountKey, machineID string) map[string]string {
|
||||
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "oidc-session"), machineID)
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"x-amz-user-agent": fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion),
|
||||
"User-Agent": fmt.Sprintf("aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/sso-oidc#%s m/E KiroIDE", fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.OIDCSDKVersion),
|
||||
"amz-sdk-invocation-id": uuid.NewString(),
|
||||
"amz-sdk-request": "attempt=1; max=4",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildLoginHeaders(accountKey, machineID string) map[string]string {
|
||||
fp := globalRuntimeFingerprints().Get(fingerprintLookupKey(accountKey, "login"), machineID)
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
}
|
||||
}
|
||||
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := baseDelay << attempt
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
const jitterFactor = 0.3
|
||||
seed := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
jitter := 1 + ((seed.Float64()*2 - 1) * jitterFactor)
|
||||
return time.Duration(float64(delay) * jitter)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildLoginHeadersStable(t *testing.T) {
|
||||
headers1 := BuildLoginHeaders("", "")
|
||||
headers2 := BuildLoginHeaders("", "")
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.Equal(t, "application/json, text/plain, */*", headers1["Accept"])
|
||||
require.Equal(t, "application/json", headers1["Content-Type"])
|
||||
require.True(t, strings.HasPrefix(headers1["User-Agent"], "KiroIDE-"))
|
||||
require.Contains(t, headers1["User-Agent"], "KiroIDE-")
|
||||
}
|
||||
|
||||
func TestBuildLoginHeadersUsesProvidedMachineID(t *testing.T) {
|
||||
machineIDA := BuildMachineID("refresh-a", "", "")
|
||||
machineIDB := BuildMachineID("refresh-b", "", "")
|
||||
headers1 := BuildLoginHeaders("account-a", machineIDA)
|
||||
headers2 := BuildLoginHeaders("account-a", machineIDA)
|
||||
headers3 := BuildLoginHeaders("account-a", machineIDB)
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||
require.Contains(t, headers1["User-Agent"], "KiroIDE-0.11.")
|
||||
require.Contains(t, headers1["User-Agent"], machineIDA)
|
||||
}
|
||||
|
||||
func TestBuildOIDCHeadersUsesProvidedAccountKey(t *testing.T) {
|
||||
machineID := BuildMachineID("", "", "oidc-machine")
|
||||
headers1 := BuildOIDCHeaders("account-a", machineID)
|
||||
headers2 := BuildOIDCHeaders("account-a", machineID)
|
||||
headers3 := BuildOIDCHeaders("account-b", machineID)
|
||||
|
||||
require.Equal(t, headers1["User-Agent"], headers2["User-Agent"])
|
||||
require.NotEqual(t, headers1["User-Agent"], headers3["User-Agent"])
|
||||
require.Contains(t, headers1["User-Agent"], "api/sso-oidc#")
|
||||
}
|
||||
|
||||
func TestBuildAccountKeyFallsBackToAccountIDBeforeRandom(t *testing.T) {
|
||||
key1 := BuildAccountKey("", "", "", "", 42)
|
||||
key2 := BuildAccountKey("", "", "", "", 42)
|
||||
key3 := BuildAccountKey("", "", "", "", 43)
|
||||
|
||||
require.Equal(t, key1, key2)
|
||||
require.Equal(t, shortSHA(fmt.Sprintf("account:%d", 42)), key1)
|
||||
require.NotEqual(t, key1, key3)
|
||||
}
|
||||
|
||||
func TestBuildMachineID(t *testing.T) {
|
||||
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "", ""))
|
||||
require.Equal(t, expectedKiroMachineID("KiroAPIKey/key"), BuildMachineID("", "key", ""))
|
||||
require.Equal(t, expectedKiroMachineID("KotlinNativeAPI/token"), BuildMachineID("token", "key", "fallback"))
|
||||
|
||||
fallback1 := BuildMachineID("", "", "account:1")
|
||||
fallback2 := BuildMachineID("", "", "account:1")
|
||||
fallback3 := BuildMachineID("", "", "account:2")
|
||||
require.Equal(t, expectedKiroMachineID("KiroFallback/account:1"), fallback1)
|
||||
require.Equal(t, fallback1, fallback2)
|
||||
require.NotEqual(t, fallback1, fallback3)
|
||||
require.Len(t, fallback1, 64)
|
||||
}
|
||||
|
||||
func TestNormalizeMachineID(t *testing.T) {
|
||||
hex64 := strings.Repeat("A", 64)
|
||||
normalized, ok := NormalizeMachineID(hex64)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, strings.ToLower(hex64), normalized)
|
||||
|
||||
normalized, ok = NormalizeMachineID("2582956e-cc88-4669-b546-07adbffcb894")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894", normalized)
|
||||
|
||||
_, ok = NormalizeMachineID("not-a-machine-id")
|
||||
require.False(t, ok)
|
||||
_, ok = NormalizeMachineID(strings.Repeat("g", 64))
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func expectedKiroMachineID(seed string) string {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package kiro
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
var DefaultModels = []Model{
|
||||
{ID: "claude-opus-4-6", Type: "model", DisplayName: "Claude Opus 4.6"},
|
||||
{ID: "claude-opus-4-6-thinking", Type: "model", DisplayName: "Claude Opus 4.6 (Thinking)"},
|
||||
{ID: "claude-sonnet-4-6", Type: "model", DisplayName: "Claude Sonnet 4.6"},
|
||||
{ID: "claude-sonnet-4-6-thinking", Type: "model", DisplayName: "Claude Sonnet 4.6 (Thinking)"},
|
||||
{ID: "claude-opus-4-5-20251101", Type: "model", DisplayName: "Claude Opus 4.5"},
|
||||
{ID: "claude-opus-4-5-20251101-thinking", Type: "model", DisplayName: "Claude Opus 4.5 (Thinking)"},
|
||||
{ID: "claude-sonnet-4-5-20250929", Type: "model", DisplayName: "Claude Sonnet 4.5"},
|
||||
{ID: "claude-sonnet-4-5-20250929-thinking", Type: "model", DisplayName: "Claude Sonnet 4.5 (Thinking)"},
|
||||
{ID: "claude-haiku-4-5-20251001", Type: "model", DisplayName: "Claude Haiku 4.5"},
|
||||
{ID: "claude-haiku-4-5-20251001-thinking", Type: "model", DisplayName: "Claude Haiku 4.5 (Thinking)"},
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultModels_MatchesKiroReferenceModels(t *testing.T) {
|
||||
ids := make([]string, 0, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
|
||||
require.Equal(t, []string{
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-6-thinking",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6-thinking",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5-20251101-thinking",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5-20250929-thinking",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5-20251001-thinking",
|
||||
}, ids)
|
||||
|
||||
require.Contains(t, ids, "claude-sonnet-4-6")
|
||||
require.Contains(t, ids, "claude-haiku-4-5-20251001-thinking")
|
||||
require.NotContains(t, ids, "auto")
|
||||
require.NotContains(t, ids, "claude-sonnet-4")
|
||||
require.NotContains(t, ids, "gpt-4o")
|
||||
require.NotContains(t, ids, "deepseek-3-2")
|
||||
require.NotContains(t, ids, "minimax-m2-1")
|
||||
require.NotContains(t, ids, "qwen3-coder-next")
|
||||
require.NotContains(t, ids, "claude-opus-4-7")
|
||||
require.NotContains(t, ids, "claude-sonnet-4-6-chat")
|
||||
for _, id := range ids {
|
||||
require.NotContains(t, id, "kiro-")
|
||||
require.NotContains(t, id, "-agentic")
|
||||
require.NotContains(t, id, "-chat")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
socialAuthPortalURL = "https://app.kiro.dev"
|
||||
socialAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
defaultIDCRegion = "us-east-1"
|
||||
BuilderIDStartURL = "https://view.awsapps.com/start"
|
||||
sessionTTL = 10 * time.Minute
|
||||
sessionCleanupEvery = 32
|
||||
sessionCleanupMin = 32
|
||||
)
|
||||
|
||||
var (
|
||||
socialAuthEndpointURL = socialAuthEndpoint
|
||||
oidcEndpointOverride = ""
|
||||
)
|
||||
|
||||
type SocialProvider string
|
||||
|
||||
const (
|
||||
SocialProviderGoogle SocialProvider = "Google"
|
||||
SocialProviderGitHub SocialProvider = "Github"
|
||||
)
|
||||
|
||||
type AuthSession struct {
|
||||
State string
|
||||
CodeVerifier string
|
||||
ProxyURL string
|
||||
CreatedAt time.Time
|
||||
AuthType string
|
||||
Provider string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Region string
|
||||
StartURL string
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
data map[string]*AuthSession
|
||||
setCount uint64
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
return &SessionStore{data: make(map[string]*AuthSession)}
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(id string) (*AuthSession, bool) {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
session, ok := s.data[id]
|
||||
if ok && sessionExpired(session, now) {
|
||||
delete(s.data, id)
|
||||
return nil, false
|
||||
}
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(id string, session *AuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.setCount++
|
||||
if len(s.data) >= sessionCleanupMin && s.setCount%sessionCleanupEvery == 0 {
|
||||
s.pruneExpiredLocked(time.Now())
|
||||
}
|
||||
s.data[id] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.data, id)
|
||||
}
|
||||
|
||||
func (s *SessionStore) pruneExpiredLocked(now time.Time) {
|
||||
for id, session := range s.data {
|
||||
if sessionExpired(session, now) {
|
||||
delete(s.data, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sessionExpired(session *AuthSession, now time.Time) bool {
|
||||
if session == nil {
|
||||
return true
|
||||
}
|
||||
if session.CreatedAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return now.After(session.CreatedAt.Add(sessionTTL))
|
||||
}
|
||||
|
||||
type TokenData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
ExpiresAt string `json:"expiresAt,omitempty"`
|
||||
AuthMethod string `json:"authMethod,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
type socialTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
type registerClientResponse struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
type createTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
type userInfoResponse struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type deviceRegistration struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
type RefreshTokenInvalidError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *RefreshTokenInvalidError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
body := strings.TrimSpace(e.Body)
|
||||
if body == "" {
|
||||
return "kiro refresh token invalid (invalid_grant)"
|
||||
}
|
||||
return fmt.Sprintf("kiro refresh token invalid (invalid_grant, status %d): %s", e.StatusCode, body)
|
||||
}
|
||||
|
||||
func GenerateSessionID() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
return randomURLSafe(16)
|
||||
}
|
||||
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
return randomURLSafe(32)
|
||||
}
|
||||
|
||||
func randomURLSafe(n int) (string, error) {
|
||||
buf := make([]byte, n)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func BuildSocialSignInURL(redirectURI, codeChallenge, state string) string {
|
||||
params := url.Values{}
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("redirect_from", "KiroIDE")
|
||||
return fmt.Sprintf("%s/signin?%s", socialAuthPortalURL, params.Encode())
|
||||
}
|
||||
|
||||
func BuildSocialTokenRedirectURI(baseRedirectURI, callbackPath, loginOption string) string {
|
||||
redirectURI := strings.TrimRight(strings.TrimSpace(baseRedirectURI), "/")
|
||||
if redirectURI == "" {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSpace(callbackPath)
|
||||
if path == "" {
|
||||
path = "/oauth/callback"
|
||||
} else if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
fullRedirectURI := redirectURI + path
|
||||
if option := strings.TrimSpace(loginOption); option != "" {
|
||||
return fullRedirectURI + "?login_option=" + url.QueryEscape(option)
|
||||
}
|
||||
return fullRedirectURI
|
||||
}
|
||||
|
||||
func CreateSocialToken(ctx context.Context, proxyURL, code, codeVerifier, redirectURI string) (*TokenData, error) {
|
||||
payload := map[string]string{
|
||||
"code": code,
|
||||
"code_verifier": codeVerifier,
|
||||
"redirect_uri": redirectURI,
|
||||
}
|
||||
var resp socialTokenResponse
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/oauth/token", payload, &resp, BuildLoginHeaders(shortSHA(codeVerifier), BuildMachineID("", "", "codeVerifier:"+codeVerifier))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
return &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RefreshSocialToken(ctx context.Context, proxyURL, refreshToken, provider string) (*TokenData, error) {
|
||||
payload := map[string]string{
|
||||
"refreshToken": refreshToken,
|
||||
}
|
||||
var resp socialTokenResponse
|
||||
accountKey := BuildAccountKey("", "", refreshToken, "", 0)
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, socialAuthEndpointURL+"/refreshToken", payload, &resp, BuildLoginHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
return &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: provider,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RegisterIDCClient(ctx context.Context, proxyURL, redirectURI, issuerURL, region string) (*registerClientResponse, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]any{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||
"redirectUris": []string{redirectURI},
|
||||
"issuerUrl": issuerURL,
|
||||
}
|
||||
var resp registerClientResponse
|
||||
headers := oidcHeaders("", BuildMachineID("", "", "register-idc-client"))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/client/register", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func BuildIDCAuthURL(clientID, redirectURI, state, codeChallenge, region string) string {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scopes", strings.Join([]string{
|
||||
"codewhisperer:completions",
|
||||
"codewhisperer:analysis",
|
||||
"codewhisperer:conversations",
|
||||
"codewhisperer:transformations",
|
||||
"codewhisperer:taskassist",
|
||||
}, " "))
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
return fmt.Sprintf("%s/authorize?%s", getOIDCEndpoint(region), params.Encode())
|
||||
}
|
||||
|
||||
func ExchangeIDCAuthCode(ctx context.Context, proxyURL, clientID, clientSecret, code, codeVerifier, redirectURI, region, startURL string) (*TokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"code": code,
|
||||
"codeVerifier": codeVerifier,
|
||||
"redirectUri": redirectURI,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
var resp createTokenResponse
|
||||
accountKey := BuildAccountKey(clientID, "", "", "", 0)
|
||||
headers := oidcHeaders(accountKey, BuildMachineID("", "", "clientID:"+clientID))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
token := &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}
|
||||
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func RefreshIDCToken(ctx context.Context, proxyURL, clientID, clientSecret, refreshToken, region, startURL string) (*TokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"refreshToken": refreshToken,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
var resp createTokenResponse
|
||||
accountKey := BuildAccountKey(clientID, "", refreshToken, "", 0)
|
||||
headers := oidcHeaders(accountKey, BuildMachineID(refreshToken, "", accountKey))
|
||||
if err := doJSON(ctx, proxyURL, http.MethodPost, getOIDCEndpoint(region)+"/token", payload, &resp, headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiresIn := resp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
token := &TokenData{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ProfileArn: resp.ProfileArn,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}
|
||||
token.Email = FetchOIDCUserEmail(ctx, proxyURL, token.AccessToken, region)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func FetchOIDCUserEmail(ctx context.Context, proxyURL, accessToken, region string) string {
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return ""
|
||||
}
|
||||
var resp userInfoResponse
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
}
|
||||
if err := doJSON(ctx, proxyURL, http.MethodGet, getOIDCEndpoint(region)+"/userinfo", nil, &resp, headers); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(resp.Email)
|
||||
}
|
||||
|
||||
func ParseImportedToken(tokenJSON string, deviceRegistrationJSON string) (*TokenData, error) {
|
||||
var token TokenData
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse kiro token: %w", err)
|
||||
}
|
||||
token.AuthMethod = strings.ToLower(strings.TrimSpace(token.AuthMethod))
|
||||
if strings.TrimSpace(token.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
if token.ClientIDHash != "" && (token.ClientID == "" || token.ClientSecret == "") && strings.TrimSpace(deviceRegistrationJSON) != "" {
|
||||
var reg deviceRegistration
|
||||
if err := json.Unmarshal([]byte(deviceRegistrationJSON), ®); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse device registration: %w", err)
|
||||
}
|
||||
if reg.ClientID != "" {
|
||||
token.ClientID = reg.ClientID
|
||||
}
|
||||
if reg.ClientSecret != "" {
|
||||
token.ClientSecret = reg.ClientSecret
|
||||
}
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func getOIDCEndpoint(region string) string {
|
||||
if strings.TrimSpace(oidcEndpointOverride) != "" {
|
||||
return strings.TrimRight(strings.TrimSpace(oidcEndpointOverride), "/")
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||
}
|
||||
|
||||
func oidcHeaders(accountKey, machineID string) map[string]string {
|
||||
headers := BuildOIDCHeaders(accountKey, machineID)
|
||||
if headers["amz-sdk-invocation-id"] == "" {
|
||||
headers["amz-sdk-invocation-id"] = uuid.NewString()
|
||||
}
|
||||
if headers["amz-sdk-request"] == "" {
|
||||
headers["amz-sdk-request"] = "attempt=1; max=4"
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func doJSON(ctx context.Context, proxyURL, method, rawURL string, payload any, out any, extraHeaders map[string]string) error {
|
||||
client, err := newHTTPClient(proxyURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(encoded)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, rawURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if payload != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
for key, value := range extraHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyText := strings.TrimSpace(string(respBody))
|
||||
if resp.StatusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(bodyText), "invalid_grant") {
|
||||
return &RefreshTokenInvalidError{StatusCode: resp.StatusCode, Body: bodyText}
|
||||
}
|
||||
return fmt.Errorf("upstream request failed (status %d): %s", resp.StatusCode, bodyText)
|
||||
}
|
||||
if out == nil || len(respBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(respBody, out)
|
||||
}
|
||||
|
||||
func newHTTPClient(rawProxyURL string) (*http.Client, error) {
|
||||
_, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport := &http.Transport{}
|
||||
if parsed != nil {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRefreshSocialTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/refreshToken", r.URL.Path)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := socialAuthEndpointURL
|
||||
socialAuthEndpointURL = server.URL
|
||||
t.Cleanup(func() { socialAuthEndpointURL = previous })
|
||||
|
||||
_, err := RefreshSocialToken(context.Background(), "", "revoked-refresh-token", "Google")
|
||||
require.Error(t, err)
|
||||
|
||||
var invalid *RefreshTokenInvalidError
|
||||
require.True(t, errors.As(err, &invalid))
|
||||
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||
require.Contains(t, invalid.Body, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestRefreshIDCTokenInvalidGrantReturnsTypedError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/token", r.URL.Path)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant","message":"Invalid refresh token provided"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
_, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "revoked-refresh-token", "us-east-1", BuilderIDStartURL)
|
||||
require.Error(t, err)
|
||||
|
||||
var invalid *RefreshTokenInvalidError
|
||||
require.True(t, errors.As(err, &invalid))
|
||||
require.Equal(t, http.StatusBadRequest, invalid.StatusCode)
|
||||
require.Contains(t, invalid.Body, "invalid_grant")
|
||||
}
|
||||
|
||||
func TestExchangeIDCAuthCodePreservesProfileArn(t *testing.T) {
|
||||
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/EXCHANGE"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
token, err := ExchangeIDCAuthCode(context.Background(), "", "client-id", "client-secret", "code", "verifier", "http://127.0.0.1:9876/oauth/callback", "us-east-1", BuilderIDStartURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, profileArn, token.ProfileArn)
|
||||
require.Equal(t, "kiro@example.com", token.Email)
|
||||
}
|
||||
|
||||
func TestRefreshIDCTokenPreservesProfileArn(t *testing.T) {
|
||||
const profileArn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/REFRESH"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"access-token","refreshToken":"refresh-token","profileArn":"` + profileArn + `","expiresIn":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"kiro@example.com"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
previous := oidcEndpointOverride
|
||||
oidcEndpointOverride = server.URL
|
||||
t.Cleanup(func() { oidcEndpointOverride = previous })
|
||||
|
||||
token, err := RefreshIDCToken(context.Background(), "", "client-id", "client-secret", "refresh-token", "us-east-1", BuilderIDStartURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, profileArn, token.ProfileArn)
|
||||
require.Equal(t, "kiro@example.com", token.Email)
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build unit
|
||||
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBuildSocialSignInURLUsesAppPortal(t *testing.T) {
|
||||
got := BuildSocialSignInURL("http://localhost:49153", "challenge123", "state456")
|
||||
want := "https://app.kiro.dev/signin?code_challenge=challenge123&code_challenge_method=S256&redirect_from=KiroIDE&redirect_uri=http%3A%2F%2Flocalhost%3A49153&state=state456"
|
||||
if got != want {
|
||||
t.Fatalf("BuildSocialSignInURL() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSocialTokenRedirectURI(t *testing.T) {
|
||||
got := BuildSocialTokenRedirectURI("http://localhost:49153", "/oauth/callback", "github")
|
||||
want := "http://localhost:49153/oauth/callback?login_option=github"
|
||||
if got != want {
|
||||
t.Fatalf("BuildSocialTokenRedirectURI() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreGetDeletesExpiredSession(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Set("expired", &AuthSession{CreatedAt: time.Now().Add(-2 * sessionTTL)})
|
||||
|
||||
session, ok := store.Get("expired")
|
||||
if ok || session != nil {
|
||||
t.Fatalf("Get(expired) = (%v, %v), want (nil, false)", session, ok)
|
||||
}
|
||||
if _, exists := store.data["expired"]; exists {
|
||||
t.Fatalf("expired session should be deleted from the store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreSetPrunesExpiredSessions(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
now := time.Now()
|
||||
for i := 0; i < sessionCleanupMin; i++ {
|
||||
store.data[fmt.Sprintf("expired-%d", i)] = &AuthSession{CreatedAt: now.Add(-2 * sessionTTL)}
|
||||
}
|
||||
store.setCount = sessionCleanupEvery - 1
|
||||
|
||||
store.Set("fresh", &AuthSession{CreatedAt: now})
|
||||
|
||||
if len(store.data) != 1 {
|
||||
t.Fatalf("store size = %d, want 1", len(store.data))
|
||||
}
|
||||
if _, ok := store.data["fresh"]; !ok {
|
||||
t.Fatalf("fresh session should remain after pruning")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,368 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const minimalWebSearchDescription = "Search the web for information. Use this tool again when the previous search results are insufficient or need refinement."
|
||||
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
|
||||
|
||||
var cachedWebSearchDescription atomic.Value // stores string
|
||||
|
||||
type MCPRequest struct {
|
||||
ID string `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type MCPResponse struct {
|
||||
Result *struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result,omitempty"`
|
||||
Error *struct {
|
||||
Code *int `json:"code,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type WebSearchResults struct {
|
||||
Results []WebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
type WebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Snippet *string `json:"snippet,omitempty"`
|
||||
PublishedDate *int64 `json:"publishedDate,omitempty"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Domain *string `json:"domain,omitempty"`
|
||||
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
|
||||
PublicDomain *bool `json:"publicDomain,omitempty"`
|
||||
}
|
||||
|
||||
type SearchIndicator struct {
|
||||
ToolUseID string
|
||||
Query string
|
||||
Results *WebSearchResults
|
||||
}
|
||||
|
||||
func GetCachedWebSearchDescription() string {
|
||||
if v := cachedWebSearchDescription.Load(); v != nil {
|
||||
return strings.TrimSpace(v.(string))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SetCachedWebSearchDescription(desc string) {
|
||||
cachedWebSearchDescription.Store(strings.TrimSpace(desc))
|
||||
}
|
||||
|
||||
func BuildMcpEndpoint(region string) string {
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
}
|
||||
|
||||
func ParseSearchResults(resp *MCPResponse) *WebSearchResults {
|
||||
if resp == nil || resp.Result == nil || len(resp.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, item := range resp.Result.Content {
|
||||
if item.Type != "" && item.Type != "text" {
|
||||
continue
|
||||
}
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(item.Text), &results); err == nil {
|
||||
return &results
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExtractSearchQuery(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
arr := messages.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
msg := arr[i]
|
||||
if msg.Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
text := extractSearchText(msg.Get("content"))
|
||||
const prefix = "Perform a web search for the query: "
|
||||
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||
if text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractSearchText(content gjson.Result) string {
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
if !content.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
if text := strings.TrimSpace(block.Get("text").String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func GenerateToolUseID() string {
|
||||
return strings.ReplaceAll(uuid.NewString(), "-", "")[:22]
|
||||
}
|
||||
|
||||
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body, err
|
||||
}
|
||||
rawTools, ok := payload["tools"].([]interface{})
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
replaced := make([]interface{}, 0, len(rawTools))
|
||||
for _, rawTool := range rawTools {
|
||||
tool, ok := rawTool.(map[string]interface{})
|
||||
if !ok {
|
||||
replaced = append(replaced, rawTool)
|
||||
continue
|
||||
}
|
||||
name := getInterfaceString(tool["name"])
|
||||
toolType := getInterfaceString(tool["type"])
|
||||
if !isWebSearchToolName(name, toolType) {
|
||||
replaced = append(replaced, rawTool)
|
||||
continue
|
||||
}
|
||||
replaced = append(replaced, map[string]interface{}{
|
||||
"name": "web_search",
|
||||
"description": minimalWebSearchDescription,
|
||||
"input_schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The search query to execute",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
payload["tools"] = replaced
|
||||
updated, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func InjectToolResultsClaude(claudePayload []byte, toolUseID, query string, results *WebSearchResults) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||
return claudePayload, fmt.Errorf("parse claude payload: %w", err)
|
||||
}
|
||||
|
||||
rawMessages, ok := payload["messages"].([]interface{})
|
||||
if !ok {
|
||||
return claudePayload, fmt.Errorf("claude payload missing messages array")
|
||||
}
|
||||
|
||||
assistantMsg := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": query},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userContent := []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": formatToolResultText(results),
|
||||
},
|
||||
}
|
||||
if guidance := searchGuidanceText(); guidance != "" {
|
||||
userContent = append(userContent, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": guidance,
|
||||
})
|
||||
}
|
||||
userMsg := map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": userContent,
|
||||
}
|
||||
|
||||
rawMessages = append(rawMessages, assistantMsg, userMsg)
|
||||
payload["messages"] = rawMessages
|
||||
updated, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return claudePayload, fmt.Errorf("marshal updated payload: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
|
||||
if len(searches) == 0 {
|
||||
return responsePayload, nil
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(responsePayload, &response); err != nil {
|
||||
return responsePayload, err
|
||||
}
|
||||
content, _ := response["content"].([]interface{})
|
||||
updated := make([]interface{}, 0, len(searches)*2+len(content))
|
||||
for _, search := range searches {
|
||||
updated = append(updated, map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": search.ToolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": search.Query},
|
||||
})
|
||||
updated = append(updated, map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"content": buildSearchResultContent(search.Results),
|
||||
})
|
||||
}
|
||||
updated = append(updated, content...)
|
||||
response["content"] = updated
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return responsePayload, err
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func buildSearchResultContent(results *WebSearchResults) []map[string]interface{} {
|
||||
content := make([]map[string]interface{}, 0)
|
||||
if results == nil {
|
||||
return content
|
||||
}
|
||||
for _, result := range results.Results {
|
||||
snippet := ""
|
||||
if result.Snippet != nil {
|
||||
snippet = strings.TrimSpace(*result.Snippet)
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": result.Title,
|
||||
"url": result.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func ExtractWebSearchToolUseFromResponse(responsePayload []byte) (toolUseID, query string, ok bool) {
|
||||
content := gjson.GetBytes(responsePayload, "content")
|
||||
if !content.IsArray() {
|
||||
return "", "", false
|
||||
}
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() != "tool_use" {
|
||||
continue
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if !isWebSearchToolName(name, "") {
|
||||
continue
|
||||
}
|
||||
query = strings.TrimSpace(block.Get("input.query").String())
|
||||
if query == "" {
|
||||
continue
|
||||
}
|
||||
return block.Get("id").String(), query, true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func isWebSearchToolName(name, toolType string) bool {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
toolType = strings.ToLower(strings.TrimSpace(toolType))
|
||||
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
||||
return true
|
||||
}
|
||||
switch name {
|
||||
case "web_search", "web_search_20250305", "google_search", "remote_web_search":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getInterfaceString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(val)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(val))
|
||||
}
|
||||
}
|
||||
|
||||
func formatToolResultText(results *WebSearchResults) string {
|
||||
if results == nil || len(results.Results) == 0 {
|
||||
return "No search results found."
|
||||
}
|
||||
payload, err := json.MarshalIndent(results.Results, "", " ")
|
||||
if err != nil {
|
||||
return "Found search results, but failed to format them."
|
||||
}
|
||||
return fmt.Sprintf("Found %d search result(s):\n\n%s", len(results.Results), string(payload))
|
||||
}
|
||||
|
||||
func searchGuidanceText() string {
|
||||
now := time.Now()
|
||||
return fmt.Sprintf(`<search_guidance>
|
||||
Current date: %s (%s)
|
||||
|
||||
IMPORTANT: Evaluate the search results above carefully. If the results are:
|
||||
- Mostly spam, SEO junk, or unrelated websites
|
||||
- Missing actual information about the query topic
|
||||
- Outdated or not matching the requested time frame
|
||||
|
||||
Then you MUST use the web_search tool again with a refined query. Try:
|
||||
- Rephrasing in English for better coverage
|
||||
- Using more specific keywords
|
||||
- Adding date context
|
||||
|
||||
Do NOT apologize for bad results without first attempting a re-search.
|
||||
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BufferedStreamResult struct {
|
||||
StopReason string
|
||||
WebSearchQuery string
|
||||
WebSearchToolUseID string
|
||||
HasWebSearchToolUse bool
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
func GenerateSearchIndicatorEvents(query, toolUseID string, results *WebSearchResults, startIndex int) [][]byte {
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if results != nil {
|
||||
for _, result := range results.Results {
|
||||
snippet := ""
|
||||
if result.Snippet != nil {
|
||||
snippet = strings.TrimSpace(*result.Snippet)
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": result.Title,
|
||||
"url": result.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
|
||||
events := []map[string]interface{}{
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": toolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
}
|
||||
|
||||
result := make([][]byte, 0, len(events))
|
||||
for _, event := range events {
|
||||
eventType, _ := event["type"].(string)
|
||||
payload, _ := json.Marshal(event)
|
||||
result = append(result, []byte("event: "+eventType+"\ndata: "+string(payload)+"\n\n"))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
var currentToolName string
|
||||
currentToolIndex := -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "message_delta":
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if stopReason, ok := delta["stop_reason"].(string); ok && strings.TrimSpace(stopReason) != "" {
|
||||
result.StopReason = stopReason
|
||||
}
|
||||
}
|
||||
case "content_block_start":
|
||||
contentBlock, ok := event["content_block"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := contentBlock["type"].(string)
|
||||
if blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
currentToolName, _ = contentBlock["name"].(string)
|
||||
currentToolName = strings.ToLower(strings.TrimSpace(currentToolName))
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
if toolUseID, ok := contentBlock["id"].(string); ok && isWebSearchToolName(currentToolName, "") {
|
||||
result.WebSearchToolUseID = strings.TrimSpace(toolUseID)
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
case "content_block_delta":
|
||||
if currentToolName == "" {
|
||||
continue
|
||||
}
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
deltaType, _ := delta["type"].(string)
|
||||
if deltaType != "input_json_delta" {
|
||||
continue
|
||||
}
|
||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partialJSON)
|
||||
}
|
||||
case "content_block_stop":
|
||||
if !isWebSearchToolName(currentToolName, "") {
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
continue
|
||||
}
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(toolInputBuilder.String()), &input); err == nil {
|
||||
result.WebSearchQuery = strings.TrimSpace(input["query"])
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func FilterChunksForClient(chunks [][]byte, webSearchToolUseIndex, indexOffset int) [][]byte {
|
||||
filtered := make([][]byte, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
adjusted, shouldForward := filterSSEChunk(chunk, webSearchToolUseIndex, indexOffset)
|
||||
if shouldForward {
|
||||
filtered = append(filtered, adjusted)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
return filterSSEChunk(chunk, -1, offset)
|
||||
}
|
||||
|
||||
func MaxContentBlockIndex(chunks [][]byte) int {
|
||||
maxIndex := -1
|
||||
for _, chunk := range chunks {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) > maxIndex {
|
||||
maxIndex = int(idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return maxIndex
|
||||
}
|
||||
|
||||
func filterSSEChunk(chunk []byte, webSearchToolUseIndex, indexOffset int) ([]byte, bool) {
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
var builder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(lines[i+1], "data: "))
|
||||
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
}
|
||||
builder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if shouldSuppressEventPayload(payload, webSearchToolUseIndex) {
|
||||
continue
|
||||
}
|
||||
adjusted := adjustEventPayload(payload, indexOffset)
|
||||
if adjusted == "" {
|
||||
continue
|
||||
}
|
||||
builder.WriteString("data: " + adjusted + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
return []byte(builder.String()), true
|
||||
}
|
||||
|
||||
func shouldSuppressEventPayload(payload string, webSearchToolUseIndex int) bool {
|
||||
if payload == "" {
|
||||
return false
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return false
|
||||
}
|
||||
eventType, _ := event["type"].(string)
|
||||
if eventType == "message_start" || eventType == "message_delta" || eventType == "message_stop" {
|
||||
return true
|
||||
}
|
||||
if webSearchToolUseIndex < 0 {
|
||||
return false
|
||||
}
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == webSearchToolUseIndex {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func adjustEventPayload(payload string, indexOffset int) string {
|
||||
if payload == "" || indexOffset == 0 {
|
||||
return payload
|
||||
}
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return payload
|
||||
}
|
||||
switch eventType, _ := event["type"].(string); eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
if adjusted, err := json.Marshal(event); err == nil {
|
||||
return string(adjusted)
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateSearchIndicatorEvents_UsesInputJSONDelta(t *testing.T) {
|
||||
snippet := "result snippet"
|
||||
events := GenerateSearchIndicatorEvents("golang concurrency", "srvtoolu_test", &WebSearchResults{
|
||||
Results: []WebSearchResult{
|
||||
{Title: "Go", URL: "https://go.dev", Snippet: &snippet},
|
||||
},
|
||||
}, 0)
|
||||
|
||||
require.Len(t, events, 5)
|
||||
require.Contains(t, string(events[0]), `"type":"server_tool_use"`)
|
||||
require.Contains(t, string(events[0]), `"input":{}`)
|
||||
require.Contains(t, string(events[1]), `"type":"input_json_delta"`)
|
||||
require.Contains(t, string(events[1]), `"{\"query\":\"golang concurrency\"}"`)
|
||||
require.Contains(t, string(events[3]), `"type":"web_search_tool_result"`)
|
||||
require.NotContains(t, string(events[3]), `"tool_use_id"`)
|
||||
require.Contains(t, string(events[3]), `"encrypted_content":"result snippet"`)
|
||||
}
|
||||
|
||||
func TestAnalyzeBufferedStream_ExtractsWebSearchToolUse(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||
}
|
||||
|
||||
result := AnalyzeBufferedStream(chunks)
|
||||
require.True(t, result.HasWebSearchToolUse)
|
||||
require.Equal(t, "golang concurrency", result.WebSearchQuery)
|
||||
require.Equal(t, "srvtoolu_next", result.WebSearchToolUseID)
|
||||
require.Equal(t, 1, result.WebSearchToolUseIndex)
|
||||
require.Equal(t, "tool_use", result.StopReason)
|
||||
}
|
||||
|
||||
func TestFilterChunksForClient_RemovesInternalToolUseAndOffsetsIndices(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Searching...\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n"),
|
||||
[]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"srvtoolu_next\",\"name\":\"web_search\",\"input\":{}}}\n\n"),
|
||||
[]byte("event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"query\\\":\\\"golang concurrency\\\"}\"}}\n\n"),
|
||||
[]byte("event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":1}\n\n"),
|
||||
[]byte("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"}}\n\n"),
|
||||
}
|
||||
|
||||
filtered := FilterChunksForClient(chunks, 1, 2)
|
||||
require.NotEmpty(t, filtered)
|
||||
joined := string(filtered[0]) + string(filtered[1]) + string(filtered[2])
|
||||
require.NotContains(t, joined, `"type":"message_start"`)
|
||||
require.NotContains(t, joined, `"type":"message_delta"`)
|
||||
require.NotContains(t, joined, `"name":"web_search"`)
|
||||
require.Contains(t, joined, `"index":2`)
|
||||
require.Equal(t, 2, MaxContentBlockIndex(filtered))
|
||||
}
|
||||
|
||||
func TestAdjustSSEChunk_OffsetsIndicesAndDropsMessageStart(t *testing.T) {
|
||||
_, shouldForward := AdjustSSEChunk([]byte("event: message_start\ndata: {\"type\":\"message_start\"}\n\n"), 2)
|
||||
require.False(t, shouldForward)
|
||||
|
||||
adjusted, shouldForward := AdjustSSEChunk([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"), 3)
|
||||
require.True(t, shouldForward)
|
||||
require.Contains(t, string(adjusted), `"index":3`)
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestReplaceWebSearchToolDescriptionUsesTypeFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools":[{"type":"web_search_20250305","description":"old"}],
|
||||
"messages":[{"role":"user","content":"golang"}]
|
||||
}`)
|
||||
|
||||
updated, err := ReplaceWebSearchToolDescription(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "web_search", gjson.GetBytes(updated, "tools.0.name").String())
|
||||
require.Equal(t, minimalWebSearchDescription, gjson.GetBytes(updated, "tools.0.description").String())
|
||||
require.Equal(t, "string", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.type").String())
|
||||
require.Equal(t, "The search query to execute", gjson.GetBytes(updated, "tools.0.input_schema.properties.query.description").String())
|
||||
require.Equal(t, "query", gjson.GetBytes(updated, "tools.0.input_schema.required.0").String())
|
||||
require.True(t, gjson.GetBytes(updated, "tools.0.input_schema.additionalProperties").Bool() == false)
|
||||
}
|
||||
|
||||
func TestInjectToolResultsClaudeAppendsMessages(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[{"role":"user","content":"what is golang"}]
|
||||
}`)
|
||||
results := &WebSearchResults{
|
||||
Results: []WebSearchResult{
|
||||
{Title: "Go", URL: "https://go.dev"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := InjectToolResultsClaude(body, "srvtoolu_test", "golang", results)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "assistant", gjson.GetBytes(updated, "messages.1.role").String())
|
||||
require.Equal(t, "tool_use", gjson.GetBytes(updated, "messages.1.content.0.type").String())
|
||||
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "messages.1.content.0.id").String())
|
||||
require.Equal(t, "user", gjson.GetBytes(updated, "messages.2.role").String())
|
||||
require.Equal(t, "tool_result", gjson.GetBytes(updated, "messages.2.content.0.type").String())
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), "https://go.dev")
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.0.content").String(), `"title": "Go"`)
|
||||
require.Contains(t, gjson.GetBytes(updated, "messages.2.content.1.text").String(), "<search_guidance>")
|
||||
}
|
||||
|
||||
func TestExtractWebSearchToolUseFromResponse(t *testing.T) {
|
||||
response := []byte(`{
|
||||
"content":[
|
||||
{"type":"text","text":"let me search"},
|
||||
{"type":"tool_use","id":"srvtoolu_next","name":"remote_web_search","input":{"query":"golang concurrency"}}
|
||||
]
|
||||
}`)
|
||||
|
||||
toolUseID, query, ok := ExtractWebSearchToolUseFromResponse(response)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "srvtoolu_next", toolUseID)
|
||||
require.Equal(t, "golang concurrency", query)
|
||||
}
|
||||
|
||||
func TestInjectSearchIndicatorsInResponse(t *testing.T) {
|
||||
response := []byte(`{
|
||||
"id":"msg_1",
|
||||
"type":"message",
|
||||
"role":"assistant",
|
||||
"model":"kiro",
|
||||
"content":[{"type":"text","text":"final"}],
|
||||
"stop_reason":"end_turn",
|
||||
"usage":{"input_tokens":1,"output_tokens":1}
|
||||
}`)
|
||||
|
||||
snippet := "result snippet"
|
||||
updated, err := InjectSearchIndicatorsInResponse(response, []SearchIndicator{
|
||||
{
|
||||
ToolUseID: "srvtoolu_test",
|
||||
Query: "golang",
|
||||
Results: &WebSearchResults{
|
||||
Results: []WebSearchResult{{Title: "Go", URL: "https://go.dev", Snippet: &snippet}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &decoded))
|
||||
require.Equal(t, "server_tool_use", gjson.GetBytes(updated, "content.0.type").String())
|
||||
require.Equal(t, "srvtoolu_test", gjson.GetBytes(updated, "content.0.id").String())
|
||||
require.Equal(t, "web_search_tool_result", gjson.GetBytes(updated, "content.1.type").String())
|
||||
require.False(t, gjson.GetBytes(updated, "content.1.tool_use_id").Exists())
|
||||
require.Equal(t, "result snippet", gjson.GetBytes(updated, "content.1.content.0.encrypted_content").String())
|
||||
require.Equal(t, "null", gjson.GetBytes(updated, "content.1.content.0.page_age").Raw)
|
||||
require.False(t, gjson.GetBytes(updated, "content.1.content.0.page_content").Exists())
|
||||
require.Equal(t, "text", gjson.GetBytes(updated, "content.2.type").String())
|
||||
}
|
||||
|
||||
func TestParseSearchResults_PreservesExtendedFields(t *testing.T) {
|
||||
resp := &MCPResponse{
|
||||
Result: &struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
}{
|
||||
Content: []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}{
|
||||
{
|
||||
Type: "text",
|
||||
Text: `{"results":[{"title":"Go","url":"https://go.dev","snippet":"snippet","publishedDate":1710000000,"id":"doc-1","domain":"go.dev","maxVerbatimWordLimit":25,"publicDomain":true}]}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ParseSearchResults(resp)
|
||||
require.NotNil(t, results)
|
||||
require.Len(t, results.Results, 1)
|
||||
require.Equal(t, int64(1710000000), *results.Results[0].PublishedDate)
|
||||
require.Equal(t, "doc-1", *results.Results[0].ID)
|
||||
require.Equal(t, "go.dev", *results.Results[0].Domain)
|
||||
require.Equal(t, 25, *results.Results[0].MaxVerbatimWordLimit)
|
||||
require.True(t, *results.Results[0].PublicDomain)
|
||||
}
|
||||
|
||||
func TestSearchGuidanceText_IsStructured(t *testing.T) {
|
||||
guidance := searchGuidanceText()
|
||||
require.Contains(t, guidance, "<search_guidance>")
|
||||
require.Contains(t, guidance, "Current date:")
|
||||
require.Contains(t, guidance, "Then you MUST use the web_search tool again with a refined query.")
|
||||
require.Contains(t, guidance, "Rephrasing in English for better coverage")
|
||||
}
|
||||
@@ -0,0 +1,479 @@
|
||||
package kirocooldown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
MinRequestInterval = time.Second
|
||||
MaxRequestInterval = 2 * time.Second
|
||||
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
|
||||
ShortCooldown = time.Minute
|
||||
MaxCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
|
||||
redisTimeout = 3 * time.Second
|
||||
activeTTL = 10 * time.Second
|
||||
stateTTL = 25 * time.Hour
|
||||
keyPrefix = "kiro:cooldown:"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
||||
|
||||
reserveRequestScript = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
|
||||
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
|
||||
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
|
||||
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
|
||||
local interval_ms = tonumber(ARGV[1])
|
||||
local active_ttl_ms = tonumber(ARGV[2])
|
||||
local state_ttl_ms = tonumber(ARGV[3])
|
||||
|
||||
if cooldown_until_ms > now_ms then
|
||||
return {1, cooldown_until_ms - now_ms, cooldown_reason}
|
||||
end
|
||||
|
||||
if cooldown_until_ms > 0 then
|
||||
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
|
||||
end
|
||||
|
||||
local next_slot_ms = now_ms
|
||||
if last_request_ms > 0 then
|
||||
local candidate_ms = last_request_ms + interval_ms
|
||||
if candidate_ms > now_ms then
|
||||
next_slot_ms = candidate_ms
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
|
||||
if fail_count > 0 or cooldown_until_ms > now_ms then
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
else
|
||||
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
|
||||
end
|
||||
return {0, next_slot_ms - now_ms, ''}
|
||||
`)
|
||||
|
||||
mark429Script = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
|
||||
local short_cooldown_ms = tonumber(ARGV[1])
|
||||
local max_cooldown_ms = tonumber(ARGV[2])
|
||||
local state_ttl_ms = tonumber(ARGV[3])
|
||||
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
|
||||
if cooldown_ms > max_cooldown_ms then
|
||||
cooldown_ms = max_cooldown_ms
|
||||
end
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', fail_count,
|
||||
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||
'cooldown_reason', ARGV[4]
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
return cooldown_ms
|
||||
`)
|
||||
|
||||
markSuccessScript = redis.NewScript(`
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', 0,
|
||||
'cooldown_until_ms', 0,
|
||||
'cooldown_reason', ''
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
|
||||
return 1
|
||||
`)
|
||||
|
||||
markSuspendedScript = redis.NewScript(`
|
||||
local t = redis.call('TIME')
|
||||
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
||||
local cooldown_ms = tonumber(ARGV[1])
|
||||
local state_ttl_ms = tonumber(ARGV[2])
|
||||
redis.call('HSET', KEYS[1],
|
||||
'fail_count', 0,
|
||||
'cooldown_until_ms', now_ms + cooldown_ms,
|
||||
'cooldown_reason', ARGV[3]
|
||||
)
|
||||
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
||||
return cooldown_ms
|
||||
`)
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
remaining time.Duration
|
||||
reason string
|
||||
}
|
||||
|
||||
type State struct {
|
||||
Active bool
|
||||
Reason string
|
||||
CooldownUntil time.Time
|
||||
Remaining time.Duration
|
||||
FailCount int
|
||||
}
|
||||
|
||||
func NewError(remaining time.Duration, reason string) error {
|
||||
return &Error{remaining: remaining, reason: reason}
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.reason == "" {
|
||||
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
|
||||
}
|
||||
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
|
||||
}
|
||||
|
||||
func Calculate429Cooldown(retryCount int) time.Duration {
|
||||
if retryCount < 0 {
|
||||
retryCount = 0
|
||||
}
|
||||
cooldown := ShortCooldown * time.Duration(1<<retryCount)
|
||||
if cooldown > MaxCooldown {
|
||||
return MaxCooldown
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
rngMu sync.Mutex
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func NewStore(client *redis.Client) *Store {
|
||||
return &Store{
|
||||
client: client,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
values, err := reserveRequestScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
s.nextInterval().Milliseconds(),
|
||||
activeTTL.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
|
||||
}
|
||||
parts, ok := values.([]interface{})
|
||||
if !ok || len(parts) != 3 {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
|
||||
}
|
||||
state, err := luaInt64(parts[0])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
|
||||
}
|
||||
waitMS, err := luaInt64(parts[1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
|
||||
}
|
||||
reason, err := luaString(parts[2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
|
||||
}
|
||||
if state == 1 {
|
||||
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
|
||||
}
|
||||
if waitMS <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return time.Duration(waitMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
|
||||
if err := s.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
if err := markSuccessScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
activeTTL.Milliseconds(),
|
||||
).Err(); err != nil {
|
||||
return fmt.Errorf("kiro cooldown mark success: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
result, err := mark429Script.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
ShortCooldown.Milliseconds(),
|
||||
MaxCooldown.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
CooldownReason429,
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||
}
|
||||
cooldownMS, err := luaInt64(result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
||||
}
|
||||
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
result, err := markSuspendedScript.Run(
|
||||
cacheCtx,
|
||||
s.client,
|
||||
[]string{RedisKey(tokenKey)},
|
||||
LongCooldown.Milliseconds(),
|
||||
stateTTL.Milliseconds(),
|
||||
CooldownReasonSuspended,
|
||||
).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||
}
|
||||
cooldownMS, err := luaInt64(result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
||||
}
|
||||
return time.Duration(cooldownMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
values, err := s.client.HMGet(
|
||||
cacheCtx,
|
||||
RedisKey(tokenKey),
|
||||
"cooldown_until_ms",
|
||||
"cooldown_reason",
|
||||
"fail_count",
|
||||
).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
|
||||
}
|
||||
if len(values) != 3 {
|
||||
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
|
||||
}
|
||||
|
||||
cooldownUntilMS, err := luaInt64(values[0])
|
||||
if err != nil && values[0] != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
|
||||
}
|
||||
reason, err := luaString(values[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
|
||||
}
|
||||
failCount, err := luaInt64(values[2])
|
||||
if err != nil && values[2] != nil {
|
||||
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
|
||||
}
|
||||
if cooldownUntilMS <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cooldownUntil := time.UnixMilli(cooldownUntilMS)
|
||||
remaining := time.Until(cooldownUntil)
|
||||
if remaining <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &State{
|
||||
Active: true,
|
||||
Reason: reason,
|
||||
CooldownUntil: cooldownUntil,
|
||||
Remaining: remaining,
|
||||
FailCount: int(failCount),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
|
||||
if err := s.validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
uniqueKeys := make([]string, 0, len(tokenKeys))
|
||||
seen := make(map[string]struct{}, len(tokenKeys))
|
||||
for _, tokenKey := range tokenKeys {
|
||||
tokenKey = strings.TrimSpace(tokenKey)
|
||||
if tokenKey == "" {
|
||||
continue
|
||||
}
|
||||
redisKey := RedisKey(tokenKey)
|
||||
if _, ok := seen[redisKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[redisKey] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, redisKey)
|
||||
}
|
||||
if len(uniqueKeys) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cacheCtx, cancel := withRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
type candidate struct {
|
||||
redisKey string
|
||||
cooldownUntilMS int64
|
||||
failCount int64
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
var best *candidate
|
||||
|
||||
pipe := s.client.Pipeline()
|
||||
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
|
||||
for _, redisKey := range uniqueKeys {
|
||||
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
|
||||
}
|
||||
if _, err := pipe.Exec(cacheCtx); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
|
||||
}
|
||||
|
||||
for i, cmd := range cmds {
|
||||
values, err := cmd.Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
|
||||
}
|
||||
if len(values) != 3 {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
|
||||
}
|
||||
cooldownUntilMS, err := luaInt64(values[0])
|
||||
if err != nil && values[0] != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
|
||||
}
|
||||
reason, err := luaString(values[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
|
||||
}
|
||||
failCount, err := luaInt64(values[2])
|
||||
if err != nil && values[2] != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
|
||||
}
|
||||
if cooldownUntilMS <= now || reason != CooldownReason429 {
|
||||
continue
|
||||
}
|
||||
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
|
||||
if best == nil ||
|
||||
current.cooldownUntilMS < best.cooldownUntilMS ||
|
||||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
|
||||
best = current
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
|
||||
}
|
||||
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
|
||||
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func RedisKey(tokenKey string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
|
||||
digest := hex.EncodeToString(sum[:])
|
||||
return keyPrefix + "{" + digest + "}"
|
||||
}
|
||||
|
||||
func ActiveTTL() time.Duration {
|
||||
return activeTTL
|
||||
}
|
||||
|
||||
func StateTTL() time.Duration {
|
||||
return stateTTL
|
||||
}
|
||||
|
||||
func (s *Store) validate() error {
|
||||
if s == nil || s.client == nil {
|
||||
return ErrStoreUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) nextInterval() time.Duration {
|
||||
s.rngMu.Lock()
|
||||
defer s.rngMu.Unlock()
|
||||
if MaxRequestInterval <= MinRequestInterval {
|
||||
return MinRequestInterval
|
||||
}
|
||||
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
|
||||
}
|
||||
|
||||
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, redisTimeout)
|
||||
}
|
||||
|
||||
func luaInt64(v any) (int64, error) {
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return n, nil
|
||||
case int:
|
||||
return int64(n), nil
|
||||
case string:
|
||||
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
|
||||
case []byte:
|
||||
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func luaString(v any) (string, error) {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s, nil
|
||||
case []byte:
|
||||
return string(s), nil
|
||||
case nil:
|
||||
return "", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported lua string type %T", v)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package kirocooldown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func TestClearEarliestTransientCooldownEmptyKeysIsSafe(t *testing.T) {
|
||||
store := NewStore(redis.NewClient(&redis.Options{Addr: "127.0.0.1:0"}))
|
||||
|
||||
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ClearEarliestTransientCooldown(nil) error = %v", err)
|
||||
}
|
||||
if cleared {
|
||||
t.Fatal("ClearEarliestTransientCooldown(nil) cleared = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearEarliestTransientCooldownUnavailableStore(t *testing.T) {
|
||||
store := NewStore(nil)
|
||||
|
||||
cleared, err := store.ClearEarliestTransientCooldown(context.Background(), []string{"token"})
|
||||
if err == nil {
|
||||
t.Fatal("ClearEarliestTransientCooldown unavailable store error = nil")
|
||||
}
|
||||
if cleared {
|
||||
t.Fatal("ClearEarliestTransientCooldown unavailable store cleared = true, want false")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user