diff --git a/src/claude/claude-kiro.js b/src/claude/claude-kiro.js index 2d25800..e30b48d 100644 --- a/src/claude/claude-kiro.js +++ b/src/claude/claude-kiro.js @@ -604,25 +604,23 @@ async initializeAuth(forceRefresh = false) { let userInputMessage = { content: '', modelId: codewhispererModel, - origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR, - userInputMessageContext: {} + origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR }; + let images = []; + let toolResults = []; + if (Array.isArray(message.content)) { - userInputMessage.images = []; // Initialize images array for (const part of message.content) { if (part.type === 'text') { userInputMessage.content += part.text; } else if (part.type === 'tool_result') { - if (!userInputMessage.userInputMessageContext.toolResults) { - userInputMessage.userInputMessageContext.toolResults = []; - } - userInputMessage.userInputMessageContext.toolResults.push({ + toolResults.push({ content: [{ text: this.getContentText(part.content) }], status: 'success', toolUseId: part.tool_use_id }); } else if (part.type === 'image') { - userInputMessage.images.push({ + images.push({ format: part.source.media_type.split('/')[1], source: { bytes: part.source.data @@ -633,18 +631,28 @@ async initializeAuth(forceRefresh = false) { } else { userInputMessage.content = this.getContentText(message); } + + // 只添加非空字段,API 不接受空数组或空对象 + if (images.length > 0) { + userInputMessage.images = images; + } + if (toolResults.length > 0) { + userInputMessage.userInputMessageContext = { toolResults }; + } + history.push({ userInputMessage }); } else if (message.role === 'assistant') { let assistantResponseMessage = { - content: '', - toolUses: [] + content: '' }; + let toolUses = []; + if (Array.isArray(message.content)) { for (const part of message.content) { if (part.type === 'text') { assistantResponseMessage.content += part.text; } else if (part.type === 'tool_use') { - assistantResponseMessage.toolUses.push({ + toolUses.push({ input: part.input, name: part.name, toolUseId: part.id @@ -654,77 +662,129 @@ async initializeAuth(forceRefresh = false) { } else { assistantResponseMessage.content = this.getContentText(message); } + + // 只添加非空字段 + if (toolUses.length > 0) { + assistantResponseMessage.toolUses = toolUses; + } + history.push({ assistantResponseMessage }); } } // Build current message - const currentMessage = processedMessages[processedMessages.length - 1]; + let currentMessage = processedMessages[processedMessages.length - 1]; let currentContent = ''; let currentToolResults = []; let currentToolUses = []; let currentImages = []; - if (Array.isArray(currentMessage.content)) { - for (const part of currentMessage.content) { - if (part.type === 'text') { - currentContent += part.text; - } else if (part.type === 'tool_result') { - currentToolResults.push({ - content: [{ text: this.getContentText(part.content) }], - status: 'success', - toolUseId: part.tool_use_id - }); - } else if (part.type === 'tool_use') { - currentToolUses.push({ - input: part.input, - name: part.name, - toolUseId: part.id - }); - } else if (part.type === 'image') { - currentImages.push({ - format: part.source.media_type.split('/')[1], - source: { - bytes: part.source.data - } - }); + // 如果最后一条消息是 assistant,需要将其加入 history,然后创建一个 user 类型的 currentMessage + // 因为 CodeWhisperer API 的 currentMessage 必须是 userInputMessage 类型 + if (currentMessage.role === 'assistant') { + console.log('[Kiro] Last message is assistant, moving it to history and creating user currentMessage'); + + // 构建 assistant 消息并加入 history + let assistantResponseMessage = { + content: '', + toolUses: [] + }; + if (Array.isArray(currentMessage.content)) { + for (const part of currentMessage.content) { + if (part.type === 'text') { + assistantResponseMessage.content += part.text; + } else if (part.type === 'tool_use') { + assistantResponseMessage.toolUses.push({ + input: part.input, + name: part.name, + toolUseId: part.id + }); + } } + } else { + assistantResponseMessage.content = this.getContentText(currentMessage); } - } else { - currentContent = this.getContentText(currentMessage); - } - - if (!currentContent && currentToolResults.length === 0 && currentToolUses.length === 0) { + if (assistantResponseMessage.toolUses.length === 0) { + delete assistantResponseMessage.toolUses; + } + history.push({ assistantResponseMessage }); + + // 设置 currentContent 为 "Continue",因为我们需要一个 user 消息来触发 AI 继续 currentContent = 'Continue'; + } else { + // 处理 user 消息 + if (Array.isArray(currentMessage.content)) { + for (const part of currentMessage.content) { + if (part.type === 'text') { + currentContent += part.text; + } else if (part.type === 'tool_result') { + currentToolResults.push({ + content: [{ text: this.getContentText(part.content) }], + status: 'success', + toolUseId: part.tool_use_id + }); + } else if (part.type === 'tool_use') { + currentToolUses.push({ + input: part.input, + name: part.name, + toolUseId: part.id + }); + } else if (part.type === 'image') { + currentImages.push({ + format: part.source.media_type.split('/')[1], + source: { + bytes: part.source.data + } + }); + } + } + } else { + currentContent = this.getContentText(currentMessage); + } + + if (!currentContent && currentToolResults.length === 0 && currentToolUses.length === 0) { + currentContent = 'Continue'; + } } const request = { conversationState: { chatTriggerType: KIRO_CONSTANTS.CHAT_TRIGGER_TYPE_MANUAL, conversationId: conversationId, - currentMessage: {}, // Will be populated based on the last message's role + currentMessage: {}, // Will be populated as userInputMessage history: history } }; - if (currentMessage.role === 'user') { - request.conversationState.currentMessage.userInputMessage = { - content: currentContent, - modelId: codewhispererModel, - origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR, - images: currentImages && currentImages.length > 0 ? currentImages : null, // Add images here - userInputMessageContext: { - toolResults: currentToolResults.length > 0 ? currentToolResults : null, - tools: Object.keys(toolsContext).length > 0 ? toolsContext.tools : null - } - }; - } else if (currentMessage.role === 'assistant') { - request.conversationState.currentMessage.assistantResponseMessage = { - content: currentContent, - toolUses: currentToolUses.length > 0 ? currentToolUses : undefined - }; + // currentMessage 始终是 userInputMessage 类型 + // 注意:API 不接受 null 值,空字段应该完全不包含 + const userInputMessage = { + content: currentContent, + modelId: codewhispererModel, + origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR + }; + + // 只有当 images 非空时才添加 + if (currentImages && currentImages.length > 0) { + userInputMessage.images = currentImages; } + // 构建 userInputMessageContext,只包含非空字段 + const userInputMessageContext = {}; + if (currentToolResults.length > 0) { + userInputMessageContext.toolResults = currentToolResults; + } + if (Object.keys(toolsContext).length > 0 && toolsContext.tools) { + userInputMessageContext.tools = toolsContext.tools; + } + + // 只有当 userInputMessageContext 有内容时才添加 + if (Object.keys(userInputMessageContext).length > 0) { + userInputMessage.userInputMessageContext = userInputMessageContext; + } + + request.conversationState.currentMessage.userInputMessage = userInputMessage; + if (this.authMethod === KIRO_CONSTANTS.AUTH_METHOD_SOCIAL) { request.profileArn = this.profileArn; } @@ -952,18 +1012,146 @@ async initializeAuth(forceRefresh = false) { } } - //kiro提供的接口没有流式返回 - async streamApi(method, model, body, isRetry = false, retryCount = 0) { + /** + * 解析 AWS Event Stream 格式,提取所有完整的 JSON 事件 + * 返回 { events: 解析出的事件数组, remaining: 未处理完的缓冲区 } + */ + parseAwsEventStreamBuffer(buffer) { + const events = []; + let remaining = buffer; + let searchStart = 0; + + while (true) { + // 查找 {"content": 或 {"toolUse" 的起始位置 + const contentStart = remaining.indexOf('{"content":', searchStart); + const toolUseStart = remaining.indexOf('{"toolUse":', searchStart); + + let jsonStart = -1; + if (contentStart >= 0 && toolUseStart >= 0) { + jsonStart = Math.min(contentStart, toolUseStart); + } else if (contentStart >= 0) { + jsonStart = contentStart; + } else if (toolUseStart >= 0) { + jsonStart = toolUseStart; + } + + if (jsonStart < 0) break; + + // 查找对应的 } 结束位置 + const jsonEnd = remaining.indexOf('}', jsonStart); + if (jsonEnd < 0) { + // 不完整的 JSON,保留在缓冲区 + remaining = remaining.substring(jsonStart); + break; + } + + const jsonStr = remaining.substring(jsonStart, jsonEnd + 1); + try { + const parsed = JSON.parse(jsonStr); + if (parsed.content !== undefined) { + events.push({ type: 'content', data: parsed.content }); + } else if (parsed.toolUse !== undefined) { + events.push({ type: 'toolUse', data: parsed.toolUse }); + } + } catch (e) { + // JSON 解析失败,可能是不完整的,继续搜索 + } + + searchStart = jsonEnd + 1; + if (searchStart >= remaining.length) { + remaining = ''; + break; + } + } + + // 如果 searchStart 有进展,截取剩余部分 + if (searchStart > 0 && remaining.length > 0) { + remaining = remaining.substring(searchStart); + } + + return { events, remaining }; + } + + /** + * 真正的流式 API 调用 - 使用 responseType: 'stream' + */ + async * streamApiReal(method, model, body, isRetry = false, retryCount = 0) { + if (!this.isInitialized) await this.initialize(); + const maxRetries = this.config.REQUEST_MAX_RETRIES || 3; + const baseDelay = this.config.REQUEST_BASE_DELAY || 1000; + + const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system); + + const token = this.accessToken; + const headers = { + 'Authorization': `Bearer ${token}`, + 'amz-sdk-invocation-id': `${uuidv4()}`, + }; + + const requestUrl = model.startsWith('amazonq') ? this.amazonQUrl : this.baseUrl; + try { - // 直接调用并返回Promise,最终解析为response - return await this.callApi(method, model, body, isRetry, retryCount); + const response = await this.axiosInstance.post(requestUrl, requestData, { + headers, + responseType: 'stream' + }); + + const stream = response.data; + let buffer = ''; + const processedPositions = new Set(); // 避免重复处理 + + for await (const chunk of stream) { + buffer += chunk.toString(); + + // 解析缓冲区中的事件 + const { events, remaining } = this.parseAwsEventStreamBuffer(buffer); + buffer = remaining; + + // 只 yield 新的事件 + for (const event of events) { + const eventKey = `${event.type}:${event.data}`; + if (!processedPositions.has(eventKey)) { + processedPositions.add(eventKey); + if (event.type === 'content' && event.data) { + yield { type: 'content', content: event.data }; + } else if (event.type === 'toolUse') { + yield { type: 'toolUse', toolUse: event.data }; + } + } + } + } } catch (error) { - console.error('[Kiro] Error calling API:', error); - throw error; // 向上抛出错误 + if (error.response?.status === 403 && !isRetry) { + console.log('[Kiro] Received 403 in stream. Attempting token refresh and retrying...'); + await this.initializeAuth(true); + yield* this.streamApiReal(method, model, body, true, retryCount); + return; + } + + if (error.response?.status === 429 && retryCount < maxRetries) { + const delay = baseDelay * Math.pow(2, retryCount); + console.log(`[Kiro] Received 429 in stream. Retrying in ${delay}ms...`); + await new Promise(resolve => setTimeout(resolve, delay)); + yield* this.streamApiReal(method, model, body, isRetry, retryCount + 1); + return; + } + + console.error('[Kiro] Stream API call failed:', error.message); + throw error; } } - // 重构2: generateContentStream 调用新的普通async函数 + // 保留旧的非流式方法用于 generateContent + async streamApi(method, model, body, isRetry = false, retryCount = 0) { + try { + return await this.callApi(method, model, body, isRetry, retryCount); + } catch (error) { + console.error('[Kiro] Error calling API:', error); + throw error; + } + } + + // 真正的流式传输实现 async * generateContentStream(model, requestBody) { if (!this.isInitialized) await this.initialize(); @@ -974,28 +1162,98 @@ async initializeAuth(forceRefresh = false) { } const finalModel = MODEL_MAPPING[model] ? model : this.modelName; - console.log(`[Kiro] Calling generateContentStream with model: ${finalModel}`); + console.log(`[Kiro] Calling generateContentStream with model: ${finalModel} (real streaming)`); - // Estimate input tokens before making the API call const inputTokens = this.estimateInputTokens(requestBody); + const messageId = `${uuidv4()}`; try { - const response = await this.streamApi('', finalModel, requestBody); - const { responseText, toolCalls } = this._processApiResponse(response); + // 1. 先发送 message_start 事件 + yield { + type: "message_start", + message: { + id: messageId, + type: "message", + role: "assistant", + model: model, + usage: { input_tokens: inputTokens, output_tokens: 0 }, + content: [] + } + }; - // Pass both responseText and toolCalls to buildClaudeResponse - // buildClaudeResponse will handle the logic of combining them into a single stream - for (const chunkJson of this.buildClaudeResponse(responseText, true, 'assistant', model, toolCalls, inputTokens)) { - yield chunkJson; + // 2. 发送 content_block_start 事件 + yield { + type: "content_block_start", + index: 0, + content_block: { type: "text", text: "" } + }; + + let totalContent = ''; + let outputTokens = 0; + const toolCalls = []; + + // 3. 流式接收并发送每个 content_block_delta + for await (const event of this.streamApiReal('', finalModel, requestBody)) { + if (event.type === 'content' && event.content) { + totalContent += event.content; + outputTokens += this.countTextTokens(event.content); + + yield { + type: "content_block_delta", + index: 0, + delta: { type: "text_delta", text: event.content } + }; + } else if (event.type === 'toolUse') { + toolCalls.push(event.toolUse); + } } + + // 4. 发送 content_block_stop 事件 + yield { type: "content_block_stop", index: 0 }; + + // 5. 处理工具调用(如果有) + if (toolCalls.length > 0) { + for (let i = 0; i < toolCalls.length; i++) { + const tc = toolCalls[i]; + const blockIndex = i + 1; + + yield { + type: "content_block_start", + index: blockIndex, + content_block: { + type: "tool_use", + id: tc.toolUseId || `tool_${uuidv4()}`, + name: tc.name, + input: {} + } + }; + + yield { + type: "content_block_delta", + index: blockIndex, + delta: { + type: "input_json_delta", + partial_json: JSON.stringify(tc.input || {}) + } + }; + + yield { type: "content_block_stop", index: blockIndex }; + } + } + + // 6. 发送 message_delta 事件 + yield { + type: "message_delta", + delta: { stop_reason: toolCalls.length > 0 ? "tool_use" : "end_turn" }, + usage: { output_tokens: outputTokens } + }; + + // 7. 发送 message_stop 事件 + yield { type: "message_stop" }; + } catch (error) { console.error('[Kiro] Error in streaming generation:', error); throw new Error(`Error processing response: ${error.message}`); - // For Claude, we yield an array of events for streaming error - // Ensure error message is passed as content, not toolCalls - // for (const chunkJson of this.buildClaudeResponse(`Error: ${error.message}`, true, 'assistant', model, null)) { - // yield chunkJson; - // } } }