480 lines
12 KiB
Go
480 lines
12 KiB
Go
package kirocooldown
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
MinRequestInterval = time.Second
|
|
MaxRequestInterval = 2 * time.Second
|
|
|
|
CooldownReason429 = "rate_limit_exceeded"
|
|
CooldownReasonSuspended = "account_suspended"
|
|
|
|
ShortCooldown = time.Minute
|
|
MaxCooldown = 5 * time.Minute
|
|
LongCooldown = 24 * time.Hour
|
|
|
|
redisTimeout = 3 * time.Second
|
|
activeTTL = 10 * time.Second
|
|
stateTTL = 25 * time.Hour
|
|
keyPrefix = "kiro:cooldown:"
|
|
)
|
|
|
|
var (
|
|
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
|
|
|
|
reserveRequestScript = redis.NewScript(`
|
|
local t = redis.call('TIME')
|
|
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
|
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
|
|
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
|
|
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
|
|
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
|
|
local interval_ms = tonumber(ARGV[1])
|
|
local active_ttl_ms = tonumber(ARGV[2])
|
|
local state_ttl_ms = tonumber(ARGV[3])
|
|
|
|
if cooldown_until_ms > now_ms then
|
|
return {1, cooldown_until_ms - now_ms, cooldown_reason}
|
|
end
|
|
|
|
if cooldown_until_ms > 0 then
|
|
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
|
|
end
|
|
|
|
local next_slot_ms = now_ms
|
|
if last_request_ms > 0 then
|
|
local candidate_ms = last_request_ms + interval_ms
|
|
if candidate_ms > now_ms then
|
|
next_slot_ms = candidate_ms
|
|
end
|
|
end
|
|
|
|
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
|
|
if fail_count > 0 or cooldown_until_ms > now_ms then
|
|
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
|
else
|
|
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
|
|
end
|
|
return {0, next_slot_ms - now_ms, ''}
|
|
`)
|
|
|
|
mark429Script = redis.NewScript(`
|
|
local t = redis.call('TIME')
|
|
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
|
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
|
|
local short_cooldown_ms = tonumber(ARGV[1])
|
|
local max_cooldown_ms = tonumber(ARGV[2])
|
|
local state_ttl_ms = tonumber(ARGV[3])
|
|
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
|
|
if cooldown_ms > max_cooldown_ms then
|
|
cooldown_ms = max_cooldown_ms
|
|
end
|
|
redis.call('HSET', KEYS[1],
|
|
'fail_count', fail_count,
|
|
'cooldown_until_ms', now_ms + cooldown_ms,
|
|
'cooldown_reason', ARGV[4]
|
|
)
|
|
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
|
return cooldown_ms
|
|
`)
|
|
|
|
markSuccessScript = redis.NewScript(`
|
|
redis.call('HSET', KEYS[1],
|
|
'fail_count', 0,
|
|
'cooldown_until_ms', 0,
|
|
'cooldown_reason', ''
|
|
)
|
|
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
|
|
return 1
|
|
`)
|
|
|
|
markSuspendedScript = redis.NewScript(`
|
|
local t = redis.call('TIME')
|
|
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
|
|
local cooldown_ms = tonumber(ARGV[1])
|
|
local state_ttl_ms = tonumber(ARGV[2])
|
|
redis.call('HSET', KEYS[1],
|
|
'fail_count', 0,
|
|
'cooldown_until_ms', now_ms + cooldown_ms,
|
|
'cooldown_reason', ARGV[3]
|
|
)
|
|
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
|
|
return cooldown_ms
|
|
`)
|
|
)
|
|
|
|
type Error struct {
|
|
remaining time.Duration
|
|
reason string
|
|
}
|
|
|
|
type State struct {
|
|
Active bool
|
|
Reason string
|
|
CooldownUntil time.Time
|
|
Remaining time.Duration
|
|
FailCount int
|
|
}
|
|
|
|
func NewError(remaining time.Duration, reason string) error {
|
|
return &Error{remaining: remaining, reason: reason}
|
|
}
|
|
|
|
func (e *Error) Error() string {
|
|
if e == nil {
|
|
return ""
|
|
}
|
|
if e.reason == "" {
|
|
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
|
|
}
|
|
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
|
|
}
|
|
|
|
func Calculate429Cooldown(retryCount int) time.Duration {
|
|
if retryCount < 0 {
|
|
retryCount = 0
|
|
}
|
|
cooldown := ShortCooldown * time.Duration(1<<retryCount)
|
|
if cooldown > MaxCooldown {
|
|
return MaxCooldown
|
|
}
|
|
return cooldown
|
|
}
|
|
|
|
type Store struct {
|
|
client *redis.Client
|
|
rngMu sync.Mutex
|
|
rng *rand.Rand
|
|
}
|
|
|
|
func NewStore(client *redis.Client) *Store {
|
|
return &Store{
|
|
client: client,
|
|
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
|
}
|
|
}
|
|
|
|
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
|
|
if err := s.validate(); err != nil {
|
|
return 0, err
|
|
}
|
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
|
defer cancel()
|
|
|
|
values, err := reserveRequestScript.Run(
|
|
cacheCtx,
|
|
s.client,
|
|
[]string{RedisKey(tokenKey)},
|
|
s.nextInterval().Milliseconds(),
|
|
activeTTL.Milliseconds(),
|
|
stateTTL.Milliseconds(),
|
|
).Result()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
|
|
}
|
|
parts, ok := values.([]any)
|
|
if !ok || len(parts) != 3 {
|
|
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
|
|
}
|
|
state, err := luaInt64(parts[0])
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
|
|
}
|
|
waitMS, err := luaInt64(parts[1])
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
|
|
}
|
|
reason, err := luaString(parts[2])
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
|
|
}
|
|
if state == 1 {
|
|
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
|
|
}
|
|
if waitMS <= 0 {
|
|
return 0, nil
|
|
}
|
|
return time.Duration(waitMS) * time.Millisecond, nil
|
|
}
|
|
|
|
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
|
|
if err := s.validate(); err != nil {
|
|
return err
|
|
}
|
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
|
defer cancel()
|
|
if err := markSuccessScript.Run(
|
|
cacheCtx,
|
|
s.client,
|
|
[]string{RedisKey(tokenKey)},
|
|
activeTTL.Milliseconds(),
|
|
).Err(); err != nil {
|
|
return fmt.Errorf("kiro cooldown mark success: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
|
|
if err := s.validate(); err != nil {
|
|
return 0, err
|
|
}
|
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
|
defer cancel()
|
|
result, err := mark429Script.Run(
|
|
cacheCtx,
|
|
s.client,
|
|
[]string{RedisKey(tokenKey)},
|
|
ShortCooldown.Milliseconds(),
|
|
MaxCooldown.Milliseconds(),
|
|
stateTTL.Milliseconds(),
|
|
CooldownReason429,
|
|
).Result()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
|
}
|
|
cooldownMS, err := luaInt64(result)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
|
|
}
|
|
return time.Duration(cooldownMS) * time.Millisecond, nil
|
|
}
|
|
|
|
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
|
|
if err := s.validate(); err != nil {
|
|
return 0, err
|
|
}
|
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
|
defer cancel()
|
|
result, err := markSuspendedScript.Run(
|
|
cacheCtx,
|
|
s.client,
|
|
[]string{RedisKey(tokenKey)},
|
|
LongCooldown.Milliseconds(),
|
|
stateTTL.Milliseconds(),
|
|
CooldownReasonSuspended,
|
|
).Result()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
|
}
|
|
cooldownMS, err := luaInt64(result)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
|
|
}
|
|
return time.Duration(cooldownMS) * time.Millisecond, nil
|
|
}
|
|
|
|
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
|
|
if err := s.validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
|
defer cancel()
|
|
|
|
values, err := s.client.HMGet(
|
|
cacheCtx,
|
|
RedisKey(tokenKey),
|
|
"cooldown_until_ms",
|
|
"cooldown_reason",
|
|
"fail_count",
|
|
).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
|
|
}
|
|
if len(values) != 3 {
|
|
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
|
|
}
|
|
|
|
cooldownUntilMS, err := luaInt64(values[0])
|
|
if err != nil && values[0] != nil {
|
|
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
|
|
}
|
|
reason, err := luaString(values[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
|
|
}
|
|
failCount, err := luaInt64(values[2])
|
|
if err != nil && values[2] != nil {
|
|
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
|
|
}
|
|
if cooldownUntilMS <= 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
cooldownUntil := time.UnixMilli(cooldownUntilMS)
|
|
remaining := time.Until(cooldownUntil)
|
|
if remaining <= 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return &State{
|
|
Active: true,
|
|
Reason: reason,
|
|
CooldownUntil: cooldownUntil,
|
|
Remaining: remaining,
|
|
FailCount: int(failCount),
|
|
}, nil
|
|
}
|
|
|
|
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
|
|
if err := s.validate(); err != nil {
|
|
return false, err
|
|
}
|
|
uniqueKeys := make([]string, 0, len(tokenKeys))
|
|
seen := make(map[string]struct{}, len(tokenKeys))
|
|
for _, tokenKey := range tokenKeys {
|
|
tokenKey = strings.TrimSpace(tokenKey)
|
|
if tokenKey == "" {
|
|
continue
|
|
}
|
|
redisKey := RedisKey(tokenKey)
|
|
if _, ok := seen[redisKey]; ok {
|
|
continue
|
|
}
|
|
seen[redisKey] = struct{}{}
|
|
uniqueKeys = append(uniqueKeys, redisKey)
|
|
}
|
|
if len(uniqueKeys) == 0 {
|
|
return false, nil
|
|
}
|
|
|
|
cacheCtx, cancel := withRedisTimeout(ctx)
|
|
defer cancel()
|
|
|
|
type candidate struct {
|
|
redisKey string
|
|
cooldownUntilMS int64
|
|
failCount int64
|
|
}
|
|
now := time.Now().UnixMilli()
|
|
var best *candidate
|
|
|
|
pipe := s.client.Pipeline()
|
|
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
|
|
for _, redisKey := range uniqueKeys {
|
|
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
|
|
}
|
|
if _, err := pipe.Exec(cacheCtx); err != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
|
|
}
|
|
|
|
for i, cmd := range cmds {
|
|
values, err := cmd.Result()
|
|
if err != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
|
|
}
|
|
if len(values) != 3 {
|
|
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
|
|
}
|
|
cooldownUntilMS, err := luaInt64(values[0])
|
|
if err != nil && values[0] != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
|
|
}
|
|
reason, err := luaString(values[1])
|
|
if err != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
|
|
}
|
|
failCount, err := luaInt64(values[2])
|
|
if err != nil && values[2] != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
|
|
}
|
|
if cooldownUntilMS <= now || reason != CooldownReason429 {
|
|
continue
|
|
}
|
|
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
|
|
if best == nil ||
|
|
current.cooldownUntilMS < best.cooldownUntilMS ||
|
|
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
|
|
best = current
|
|
}
|
|
}
|
|
if best == nil {
|
|
return false, nil
|
|
}
|
|
|
|
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
|
|
}
|
|
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
|
|
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func RedisKey(tokenKey string) string {
|
|
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
|
|
digest := hex.EncodeToString(sum[:])
|
|
return keyPrefix + "{" + digest + "}"
|
|
}
|
|
|
|
func ActiveTTL() time.Duration {
|
|
return activeTTL
|
|
}
|
|
|
|
func StateTTL() time.Duration {
|
|
return stateTTL
|
|
}
|
|
|
|
func (s *Store) validate() error {
|
|
if s == nil || s.client == nil {
|
|
return ErrStoreUnavailable
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) nextInterval() time.Duration {
|
|
s.rngMu.Lock()
|
|
defer s.rngMu.Unlock()
|
|
if MaxRequestInterval <= MinRequestInterval {
|
|
return MinRequestInterval
|
|
}
|
|
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
|
|
}
|
|
|
|
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
return context.WithTimeout(ctx, redisTimeout)
|
|
}
|
|
|
|
func luaInt64(v any) (int64, error) {
|
|
switch n := v.(type) {
|
|
case int64:
|
|
return n, nil
|
|
case int:
|
|
return int64(n), nil
|
|
case string:
|
|
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
|
|
case []byte:
|
|
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
|
|
default:
|
|
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
|
|
}
|
|
}
|
|
|
|
func luaString(v any) (string, error) {
|
|
switch s := v.(type) {
|
|
case string:
|
|
return s, nil
|
|
case []byte:
|
|
return string(s), nil
|
|
case nil:
|
|
return "", nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported lua string type %T", v)
|
|
}
|
|
}
|