diff --git a/package.json b/package.json index 49b1fff..57c009b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,7 @@ { "type": "module", "dependencies": { + "@anthropic-ai/tokenizer": "^0.0.4", "axios": "^1.10.0", "deepmerge": "^4.3.1", "dotenv": "^16.4.5", diff --git a/src/claude/claude-kiro.js b/src/claude/claude-kiro.js index 5df9d87..2d25800 100644 --- a/src/claude/claude-kiro.js +++ b/src/claude/claude-kiro.js @@ -5,6 +5,7 @@ import * as path from 'path'; import * as os from 'os'; import * as crypto from 'crypto'; import { getProviderModels } from '../provider-models.js'; +import { countTokens } from '@anthropic-ai/tokenizer'; const KIRO_CONSTANTS = { REFRESH_URL: 'https://prod.{{region}}.auth.desktop.kiro.dev/refreshToken', @@ -936,11 +937,15 @@ async initializeAuth(forceRefresh = false) { const finalModel = MODEL_MAPPING[model] ? model : this.modelName; console.log(`[Kiro] Calling generateContent with model: ${finalModel}`); + + // Estimate input tokens before making the API call + const inputTokens = this.estimateInputTokens(requestBody); + const response = await this.callApi('', finalModel, requestBody); try { const { responseText, toolCalls } = this._processApiResponse(response); - return this.buildClaudeResponse(responseText, false, 'assistant', model, toolCalls); + return this.buildClaudeResponse(responseText, false, 'assistant', model, toolCalls, inputTokens); } catch (error) { console.error('[Kiro] Error in generateContent:', error); throw new Error(`Error processing response: ${error.message}`); @@ -971,13 +976,16 @@ async initializeAuth(forceRefresh = false) { const finalModel = MODEL_MAPPING[model] ? model : this.modelName; console.log(`[Kiro] Calling generateContentStream with model: ${finalModel}`); + // Estimate input tokens before making the API call + const inputTokens = this.estimateInputTokens(requestBody); + try { const response = await this.streamApi('', finalModel, requestBody); const { responseText, toolCalls } = this._processApiResponse(response); // Pass both responseText and toolCalls to buildClaudeResponse // buildClaudeResponse will handle the logic of combining them into a single stream - for (const chunkJson of this.buildClaudeResponse(responseText, true, 'assistant', model, toolCalls)) { + for (const chunkJson of this.buildClaudeResponse(responseText, true, 'assistant', model, toolCalls, inputTokens)) { yield chunkJson; } } catch (error) { @@ -991,13 +999,55 @@ async initializeAuth(forceRefresh = false) { } } + /** + * Count tokens for a given text using Claude's official tokenizer + */ + countTextTokens(text) { + if (!text) return 0; + try { + return countTokens(text); + } catch (error) { + // Fallback to estimation if tokenizer fails + console.warn('[Kiro] Tokenizer error, falling back to estimation:', error.message); + return Math.ceil((text || '').length / 4); + } + } + + /** + * Calculate input tokens from request body using Claude's official tokenizer + */ + estimateInputTokens(requestBody) { + let totalTokens = 0; + + // Count system prompt tokens + if (requestBody.system) { + const systemText = this.getContentText(requestBody.system); + totalTokens += this.countTextTokens(systemText); + } + + // Count all messages tokens + if (requestBody.messages && Array.isArray(requestBody.messages)) { + for (const message of requestBody.messages) { + if (message.content) { + const contentText = this.getContentText(message); + totalTokens += this.countTextTokens(contentText); + } + } + } + + // Count tools definitions tokens if present + if (requestBody.tools && Array.isArray(requestBody.tools)) { + totalTokens += this.countTextTokens(JSON.stringify(requestBody.tools)); + } + + return totalTokens; + } + /** * Build Claude compatible response object */ - buildClaudeResponse(content, isStream = false, role = 'assistant', model, toolCalls = null) { + buildClaudeResponse(content, isStream = false, role = 'assistant', model, toolCalls = null, inputTokens = 0) { const messageId = `${uuidv4()}`; - // Helper to estimate tokens (simple heuristic) - const estimateTokens = (text) => Math.ceil((text || '').length / 4); if (isStream) { // Kiro API is "pseudo-streaming", so we'll send a few events to simulate @@ -1013,7 +1063,7 @@ async initializeAuth(forceRefresh = false) { role: role, model: model, usage: { - input_tokens: 0, // Kiro API doesn't provide this + input_tokens: inputTokens, output_tokens: 0 // Will be updated in message_delta }, content: [] // Content will be streamed via content_block_delta @@ -1050,7 +1100,7 @@ async initializeAuth(forceRefresh = false) { type: "content_block_stop", index: contentBlockIndex }); - totalOutputTokens += estimateTokens(content); + totalOutputTokens += this.countTextTokens(content); // If there are tool calls, the stop reason remains "tool_use". // If only content, it's "end_turn". if (!toolCalls || toolCalls.length === 0) { @@ -1098,7 +1148,7 @@ async initializeAuth(forceRefresh = false) { type: "content_block_stop", index: index }); - totalOutputTokens += estimateTokens(JSON.stringify(inputObject)); + totalOutputTokens += this.countTextTokens(JSON.stringify(inputObject)); }); stopReason = "tool_use"; // If there are tool calls, the stop reason is tool_use } @@ -1143,7 +1193,7 @@ async initializeAuth(forceRefresh = false) { name: tc.function.name, input: inputObject }); - outputTokens += estimateTokens(tc.function.arguments); + outputTokens += this.countTextTokens(tc.function.arguments); } stopReason = "tool_use"; // Set stop_reason to "tool_use" when toolCalls exist } else if (content) { @@ -1151,7 +1201,7 @@ async initializeAuth(forceRefresh = false) { type: "text", text: content }); - outputTokens += estimateTokens(content); + outputTokens += this.countTextTokens(content); } return { @@ -1162,7 +1212,7 @@ async initializeAuth(forceRefresh = false) { stop_reason: stopReason, stop_sequence: null, usage: { - input_tokens: 0, // Kiro API doesn't provide this + input_tokens: inputTokens, output_tokens: outputTokens }, content: contentArray