Merge pull request #210 from tickernelz/feat/kiro-think-token-fix

feat(kiro): extended thinking support dan fix token counting
This commit is contained in:
何夕2077 2026-01-12 12:53:19 +08:00 committed by GitHub
commit 86014b217b

View file

@ -11,6 +11,15 @@ import { countTokens } from '@anthropic-ai/tokenizer';
import { configureAxiosProxy } from '../../utils/proxy-utils.js';
import { isRetryableNetworkError } from '../../utils/common.js';
const KIRO_THINKING = {
MAX_BUDGET_TOKENS: 24576,
DEFAULT_BUDGET_TOKENS: 20000,
START_TAG: '<thinking>',
END_TAG: '</thinking>',
MODE_TAG: '<thinking_mode>',
MAX_LEN_TAG: '<max_thinking_length>',
};
const KIRO_CONSTANTS = {
REFRESH_URL: 'https://prod.{{region}}.auth.desktop.kiro.dev/refreshToken',
REFRESH_IDC_URL: 'https://oidc.{{region}}.amazonaws.com/token',
@ -89,6 +98,28 @@ function getSystemRuntimeInfo() {
// Helper functions for tool calls and JSON parsing
function isQuoteCharAt(text, index) {
if (index < 0 || index >= text.length) return false;
const ch = text[index];
return ch === '"' || ch === "'" || ch === '`';
}
function findRealTag(text, tag, startIndex = 0) {
let searchStart = Math.max(0, startIndex);
while (true) {
const pos = text.indexOf(tag, searchStart);
if (pos === -1) return -1;
const hasQuoteBefore = isQuoteCharAt(text, pos - 1);
const hasQuoteAfter = isQuoteCharAt(text, pos + tag.length);
if (!hasQuoteBefore && !hasQuoteAfter) {
return pos;
}
searchStart = pos + 1;
}
}
/**
* 通用的括号匹配函数 - 支持多种括号类型
* @param {string} text - 要搜索的文本
@ -554,26 +585,85 @@ async initializeAuth(forceRefresh = false) {
if(message==null){
return "";
}
if (Array.isArray(message) ) {
return message
.filter(part => part.type === 'text' && part.text)
.map(part => part.text)
.join('');
if (Array.isArray(message)) {
return message.map(part => {
if (typeof part === 'string') return part;
if (part && typeof part === 'object') {
if (part.type === 'text' && part.text) return part.text;
if (part.text) return part.text;
}
return '';
}).join('');
} else if (typeof message.content === 'string') {
return message.content;
} else if (Array.isArray(message.content) ) {
return message.content
.filter(part => part.type === 'text' && part.text)
.map(part => part.text)
.join('');
}
} else if (Array.isArray(message.content)) {
return message.content.map(part => {
if (typeof part === 'string') return part;
if (part && typeof part === 'object') {
if (part.type === 'text' && part.text) return part.text;
if (part.text) return part.text;
}
return '';
}).join('');
}
return String(message.content || message);
}
_normalizeThinkingBudgetTokens(budgetTokens) {
let value = Number(budgetTokens);
if (!Number.isFinite(value) || value <= 0) {
value = KIRO_THINKING.DEFAULT_BUDGET_TOKENS;
}
value = Math.floor(value);
return Math.min(value, KIRO_THINKING.MAX_BUDGET_TOKENS);
}
_generateThinkingPrefix(thinking) {
if (!thinking || thinking.type !== 'enabled') return null;
const budget = this._normalizeThinkingBudgetTokens(thinking.budget_tokens);
return `<thinking_mode>enabled</thinking_mode><max_thinking_length>${budget}</max_thinking_length>`;
}
_hasThinkingPrefix(text) {
if (!text) return false;
return text.includes(KIRO_THINKING.MODE_TAG) || text.includes(KIRO_THINKING.MAX_LEN_TAG);
}
_toClaudeContentBlocksFromKiroText(content) {
const raw = content ?? '';
if (!raw) return [];
const startPos = findRealTag(raw, KIRO_THINKING.START_TAG);
if (startPos === -1) {
return [{ type: "text", text: raw }];
}
const before = raw.slice(0, startPos);
let rest = raw.slice(startPos + KIRO_THINKING.START_TAG.length);
const endPosInRest = findRealTag(rest, KIRO_THINKING.END_TAG);
let thinking = '';
let after = '';
if (endPosInRest === -1) {
thinking = rest;
} else {
thinking = rest.slice(0, endPosInRest);
after = rest.slice(endPosInRest + KIRO_THINKING.END_TAG.length);
}
if (after.startsWith('\n\n')) after = after.slice(2);
const blocks = [];
if (before) blocks.push({ type: "text", text: before });
blocks.push({ type: "thinking", thinking });
if (after) blocks.push({ type: "text", text: after });
return blocks;
}
/**
* Build CodeWhisperer request from OpenAI messages
*/
buildCodewhispererRequest(messages, model, tools = null, inSystemPrompt = null) {
buildCodewhispererRequest(messages, model, tools = null, inSystemPrompt = null, thinking = null) {
const conversationId = uuidv4();
let systemPrompt = this.getContentText(inSystemPrompt);
@ -583,6 +673,15 @@ async initializeAuth(forceRefresh = false) {
throw new Error('No user messages found');
}
const thinkingPrefix = this._generateThinkingPrefix(thinking);
if (thinkingPrefix) {
if (!systemPrompt) {
systemPrompt = thinkingPrefix;
} else if (!this._hasThinkingPrefix(systemPrompt)) {
systemPrompt = `${thinkingPrefix}\n${systemPrompt}`;
}
}
// 判断最后一条消息是否为 assistant,如果是则移除
const lastMessage = processedMessages[processedMessages.length - 1];
if (processedMessages.length > 0 && lastMessage.role === 'assistant') {
@ -833,11 +932,14 @@ async initializeAuth(forceRefresh = false) {
content: ''
};
let toolUses = [];
let thinkingText = '';
if (Array.isArray(message.content)) {
for (const part of message.content) {
if (part.type === 'text') {
assistantResponseMessage.content += part.text;
} else if (part.type === 'thinking') {
thinkingText += (part.thinking ?? part.text ?? '');
} else if (part.type === 'tool_use') {
toolUses.push({
input: part.input,
@ -850,6 +952,12 @@ async initializeAuth(forceRefresh = false) {
assistantResponseMessage.content = this.getContentText(message);
}
if (thinkingText) {
assistantResponseMessage.content = assistantResponseMessage.content
? `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}\n\n${assistantResponseMessage.content}`
: `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}`;
}
// 只添加非空字段
if (toolUses.length > 0) {
assistantResponseMessage.toolUses = toolUses;
@ -876,10 +984,13 @@ async initializeAuth(forceRefresh = false) {
content: '',
toolUses: []
};
let thinkingText = '';
if (Array.isArray(currentMessage.content)) {
for (const part of currentMessage.content) {
if (part.type === 'text') {
assistantResponseMessage.content += part.text;
} else if (part.type === 'thinking') {
thinkingText += (part.thinking ?? part.text ?? '');
} else if (part.type === 'tool_use') {
assistantResponseMessage.toolUses.push({
input: part.input,
@ -891,6 +1002,11 @@ async initializeAuth(forceRefresh = false) {
} else {
assistantResponseMessage.content = this.getContentText(currentMessage);
}
if (thinkingText) {
assistantResponseMessage.content = assistantResponseMessage.content
? `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}\n\n${assistantResponseMessage.content}`
: `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}`;
}
if (assistantResponseMessage.toolUses.length === 0) {
delete assistantResponseMessage.toolUses;
}
@ -1115,9 +1231,9 @@ async initializeAuth(forceRefresh = false) {
async callApi(method, model, body, isRetry = false, retryCount = 0) {
if (!this.isInitialized) await this.initialize();
const maxRetries = this.config.REQUEST_MAX_RETRIES || 3;
const baseDelay = this.config.REQUEST_BASE_DELAY || 1000; // 1 second base delay
const baseDelay = this.config.REQUEST_BASE_DELAY || 1000;
const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system);
const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system, body.thinking);
try {
const token = this.accessToken; // Use the already initialized token
@ -1401,7 +1517,7 @@ async initializeAuth(forceRefresh = false) {
const maxRetries = this.config.REQUEST_MAX_RETRIES || 3;
const baseDelay = this.config.REQUEST_BASE_DELAY || 1000;
const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system);
const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system, body.thinking);
const token = this.accessToken;
const headers = {
@ -1529,11 +1645,93 @@ async initializeAuth(forceRefresh = false) {
const finalModel = MODEL_MAPPING[model] ? model : this.modelName;
console.log(`[Kiro] Calling generateContentStream with model: ${finalModel} (real streaming)`);
const inputTokens = this.estimateInputTokens(requestBody);
let inputTokens = 0;
let contextUsagePercentage = null;
const messageId = `${uuidv4()}`;
const thinkingRequested = requestBody?.thinking?.type === 'enabled';
const streamState = {
thinkingRequested,
buffer: '',
inThinking: false,
thinkingExtracted: false,
thinkingBlockIndex: null,
textBlockIndex: null,
nextBlockIndex: 0,
stoppedBlocks: new Set(),
};
const ensureBlockStart = (blockType) => {
if (blockType === 'thinking') {
if (streamState.thinkingBlockIndex != null) return [];
const idx = streamState.nextBlockIndex++;
streamState.thinkingBlockIndex = idx;
return [{
type: "content_block_start",
index: idx,
content_block: { type: "thinking", thinking: "" }
}];
}
if (blockType === 'text') {
if (streamState.textBlockIndex != null) return [];
const idx = streamState.nextBlockIndex++;
streamState.textBlockIndex = idx;
return [{
type: "content_block_start",
index: idx,
content_block: { type: "text", text: "" }
}];
}
return [];
};
const stopBlock = (index) => {
if (index == null) return [];
if (streamState.stoppedBlocks.has(index)) return [];
streamState.stoppedBlocks.add(index);
return [{ type: "content_block_stop", index }];
};
const createTextDeltaEvents = (text) => {
if (!text) return [];
const events = [];
events.push(...ensureBlockStart('text'));
events.push({
type: "content_block_delta",
index: streamState.textBlockIndex,
delta: { type: "text_delta", text }
});
return events;
};
const createThinkingDeltaEvents = (thinking) => {
const events = [];
events.push(...ensureBlockStart('thinking'));
events.push({
type: "content_block_delta",
index: streamState.thinkingBlockIndex,
delta: { type: "thinking_delta", thinking }
});
return events;
};
function* pushEvents(events) {
for (const ev of events) {
yield ev;
}
}
try {
let totalContent = '';
let outputTokens = 0;
const toolCalls = [];
let currentToolCall = null; // 用于累积结构化工具调用
const estimatedInputTokens = this.estimateInputTokens(requestBody);
const tokenBreakdown = this._lastTokenBreakdown || {};
// 1. 先发送 message_start 事件
yield {
type: "message_start",
@ -1542,39 +1740,104 @@ async initializeAuth(forceRefresh = false) {
type: "message",
role: "assistant",
model: model,
usage: { input_tokens: inputTokens, output_tokens: 0 },
usage: {
input_tokens: estimatedInputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0
},
content: []
}
};
// 2. 发送 content_block_start 事件
yield {
type: "content_block_start",
index: 0,
content_block: { type: "text", text: "" }
};
let totalContent = '';
let outputTokens = 0;
const toolCalls = [];
let currentToolCall = null; // 用于累积结构化工具调用
let contextUsagePercentage = null; // 用于存储上下文使用百分比
// 3. 流式接收并发送每个 content_block_delta
// 2. 流式接收并发送每个 content_block_delta
for await (const event of this.streamApiReal('', finalModel, requestBody)) {
if (event.type === 'content' && event.content) {
totalContent += event.content;
// 不再每个 chunk 都计算 token改为最后统一计算避免阻塞事件循环
yield {
type: "content_block_delta",
index: 0,
delta: { type: "text_delta", text: event.content }
};
} else if (event.type === 'contextUsage') {
if (event.type === 'contextUsage' && event.percentage) {
// 捕获上下文使用百分比
contextUsagePercentage = event.contextUsagePercentage;
console.log(`[Kiro] Received contextUsagePercentage: ${contextUsagePercentage}%`);
contextUsagePercentage = event.percentage;
inputTokens = this.calculateInputTokensFromPercentage(contextUsagePercentage);
if (Math.abs(inputTokens - estimatedInputTokens) > estimatedInputTokens * 0.1) {
yield {
type: "message_delta",
delta: {},
usage: {
input_tokens: inputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0
}
};
}
} else if (event.type === 'content' && event.content) {
totalContent += event.content;
if (!thinkingRequested) {
yield* pushEvents(createTextDeltaEvents(event.content));
continue;
}
streamState.buffer += event.content;
const events = [];
while (streamState.buffer.length > 0) {
if (!streamState.inThinking && !streamState.thinkingExtracted) {
const startPos = findRealTag(streamState.buffer, KIRO_THINKING.START_TAG);
if (startPos !== -1) {
const before = streamState.buffer.slice(0, startPos);
if (before) events.push(...createTextDeltaEvents(before));
streamState.buffer = streamState.buffer.slice(startPos + KIRO_THINKING.START_TAG.length);
streamState.inThinking = true;
continue;
}
const safeLen = Math.max(0, streamState.buffer.length - KIRO_THINKING.START_TAG.length);
if (safeLen > 0) {
const safeText = streamState.buffer.slice(0, safeLen);
if (safeText) events.push(...createTextDeltaEvents(safeText));
streamState.buffer = streamState.buffer.slice(safeLen);
}
break;
}
if (streamState.inThinking) {
const endPos = findRealTag(streamState.buffer, KIRO_THINKING.END_TAG);
if (endPos !== -1) {
const thinkingPart = streamState.buffer.slice(0, endPos);
if (thinkingPart) events.push(...createThinkingDeltaEvents(thinkingPart));
streamState.buffer = streamState.buffer.slice(endPos + KIRO_THINKING.END_TAG.length);
streamState.inThinking = false;
streamState.thinkingExtracted = true;
events.push(...createThinkingDeltaEvents(""));
events.push(...stopBlock(streamState.thinkingBlockIndex));
if (streamState.buffer.startsWith('\n\n')) {
streamState.buffer = streamState.buffer.slice(2);
}
continue;
}
const safeLen = Math.max(0, streamState.buffer.length - KIRO_THINKING.END_TAG.length);
if (safeLen > 0) {
const safeThinking = streamState.buffer.slice(0, safeLen);
if (safeThinking) events.push(...createThinkingDeltaEvents(safeThinking));
streamState.buffer = streamState.buffer.slice(safeLen);
}
break;
}
if (streamState.thinkingExtracted) {
const rest = streamState.buffer;
streamState.buffer = '';
if (rest) events.push(...createTextDeltaEvents(rest));
break;
}
}
yield* pushEvents(events);
} else if (event.type === 'toolUse') {
const tc = event.toolUse;
// 统计工具调用的内容到 totalContent用于 token 计算)
@ -1648,7 +1911,30 @@ async initializeAuth(forceRefresh = false) {
toolCalls.push(currentToolCall);
currentToolCall = null;
}
if (thinkingRequested && streamState.buffer) {
if (streamState.inThinking) {
console.warn('[Kiro] Incomplete thinking tag at stream end');
yield* pushEvents(createThinkingDeltaEvents(streamState.buffer));
streamState.buffer = '';
yield* pushEvents(createThinkingDeltaEvents(""));
yield* pushEvents(stopBlock(streamState.thinkingBlockIndex));
} else if (!streamState.thinkingExtracted) {
yield* pushEvents(createTextDeltaEvents(streamState.buffer));
streamState.buffer = '';
} else {
yield* pushEvents(createTextDeltaEvents(streamState.buffer));
streamState.buffer = '';
}
}
yield* pushEvents(stopBlock(streamState.textBlockIndex));
if (contextUsagePercentage === null) {
console.warn('[Kiro Stream] contextUsagePercentage not received, using estimation');
inputTokens = estimatedInputTokens;
}
// 检查文本内容中的 bracket 格式工具调用
const bracketToolCalls = parseBracketToolCalls(totalContent);
if (bracketToolCalls && bracketToolCalls.length > 0) {
@ -1661,15 +1947,13 @@ async initializeAuth(forceRefresh = false) {
}
}
// 4. 发送 content_block_stop 事件
yield { type: "content_block_stop", index: 0 };
// 5. 处理工具调用(如果有)
// 3. 处理工具调用(如果有)
if (toolCalls.length > 0) {
const baseIndex = streamState.nextBlockIndex;
for (let i = 0; i < toolCalls.length; i++) {
const tc = toolCalls[i];
const blockIndex = i + 1;
const blockIndex = baseIndex + i;
yield {
type: "content_block_start",
index: blockIndex,
@ -1694,32 +1978,31 @@ async initializeAuth(forceRefresh = false) {
}
}
// 6. 发送 message_delta 事件
// 如果有 contextUsagePercentage使用它来计算 token
// 总上下文 200k tokens通过百分比计算总使用量再减去输入 token 得到输出 token
let totalTokens = 0;
if (contextUsagePercentage !== null && contextUsagePercentage > 0 && true) {
const totalContextTokens = KIRO_CONSTANTS.TOTAL_CONTEXT_TOKENS;
// totalUsedTokens 就是通过百分比计算出的总使用量,直接作为 total_tokens
totalTokens = Math.round(totalContextTokens * contextUsagePercentage / 100);
outputTokens = Math.max(0, totalTokens - inputTokens);
console.log(`[Kiro] Token calculation from contextUsagePercentage: total=${totalTokens}, input=${inputTokens}, output=${outputTokens}`);
} else {
// 回退到原来的计算方式
outputTokens = this.countTextTokens(totalContent);
for (const tc of toolCalls) {
outputTokens += this.countTextTokens(JSON.stringify(tc.input || {}));
}
totalTokens = inputTokens + outputTokens;
const contentBlocksForCount = thinkingRequested
? this._toClaudeContentBlocksFromKiroText(totalContent)
: [{ type: "text", text: totalContent }];
const plainForCount = contentBlocksForCount
.map(b => (b.type === 'thinking' ? (b.thinking ?? '') : (b.text ?? '')))
.join('');
outputTokens = this.countTextTokens(plainForCount);
for (const tc of toolCalls) {
outputTokens += this.countTextTokens(JSON.stringify(tc.input || {}));
}
// 4. 发送 message_delta 事件
yield {
type: "message_delta",
delta: { stop_reason: toolCalls.length > 0 ? "tool_use" : "end_turn" },
usage: { input_tokens: inputTokens, output_tokens: outputTokens, total_tokens: totalTokens }
usage: {
input_tokens: inputTokens,
output_tokens: outputTokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0
}
};
// 7. 发送 message_stop 事件
// 5. 发送 message_stop 事件
yield { type: "message_stop" };
} catch (error) {
@ -1728,9 +2011,6 @@ async initializeAuth(forceRefresh = false) {
}
}
/**
* Count tokens for a given text using Claude's official tokenizer
*/
countTextTokens(text) {
if (!text) return 0;
try {
@ -1748,28 +2028,117 @@ async initializeAuth(forceRefresh = false) {
estimateInputTokens(requestBody) {
let totalTokens = 0;
// 定义各类内容的开销乘数
const OVERHEAD_MULTIPLIERS = {
system: 1.0,
message: 1.0,
tools: 1.0,
thinking: 1.0,
tool_result: 1.0,
tool_use_input: 1.0,
image: 1500
};
const breakdown = {
system: 0,
thinking: 0,
text: 0,
tool_result: 0,
tool_use_input: 0,
image: 0,
thinking_content: 0,
tools_def: 0
};
// Count system prompt tokens
if (requestBody.system) {
const systemText = this.getContentText(requestBody.system);
totalTokens += this.countTextTokens(systemText);
const systemTokens = this.countTextTokens(systemText);
const counted = Math.ceil(systemTokens * OVERHEAD_MULTIPLIERS.system);
breakdown.system = counted;
totalTokens += counted;
}
if (requestBody.thinking?.type === 'enabled') {
const budget = this._normalizeThinkingBudgetTokens(requestBody.thinking.budget_tokens);
const prefixText = `<thinking_mode>enabled</thinking_mode><max_thinking_length>${budget}</max_thinking_length>`;
const prefixTokens = this.countTextTokens(prefixText);
const counted = Math.ceil(prefixTokens * OVERHEAD_MULTIPLIERS.thinking);
breakdown.thinking = counted;
totalTokens += counted;
}
// Count all messages tokens
if (requestBody.messages && Array.isArray(requestBody.messages)) {
for (const message of requestBody.messages) {
if (message.content) {
const contentText = this.getContentText(message);
totalTokens += this.countTextTokens(contentText);
if (!message.content) {
continue;
}
if (Array.isArray(message.content)) {
for (const part of message.content) {
if (part.type === 'text' && part.text) {
const counted = Math.ceil(this.countTextTokens(part.text) * OVERHEAD_MULTIPLIERS.message);
breakdown.text += counted;
totalTokens += counted;
}
else if (part.type === 'tool_result') {
const toolResultText = this.getContentText(part.content);
const counted = Math.ceil(this.countTextTokens(toolResultText) * OVERHEAD_MULTIPLIERS.tool_result);
breakdown.tool_result += counted;
totalTokens += counted;
}
else if (part.type === 'tool_use' && part.input) {
const inputJson = JSON.stringify(part.input);
const counted = Math.ceil(this.countTextTokens(inputJson) * OVERHEAD_MULTIPLIERS.tool_use_input);
breakdown.tool_use_input += counted;
totalTokens += counted;
}
else if (part.type === 'image') {
breakdown.image += OVERHEAD_MULTIPLIERS.image;
totalTokens += OVERHEAD_MULTIPLIERS.image;
}
else if (part.type === 'thinking' && part.thinking) {
const counted = Math.ceil(this.countTextTokens(part.thinking) * OVERHEAD_MULTIPLIERS.message);
breakdown.thinking_content += counted;
totalTokens += counted;
}
}
}
else if (typeof message.content === 'string') {
const counted = Math.ceil(this.countTextTokens(message.content) * OVERHEAD_MULTIPLIERS.message);
breakdown.text += counted;
totalTokens += counted;
}
}
}
// Count tools definitions tokens if present
if (requestBody.tools && Array.isArray(requestBody.tools)) {
totalTokens += this.countTextTokens(JSON.stringify(requestBody.tools));
for (const tool of requestBody.tools) {
const toolJson = JSON.stringify(tool);
const toolTokens = this.countTextTokens(toolJson);
const counted = Math.ceil(toolTokens * OVERHEAD_MULTIPLIERS.tools);
breakdown.tools_def += counted;
totalTokens += counted;
}
}
const hasTools = requestBody.tools && requestBody.tools.length > 0;
const toolsDefTokens = breakdown.tools_def || 0;
const isSmallToolsDef = toolsDefTokens > 0 && toolsDefTokens < 21000;
const KIRO_BASE_OVERHEAD = 400;
const KIRO_PERCENTAGE_OVERHEAD = hasTools
? (isSmallToolsDef ? 0.18 : 0.08)
: 0.25;
const baseOverhead = KIRO_BASE_OVERHEAD;
const percentageOverhead = Math.ceil(totalTokens * KIRO_PERCENTAGE_OVERHEAD);
totalTokens += baseOverhead + percentageOverhead;
return totalTokens;
this._lastTokenBreakdown = breakdown;
return Math.ceil(totalTokens);
}
/**