fix: improve kiro usage token accounting

This commit is contained in:
kone
2026-05-17 01:25:37 +08:00
parent 4a06371bba
commit ebb03dc91c
3 changed files with 111 additions and 11 deletions
+59 -10
View File
@@ -1827,6 +1827,9 @@ func parseEventStream(body io.Reader) (string, []KiroToolUse, Usage, string, err
cleanText, embeddedToolUses, _ := drainEmbeddedToolText(content.String())
toolUses = append(toolUses, embeddedToolUses...)
toolUses = deduplicateToolUses(toolUses)
if usage.OutputTokens == 0 {
usage.OutputTokens = estimateKiroOutputTokens(cleanText, toolUses)
}
if usage.TotalTokens == 0 {
usage.TotalTokens = usage.InputTokens + usage.OutputTokens
@@ -2629,16 +2632,16 @@ func updateUsageFromEvent(usage *Usage, eventType string, event map[string]inter
meta = event
}
if tokenUsage, ok := meta["tokenUsage"].(map[string]interface{}); ok {
if value, ok := toInt(tokenUsage["uncachedInputTokens"]); ok {
if value, ok := firstInt(tokenUsage, "uncachedInputTokens", "inputTokens", "inputTokenCount", "promptTokens", "prompt_tokens"); ok {
usage.InputTokens = value
}
if value, ok := toInt(tokenUsage["outputTokens"]); ok {
if value, ok := firstInt(tokenUsage, "outputTokens", "outputTokenCount", "completionTokens", "completion_tokens", "generatedTokens", "generatedTokenCount"); ok {
usage.OutputTokens = value
}
if value, ok := toInt(tokenUsage["totalTokens"]); ok {
if value, ok := firstInt(tokenUsage, "totalTokens", "totalTokenCount"); ok {
usage.TotalTokens = value
}
if value, ok := toInt(tokenUsage["cacheReadInputTokens"]); ok {
if value, ok := firstInt(tokenUsage, "cacheReadInputTokens", "cachedInputTokens", "cacheReadTokens", "cachedTokens", "cached_tokens"); ok {
usage.CacheReadInputTokens = value
if usage.InputTokens == 0 {
usage.InputTokens = value
@@ -2647,26 +2650,72 @@ func updateUsageFromEvent(usage *Usage, eventType string, event map[string]inter
}
}
}
if value, ok := toInt(event["inputTokens"]); ok && value > 0 {
if value, ok := firstInt(event, "inputTokens", "inputTokenCount", "promptTokens", "prompt_tokens"); ok && value > 0 {
usage.InputTokens = value
}
if value, ok := toInt(event["outputTokens"]); ok && value > 0 {
if value, ok := firstInt(event, "outputTokens", "outputTokenCount", "completionTokens", "completion_tokens", "generatedTokens", "generatedTokenCount"); ok && value > 0 {
usage.OutputTokens = value
}
if value, ok := toInt(event["totalTokens"]); ok && value > 0 {
if value, ok := firstInt(event, "totalTokens", "totalTokenCount"); ok && value > 0 {
usage.TotalTokens = value
}
if value, ok := toInt(meta["inputTokens"]); ok && value > 0 {
if value, ok := firstInt(meta, "inputTokens", "inputTokenCount", "promptTokens", "prompt_tokens"); ok && value > 0 {
usage.InputTokens = value
}
if value, ok := toInt(meta["outputTokens"]); ok && value > 0 {
if value, ok := firstInt(meta, "outputTokens", "outputTokenCount", "completionTokens", "completion_tokens", "generatedTokens", "generatedTokenCount"); ok && value > 0 {
usage.OutputTokens = value
}
if value, ok := toInt(meta["totalTokens"]); ok && value > 0 {
if value, ok := firstInt(meta, "totalTokens", "totalTokenCount"); ok && value > 0 {
usage.TotalTokens = value
}
}
func firstInt(m map[string]interface{}, keys ...string) (int, bool) {
for _, key := range keys {
if value, ok := toInt(m[key]); ok {
return value, true
}
}
return 0, false
}
func estimateKiroOutputTokens(content string, toolUses []KiroToolUse) int {
total := countKiroTextTokens(content)
for _, tool := range toolUses {
if tool.IsTruncated {
continue
}
if tool.Name != "" {
total += countKiroTextTokens(tool.Name)
}
if tool.Input != nil {
if b, err := json.Marshal(tool.Input); err == nil {
total += countKiroTextTokens(string(b))
}
}
}
return total
}
func countKiroTextTokens(text string) int {
if strings.TrimSpace(text) == "" {
return 0
}
units := 0
for _, r := range text {
if r >= '\u4e00' && r <= '\u9fff' {
units += 2
} else {
units++
}
}
tokens := (units + 3) / 4
if tokens < 1 {
return 1
}
return tokens
}
func readToolUses(primary, fallback map[string]interface{}) []KiroToolUse {
var raw []interface{}
if value, ok := primary["toolUses"].([]interface{}); ok {
@@ -279,6 +279,57 @@ func TestParseNonStreamingEventStream(t *testing.T) {
require.True(t, strings.Contains(firstText, "hello from kiro"))
}
func TestParseNonStreamingEventStreamUsageAliases(t *testing.T) {
stream := bytes.NewBuffer(nil)
_, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
"assistantResponseEvent": map[string]any{
"content": "hello",
},
}))
_, _ = stream.Write(buildEventStreamFrame(t, "metadataEvent", map[string]any{
"metadataEvent": map[string]any{
"tokenUsage": map[string]any{
"inputTokenCount": 12,
"completionTokens": 7,
"cachedTokens": 3,
"totalTokenCount": 22,
},
},
}))
result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
require.NoError(t, err)
require.Equal(t, 15, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.Equal(t, 22, result.Usage.TotalTokens)
require.Equal(t, float64(3), gjson.GetBytes(result.ResponseBody, "usage.cache_read_input_tokens").Float())
}
func TestParseNonStreamingEventStreamEstimatesMissingOutputTokens(t *testing.T) {
stream := bytes.NewBuffer(nil)
_, _ = stream.Write(buildEventStreamFrame(t, "assistantResponseEvent", map[string]any{
"assistantResponseEvent": map[string]any{
"content": "hello from kiro",
},
}))
_, _ = stream.Write(buildEventStreamFrame(t, "messageMetadataEvent", map[string]any{
"messageMetadataEvent": map[string]any{
"tokenUsage": map[string]any{
"uncachedInputTokens": 12,
"cacheReadInputTokens": 3,
},
},
}))
result, err := ParseNonStreamingEventStream(stream, "claude-sonnet-4-5")
require.NoError(t, err)
require.Equal(t, 15, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 19, result.Usage.TotalTokens)
require.Equal(t, float64(4), gjson.GetBytes(result.ResponseBody, "usage.output_tokens").Float())
}
func TestExtractThinkingBlocksIgnoresLiteralTags(t *testing.T) {
content := strings.Join([]string{
"Use `<thinking>` literally.",