Merge pull request #93 from Sanyela/main

fix: 修复 Kiro API 的兼容问题
This commit is contained in:
何夕2077 2025-12-03 13:58:17 +08:00 committed by GitHub
commit 377db070e2

View file

@ -604,25 +604,23 @@ async initializeAuth(forceRefresh = false) {
let userInputMessage = {
content: '',
modelId: codewhispererModel,
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR,
userInputMessageContext: {}
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR
};
let images = [];
let toolResults = [];
if (Array.isArray(message.content)) {
userInputMessage.images = []; // Initialize images array
for (const part of message.content) {
if (part.type === 'text') {
userInputMessage.content += part.text;
} else if (part.type === 'tool_result') {
if (!userInputMessage.userInputMessageContext.toolResults) {
userInputMessage.userInputMessageContext.toolResults = [];
}
userInputMessage.userInputMessageContext.toolResults.push({
toolResults.push({
content: [{ text: this.getContentText(part.content) }],
status: 'success',
toolUseId: part.tool_use_id
});
} else if (part.type === 'image') {
userInputMessage.images.push({
images.push({
format: part.source.media_type.split('/')[1],
source: {
bytes: part.source.data
@ -633,18 +631,28 @@ async initializeAuth(forceRefresh = false) {
} else {
userInputMessage.content = this.getContentText(message);
}
// 只添加非空字段API 不接受空数组或空对象
if (images.length > 0) {
userInputMessage.images = images;
}
if (toolResults.length > 0) {
userInputMessage.userInputMessageContext = { toolResults };
}
history.push({ userInputMessage });
} else if (message.role === 'assistant') {
let assistantResponseMessage = {
content: '',
toolUses: []
content: ''
};
let toolUses = [];
if (Array.isArray(message.content)) {
for (const part of message.content) {
if (part.type === 'text') {
assistantResponseMessage.content += part.text;
} else if (part.type === 'tool_use') {
assistantResponseMessage.toolUses.push({
toolUses.push({
input: part.input,
name: part.name,
toolUseId: part.id
@ -654,77 +662,129 @@ async initializeAuth(forceRefresh = false) {
} else {
assistantResponseMessage.content = this.getContentText(message);
}
// 只添加非空字段
if (toolUses.length > 0) {
assistantResponseMessage.toolUses = toolUses;
}
history.push({ assistantResponseMessage });
}
}
// Build current message
const currentMessage = processedMessages[processedMessages.length - 1];
let currentMessage = processedMessages[processedMessages.length - 1];
let currentContent = '';
let currentToolResults = [];
let currentToolUses = [];
let currentImages = [];
if (Array.isArray(currentMessage.content)) {
for (const part of currentMessage.content) {
if (part.type === 'text') {
currentContent += part.text;
} else if (part.type === 'tool_result') {
currentToolResults.push({
content: [{ text: this.getContentText(part.content) }],
status: 'success',
toolUseId: part.tool_use_id
});
} else if (part.type === 'tool_use') {
currentToolUses.push({
input: part.input,
name: part.name,
toolUseId: part.id
});
} else if (part.type === 'image') {
currentImages.push({
format: part.source.media_type.split('/')[1],
source: {
bytes: part.source.data
}
});
// 如果最后一条消息是 assistant需要将其加入 history然后创建一个 user 类型的 currentMessage
// 因为 CodeWhisperer API 的 currentMessage 必须是 userInputMessage 类型
if (currentMessage.role === 'assistant') {
console.log('[Kiro] Last message is assistant, moving it to history and creating user currentMessage');
// 构建 assistant 消息并加入 history
let assistantResponseMessage = {
content: '',
toolUses: []
};
if (Array.isArray(currentMessage.content)) {
for (const part of currentMessage.content) {
if (part.type === 'text') {
assistantResponseMessage.content += part.text;
} else if (part.type === 'tool_use') {
assistantResponseMessage.toolUses.push({
input: part.input,
name: part.name,
toolUseId: part.id
});
}
}
} else {
assistantResponseMessage.content = this.getContentText(currentMessage);
}
} else {
currentContent = this.getContentText(currentMessage);
}
if (!currentContent && currentToolResults.length === 0 && currentToolUses.length === 0) {
if (assistantResponseMessage.toolUses.length === 0) {
delete assistantResponseMessage.toolUses;
}
history.push({ assistantResponseMessage });
// 设置 currentContent 为 "Continue",因为我们需要一个 user 消息来触发 AI 继续
currentContent = 'Continue';
} else {
// 处理 user 消息
if (Array.isArray(currentMessage.content)) {
for (const part of currentMessage.content) {
if (part.type === 'text') {
currentContent += part.text;
} else if (part.type === 'tool_result') {
currentToolResults.push({
content: [{ text: this.getContentText(part.content) }],
status: 'success',
toolUseId: part.tool_use_id
});
} else if (part.type === 'tool_use') {
currentToolUses.push({
input: part.input,
name: part.name,
toolUseId: part.id
});
} else if (part.type === 'image') {
currentImages.push({
format: part.source.media_type.split('/')[1],
source: {
bytes: part.source.data
}
});
}
}
} else {
currentContent = this.getContentText(currentMessage);
}
if (!currentContent && currentToolResults.length === 0 && currentToolUses.length === 0) {
currentContent = 'Continue';
}
}
const request = {
conversationState: {
chatTriggerType: KIRO_CONSTANTS.CHAT_TRIGGER_TYPE_MANUAL,
conversationId: conversationId,
currentMessage: {}, // Will be populated based on the last message's role
currentMessage: {}, // Will be populated as userInputMessage
history: history
}
};
if (currentMessage.role === 'user') {
request.conversationState.currentMessage.userInputMessage = {
content: currentContent,
modelId: codewhispererModel,
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR,
images: currentImages && currentImages.length > 0 ? currentImages : null, // Add images here
userInputMessageContext: {
toolResults: currentToolResults.length > 0 ? currentToolResults : null,
tools: Object.keys(toolsContext).length > 0 ? toolsContext.tools : null
}
};
} else if (currentMessage.role === 'assistant') {
request.conversationState.currentMessage.assistantResponseMessage = {
content: currentContent,
toolUses: currentToolUses.length > 0 ? currentToolUses : undefined
};
// currentMessage 始终是 userInputMessage 类型
// 注意API 不接受 null 值,空字段应该完全不包含
const userInputMessage = {
content: currentContent,
modelId: codewhispererModel,
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR
};
// 只有当 images 非空时才添加
if (currentImages && currentImages.length > 0) {
userInputMessage.images = currentImages;
}
// 构建 userInputMessageContext只包含非空字段
const userInputMessageContext = {};
if (currentToolResults.length > 0) {
userInputMessageContext.toolResults = currentToolResults;
}
if (Object.keys(toolsContext).length > 0 && toolsContext.tools) {
userInputMessageContext.tools = toolsContext.tools;
}
// 只有当 userInputMessageContext 有内容时才添加
if (Object.keys(userInputMessageContext).length > 0) {
userInputMessage.userInputMessageContext = userInputMessageContext;
}
request.conversationState.currentMessage.userInputMessage = userInputMessage;
if (this.authMethod === KIRO_CONSTANTS.AUTH_METHOD_SOCIAL) {
request.profileArn = this.profileArn;
}
@ -952,18 +1012,146 @@ async initializeAuth(forceRefresh = false) {
}
}
//kiro提供的接口没有流式返回
async streamApi(method, model, body, isRetry = false, retryCount = 0) {
/**
* 解析 AWS Event Stream 格式提取所有完整的 JSON 事件
* 返回 { events: 解析出的事件数组, remaining: 未处理完的缓冲区 }
*/
parseAwsEventStreamBuffer(buffer) {
const events = [];
let remaining = buffer;
let searchStart = 0;
while (true) {
// 查找 {"content": 或 {"toolUse" 的起始位置
const contentStart = remaining.indexOf('{"content":', searchStart);
const toolUseStart = remaining.indexOf('{"toolUse":', searchStart);
let jsonStart = -1;
if (contentStart >= 0 && toolUseStart >= 0) {
jsonStart = Math.min(contentStart, toolUseStart);
} else if (contentStart >= 0) {
jsonStart = contentStart;
} else if (toolUseStart >= 0) {
jsonStart = toolUseStart;
}
if (jsonStart < 0) break;
// 查找对应的 } 结束位置
const jsonEnd = remaining.indexOf('}', jsonStart);
if (jsonEnd < 0) {
// 不完整的 JSON保留在缓冲区
remaining = remaining.substring(jsonStart);
break;
}
const jsonStr = remaining.substring(jsonStart, jsonEnd + 1);
try {
const parsed = JSON.parse(jsonStr);
if (parsed.content !== undefined) {
events.push({ type: 'content', data: parsed.content });
} else if (parsed.toolUse !== undefined) {
events.push({ type: 'toolUse', data: parsed.toolUse });
}
} catch (e) {
// JSON 解析失败,可能是不完整的,继续搜索
}
searchStart = jsonEnd + 1;
if (searchStart >= remaining.length) {
remaining = '';
break;
}
}
// 如果 searchStart 有进展,截取剩余部分
if (searchStart > 0 && remaining.length > 0) {
remaining = remaining.substring(searchStart);
}
return { events, remaining };
}
/**
* 真正的流式 API 调用 - 使用 responseType: 'stream'
*/
async * streamApiReal(method, model, body, isRetry = false, retryCount = 0) {
if (!this.isInitialized) await this.initialize();
const maxRetries = this.config.REQUEST_MAX_RETRIES || 3;
const baseDelay = this.config.REQUEST_BASE_DELAY || 1000;
const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system);
const token = this.accessToken;
const headers = {
'Authorization': `Bearer ${token}`,
'amz-sdk-invocation-id': `${uuidv4()}`,
};
const requestUrl = model.startsWith('amazonq') ? this.amazonQUrl : this.baseUrl;
try {
// 直接调用并返回Promise最终解析为response
return await this.callApi(method, model, body, isRetry, retryCount);
const response = await this.axiosInstance.post(requestUrl, requestData, {
headers,
responseType: 'stream'
});
const stream = response.data;
let buffer = '';
const processedPositions = new Set(); // 避免重复处理
for await (const chunk of stream) {
buffer += chunk.toString();
// 解析缓冲区中的事件
const { events, remaining } = this.parseAwsEventStreamBuffer(buffer);
buffer = remaining;
// 只 yield 新的事件
for (const event of events) {
const eventKey = `${event.type}:${event.data}`;
if (!processedPositions.has(eventKey)) {
processedPositions.add(eventKey);
if (event.type === 'content' && event.data) {
yield { type: 'content', content: event.data };
} else if (event.type === 'toolUse') {
yield { type: 'toolUse', toolUse: event.data };
}
}
}
}
} catch (error) {
console.error('[Kiro] Error calling API:', error);
throw error; // 向上抛出错误
if (error.response?.status === 403 && !isRetry) {
console.log('[Kiro] Received 403 in stream. Attempting token refresh and retrying...');
await this.initializeAuth(true);
yield* this.streamApiReal(method, model, body, true, retryCount);
return;
}
if (error.response?.status === 429 && retryCount < maxRetries) {
const delay = baseDelay * Math.pow(2, retryCount);
console.log(`[Kiro] Received 429 in stream. Retrying in ${delay}ms...`);
await new Promise(resolve => setTimeout(resolve, delay));
yield* this.streamApiReal(method, model, body, isRetry, retryCount + 1);
return;
}
console.error('[Kiro] Stream API call failed:', error.message);
throw error;
}
}
// 重构2: generateContentStream 调用新的普通async函数
// 保留旧的非流式方法用于 generateContent
async streamApi(method, model, body, isRetry = false, retryCount = 0) {
try {
return await this.callApi(method, model, body, isRetry, retryCount);
} catch (error) {
console.error('[Kiro] Error calling API:', error);
throw error;
}
}
// 真正的流式传输实现
async * generateContentStream(model, requestBody) {
if (!this.isInitialized) await this.initialize();
@ -974,28 +1162,98 @@ async initializeAuth(forceRefresh = false) {
}
const finalModel = MODEL_MAPPING[model] ? model : this.modelName;
console.log(`[Kiro] Calling generateContentStream with model: ${finalModel}`);
console.log(`[Kiro] Calling generateContentStream with model: ${finalModel} (real streaming)`);
// Estimate input tokens before making the API call
const inputTokens = this.estimateInputTokens(requestBody);
const messageId = `${uuidv4()}`;
try {
const response = await this.streamApi('', finalModel, requestBody);
const { responseText, toolCalls } = this._processApiResponse(response);
// 1. 先发送 message_start 事件
yield {
type: "message_start",
message: {
id: messageId,
type: "message",
role: "assistant",
model: model,
usage: { input_tokens: inputTokens, output_tokens: 0 },
content: []
}
};
// Pass both responseText and toolCalls to buildClaudeResponse
// buildClaudeResponse will handle the logic of combining them into a single stream
for (const chunkJson of this.buildClaudeResponse(responseText, true, 'assistant', model, toolCalls, inputTokens)) {
yield chunkJson;
// 2. 发送 content_block_start 事件
yield {
type: "content_block_start",
index: 0,
content_block: { type: "text", text: "" }
};
let totalContent = '';
let outputTokens = 0;
const toolCalls = [];
// 3. 流式接收并发送每个 content_block_delta
for await (const event of this.streamApiReal('', finalModel, requestBody)) {
if (event.type === 'content' && event.content) {
totalContent += event.content;
outputTokens += this.countTextTokens(event.content);
yield {
type: "content_block_delta",
index: 0,
delta: { type: "text_delta", text: event.content }
};
} else if (event.type === 'toolUse') {
toolCalls.push(event.toolUse);
}
}
// 4. 发送 content_block_stop 事件
yield { type: "content_block_stop", index: 0 };
// 5. 处理工具调用(如果有)
if (toolCalls.length > 0) {
for (let i = 0; i < toolCalls.length; i++) {
const tc = toolCalls[i];
const blockIndex = i + 1;
yield {
type: "content_block_start",
index: blockIndex,
content_block: {
type: "tool_use",
id: tc.toolUseId || `tool_${uuidv4()}`,
name: tc.name,
input: {}
}
};
yield {
type: "content_block_delta",
index: blockIndex,
delta: {
type: "input_json_delta",
partial_json: JSON.stringify(tc.input || {})
}
};
yield { type: "content_block_stop", index: blockIndex };
}
}
// 6. 发送 message_delta 事件
yield {
type: "message_delta",
delta: { stop_reason: toolCalls.length > 0 ? "tool_use" : "end_turn" },
usage: { output_tokens: outputTokens }
};
// 7. 发送 message_stop 事件
yield { type: "message_stop" };
} catch (error) {
console.error('[Kiro] Error in streaming generation:', error);
throw new Error(`Error processing response: ${error.message}`);
// For Claude, we yield an array of events for streaming error
// Ensure error message is passed as content, not toolCalls
// for (const chunkJson of this.buildClaudeResponse(`Error: ${error.message}`, true, 'assistant', model, null)) {
// yield chunkJson;
// }
}
}