commit
377db070e2
1 changed files with 335 additions and 77 deletions
|
|
@ -604,25 +604,23 @@ async initializeAuth(forceRefresh = false) {
|
|||
let userInputMessage = {
|
||||
content: '',
|
||||
modelId: codewhispererModel,
|
||||
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR,
|
||||
userInputMessageContext: {}
|
||||
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR
|
||||
};
|
||||
let images = [];
|
||||
let toolResults = [];
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
userInputMessage.images = []; // Initialize images array
|
||||
for (const part of message.content) {
|
||||
if (part.type === 'text') {
|
||||
userInputMessage.content += part.text;
|
||||
} else if (part.type === 'tool_result') {
|
||||
if (!userInputMessage.userInputMessageContext.toolResults) {
|
||||
userInputMessage.userInputMessageContext.toolResults = [];
|
||||
}
|
||||
userInputMessage.userInputMessageContext.toolResults.push({
|
||||
toolResults.push({
|
||||
content: [{ text: this.getContentText(part.content) }],
|
||||
status: 'success',
|
||||
toolUseId: part.tool_use_id
|
||||
});
|
||||
} else if (part.type === 'image') {
|
||||
userInputMessage.images.push({
|
||||
images.push({
|
||||
format: part.source.media_type.split('/')[1],
|
||||
source: {
|
||||
bytes: part.source.data
|
||||
|
|
@ -633,18 +631,28 @@ async initializeAuth(forceRefresh = false) {
|
|||
} else {
|
||||
userInputMessage.content = this.getContentText(message);
|
||||
}
|
||||
|
||||
// 只添加非空字段,API 不接受空数组或空对象
|
||||
if (images.length > 0) {
|
||||
userInputMessage.images = images;
|
||||
}
|
||||
if (toolResults.length > 0) {
|
||||
userInputMessage.userInputMessageContext = { toolResults };
|
||||
}
|
||||
|
||||
history.push({ userInputMessage });
|
||||
} else if (message.role === 'assistant') {
|
||||
let assistantResponseMessage = {
|
||||
content: '',
|
||||
toolUses: []
|
||||
content: ''
|
||||
};
|
||||
let toolUses = [];
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
if (part.type === 'text') {
|
||||
assistantResponseMessage.content += part.text;
|
||||
} else if (part.type === 'tool_use') {
|
||||
assistantResponseMessage.toolUses.push({
|
||||
toolUses.push({
|
||||
input: part.input,
|
||||
name: part.name,
|
||||
toolUseId: part.id
|
||||
|
|
@ -654,77 +662,129 @@ async initializeAuth(forceRefresh = false) {
|
|||
} else {
|
||||
assistantResponseMessage.content = this.getContentText(message);
|
||||
}
|
||||
|
||||
// 只添加非空字段
|
||||
if (toolUses.length > 0) {
|
||||
assistantResponseMessage.toolUses = toolUses;
|
||||
}
|
||||
|
||||
history.push({ assistantResponseMessage });
|
||||
}
|
||||
}
|
||||
|
||||
// Build current message
|
||||
const currentMessage = processedMessages[processedMessages.length - 1];
|
||||
let currentMessage = processedMessages[processedMessages.length - 1];
|
||||
let currentContent = '';
|
||||
let currentToolResults = [];
|
||||
let currentToolUses = [];
|
||||
let currentImages = [];
|
||||
|
||||
if (Array.isArray(currentMessage.content)) {
|
||||
for (const part of currentMessage.content) {
|
||||
if (part.type === 'text') {
|
||||
currentContent += part.text;
|
||||
} else if (part.type === 'tool_result') {
|
||||
currentToolResults.push({
|
||||
content: [{ text: this.getContentText(part.content) }],
|
||||
status: 'success',
|
||||
toolUseId: part.tool_use_id
|
||||
});
|
||||
} else if (part.type === 'tool_use') {
|
||||
currentToolUses.push({
|
||||
input: part.input,
|
||||
name: part.name,
|
||||
toolUseId: part.id
|
||||
});
|
||||
} else if (part.type === 'image') {
|
||||
currentImages.push({
|
||||
format: part.source.media_type.split('/')[1],
|
||||
source: {
|
||||
bytes: part.source.data
|
||||
}
|
||||
});
|
||||
// 如果最后一条消息是 assistant,需要将其加入 history,然后创建一个 user 类型的 currentMessage
|
||||
// 因为 CodeWhisperer API 的 currentMessage 必须是 userInputMessage 类型
|
||||
if (currentMessage.role === 'assistant') {
|
||||
console.log('[Kiro] Last message is assistant, moving it to history and creating user currentMessage');
|
||||
|
||||
// 构建 assistant 消息并加入 history
|
||||
let assistantResponseMessage = {
|
||||
content: '',
|
||||
toolUses: []
|
||||
};
|
||||
if (Array.isArray(currentMessage.content)) {
|
||||
for (const part of currentMessage.content) {
|
||||
if (part.type === 'text') {
|
||||
assistantResponseMessage.content += part.text;
|
||||
} else if (part.type === 'tool_use') {
|
||||
assistantResponseMessage.toolUses.push({
|
||||
input: part.input,
|
||||
name: part.name,
|
||||
toolUseId: part.id
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assistantResponseMessage.content = this.getContentText(currentMessage);
|
||||
}
|
||||
} else {
|
||||
currentContent = this.getContentText(currentMessage);
|
||||
}
|
||||
|
||||
if (!currentContent && currentToolResults.length === 0 && currentToolUses.length === 0) {
|
||||
if (assistantResponseMessage.toolUses.length === 0) {
|
||||
delete assistantResponseMessage.toolUses;
|
||||
}
|
||||
history.push({ assistantResponseMessage });
|
||||
|
||||
// 设置 currentContent 为 "Continue",因为我们需要一个 user 消息来触发 AI 继续
|
||||
currentContent = 'Continue';
|
||||
} else {
|
||||
// 处理 user 消息
|
||||
if (Array.isArray(currentMessage.content)) {
|
||||
for (const part of currentMessage.content) {
|
||||
if (part.type === 'text') {
|
||||
currentContent += part.text;
|
||||
} else if (part.type === 'tool_result') {
|
||||
currentToolResults.push({
|
||||
content: [{ text: this.getContentText(part.content) }],
|
||||
status: 'success',
|
||||
toolUseId: part.tool_use_id
|
||||
});
|
||||
} else if (part.type === 'tool_use') {
|
||||
currentToolUses.push({
|
||||
input: part.input,
|
||||
name: part.name,
|
||||
toolUseId: part.id
|
||||
});
|
||||
} else if (part.type === 'image') {
|
||||
currentImages.push({
|
||||
format: part.source.media_type.split('/')[1],
|
||||
source: {
|
||||
bytes: part.source.data
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
currentContent = this.getContentText(currentMessage);
|
||||
}
|
||||
|
||||
if (!currentContent && currentToolResults.length === 0 && currentToolUses.length === 0) {
|
||||
currentContent = 'Continue';
|
||||
}
|
||||
}
|
||||
|
||||
const request = {
|
||||
conversationState: {
|
||||
chatTriggerType: KIRO_CONSTANTS.CHAT_TRIGGER_TYPE_MANUAL,
|
||||
conversationId: conversationId,
|
||||
currentMessage: {}, // Will be populated based on the last message's role
|
||||
currentMessage: {}, // Will be populated as userInputMessage
|
||||
history: history
|
||||
}
|
||||
};
|
||||
|
||||
if (currentMessage.role === 'user') {
|
||||
request.conversationState.currentMessage.userInputMessage = {
|
||||
content: currentContent,
|
||||
modelId: codewhispererModel,
|
||||
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR,
|
||||
images: currentImages && currentImages.length > 0 ? currentImages : null, // Add images here
|
||||
userInputMessageContext: {
|
||||
toolResults: currentToolResults.length > 0 ? currentToolResults : null,
|
||||
tools: Object.keys(toolsContext).length > 0 ? toolsContext.tools : null
|
||||
}
|
||||
};
|
||||
} else if (currentMessage.role === 'assistant') {
|
||||
request.conversationState.currentMessage.assistantResponseMessage = {
|
||||
content: currentContent,
|
||||
toolUses: currentToolUses.length > 0 ? currentToolUses : undefined
|
||||
};
|
||||
// currentMessage 始终是 userInputMessage 类型
|
||||
// 注意:API 不接受 null 值,空字段应该完全不包含
|
||||
const userInputMessage = {
|
||||
content: currentContent,
|
||||
modelId: codewhispererModel,
|
||||
origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR
|
||||
};
|
||||
|
||||
// 只有当 images 非空时才添加
|
||||
if (currentImages && currentImages.length > 0) {
|
||||
userInputMessage.images = currentImages;
|
||||
}
|
||||
|
||||
// 构建 userInputMessageContext,只包含非空字段
|
||||
const userInputMessageContext = {};
|
||||
if (currentToolResults.length > 0) {
|
||||
userInputMessageContext.toolResults = currentToolResults;
|
||||
}
|
||||
if (Object.keys(toolsContext).length > 0 && toolsContext.tools) {
|
||||
userInputMessageContext.tools = toolsContext.tools;
|
||||
}
|
||||
|
||||
// 只有当 userInputMessageContext 有内容时才添加
|
||||
if (Object.keys(userInputMessageContext).length > 0) {
|
||||
userInputMessage.userInputMessageContext = userInputMessageContext;
|
||||
}
|
||||
|
||||
request.conversationState.currentMessage.userInputMessage = userInputMessage;
|
||||
|
||||
if (this.authMethod === KIRO_CONSTANTS.AUTH_METHOD_SOCIAL) {
|
||||
request.profileArn = this.profileArn;
|
||||
}
|
||||
|
|
@ -952,18 +1012,146 @@ async initializeAuth(forceRefresh = false) {
|
|||
}
|
||||
}
|
||||
|
||||
//kiro提供的接口没有流式返回
|
||||
async streamApi(method, model, body, isRetry = false, retryCount = 0) {
|
||||
/**
|
||||
* 解析 AWS Event Stream 格式,提取所有完整的 JSON 事件
|
||||
* 返回 { events: 解析出的事件数组, remaining: 未处理完的缓冲区 }
|
||||
*/
|
||||
parseAwsEventStreamBuffer(buffer) {
|
||||
const events = [];
|
||||
let remaining = buffer;
|
||||
let searchStart = 0;
|
||||
|
||||
while (true) {
|
||||
// 查找 {"content": 或 {"toolUse" 的起始位置
|
||||
const contentStart = remaining.indexOf('{"content":', searchStart);
|
||||
const toolUseStart = remaining.indexOf('{"toolUse":', searchStart);
|
||||
|
||||
let jsonStart = -1;
|
||||
if (contentStart >= 0 && toolUseStart >= 0) {
|
||||
jsonStart = Math.min(contentStart, toolUseStart);
|
||||
} else if (contentStart >= 0) {
|
||||
jsonStart = contentStart;
|
||||
} else if (toolUseStart >= 0) {
|
||||
jsonStart = toolUseStart;
|
||||
}
|
||||
|
||||
if (jsonStart < 0) break;
|
||||
|
||||
// 查找对应的 } 结束位置
|
||||
const jsonEnd = remaining.indexOf('}', jsonStart);
|
||||
if (jsonEnd < 0) {
|
||||
// 不完整的 JSON,保留在缓冲区
|
||||
remaining = remaining.substring(jsonStart);
|
||||
break;
|
||||
}
|
||||
|
||||
const jsonStr = remaining.substring(jsonStart, jsonEnd + 1);
|
||||
try {
|
||||
const parsed = JSON.parse(jsonStr);
|
||||
if (parsed.content !== undefined) {
|
||||
events.push({ type: 'content', data: parsed.content });
|
||||
} else if (parsed.toolUse !== undefined) {
|
||||
events.push({ type: 'toolUse', data: parsed.toolUse });
|
||||
}
|
||||
} catch (e) {
|
||||
// JSON 解析失败,可能是不完整的,继续搜索
|
||||
}
|
||||
|
||||
searchStart = jsonEnd + 1;
|
||||
if (searchStart >= remaining.length) {
|
||||
remaining = '';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 searchStart 有进展,截取剩余部分
|
||||
if (searchStart > 0 && remaining.length > 0) {
|
||||
remaining = remaining.substring(searchStart);
|
||||
}
|
||||
|
||||
return { events, remaining };
|
||||
}
|
||||
|
||||
/**
|
||||
* 真正的流式 API 调用 - 使用 responseType: 'stream'
|
||||
*/
|
||||
async * streamApiReal(method, model, body, isRetry = false, retryCount = 0) {
|
||||
if (!this.isInitialized) await this.initialize();
|
||||
const maxRetries = this.config.REQUEST_MAX_RETRIES || 3;
|
||||
const baseDelay = this.config.REQUEST_BASE_DELAY || 1000;
|
||||
|
||||
const requestData = this.buildCodewhispererRequest(body.messages, model, body.tools, body.system);
|
||||
|
||||
const token = this.accessToken;
|
||||
const headers = {
|
||||
'Authorization': `Bearer ${token}`,
|
||||
'amz-sdk-invocation-id': `${uuidv4()}`,
|
||||
};
|
||||
|
||||
const requestUrl = model.startsWith('amazonq') ? this.amazonQUrl : this.baseUrl;
|
||||
|
||||
try {
|
||||
// 直接调用并返回Promise,最终解析为response
|
||||
return await this.callApi(method, model, body, isRetry, retryCount);
|
||||
const response = await this.axiosInstance.post(requestUrl, requestData, {
|
||||
headers,
|
||||
responseType: 'stream'
|
||||
});
|
||||
|
||||
const stream = response.data;
|
||||
let buffer = '';
|
||||
const processedPositions = new Set(); // 避免重复处理
|
||||
|
||||
for await (const chunk of stream) {
|
||||
buffer += chunk.toString();
|
||||
|
||||
// 解析缓冲区中的事件
|
||||
const { events, remaining } = this.parseAwsEventStreamBuffer(buffer);
|
||||
buffer = remaining;
|
||||
|
||||
// 只 yield 新的事件
|
||||
for (const event of events) {
|
||||
const eventKey = `${event.type}:${event.data}`;
|
||||
if (!processedPositions.has(eventKey)) {
|
||||
processedPositions.add(eventKey);
|
||||
if (event.type === 'content' && event.data) {
|
||||
yield { type: 'content', content: event.data };
|
||||
} else if (event.type === 'toolUse') {
|
||||
yield { type: 'toolUse', toolUse: event.data };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[Kiro] Error calling API:', error);
|
||||
throw error; // 向上抛出错误
|
||||
if (error.response?.status === 403 && !isRetry) {
|
||||
console.log('[Kiro] Received 403 in stream. Attempting token refresh and retrying...');
|
||||
await this.initializeAuth(true);
|
||||
yield* this.streamApiReal(method, model, body, true, retryCount);
|
||||
return;
|
||||
}
|
||||
|
||||
if (error.response?.status === 429 && retryCount < maxRetries) {
|
||||
const delay = baseDelay * Math.pow(2, retryCount);
|
||||
console.log(`[Kiro] Received 429 in stream. Retrying in ${delay}ms...`);
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
yield* this.streamApiReal(method, model, body, isRetry, retryCount + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
console.error('[Kiro] Stream API call failed:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 重构2: generateContentStream 调用新的普通async函数
|
||||
// 保留旧的非流式方法用于 generateContent
|
||||
async streamApi(method, model, body, isRetry = false, retryCount = 0) {
|
||||
try {
|
||||
return await this.callApi(method, model, body, isRetry, retryCount);
|
||||
} catch (error) {
|
||||
console.error('[Kiro] Error calling API:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 真正的流式传输实现
|
||||
async * generateContentStream(model, requestBody) {
|
||||
if (!this.isInitialized) await this.initialize();
|
||||
|
||||
|
|
@ -974,28 +1162,98 @@ async initializeAuth(forceRefresh = false) {
|
|||
}
|
||||
|
||||
const finalModel = MODEL_MAPPING[model] ? model : this.modelName;
|
||||
console.log(`[Kiro] Calling generateContentStream with model: ${finalModel}`);
|
||||
console.log(`[Kiro] Calling generateContentStream with model: ${finalModel} (real streaming)`);
|
||||
|
||||
// Estimate input tokens before making the API call
|
||||
const inputTokens = this.estimateInputTokens(requestBody);
|
||||
const messageId = `${uuidv4()}`;
|
||||
|
||||
try {
|
||||
const response = await this.streamApi('', finalModel, requestBody);
|
||||
const { responseText, toolCalls } = this._processApiResponse(response);
|
||||
// 1. 先发送 message_start 事件
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
id: messageId,
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: model,
|
||||
usage: { input_tokens: inputTokens, output_tokens: 0 },
|
||||
content: []
|
||||
}
|
||||
};
|
||||
|
||||
// Pass both responseText and toolCalls to buildClaudeResponse
|
||||
// buildClaudeResponse will handle the logic of combining them into a single stream
|
||||
for (const chunkJson of this.buildClaudeResponse(responseText, true, 'assistant', model, toolCalls, inputTokens)) {
|
||||
yield chunkJson;
|
||||
// 2. 发送 content_block_start 事件
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: { type: "text", text: "" }
|
||||
};
|
||||
|
||||
let totalContent = '';
|
||||
let outputTokens = 0;
|
||||
const toolCalls = [];
|
||||
|
||||
// 3. 流式接收并发送每个 content_block_delta
|
||||
for await (const event of this.streamApiReal('', finalModel, requestBody)) {
|
||||
if (event.type === 'content' && event.content) {
|
||||
totalContent += event.content;
|
||||
outputTokens += this.countTextTokens(event.content);
|
||||
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: { type: "text_delta", text: event.content }
|
||||
};
|
||||
} else if (event.type === 'toolUse') {
|
||||
toolCalls.push(event.toolUse);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 发送 content_block_stop 事件
|
||||
yield { type: "content_block_stop", index: 0 };
|
||||
|
||||
// 5. 处理工具调用(如果有)
|
||||
if (toolCalls.length > 0) {
|
||||
for (let i = 0; i < toolCalls.length; i++) {
|
||||
const tc = toolCalls[i];
|
||||
const blockIndex = i + 1;
|
||||
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
index: blockIndex,
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: tc.toolUseId || `tool_${uuidv4()}`,
|
||||
name: tc.name,
|
||||
input: {}
|
||||
}
|
||||
};
|
||||
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
index: blockIndex,
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: JSON.stringify(tc.input || {})
|
||||
}
|
||||
};
|
||||
|
||||
yield { type: "content_block_stop", index: blockIndex };
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 发送 message_delta 事件
|
||||
yield {
|
||||
type: "message_delta",
|
||||
delta: { stop_reason: toolCalls.length > 0 ? "tool_use" : "end_turn" },
|
||||
usage: { output_tokens: outputTokens }
|
||||
};
|
||||
|
||||
// 7. 发送 message_stop 事件
|
||||
yield { type: "message_stop" };
|
||||
|
||||
} catch (error) {
|
||||
console.error('[Kiro] Error in streaming generation:', error);
|
||||
throw new Error(`Error processing response: ${error.message}`);
|
||||
// For Claude, we yield an array of events for streaming error
|
||||
// Ensure error message is passed as content, not toolCalls
|
||||
// for (const chunkJson of this.buildClaudeResponse(`Error: ${error.message}`, true, 'assistant', model, null)) {
|
||||
// yield chunkJson;
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue