Files
sub2api/backend/internal/pkg/kirocooldown/store.go
T
2026-05-17 06:45:35 +08:00

480 lines
12 KiB
Go

package kirocooldown
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
const (
MinRequestInterval = time.Second
MaxRequestInterval = 2 * time.Second
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
ShortCooldown = time.Minute
MaxCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
redisTimeout = 3 * time.Second
activeTTL = 10 * time.Second
stateTTL = 25 * time.Hour
keyPrefix = "kiro:cooldown:"
)
var (
ErrStoreUnavailable = errors.New("kiro cooldown store unavailable")
reserveRequestScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local last_request_ms = tonumber(redis.call('HGET', KEYS[1], 'last_request_ms') or '0')
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0')
local cooldown_until_ms = tonumber(redis.call('HGET', KEYS[1], 'cooldown_until_ms') or '0')
local cooldown_reason = redis.call('HGET', KEYS[1], 'cooldown_reason') or ''
local interval_ms = tonumber(ARGV[1])
local active_ttl_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
if cooldown_until_ms > now_ms then
return {1, cooldown_until_ms - now_ms, cooldown_reason}
end
if cooldown_until_ms > 0 then
redis.call('HDEL', KEYS[1], 'cooldown_until_ms', 'cooldown_reason')
end
local next_slot_ms = now_ms
if last_request_ms > 0 then
local candidate_ms = last_request_ms + interval_ms
if candidate_ms > now_ms then
next_slot_ms = candidate_ms
end
end
redis.call('HSET', KEYS[1], 'last_request_ms', next_slot_ms)
if fail_count > 0 or cooldown_until_ms > now_ms then
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
else
redis.call('PEXPIRE', KEYS[1], active_ttl_ms)
end
return {0, next_slot_ms - now_ms, ''}
`)
mark429Script = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local fail_count = tonumber(redis.call('HGET', KEYS[1], 'fail_count') or '0') + 1
local short_cooldown_ms = tonumber(ARGV[1])
local max_cooldown_ms = tonumber(ARGV[2])
local state_ttl_ms = tonumber(ARGV[3])
local cooldown_ms = short_cooldown_ms * (2 ^ (fail_count - 1))
if cooldown_ms > max_cooldown_ms then
cooldown_ms = max_cooldown_ms
end
redis.call('HSET', KEYS[1],
'fail_count', fail_count,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[4]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
markSuccessScript = redis.NewScript(`
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', 0,
'cooldown_reason', ''
)
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
return 1
`)
markSuspendedScript = redis.NewScript(`
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local cooldown_ms = tonumber(ARGV[1])
local state_ttl_ms = tonumber(ARGV[2])
redis.call('HSET', KEYS[1],
'fail_count', 0,
'cooldown_until_ms', now_ms + cooldown_ms,
'cooldown_reason', ARGV[3]
)
redis.call('PEXPIRE', KEYS[1], state_ttl_ms)
return cooldown_ms
`)
)
type Error struct {
remaining time.Duration
reason string
}
type State struct {
Active bool
Reason string
CooldownUntil time.Time
Remaining time.Duration
FailCount int
}
func NewError(remaining time.Duration, reason string) error {
return &Error{remaining: remaining, reason: reason}
}
func (e *Error) Error() string {
if e == nil {
return ""
}
if e.reason == "" {
return fmt.Sprintf("kiro token is in cooldown for %v", e.remaining.Round(time.Second))
}
return fmt.Sprintf("kiro token is in cooldown for %v (reason: %s)", e.remaining.Round(time.Second), e.reason)
}
func Calculate429Cooldown(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
cooldown := ShortCooldown * time.Duration(1<<retryCount)
if cooldown > MaxCooldown {
return MaxCooldown
}
return cooldown
}
type Store struct {
client *redis.Client
rngMu sync.Mutex
rng *rand.Rand
}
func NewStore(client *redis.Client) *Store {
return &Store{
client: client,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (s *Store) ReserveRequest(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := reserveRequestScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
s.nextInterval().Milliseconds(),
activeTTL.Milliseconds(),
stateTTL.Milliseconds(),
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request: %w", err)
}
parts, ok := values.([]any)
if !ok || len(parts) != 3 {
return 0, fmt.Errorf("kiro cooldown reserve request: unexpected response %T", values)
}
state, err := luaInt64(parts[0])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request state: %w", err)
}
waitMS, err := luaInt64(parts[1])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request wait: %w", err)
}
reason, err := luaString(parts[2])
if err != nil {
return 0, fmt.Errorf("kiro cooldown reserve request reason: %w", err)
}
if state == 1 {
return 0, NewError(time.Duration(waitMS)*time.Millisecond, reason)
}
if waitMS <= 0 {
return 0, nil
}
return time.Duration(waitMS) * time.Millisecond, nil
}
func (s *Store) MarkSuccess(ctx context.Context, tokenKey string) error {
if err := s.validate(); err != nil {
return err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
if err := markSuccessScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
activeTTL.Milliseconds(),
).Err(); err != nil {
return fmt.Errorf("kiro cooldown mark success: %w", err)
}
return nil
}
func (s *Store) Mark429(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := mark429Script.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
ShortCooldown.Milliseconds(),
MaxCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReason429,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark 429: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) MarkSuspended(ctx context.Context, tokenKey string) (time.Duration, error) {
if err := s.validate(); err != nil {
return 0, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
result, err := markSuspendedScript.Run(
cacheCtx,
s.client,
[]string{RedisKey(tokenKey)},
LongCooldown.Milliseconds(),
stateTTL.Milliseconds(),
CooldownReasonSuspended,
).Result()
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
cooldownMS, err := luaInt64(result)
if err != nil {
return 0, fmt.Errorf("kiro cooldown mark suspended: %w", err)
}
return time.Duration(cooldownMS) * time.Millisecond, nil
}
func (s *Store) GetState(ctx context.Context, tokenKey string) (*State, error) {
if err := s.validate(); err != nil {
return nil, err
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
values, err := s.client.HMGet(
cacheCtx,
RedisKey(tokenKey),
"cooldown_until_ms",
"cooldown_reason",
"fail_count",
).Result()
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state: %w", err)
}
if len(values) != 3 {
return nil, fmt.Errorf("kiro cooldown get state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return nil, fmt.Errorf("kiro cooldown get state cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return nil, fmt.Errorf("kiro cooldown get state reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return nil, fmt.Errorf("kiro cooldown get state fail_count: %w", err)
}
if cooldownUntilMS <= 0 {
return nil, nil
}
cooldownUntil := time.UnixMilli(cooldownUntilMS)
remaining := time.Until(cooldownUntil)
if remaining <= 0 {
return nil, nil
}
return &State{
Active: true,
Reason: reason,
CooldownUntil: cooldownUntil,
Remaining: remaining,
FailCount: int(failCount),
}, nil
}
func (s *Store) ClearEarliestTransientCooldown(ctx context.Context, tokenKeys []string) (bool, error) {
if err := s.validate(); err != nil {
return false, err
}
uniqueKeys := make([]string, 0, len(tokenKeys))
seen := make(map[string]struct{}, len(tokenKeys))
for _, tokenKey := range tokenKeys {
tokenKey = strings.TrimSpace(tokenKey)
if tokenKey == "" {
continue
}
redisKey := RedisKey(tokenKey)
if _, ok := seen[redisKey]; ok {
continue
}
seen[redisKey] = struct{}{}
uniqueKeys = append(uniqueKeys, redisKey)
}
if len(uniqueKeys) == 0 {
return false, nil
}
cacheCtx, cancel := withRedisTimeout(ctx)
defer cancel()
type candidate struct {
redisKey string
cooldownUntilMS int64
failCount int64
}
now := time.Now().UnixMilli()
var best *candidate
pipe := s.client.Pipeline()
cmds := make([]*redis.SliceCmd, 0, len(uniqueKeys))
for _, redisKey := range uniqueKeys {
cmds = append(cmds, pipe.HMGet(cacheCtx, redisKey, "cooldown_until_ms", "cooldown_reason", "fail_count"))
}
if _, err := pipe.Exec(cacheCtx); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient scan: %w", err)
}
for i, cmd := range cmds {
values, err := cmd.Result()
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient state: %w", err)
}
if len(values) != 3 {
return false, fmt.Errorf("kiro cooldown clear transient state: unexpected response length %d", len(values))
}
cooldownUntilMS, err := luaInt64(values[0])
if err != nil && values[0] != nil {
return false, fmt.Errorf("kiro cooldown clear transient cooldown_until_ms: %w", err)
}
reason, err := luaString(values[1])
if err != nil {
return false, fmt.Errorf("kiro cooldown clear transient reason: %w", err)
}
failCount, err := luaInt64(values[2])
if err != nil && values[2] != nil {
return false, fmt.Errorf("kiro cooldown clear transient fail_count: %w", err)
}
if cooldownUntilMS <= now || reason != CooldownReason429 {
continue
}
current := &candidate{redisKey: uniqueKeys[i], cooldownUntilMS: cooldownUntilMS, failCount: failCount}
if best == nil ||
current.cooldownUntilMS < best.cooldownUntilMS ||
(current.cooldownUntilMS == best.cooldownUntilMS && current.failCount < best.failCount) {
best = current
}
}
if best == nil {
return false, nil
}
if err := s.client.HDel(cacheCtx, best.redisKey, "cooldown_until_ms", "cooldown_reason").Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient: %w", err)
}
if err := s.client.Expire(cacheCtx, best.redisKey, activeTTL).Err(); err != nil {
return false, fmt.Errorf("kiro cooldown clear transient ttl: %w", err)
}
return true, nil
}
func RedisKey(tokenKey string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(tokenKey)))
digest := hex.EncodeToString(sum[:])
return keyPrefix + "{" + digest + "}"
}
func ActiveTTL() time.Duration {
return activeTTL
}
func StateTTL() time.Duration {
return stateTTL
}
func (s *Store) validate() error {
if s == nil || s.client == nil {
return ErrStoreUnavailable
}
return nil
}
func (s *Store) nextInterval() time.Duration {
s.rngMu.Lock()
defer s.rngMu.Unlock()
if MaxRequestInterval <= MinRequestInterval {
return MinRequestInterval
}
return MinRequestInterval + time.Duration(s.rng.Int63n(int64(MaxRequestInterval-MinRequestInterval)))
}
func withRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, redisTimeout)
}
func luaInt64(v any) (int64, error) {
switch n := v.(type) {
case int64:
return n, nil
case int:
return int64(n), nil
case string:
return strconv.ParseInt(strings.TrimSpace(n), 10, 64)
case []byte:
return strconv.ParseInt(strings.TrimSpace(string(n)), 10, 64)
default:
return 0, fmt.Errorf("unsupported lua numeric type %T", v)
}
}
func luaString(v any) (string, error) {
switch s := v.(type) {
case string:
return s, nil
case []byte:
return string(s), nil
case nil:
return "", nil
default:
return "", fmt.Errorf("unsupported lua string type %T", v)
}
}