Files
sub2api/backend/internal/service/kiro_available_models.go
2026-05-16 15:31:08 +08:00

202 lines
6.2 KiB
Go

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
kiropkg "github.com/Wei-Shaw/sub2api/internal/pkg/kiro"
"github.com/google/uuid"
)
var ErrKiroModelListUnsupported = errors.New("kiro upstream model list requires an OAuth/access-token account")
type kiroAvailableModelsResponse struct {
AvailableModels []kiroAvailableModelItem `json:"availableModels"`
AvailableModelsSnake []kiroAvailableModelItem `json:"available_models"`
Models []kiroAvailableModelItem `json:"models"`
NextToken string `json:"nextToken"`
NextTokenSnake string `json:"next_token"`
}
type kiroAvailableModelItem struct {
ModelID string `json:"modelId"`
ModelIDSnake string `json:"model_id"`
ID string `json:"id"`
ModelName string `json:"modelName"`
ModelNameSnake string `json:"model_name"`
DisplayName string `json:"displayName"`
DisplayNameSnake string `json:"display_name"`
Name string `json:"name"`
}
func (s *AccountTestService) FetchKiroUpstreamModels(ctx context.Context, account *Account) ([]kiropkg.Model, error) {
if account == nil {
return nil, errors.New("account is nil")
}
if account.Platform != PlatformKiro {
return nil, fmt.Errorf("not a kiro account")
}
token := strings.TrimSpace(account.GetCredential("access_token"))
if account.Type == AccountTypeOAuth {
if s == nil || s.kiroTokenProvider == nil {
return nil, errors.New("kiro token provider not configured")
}
accessToken, err := s.kiroTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get kiro access token failed: %w", err)
}
token = strings.TrimSpace(accessToken)
}
if token == "" {
return nil, ErrKiroModelListUnsupported
}
return requestKiroAvailableModels(ctx, account, kiroAPIRegion(account), strings.TrimSpace(account.GetCredential("profile_arn")), token)
}
func requestKiroAvailableModels(ctx context.Context, account *Account, region, profileArn, token string) ([]kiropkg.Model, error) {
endpoint := resolveKiroRuntimeEndpoint(region)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: accountProxyURL(account),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: isLoopbackEndpoint(endpoint),
})
if err != nil {
return nil, fmt.Errorf("create kiro model list client failed: %w", err)
}
var all []kiroAvailableModelItem
nextToken := ""
for {
resp, err := requestKiroAvailableModelsPage(ctx, client, account, endpoint, profileArn, token, nextToken)
if err != nil {
return nil, err
}
all = append(all, resp.items()...)
nextToken = strings.TrimSpace(resp.NextToken)
if nextToken == "" {
nextToken = strings.TrimSpace(resp.NextTokenSnake)
}
if nextToken == "" {
break
}
}
return mapKiroAvailableModels(all), nil
}
func requestKiroAvailableModelsPage(ctx context.Context, client *http.Client, account *Account, endpoint, profileArn, token, nextToken string) (*kiroAvailableModelsResponse, error) {
reqURL, err := url.Parse(endpoint + "/ListAvailableModels")
if err != nil {
return nil, fmt.Errorf("build kiro model list url failed: %w", err)
}
q := reqURL.Query()
q.Set("origin", kiroUsageOrigin)
q.Set("maxResults", "50")
if profileArn != "" {
q.Set("profileArn", profileArn)
}
if nextToken != "" {
q.Set("nextToken", nextToken)
}
reqURL.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("create kiro model list request failed: %w", err)
}
applyKiroModelListHeaders(req, account, token)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("kiro model list request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("read kiro model list response failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &kiroUsageHTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(body))}
}
var parsed kiroAvailableModelsResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("decode kiro model list response failed: %w", err)
}
return &parsed, nil
}
func applyKiroModelListHeaders(req *http.Request, account *Account, token string) {
if req == nil {
return
}
accountKey := buildKiroAccountKey(account)
machineID := buildKiroMachineID(account)
req.Header.Set("Accept", "*/*")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("User-Agent", kiropkg.BuildRuntimeUserAgent(accountKey, machineID))
req.Header.Set("X-Amz-User-Agent", kiropkg.BuildRuntimeAmzUserAgent(accountKey, machineID))
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.NewString())
if account != nil {
applyKiroConditionalHeaders(req, account)
}
}
func (r *kiroAvailableModelsResponse) items() []kiroAvailableModelItem {
if r == nil {
return nil
}
switch {
case len(r.AvailableModels) > 0:
return r.AvailableModels
case len(r.AvailableModelsSnake) > 0:
return r.AvailableModelsSnake
default:
return r.Models
}
}
func mapKiroAvailableModels(items []kiroAvailableModelItem) []kiropkg.Model {
seen := make(map[string]struct{}, len(items))
models := make([]kiropkg.Model, 0, len(items))
for _, item := range items {
id := firstNonEmptyKiroModelField(item.ModelID, item.ModelIDSnake, item.ID)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
displayName := firstNonEmptyKiroModelField(item.ModelName, item.ModelNameSnake, item.DisplayName, item.DisplayNameSnake, item.Name, id)
models = append(models, kiropkg.Model{ID: id, Type: "model", DisplayName: displayName})
}
sort.Slice(models, func(i, j int) bool { return models[i].ID < models[j].ID })
return models
}
func firstNonEmptyKiroModelField(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}