diff --git a/src/providers/claude/claude-kiro.js b/src/providers/claude/claude-kiro.js index 69d6dee..9ee99fc 100644 --- a/src/providers/claude/claude-kiro.js +++ b/src/providers/claude/claude-kiro.js @@ -11,6 +11,15 @@ import { countTokens } from '@anthropic-ai/tokenizer'; import { configureAxiosProxy } from '../../utils/proxy-utils.js'; import { isRetryableNetworkError } from '../../utils/common.js'; +const KIRO_THINKING = { + MAX_BUDGET_TOKENS: 24576, + DEFAULT_BUDGET_TOKENS: 20000, + START_TAG: '', + END_TAG: '', + MODE_TAG: '', + MAX_LEN_TAG: '', +}; + const KIRO_CONSTANTS = { REFRESH_URL: 'https://prod.{{region}}.auth.desktop.kiro.dev/refreshToken', REFRESH_IDC_URL: 'https://oidc.{{region}}.amazonaws.com/token', @@ -89,6 +98,28 @@ function getSystemRuntimeInfo() { // Helper functions for tool calls and JSON parsing +function isQuoteCharAt(text, index) { + if (index < 0 || index >= text.length) return false; + const ch = text[index]; + return ch === '"' || ch === "'" || ch === '`'; +} + +function findRealTag(text, tag, startIndex = 0) { + let searchStart = Math.max(0, startIndex); + while (true) { + const pos = text.indexOf(tag, searchStart); + if (pos === -1) return -1; + + const hasQuoteBefore = isQuoteCharAt(text, pos - 1); + const hasQuoteAfter = isQuoteCharAt(text, pos + tag.length); + if (!hasQuoteBefore && !hasQuoteAfter) { + return pos; + } + + searchStart = pos + 1; + } +} + /** * 通用的括号匹配函数 - 支持多种括号类型 * @param {string} text - 要搜索的文本 @@ -554,26 +585,85 @@ async initializeAuth(forceRefresh = false) { if(message==null){ return ""; } - if (Array.isArray(message) ) { - return message - .filter(part => part.type === 'text' && part.text) - .map(part => part.text) - .join(''); + if (Array.isArray(message)) { + return message.map(part => { + if (typeof part === 'string') return part; + if (part && typeof part === 'object') { + if (part.type === 'text' && part.text) return part.text; + if (part.text) return part.text; + } + return ''; + }).join(''); } else if (typeof message.content === 'string') { return message.content; - } else if (Array.isArray(message.content) ) { - return message.content - .filter(part => part.type === 'text' && part.text) - .map(part => part.text) - .join(''); - } + } else if (Array.isArray(message.content)) { + return message.content.map(part => { + if (typeof part === 'string') return part; + if (part && typeof part === 'object') { + if (part.type === 'text' && part.text) return part.text; + if (part.text) return part.text; + } + return ''; + }).join(''); + } return String(message.content || message); } + _normalizeThinkingBudgetTokens(budgetTokens) { + let value = Number(budgetTokens); + if (!Number.isFinite(value) || value <= 0) { + value = KIRO_THINKING.DEFAULT_BUDGET_TOKENS; + } + value = Math.floor(value); + return Math.min(value, KIRO_THINKING.MAX_BUDGET_TOKENS); + } + + _generateThinkingPrefix(thinking) { + if (!thinking || thinking.type !== 'enabled') return null; + const budget = this._normalizeThinkingBudgetTokens(thinking.budget_tokens); + return `enabled${budget}`; + } + + _hasThinkingPrefix(text) { + if (!text) return false; + return text.includes(KIRO_THINKING.MODE_TAG) || text.includes(KIRO_THINKING.MAX_LEN_TAG); + } + + _toClaudeContentBlocksFromKiroText(content) { + const raw = content ?? ''; + if (!raw) return []; + + const startPos = findRealTag(raw, KIRO_THINKING.START_TAG); + if (startPos === -1) { + return [{ type: "text", text: raw }]; + } + + const before = raw.slice(0, startPos); + let rest = raw.slice(startPos + KIRO_THINKING.START_TAG.length); + + const endPosInRest = findRealTag(rest, KIRO_THINKING.END_TAG); + let thinking = ''; + let after = ''; + if (endPosInRest === -1) { + thinking = rest; + } else { + thinking = rest.slice(0, endPosInRest); + after = rest.slice(endPosInRest + KIRO_THINKING.END_TAG.length); + } + + if (after.startsWith('\n\n')) after = after.slice(2); + + const blocks = []; + if (before) blocks.push({ type: "text", text: before }); + blocks.push({ type: "thinking", thinking }); + if (after) blocks.push({ type: "text", text: after }); + return blocks; + } + /** * Build CodeWhisperer request from OpenAI messages */ - buildCodewhispererRequest(messages, model, tools = null, inSystemPrompt = null) { + buildCodewhispererRequest(messages, model, tools = null, inSystemPrompt = null, thinking = null) { const conversationId = uuidv4(); let systemPrompt = this.getContentText(inSystemPrompt); @@ -583,6 +673,15 @@ async initializeAuth(forceRefresh = false) { throw new Error('No user messages found'); } + const thinkingPrefix = this._generateThinkingPrefix(thinking); + if (thinkingPrefix) { + if (!systemPrompt) { + systemPrompt = thinkingPrefix; + } else if (!this._hasThinkingPrefix(systemPrompt)) { + systemPrompt = `${thinkingPrefix}\n${systemPrompt}`; + } + } + // 判断最后一条消息是否为 assistant,如果是则移除 const lastMessage = processedMessages[processedMessages.length - 1]; if (processedMessages.length > 0 && lastMessage.role === 'assistant') { @@ -833,11 +932,14 @@ async initializeAuth(forceRefresh = false) { content: '' }; let toolUses = []; + let thinkingText = ''; if (Array.isArray(message.content)) { for (const part of message.content) { if (part.type === 'text') { assistantResponseMessage.content += part.text; + } else if (part.type === 'thinking') { + thinkingText += (part.thinking ?? part.text ?? ''); } else if (part.type === 'tool_use') { toolUses.push({ input: part.input, @@ -850,6 +952,12 @@ async initializeAuth(forceRefresh = false) { assistantResponseMessage.content = this.getContentText(message); } + if (thinkingText) { + assistantResponseMessage.content = assistantResponseMessage.content + ? `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}\n\n${assistantResponseMessage.content}` + : `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}`; + } + // 只添加非空字段 if (toolUses.length > 0) { assistantResponseMessage.toolUses = toolUses; @@ -876,10 +984,13 @@ async initializeAuth(forceRefresh = false) { content: '', toolUses: [] }; + let thinkingText = ''; if (Array.isArray(currentMessage.content)) { for (const part of currentMessage.content) { if (part.type === 'text') { assistantResponseMessage.content += part.text; + } else if (part.type === 'thinking') { + thinkingText += (part.thinking ?? part.text ?? ''); } else if (part.type === 'tool_use') { assistantResponseMessage.toolUses.push({ input: part.input, @@ -891,6 +1002,11 @@ async initializeAuth(forceRefresh = false) { } else { assistantResponseMessage.content = this.getContentText(currentMessage); } + if (thinkingText) { + assistantResponseMessage.content = assistantResponseMessage.content + ? `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}\n\n${assistantResponseMessage.content}` + : `${KIRO_THINKING.START_TAG}${thinkingText}${KIRO_THINKING.END_TAG}`; + } if (assistantResponseMessage.toolUses.length === 0) { delete assistantResponseMessage.toolUses; } @@ -1115,9 +1231,9 @@ async initializeAuth(forceRefresh = false) { async callApi(method, model, body, isRetry = false, retryCount = 0) { if (!this.isInitialized) await this.initialize(); const maxRetries = this.config.REQUEST_MAX_RETRIES || 3; - const baseDelay = this.config.REQUEST_BASE_DELAY || 1000; // 1 second base delay + const baseDelay = this.config.REQUEST_BASE_DELAY || 1000; - const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system); + const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system, body.thinking); try { const token = this.accessToken; // Use the already initialized token @@ -1401,7 +1517,7 @@ async initializeAuth(forceRefresh = false) { const maxRetries = this.config.REQUEST_MAX_RETRIES || 3; const baseDelay = this.config.REQUEST_BASE_DELAY || 1000; - const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system); + const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system, body.thinking); const token = this.accessToken; const headers = { @@ -1529,11 +1645,93 @@ async initializeAuth(forceRefresh = false) { const finalModel = MODEL_MAPPING[model] ? model : this.modelName; console.log(`[Kiro] Calling generateContentStream with model: ${finalModel} (real streaming)`); - - const inputTokens = this.estimateInputTokens(requestBody); + + let inputTokens = 0; + let contextUsagePercentage = null; const messageId = `${uuidv4()}`; - + + const thinkingRequested = requestBody?.thinking?.type === 'enabled'; + + const streamState = { + thinkingRequested, + buffer: '', + inThinking: false, + thinkingExtracted: false, + thinkingBlockIndex: null, + textBlockIndex: null, + nextBlockIndex: 0, + stoppedBlocks: new Set(), + }; + + const ensureBlockStart = (blockType) => { + if (blockType === 'thinking') { + if (streamState.thinkingBlockIndex != null) return []; + const idx = streamState.nextBlockIndex++; + streamState.thinkingBlockIndex = idx; + return [{ + type: "content_block_start", + index: idx, + content_block: { type: "thinking", thinking: "" } + }]; + } + if (blockType === 'text') { + if (streamState.textBlockIndex != null) return []; + const idx = streamState.nextBlockIndex++; + streamState.textBlockIndex = idx; + return [{ + type: "content_block_start", + index: idx, + content_block: { type: "text", text: "" } + }]; + } + return []; + }; + + const stopBlock = (index) => { + if (index == null) return []; + if (streamState.stoppedBlocks.has(index)) return []; + streamState.stoppedBlocks.add(index); + return [{ type: "content_block_stop", index }]; + }; + + const createTextDeltaEvents = (text) => { + if (!text) return []; + const events = []; + events.push(...ensureBlockStart('text')); + events.push({ + type: "content_block_delta", + index: streamState.textBlockIndex, + delta: { type: "text_delta", text } + }); + return events; + }; + + const createThinkingDeltaEvents = (thinking) => { + const events = []; + events.push(...ensureBlockStart('thinking')); + events.push({ + type: "content_block_delta", + index: streamState.thinkingBlockIndex, + delta: { type: "thinking_delta", thinking } + }); + return events; + }; + + function* pushEvents(events) { + for (const ev of events) { + yield ev; + } + } + try { + let totalContent = ''; + let outputTokens = 0; + const toolCalls = []; + let currentToolCall = null; // 用于累积结构化工具调用 + + const estimatedInputTokens = this.estimateInputTokens(requestBody); + const tokenBreakdown = this._lastTokenBreakdown || {}; + // 1. 先发送 message_start 事件 yield { type: "message_start", @@ -1542,39 +1740,104 @@ async initializeAuth(forceRefresh = false) { type: "message", role: "assistant", model: model, - usage: { input_tokens: inputTokens, output_tokens: 0 }, + usage: { + input_tokens: estimatedInputTokens, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0 + }, content: [] } }; - // 2. 发送 content_block_start 事件 - yield { - type: "content_block_start", - index: 0, - content_block: { type: "text", text: "" } - }; - - let totalContent = ''; - let outputTokens = 0; - const toolCalls = []; - let currentToolCall = null; // 用于累积结构化工具调用 - let contextUsagePercentage = null; // 用于存储上下文使用百分比 - - // 3. 流式接收并发送每个 content_block_delta + // 2. 流式接收并发送每个 content_block_delta for await (const event of this.streamApiReal('', finalModel, requestBody)) { - if (event.type === 'content' && event.content) { - totalContent += event.content; - // 不再每个 chunk 都计算 token,改为最后统一计算,避免阻塞事件循环 - - yield { - type: "content_block_delta", - index: 0, - delta: { type: "text_delta", text: event.content } - }; - } else if (event.type === 'contextUsage') { + if (event.type === 'contextUsage' && event.percentage) { // 捕获上下文使用百分比 - contextUsagePercentage = event.contextUsagePercentage; - console.log(`[Kiro] Received contextUsagePercentage: ${contextUsagePercentage}%`); + contextUsagePercentage = event.percentage; + inputTokens = this.calculateInputTokensFromPercentage(contextUsagePercentage); + + if (Math.abs(inputTokens - estimatedInputTokens) > estimatedInputTokens * 0.1) { + yield { + type: "message_delta", + delta: {}, + usage: { + input_tokens: inputTokens, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0 + } + }; + } + } else if (event.type === 'content' && event.content) { + totalContent += event.content; + + if (!thinkingRequested) { + yield* pushEvents(createTextDeltaEvents(event.content)); + continue; + } + + streamState.buffer += event.content; + const events = []; + + while (streamState.buffer.length > 0) { + if (!streamState.inThinking && !streamState.thinkingExtracted) { + const startPos = findRealTag(streamState.buffer, KIRO_THINKING.START_TAG); + if (startPos !== -1) { + const before = streamState.buffer.slice(0, startPos); + if (before) events.push(...createTextDeltaEvents(before)); + + streamState.buffer = streamState.buffer.slice(startPos + KIRO_THINKING.START_TAG.length); + streamState.inThinking = true; + continue; + } + + const safeLen = Math.max(0, streamState.buffer.length - KIRO_THINKING.START_TAG.length); + if (safeLen > 0) { + const safeText = streamState.buffer.slice(0, safeLen); + if (safeText) events.push(...createTextDeltaEvents(safeText)); + streamState.buffer = streamState.buffer.slice(safeLen); + } + break; + } + + if (streamState.inThinking) { + const endPos = findRealTag(streamState.buffer, KIRO_THINKING.END_TAG); + if (endPos !== -1) { + const thinkingPart = streamState.buffer.slice(0, endPos); + if (thinkingPart) events.push(...createThinkingDeltaEvents(thinkingPart)); + + streamState.buffer = streamState.buffer.slice(endPos + KIRO_THINKING.END_TAG.length); + streamState.inThinking = false; + streamState.thinkingExtracted = true; + + events.push(...createThinkingDeltaEvents("")); + events.push(...stopBlock(streamState.thinkingBlockIndex)); + + if (streamState.buffer.startsWith('\n\n')) { + streamState.buffer = streamState.buffer.slice(2); + } + continue; + } + + const safeLen = Math.max(0, streamState.buffer.length - KIRO_THINKING.END_TAG.length); + if (safeLen > 0) { + const safeThinking = streamState.buffer.slice(0, safeLen); + if (safeThinking) events.push(...createThinkingDeltaEvents(safeThinking)); + streamState.buffer = streamState.buffer.slice(safeLen); + } + break; + } + + if (streamState.thinkingExtracted) { + const rest = streamState.buffer; + streamState.buffer = ''; + if (rest) events.push(...createTextDeltaEvents(rest)); + break; + } + } + + yield* pushEvents(events); } else if (event.type === 'toolUse') { const tc = event.toolUse; // 统计工具调用的内容到 totalContent(用于 token 计算) @@ -1648,7 +1911,30 @@ async initializeAuth(forceRefresh = false) { toolCalls.push(currentToolCall); currentToolCall = null; } - + + if (thinkingRequested && streamState.buffer) { + if (streamState.inThinking) { + console.warn('[Kiro] Incomplete thinking tag at stream end'); + yield* pushEvents(createThinkingDeltaEvents(streamState.buffer)); + streamState.buffer = ''; + yield* pushEvents(createThinkingDeltaEvents("")); + yield* pushEvents(stopBlock(streamState.thinkingBlockIndex)); + } else if (!streamState.thinkingExtracted) { + yield* pushEvents(createTextDeltaEvents(streamState.buffer)); + streamState.buffer = ''; + } else { + yield* pushEvents(createTextDeltaEvents(streamState.buffer)); + streamState.buffer = ''; + } + } + + yield* pushEvents(stopBlock(streamState.textBlockIndex)); + + if (contextUsagePercentage === null) { + console.warn('[Kiro Stream] contextUsagePercentage not received, using estimation'); + inputTokens = estimatedInputTokens; + } + // 检查文本内容中的 bracket 格式工具调用 const bracketToolCalls = parseBracketToolCalls(totalContent); if (bracketToolCalls && bracketToolCalls.length > 0) { @@ -1661,15 +1947,13 @@ async initializeAuth(forceRefresh = false) { } } - // 4. 发送 content_block_stop 事件 - yield { type: "content_block_stop", index: 0 }; - - // 5. 处理工具调用(如果有) + // 3. 处理工具调用(如果有) if (toolCalls.length > 0) { + const baseIndex = streamState.nextBlockIndex; for (let i = 0; i < toolCalls.length; i++) { const tc = toolCalls[i]; - const blockIndex = i + 1; - + const blockIndex = baseIndex + i; + yield { type: "content_block_start", index: blockIndex, @@ -1694,32 +1978,31 @@ async initializeAuth(forceRefresh = false) { } } - // 6. 发送 message_delta 事件 - // 如果有 contextUsagePercentage,使用它来计算 token - // 总上下文 200k tokens,通过百分比计算总使用量,再减去输入 token 得到输出 token - let totalTokens = 0; - if (contextUsagePercentage !== null && contextUsagePercentage > 0 && true) { - const totalContextTokens = KIRO_CONSTANTS.TOTAL_CONTEXT_TOKENS; - // totalUsedTokens 就是通过百分比计算出的总使用量,直接作为 total_tokens - totalTokens = Math.round(totalContextTokens * contextUsagePercentage / 100); - outputTokens = Math.max(0, totalTokens - inputTokens); - console.log(`[Kiro] Token calculation from contextUsagePercentage: total=${totalTokens}, input=${inputTokens}, output=${outputTokens}`); - } else { - // 回退到原来的计算方式 - outputTokens = this.countTextTokens(totalContent); - for (const tc of toolCalls) { - outputTokens += this.countTextTokens(JSON.stringify(tc.input || {})); - } - totalTokens = inputTokens + outputTokens; + const contentBlocksForCount = thinkingRequested + ? this._toClaudeContentBlocksFromKiroText(totalContent) + : [{ type: "text", text: totalContent }]; + const plainForCount = contentBlocksForCount + .map(b => (b.type === 'thinking' ? (b.thinking ?? '') : (b.text ?? ''))) + .join(''); + outputTokens = this.countTextTokens(plainForCount); + + for (const tc of toolCalls) { + outputTokens += this.countTextTokens(JSON.stringify(tc.input || {})); } - + + // 4. 发送 message_delta 事件 yield { type: "message_delta", delta: { stop_reason: toolCalls.length > 0 ? "tool_use" : "end_turn" }, - usage: { input_tokens: inputTokens, output_tokens: outputTokens, total_tokens: totalTokens } + usage: { + input_tokens: inputTokens, + output_tokens: outputTokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0 + } }; - // 7. 发送 message_stop 事件 + // 5. 发送 message_stop 事件 yield { type: "message_stop" }; } catch (error) { @@ -1728,9 +2011,6 @@ async initializeAuth(forceRefresh = false) { } } - /** - * Count tokens for a given text using Claude's official tokenizer - */ countTextTokens(text) { if (!text) return 0; try { @@ -1748,28 +2028,117 @@ async initializeAuth(forceRefresh = false) { estimateInputTokens(requestBody) { let totalTokens = 0; + // 定义各类内容的开销乘数 + const OVERHEAD_MULTIPLIERS = { + system: 1.0, + message: 1.0, + tools: 1.0, + thinking: 1.0, + tool_result: 1.0, + tool_use_input: 1.0, + image: 1500 + }; + + const breakdown = { + system: 0, + thinking: 0, + text: 0, + tool_result: 0, + tool_use_input: 0, + image: 0, + thinking_content: 0, + tools_def: 0 + }; + // Count system prompt tokens if (requestBody.system) { const systemText = this.getContentText(requestBody.system); - totalTokens += this.countTextTokens(systemText); + const systemTokens = this.countTextTokens(systemText); + const counted = Math.ceil(systemTokens * OVERHEAD_MULTIPLIERS.system); + breakdown.system = counted; + totalTokens += counted; } - + + if (requestBody.thinking?.type === 'enabled') { + const budget = this._normalizeThinkingBudgetTokens(requestBody.thinking.budget_tokens); + const prefixText = `enabled${budget}`; + const prefixTokens = this.countTextTokens(prefixText); + const counted = Math.ceil(prefixTokens * OVERHEAD_MULTIPLIERS.thinking); + breakdown.thinking = counted; + totalTokens += counted; + } + // Count all messages tokens if (requestBody.messages && Array.isArray(requestBody.messages)) { for (const message of requestBody.messages) { - if (message.content) { - const contentText = this.getContentText(message); - totalTokens += this.countTextTokens(contentText); + if (!message.content) { + continue; + } + + if (Array.isArray(message.content)) { + for (const part of message.content) { + if (part.type === 'text' && part.text) { + const counted = Math.ceil(this.countTextTokens(part.text) * OVERHEAD_MULTIPLIERS.message); + breakdown.text += counted; + totalTokens += counted; + } + else if (part.type === 'tool_result') { + const toolResultText = this.getContentText(part.content); + const counted = Math.ceil(this.countTextTokens(toolResultText) * OVERHEAD_MULTIPLIERS.tool_result); + breakdown.tool_result += counted; + totalTokens += counted; + } + else if (part.type === 'tool_use' && part.input) { + const inputJson = JSON.stringify(part.input); + const counted = Math.ceil(this.countTextTokens(inputJson) * OVERHEAD_MULTIPLIERS.tool_use_input); + breakdown.tool_use_input += counted; + totalTokens += counted; + } + else if (part.type === 'image') { + breakdown.image += OVERHEAD_MULTIPLIERS.image; + totalTokens += OVERHEAD_MULTIPLIERS.image; + } + else if (part.type === 'thinking' && part.thinking) { + const counted = Math.ceil(this.countTextTokens(part.thinking) * OVERHEAD_MULTIPLIERS.message); + breakdown.thinking_content += counted; + totalTokens += counted; + } + } + } + else if (typeof message.content === 'string') { + const counted = Math.ceil(this.countTextTokens(message.content) * OVERHEAD_MULTIPLIERS.message); + breakdown.text += counted; + totalTokens += counted; } } } - + // Count tools definitions tokens if present if (requestBody.tools && Array.isArray(requestBody.tools)) { - totalTokens += this.countTextTokens(JSON.stringify(requestBody.tools)); + for (const tool of requestBody.tools) { + const toolJson = JSON.stringify(tool); + const toolTokens = this.countTextTokens(toolJson); + const counted = Math.ceil(toolTokens * OVERHEAD_MULTIPLIERS.tools); + breakdown.tools_def += counted; + totalTokens += counted; + } } + + const hasTools = requestBody.tools && requestBody.tools.length > 0; + const toolsDefTokens = breakdown.tools_def || 0; + const isSmallToolsDef = toolsDefTokens > 0 && toolsDefTokens < 21000; + + const KIRO_BASE_OVERHEAD = 400; + const KIRO_PERCENTAGE_OVERHEAD = hasTools + ? (isSmallToolsDef ? 0.18 : 0.08) + : 0.25; + + const baseOverhead = KIRO_BASE_OVERHEAD; + const percentageOverhead = Math.ceil(totalTokens * KIRO_PERCENTAGE_OVERHEAD); + totalTokens += baseOverhead + percentageOverhead; - return totalTokens; + this._lastTokenBreakdown = breakdown; + return Math.ceil(totalTokens); } /**