Merge pull request #142 from clansty/feature/count-tokens-api

feat: 实现 Anthropic 兼容的 count_tokens API
This commit is contained in:
何夕2077 2025-12-28 16:42:27 +08:00 committed by GitHub
commit ecdf7f31e7
3 changed files with 125 additions and 10 deletions

View file

@ -306,6 +306,15 @@ export class KiroApiServiceAdapter extends ApiServiceAdapter {
}
return this.kiroApiService.getUsageLimits();
}
/**
* Count tokens for a message request (compatible with Anthropic API)
* @param {Object} requestBody - The request body containing model, messages, system, tools, etc.
* @returns {Object} { input_tokens: number }
*/
countTokens(requestBody) {
return this.kiroApiService.countTokens(requestBody);
}
}
// Qwen API 服务适配器

View file

@ -1778,6 +1778,73 @@ async initializeAuth(forceRefresh = false) {
}
}
/**
* Count tokens for a message request (compatible with Anthropic API)
* POST /v1/messages/count_tokens
* @param {Object} requestBody - The request body containing model, messages, system, tools, etc.
* @returns {Object} { input_tokens: number }
*/
countTokens(requestBody) {
let totalTokens = 0;
// Count system prompt tokens
if (requestBody.system) {
const systemText = this.getContentText(requestBody.system);
totalTokens += this.countTextTokens(systemText);
}
// Count all messages tokens
if (requestBody.messages && Array.isArray(requestBody.messages)) {
for (const message of requestBody.messages) {
if (message.content) {
if (typeof message.content === 'string') {
totalTokens += this.countTextTokens(message.content);
} else if (Array.isArray(message.content)) {
for (const block of message.content) {
if (block.type === 'text' && block.text) {
totalTokens += this.countTextTokens(block.text);
} else if (block.type === 'tool_use') {
// Count tool use block tokens
totalTokens += this.countTextTokens(block.name || '');
totalTokens += this.countTextTokens(JSON.stringify(block.input || {}));
} else if (block.type === 'tool_result') {
// Count tool result block tokens
const resultContent = this.getContentText(block.content);
totalTokens += this.countTextTokens(resultContent);
} else if (block.type === 'image') {
// Images have a fixed token cost (approximately 1600 tokens for a typical image)
// This is an estimation as actual cost depends on image size
totalTokens += 1600;
} else if (block.type === 'document') {
// Documents - estimate based on content if available
if (block.source?.data) {
// For base64 encoded documents, estimate tokens
const estimatedChars = block.source.data.length * 0.75; // base64 to bytes ratio
totalTokens += Math.ceil(estimatedChars / 4);
}
}
}
}
}
}
}
// Count tools definitions tokens if present
if (requestBody.tools && Array.isArray(requestBody.tools)) {
for (const tool of requestBody.tools) {
// Count tool name and description
totalTokens += this.countTextTokens(tool.name || '');
totalTokens += this.countTextTokens(tool.description || '');
// Count input schema
if (tool.input_schema) {
totalTokens += this.countTextTokens(JSON.stringify(tool.input_schema));
}
}
}
return { input_tokens: totalTokens };
}
/**
* 获取用量限制信息
* @returns {Promise<Object>} 用量限制信息

View file

@ -8,6 +8,24 @@ import { MODEL_PROVIDER } from './common.js';
import { PROMPT_LOG_FILENAME } from './config-manager.js';
import { handleOllamaRequest, handleOllamaShow } from './ollama-handler.js';
/**
* Parse request body as JSON
*/
function parseRequestBody(req) {
return new Promise((resolve, reject) => {
let body = '';
req.on('data', chunk => { body += chunk.toString(); });
req.on('end', () => {
try {
resolve(body ? JSON.parse(body) : {});
} catch (e) {
reject(new Error('Invalid JSON in request body'));
}
});
req.on('error', reject);
});
}
/**
* Main request handler. It authenticates the request, determines the endpoint type,
* and delegates to the appropriate specialized handler function.
@ -97,16 +115,6 @@ export function createRequestHandler(config, providerPoolManager) {
}
}
// Ignore count_tokens requests
if (path.includes('/count_tokens')) {
console.log(`[Server] Ignoring count_tokens request: ${path}`);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({
tokens: 0,
message: 'Token counting is not supported'
}));
return true;
}
// Handle API requests
// Allow overriding MODEL_PROVIDER via request header
@ -154,6 +162,37 @@ export function createRequestHandler(config, providerPoolManager) {
return;
}
// Handle count_tokens requests (Anthropic API compatible)
if (path.includes('/count_tokens') && method === 'POST') {
try {
const body = await parseRequestBody(req);
console.log(`[Server] Handling count_tokens request for model: ${body.model}`);
// Check if apiService has countTokens method
if (apiService && typeof apiService.countTokens === 'function') {
const result = apiService.countTokens(body);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(result));
} else {
// Fallback: use estimateInputTokens if available
if (apiService && typeof apiService.estimateInputTokens === 'function') {
const inputTokens = apiService.estimateInputTokens(body);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ input_tokens: inputTokens }));
} else {
// Last resort: return 0 with a message
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ input_tokens: 0 }));
}
}
return true;
} catch (error) {
console.error(`[Server] count_tokens error: ${error.message}`);
handleError(res, { statusCode: 500, message: `Failed to count tokens: ${error.message}` });
return;
}
}
try {
// Handle Ollama request (normalize path and route to appropriate endpoints)
const { handled, normalizedPath } = await handleOllamaRequest(method, path, requestUrl, req, res, apiService, currentConfig, providerPoolManager);