From de3f46149f97293434c55595d710cf444f380b9e Mon Sep 17 00:00:00 2001 From: hex2077 Date: Wed, 1 Apr 2026 23:07:02 +0800 Subject: [PATCH] =?UTF-8?q?feat(grok):=20=E6=B7=BB=E5=8A=A0WebSocket?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E7=94=9F=E6=88=90=E6=94=AF=E6=8C=81=E4=B8=8E?= =?UTF-8?q?=E5=A4=9A=E5=9B=BE=E5=B9=B6=E5=8F=91=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 WebSocket 图片生成服务类,支持流式生成图片 - 在常规 API 失败时自动回退到 WebSocket 方式生成图片 - 支持单次生成超过2张图片时自动拆分为并发请求 - 改进图片生成参数处理,支持返回 base64 格式图片 - 更新版本号至 2.12.2.2 --- VERSION | 2 +- src/providers/grok/grok-core.js | 213 ++++++++++++++++++++++++++++++- src/providers/grok/ws-imagine.js | 114 +++++++++++++++++ 3 files changed, 321 insertions(+), 8 deletions(-) create mode 100644 src/providers/grok/ws-imagine.js diff --git a/VERSION b/VERSION index 500b629..d463c35 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.12.2.1 +2.12.2.2 diff --git a/src/providers/grok/grok-core.js b/src/providers/grok/grok-core.js index 815e07b..7fb962d 100644 --- a/src/providers/grok/grok-core.js +++ b/src/providers/grok/grok-core.js @@ -10,6 +10,7 @@ import { MODEL_PROVIDER } from '../../utils/common.js'; import { ConverterFactory } from '../../converters/ConverterFactory.js'; import * as readline from 'readline'; import { getProviderPoolManager } from '../../services/service-manager.js'; +import { ImagineWebSocketService } from './ws-imagine.js'; // Chrome 136 TLS cipher suites const CHROME_CIPHERS = [ @@ -416,28 +417,35 @@ export class GrokApiService { const isMediaModel = modelLower.includes('imagine') || modelLower.includes('video') || modelLower.includes('edit'); const isNsfw = isGrokNsfwModel(rawModelId) || requestBody.nsfw === true || requestBody.disableNsfwFilter === true; + // 处理生成图片数量,API 通常限制单次最多 2 张 + const imageGenerationCount = Math.min(parseInt(requestBody.n || requestBody.imageGenerationCount || (isMediaModel ? 2 : 0)), 2); + + // 处理响应格式 + const returnImageBytes = requestBody.response_format === 'b64_json' || requestBody.responseFormat === 'b64_json'; + const payload = { "deviceEnvInfo": { "darkModeEnabled": false, "devicePixelRatio": 2, "screenWidth": 2056, "screenHeight": 1329, "viewportWidth": 2056, "viewportHeight": 1083 }, "disableMemory": false, "disableNsfwFilter": isNsfw, "disableSearch": false, "disableSelfHarmShortCircuit": false, "disableTextFollowUps": false, "enableImageGeneration": isMediaModel, "enableImageStreaming": isMediaModel, "enableSideBySide": true, - "fileAttachments": fileAttachments, "forceConcise": false, "forceSideBySide": false, "imageAttachments": [], "imageGenerationCount": 2, + "fileAttachments": fileAttachments, "forceConcise": false, "forceSideBySide": false, "imageAttachments": [], + "imageGenerationCount": imageGenerationCount, "isAsyncChat": false, "isReasoning": false, "message": message, "modelMode": mapping.mode, "modelName": mapping.name, "responseMetadata": { "requestModelDetails": { "modelId": mapping.name }, "modelConfigOverride": modelConfigOverride }, - "returnImageBytes": false, "returnRawGrokInXaiRequest": false, "sendFinalMetadata": true, "temporary": true, "toolOverrides": toolOverrides, + "returnImageBytes": returnImageBytes, "returnRawGrokInXaiRequest": false, "sendFinalMetadata": true, "temporary": true, "toolOverrides": toolOverrides, }; if (isMediaModel && !modelLower.includes('video')) { payload.enable_nsfw = isNsfw; - if (requestBody.aspect_ratio || requestBody.aspectRatio) { - payload.aspect_ratio = requestBody.aspect_ratio || requestBody.aspectRatio; + const aspectRatio = requestBody.aspect_ratio || requestBody.aspectRatio; + if (aspectRatio) { + payload.aspect_ratio = aspectRatio; } } return payload; } - async generateContent(model, requestBody) { - logger.info(`[Grok] Starting generateContent (unified processing)`); + async _generateAndCollect(model, requestBody) { const stream = this.generateContentStream(model, requestBody); const collected = { message: "", @@ -510,9 +518,67 @@ export class GrokApiService { } } } + return collected; + } + + async generateContent(model, requestBody) { + logger.info(`[Grok] Starting generateContent (unified processing)`); + + const n = parseInt(requestBody.n || 1); + const isImagine = model.toLowerCase().includes('imagine'); + + let collected; + try { + if (n <= 2 || !isImagine) { + // 单次请求处理 + collected = await this._generateAndCollect(model, requestBody); + } else { + // 处理 n > 2 的情况,分批并发请求 + logger.info(`[Grok] Multi-image request detected (n=${n}), splitting into multiple tasks`); + const perCall = 2; + const callsNeeded = Math.ceil(n / perCall); + const tasks = []; + + for (let i = 0; i < callsNeeded; i++) { + const count = Math.min(perCall, n - i * perCall); + const subRequestBody = { ...requestBody, n: count }; + tasks.push(this._generateAndCollect(model, subRequestBody)); + } + + const results = await Promise.all(tasks); + + // 合并所有批次的结果 + collected = results[0]; + for (let i = 1; i < results.length; i++) { + const res = results[i]; + // 合并消息文本 + if (res.message) collected.message += "\n" + res.message; + // 合并卡片附件 + if (res.cardAttachments) collected.cardAttachments.push(...res.cardAttachments); + // 合并 modelResponse 中的卡片 JSON + if (res.modelResponse?.cardAttachmentsJson) { + if (!collected.modelResponse) collected.modelResponse = { cardAttachmentsJson: [] }; + if (!collected.modelResponse.cardAttachmentsJson) collected.modelResponse.cardAttachmentsJson = []; + collected.modelResponse.cardAttachmentsJson.push(...res.modelResponse.cardAttachmentsJson); + } + } + } + } catch (error) { + // 只有图片生成才支持 WebSocket Fallback + if (isImagine) { + logger.warn(`[Grok] app_chat image generation failed, trying ws_imagine fallback: ${error.message}`); + try { + return await this._generateAndCollectWS(model, requestBody); + } catch (wsError) { + logger.error(`[Grok] ws_imagine fallback also failed: ${wsError.message}`); + throw error; // 抛出原始错误 + } + } + throw error; + } logger.info(`[Grok] Finalizing collection. model: ${model}, respId: ${collected.responseId}, videoPostId: ${collected.postId}`); - + // 1. 仅针对视频进行 postId 提取和分享链接创建 const isVideo = !!(collected.finalVideoUrl || collected.streamingVideoGenerationResponse || model.toLowerCase().includes('video')); logger.info(`[Grok Decision] isVideo detected: ${isVideo}. (finalUrl: ${!!collected.finalVideoUrl}, streamResp: ${!!collected.streamingVideoGenerationResponse}, modelIncludeVideo: ${model.toLowerCase().includes('video')})`); @@ -551,6 +617,124 @@ export class GrokApiService { return collected; } + /** + * WebSocket 方式生成图片 (Fallback) + */ + async _generateAndCollectWS(model, requestBody) { + const n = parseInt(requestBody.n || 1); + // 提取 prompt + let prompt = requestBody.message || requestBody.videoGenPrompt; + if (!prompt && requestBody.messages?.length > 0) { + const lastMsg = requestBody.messages[requestBody.messages.length - 1]; + prompt = typeof lastMsg.content === 'string' ? lastMsg.content : (lastMsg.content?.find(p => p.type === 'text')?.text || ""); + } + prompt = prompt || "A beautiful image"; + + const aspectRatio = requestBody.aspect_ratio || requestBody.aspectRatio || "1:1"; + const enableNsfw = requestBody.nsfw !== false; + + logger.info(`[Grok WS] Starting fallback image generation for: ${prompt.substring(0, 50)}...`); + + const wsService = new ImagineWebSocketService(this.config); + const stream = wsService.stream(this.token, prompt, aspectRatio, n, enableNsfw); + + const collected = { + message: "", + responseId: `ws-${uuidv4()}`, + postId: "", + llmInfo: { modelHash: "ws-imagine" }, + rolloutId: "", + modelResponse: { cardAttachmentsJson: [] }, + cardAttachments: [] + }; + + for await (const item of stream) { + if (item.type === 'error') { + throw new Error(item.error || 'WebSocket generation failed'); + } + if (item.type === 'image' && item.stage === 'final') { + const cardData = { + id: item.image_id || uuidv4(), + image: { + original: item.blob.startsWith('data:') ? item.blob : `data:image/png;base64,${item.blob}`, + title: "Generated Image" + } + }; + const jsonStr = JSON.stringify(cardData); + collected.modelResponse.cardAttachmentsJson.push(jsonStr); + collected.cardAttachments.push({ jsonData: jsonStr }); + logger.info(`[Grok WS] Received image: ${cardData.id}`); + } + } + + if (collected.cardAttachments.length === 0) { + throw new Error("WebSocket generation returned no images"); + } + + return collected; + } + + /** + * WebSocket 方式流式生成图片 (Fallback) + */ + async * _generateContentStreamWS(model, requestBody) { + const n = parseInt(requestBody.n || 1); + let prompt = requestBody.message || requestBody.videoGenPrompt; + if (!prompt && requestBody.messages?.length > 0) { + const lastMsg = requestBody.messages[requestBody.messages.length - 1]; + prompt = typeof lastMsg.content === 'string' ? lastMsg.content : (lastMsg.content?.find(p => p.type === 'text')?.text || ""); + } + prompt = prompt || "A beautiful image"; + + const aspectRatio = requestBody.aspect_ratio || requestBody.aspectRatio || "1:1"; + const enableNsfw = requestBody.nsfw !== false; + + const wsService = new ImagineWebSocketService(this.config); + const stream = wsService.stream(this.token, prompt, aspectRatio, n, enableNsfw); + + const responseId = `ws-${uuidv4()}`; + + for await (const item of stream) { + if (item.type === 'error') { + throw new Error(item.error || 'WebSocket generation failed'); + } + if (item.type === 'image') { + yield { + result: { + response: { + responseId, + streamingImageGenerationResponse: { + imageIndex: 0, + progress: item.stage === 'final' ? 100 : (item.stage === 'medium' ? 50 : 10) + } + } + } + }; + + if (item.stage === 'final') { + const cardData = { + id: item.image_id || uuidv4(), + image: { + original: item.blob.startsWith('data:') ? item.blob : `data:image/png;base64,${item.blob}`, + title: "Generated Image" + } + }; + yield { + result: { + response: { + responseId, + cardAttachment: { + jsonData: JSON.stringify(cardData) + } + } + } + }; + } + } + } + yield { result: { response: { isDone: true, responseId } } }; + } + async uploadFile(fileInput) { let b64 = "", mime = "application/octet-stream"; if (fileInput.startsWith("data:")) { @@ -677,6 +861,9 @@ export class GrokApiService { resp._requestBaseUrl = reqBaseUrl; resp._uuid = this.uuid; if (resp.responseId) lastResponseId = resp.responseId; + if (resp.streamingImageGenerationResponse) { + // 图片生成进度通过流透传,暂无额外处理 + } if (resp.streamingVideoGenerationResponse) { const vid = resp.streamingVideoGenerationResponse; if (vid.progress === 100 && vid.videoUrl && (requestBody.videoGenModelConfig?.resolutionName === "720p")) { @@ -694,6 +881,18 @@ export class GrokApiService { const { status, errorCode, errorMessage, isNetworkError } = this.classifyApiError(error); const canRetryInRequest = !hasYieldedData && retryCount < maxRetries; + // 只有图片生成且未发送过数据时才尝试 WebSocket Fallback + const isImagine = modelLower.includes('imagine'); + if (isImagine && !hasYieldedData && retryCount === 0) { + logger.warn(`[Grok] app_chat stream failed, trying ws_imagine fallback: ${error.message}`); + try { + yield* this._generateContentStreamWS(model, requestBody); + return; + } catch (wsError) { + logger.error(`[Grok] ws_imagine fallback also failed: ${wsError.message}`); + } + } + if (status === 429 && canRetryInRequest) { const delay = baseDelay * Math.pow(2, retryCount); logger.info(`[Grok API] Received 429 during stream. Retrying in ${delay}ms... (attempt ${retryCount + 1}/${maxRetries})`); diff --git a/src/providers/grok/ws-imagine.js b/src/providers/grok/ws-imagine.js new file mode 100644 index 0000000..d113c3f --- /dev/null +++ b/src/providers/grok/ws-imagine.js @@ -0,0 +1,114 @@ +import WebSocket from 'ws'; +import logger from '../../utils/logger.js'; +import { getProxyConfigForProvider } from '../../utils/proxy-utils.js'; +import { MODEL_PROVIDER } from '../../utils/common.js'; + +/** + * Grok WebSocket Imagine Service + * Handles image generation via Grok's WebSocket endpoint. + */ +export class ImagineWebSocketService { + constructor(config) { + this.config = config; + this.baseUrl = (config.GROK_BASE_URL || 'https://grok.com').replace(/\/$/, ''); + this.wsUrl = this.baseUrl.replace(/^http/, 'ws') + '/rpc/imagine/streaming'; + } + + /** + * Start an image generation stream via WebSocket. + * + * @param {string} token - SSO token + * @param {string} prompt - Image prompt + * @param {string} aspectRatio - Aspect ratio (e.g. "1:1") + * @param {number} n - Number of images + * @param {boolean} enableNsfw - Enable NSFW filter + * @returns {AsyncGenerator} + */ + async *stream(token, prompt, aspectRatio = '1:1', n = 1, enableNsfw = true) { + const proxyConfig = getProxyConfigForProvider(this.config, MODEL_PROVIDER.GROK_CUSTOM); + const agent = proxyConfig?.httpsAgent; + + let ssoToken = token || ""; + if (ssoToken.startsWith("sso=")) ssoToken = ssoToken.substring(4); + const cookie = ssoToken ? `sso=${ssoToken}; sso-rw=${ssoToken}` : ""; + + const headers = { + 'Cookie': cookie, + 'Origin': this.baseUrl, + 'User-Agent': this.config.GROK_USER_AGENT || 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/143.0.0.0 Safari/537.36', + }; + + logger.debug(`[Grok WS] Connecting to ${this.wsUrl} for prompt: ${prompt.substring(0, 50)}...`); + + const ws = new WebSocket(this.wsUrl, { + headers, + agent, + handshakeTimeout: 15000, + rejectUnauthorized: false + }); + + const queue = []; + let done = false; + let resolveNext = null; + + ws.on('open', () => { + logger.debug(`[Grok WS] Connected. Sending imagine request.`); + ws.send(JSON.stringify({ + method: 'imagine', + params: { + prompt, + aspectRatio, + count: n, + enableNsfw + } + })); + }); + + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + queue.push(msg); + if (resolveNext) { + resolveNext(); + resolveNext = null; + } + } catch (e) { + logger.error(`[Grok WS] Failed to parse message: ${data.toString().substring(0, 100)}`); + } + }); + + ws.on('close', (code, reason) => { + logger.debug(`[Grok WS] Connection closed: ${code} ${reason}`); + done = true; + if (resolveNext) { + resolveNext(); + resolveNext = null; + } + }); + + ws.on('error', (err) => { + logger.error(`[Grok WS] WebSocket error: ${err.message}`); + queue.push({ type: 'error', error: err.message }); + done = true; + if (resolveNext) { + resolveNext(); + resolveNext = null; + } + }); + + try { + while (!done || queue.length > 0) { + if (queue.length === 0 && !done) { + await new Promise(r => resolveNext = r); + } + while (queue.length > 0) { + yield queue.shift(); + } + } + } finally { + if (ws.readyState === WebSocket.OPEN || ws.readyState === WebSocket.CONNECTING) { + ws.close(); + } + } + } +}