diff --git a/src/adapter.js b/src/adapter.js index 8bff212..ed5d677 100644 --- a/src/adapter.js +++ b/src/adapter.js @@ -306,6 +306,15 @@ export class KiroApiServiceAdapter extends ApiServiceAdapter { } return this.kiroApiService.getUsageLimits(); } + + /** + * Count tokens for a message request (compatible with Anthropic API) + * @param {Object} requestBody - The request body containing model, messages, system, tools, etc. + * @returns {Object} { input_tokens: number } + */ + countTokens(requestBody) { + return this.kiroApiService.countTokens(requestBody); + } } // Qwen API 服务适配器 diff --git a/src/claude/claude-kiro.js b/src/claude/claude-kiro.js index 775644c..afec11f 100644 --- a/src/claude/claude-kiro.js +++ b/src/claude/claude-kiro.js @@ -1778,6 +1778,73 @@ async initializeAuth(forceRefresh = false) { } } + /** + * Count tokens for a message request (compatible with Anthropic API) + * POST /v1/messages/count_tokens + * @param {Object} requestBody - The request body containing model, messages, system, tools, etc. + * @returns {Object} { input_tokens: number } + */ + countTokens(requestBody) { + let totalTokens = 0; + + // Count system prompt tokens + if (requestBody.system) { + const systemText = this.getContentText(requestBody.system); + totalTokens += this.countTextTokens(systemText); + } + + // Count all messages tokens + if (requestBody.messages && Array.isArray(requestBody.messages)) { + for (const message of requestBody.messages) { + if (message.content) { + if (typeof message.content === 'string') { + totalTokens += this.countTextTokens(message.content); + } else if (Array.isArray(message.content)) { + for (const block of message.content) { + if (block.type === 'text' && block.text) { + totalTokens += this.countTextTokens(block.text); + } else if (block.type === 'tool_use') { + // Count tool use block tokens + totalTokens += this.countTextTokens(block.name || ''); + totalTokens += this.countTextTokens(JSON.stringify(block.input || {})); + } else if (block.type === 'tool_result') { + // Count tool result block tokens + const resultContent = this.getContentText(block.content); + totalTokens += this.countTextTokens(resultContent); + } else if (block.type === 'image') { + // Images have a fixed token cost (approximately 1600 tokens for a typical image) + // This is an estimation as actual cost depends on image size + totalTokens += 1600; + } else if (block.type === 'document') { + // Documents - estimate based on content if available + if (block.source?.data) { + // For base64 encoded documents, estimate tokens + const estimatedChars = block.source.data.length * 0.75; // base64 to bytes ratio + totalTokens += Math.ceil(estimatedChars / 4); + } + } + } + } + } + } + } + + // Count tools definitions tokens if present + if (requestBody.tools && Array.isArray(requestBody.tools)) { + for (const tool of requestBody.tools) { + // Count tool name and description + totalTokens += this.countTextTokens(tool.name || ''); + totalTokens += this.countTextTokens(tool.description || ''); + // Count input schema + if (tool.input_schema) { + totalTokens += this.countTextTokens(JSON.stringify(tool.input_schema)); + } + } + } + + return { input_tokens: totalTokens }; + } + /** * 获取用量限制信息 * @returns {Promise} 用量限制信息 diff --git a/src/request-handler.js b/src/request-handler.js index 9a11838..77b1e8c 100644 --- a/src/request-handler.js +++ b/src/request-handler.js @@ -8,6 +8,24 @@ import { MODEL_PROVIDER } from './common.js'; import { PROMPT_LOG_FILENAME } from './config-manager.js'; import { handleOllamaRequest, handleOllamaShow } from './ollama-handler.js'; +/** + * Parse request body as JSON + */ +function parseRequestBody(req) { + return new Promise((resolve, reject) => { + let body = ''; + req.on('data', chunk => { body += chunk.toString(); }); + req.on('end', () => { + try { + resolve(body ? JSON.parse(body) : {}); + } catch (e) { + reject(new Error('Invalid JSON in request body')); + } + }); + req.on('error', reject); + }); +} + /** * Main request handler. It authenticates the request, determines the endpoint type, * and delegates to the appropriate specialized handler function. @@ -97,16 +115,6 @@ export function createRequestHandler(config, providerPoolManager) { } } - // Ignore count_tokens requests - if (path.includes('/count_tokens')) { - console.log(`[Server] Ignoring count_tokens request: ${path}`); - res.writeHead(200, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ - tokens: 0, - message: 'Token counting is not supported' - })); - return true; - } // Handle API requests // Allow overriding MODEL_PROVIDER via request header @@ -154,6 +162,37 @@ export function createRequestHandler(config, providerPoolManager) { return; } + // Handle count_tokens requests (Anthropic API compatible) + if (path.includes('/count_tokens') && method === 'POST') { + try { + const body = await parseRequestBody(req); + console.log(`[Server] Handling count_tokens request for model: ${body.model}`); + + // Check if apiService has countTokens method + if (apiService && typeof apiService.countTokens === 'function') { + const result = apiService.countTokens(body); + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify(result)); + } else { + // Fallback: use estimateInputTokens if available + if (apiService && typeof apiService.estimateInputTokens === 'function') { + const inputTokens = apiService.estimateInputTokens(body); + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ input_tokens: inputTokens })); + } else { + // Last resort: return 0 with a message + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ input_tokens: 0 })); + } + } + return true; + } catch (error) { + console.error(`[Server] count_tokens error: ${error.message}`); + handleError(res, { statusCode: 500, message: `Failed to count tokens: ${error.message}` }); + return; + } + } + try { // Handle Ollama request (normalize path and route to appropriate endpoints) const { handled, normalizedPath } = await handleOllamaRequest(method, path, requestUrl, req, res, apiService, currentConfig, providerPoolManager);