Add OpenAI REST API and Deepgram WebSocket
This commit is contained in:
parent
c0e2600993
commit
9ac7c26a0b
2 changed files with 565 additions and 13 deletions
|
|
@ -1,13 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
|
||||||
|
|
||||||
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
from whisperlivekit import AudioProcessor, TranscriptionEngine, get_inline_ui_html, parse_args
|
||||||
get_inline_ui_html, parse_args)
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
|
@ -37,11 +37,26 @@ async def get():
|
||||||
return HTMLResponse(get_inline_ui_html())
|
return HTMLResponse(get_inline_ui_html())
|
||||||
|
|
||||||
|
|
||||||
async def handle_websocket_results(websocket, results_generator):
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
global transcription_engine
|
||||||
|
backend = getattr(transcription_engine.config, "backend", "whisper") if transcription_engine else None
|
||||||
|
return JSONResponse({
|
||||||
|
"status": "ok",
|
||||||
|
"backend": backend,
|
||||||
|
"ready": transcription_engine is not None,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_websocket_results(websocket, results_generator, diff_tracker=None):
|
||||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||||
try:
|
try:
|
||||||
async for response in results_generator:
|
async for response in results_generator:
|
||||||
await websocket.send_json(response.to_dict())
|
if diff_tracker is not None:
|
||||||
|
await websocket.send_json(diff_tracker.to_message(response))
|
||||||
|
else:
|
||||||
|
await websocket.send_json(response.to_dict())
|
||||||
# when the results_generator finishes it means all audio has been processed
|
# when the results_generator finishes it means all audio has been processed
|
||||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||||
await websocket.send_json({"type": "ready_to_stop"})
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
|
|
@ -54,19 +69,33 @@ async def handle_websocket_results(websocket, results_generator):
|
||||||
@app.websocket("/asr")
|
@app.websocket("/asr")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
global transcription_engine
|
global transcription_engine
|
||||||
|
|
||||||
|
# Read per-session options from query parameters
|
||||||
|
session_language = websocket.query_params.get("language", None)
|
||||||
|
mode = websocket.query_params.get("mode", "full")
|
||||||
|
|
||||||
audio_processor = AudioProcessor(
|
audio_processor = AudioProcessor(
|
||||||
transcription_engine=transcription_engine,
|
transcription_engine=transcription_engine,
|
||||||
|
language=session_language,
|
||||||
)
|
)
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
logger.info("WebSocket connection opened.")
|
logger.info(
|
||||||
|
"WebSocket connection opened.%s",
|
||||||
|
f" language={session_language}" if session_language else "",
|
||||||
|
)
|
||||||
|
diff_tracker = None
|
||||||
|
if mode == "diff":
|
||||||
|
from whisperlivekit.diff_protocol import DiffTracker
|
||||||
|
diff_tracker = DiffTracker()
|
||||||
|
logger.info("Client requested diff mode")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input)})
|
await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input), "mode": mode})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to send config to client: {e}")
|
logger.warning(f"Failed to send config to client: {e}")
|
||||||
|
|
||||||
results_generator = await audio_processor.create_tasks()
|
results_generator = await audio_processor.create_tasks()
|
||||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator, diff_tracker))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -74,7 +103,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
await audio_processor.process_audio(message)
|
await audio_processor.process_audio(message)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
if 'bytes' in str(e):
|
if 'bytes' in str(e):
|
||||||
logger.warning(f"Client has closed the connection.")
|
logger.warning("Client has closed the connection.")
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
|
|
@ -91,14 +120,227 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
logger.info("WebSocket results handler task was cancelled.")
|
logger.info("WebSocket results handler task was cancelled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
||||||
|
|
||||||
await audio_processor.cleanup()
|
await audio_processor.cleanup()
|
||||||
logger.info("WebSocket endpoint cleaned up successfully.")
|
logger.info("WebSocket endpoint cleaned up successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deepgram-compatible WebSocket API (/v1/listen)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@app.websocket("/v1/listen")
|
||||||
|
async def deepgram_websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""Deepgram-compatible live transcription WebSocket."""
|
||||||
|
global transcription_engine
|
||||||
|
from whisperlivekit.deepgram_compat import handle_deepgram_websocket
|
||||||
|
await handle_deepgram_websocket(websocket, transcription_engine, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenAI-compatible REST API (/v1/audio/transcriptions)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _convert_to_pcm(audio_bytes: bytes) -> bytes:
|
||||||
|
"""Convert any audio format to PCM s16le mono 16kHz using ffmpeg."""
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"ffmpeg", "-i", "pipe:0",
|
||||||
|
"-f", "s16le", "-acodec", "pcm_s16le",
|
||||||
|
"-ar", "16000", "-ac", "1",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"pipe:1",
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await proc.communicate(input=audio_bytes)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=400, detail=f"Audio conversion failed: {stderr.decode().strip()}")
|
||||||
|
return stdout
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_time_str(time_str: str) -> float:
|
||||||
|
"""Parse 'H:MM:SS.cc' to seconds."""
|
||||||
|
parts = time_str.split(":")
|
||||||
|
if len(parts) == 3:
|
||||||
|
return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
|
||||||
|
if len(parts) == 2:
|
||||||
|
return int(parts[0]) * 60 + float(parts[1])
|
||||||
|
return float(parts[0])
|
||||||
|
|
||||||
|
|
||||||
|
def _format_openai_response(front_data, response_format: str, language: Optional[str], duration: float) -> dict:
|
||||||
|
"""Convert FrontData to OpenAI-compatible response."""
|
||||||
|
d = front_data.to_dict()
|
||||||
|
lines = d.get("lines", [])
|
||||||
|
|
||||||
|
# Combine all speech text (exclude silence segments)
|
||||||
|
text_parts = [l["text"] for l in lines if l.get("text") and l.get("speaker", 0) != -2]
|
||||||
|
full_text = " ".join(text_parts).strip()
|
||||||
|
|
||||||
|
if response_format == "text":
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
# Build segments and words for verbose_json
|
||||||
|
segments = []
|
||||||
|
words = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.get("speaker") == -2 or not line.get("text"):
|
||||||
|
continue
|
||||||
|
start = _parse_time_str(line.get("start", "0:00:00"))
|
||||||
|
end = _parse_time_str(line.get("end", "0:00:00"))
|
||||||
|
segments.append({
|
||||||
|
"id": len(segments),
|
||||||
|
"start": round(start, 2),
|
||||||
|
"end": round(end, 2),
|
||||||
|
"text": line["text"],
|
||||||
|
})
|
||||||
|
# Split segment text into approximate words with estimated timestamps
|
||||||
|
seg_words = line["text"].split()
|
||||||
|
if seg_words:
|
||||||
|
word_duration = (end - start) / max(len(seg_words), 1)
|
||||||
|
for j, word in enumerate(seg_words):
|
||||||
|
words.append({
|
||||||
|
"word": word,
|
||||||
|
"start": round(start + j * word_duration, 2),
|
||||||
|
"end": round(start + (j + 1) * word_duration, 2),
|
||||||
|
})
|
||||||
|
|
||||||
|
if response_format == "verbose_json":
|
||||||
|
return {
|
||||||
|
"task": "transcribe",
|
||||||
|
"language": language or "unknown",
|
||||||
|
"duration": round(duration, 2),
|
||||||
|
"text": full_text,
|
||||||
|
"words": words,
|
||||||
|
"segments": segments,
|
||||||
|
}
|
||||||
|
|
||||||
|
if response_format in ("srt", "vtt"):
|
||||||
|
lines_out = []
|
||||||
|
if response_format == "vtt":
|
||||||
|
lines_out.append("WEBVTT\n")
|
||||||
|
for i, seg in enumerate(segments):
|
||||||
|
start_ts = _srt_timestamp(seg["start"], response_format)
|
||||||
|
end_ts = _srt_timestamp(seg["end"], response_format)
|
||||||
|
if response_format == "srt":
|
||||||
|
lines_out.append(f"{i + 1}")
|
||||||
|
lines_out.append(f"{start_ts} --> {end_ts}")
|
||||||
|
lines_out.append(seg["text"])
|
||||||
|
lines_out.append("")
|
||||||
|
return "\n".join(lines_out)
|
||||||
|
|
||||||
|
# Default: json
|
||||||
|
return {"text": full_text}
|
||||||
|
|
||||||
|
|
||||||
|
def _srt_timestamp(seconds: float, fmt: str) -> str:
|
||||||
|
"""Format seconds as SRT (HH:MM:SS,mmm) or VTT (HH:MM:SS.mmm) timestamp."""
|
||||||
|
h = int(seconds // 3600)
|
||||||
|
m = int((seconds % 3600) // 60)
|
||||||
|
s = int(seconds % 60)
|
||||||
|
ms = int(round((seconds % 1) * 1000))
|
||||||
|
sep = "," if fmt == "srt" else "."
|
||||||
|
return f"{h:02d}:{m:02d}:{s:02d}{sep}{ms:03d}"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions")
|
||||||
|
async def create_transcription(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
model: str = Form(default=""),
|
||||||
|
language: Optional[str] = Form(default=None),
|
||||||
|
prompt: str = Form(default=""),
|
||||||
|
response_format: str = Form(default="json"),
|
||||||
|
timestamp_granularities: Optional[List[str]] = Form(default=None),
|
||||||
|
):
|
||||||
|
"""OpenAI-compatible audio transcription endpoint.
|
||||||
|
|
||||||
|
Accepts the same parameters as OpenAI's /v1/audio/transcriptions API.
|
||||||
|
The `model` parameter is accepted but ignored (uses the server's configured backend).
|
||||||
|
"""
|
||||||
|
global transcription_engine
|
||||||
|
|
||||||
|
audio_bytes = await file.read()
|
||||||
|
if not audio_bytes:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=400, detail="Empty audio file")
|
||||||
|
|
||||||
|
# Convert to PCM for pipeline processing
|
||||||
|
pcm_data = await _convert_to_pcm(audio_bytes)
|
||||||
|
duration = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit
|
||||||
|
|
||||||
|
# Process through the full pipeline
|
||||||
|
processor = AudioProcessor(
|
||||||
|
transcription_engine=transcription_engine,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
# Force PCM input regardless of server config
|
||||||
|
processor.is_pcm_input = True
|
||||||
|
|
||||||
|
results_gen = await processor.create_tasks()
|
||||||
|
|
||||||
|
# Collect results in background while feeding audio
|
||||||
|
final_result = None
|
||||||
|
|
||||||
|
async def collect():
|
||||||
|
nonlocal final_result
|
||||||
|
async for result in results_gen:
|
||||||
|
final_result = result
|
||||||
|
|
||||||
|
collect_task = asyncio.create_task(collect())
|
||||||
|
|
||||||
|
# Feed audio in chunks (1 second each)
|
||||||
|
chunk_size = 16000 * 2 # 1 second of PCM
|
||||||
|
for i in range(0, len(pcm_data), chunk_size):
|
||||||
|
await processor.process_audio(pcm_data[i:i + chunk_size])
|
||||||
|
|
||||||
|
# Signal end of audio
|
||||||
|
await processor.process_audio(b"")
|
||||||
|
|
||||||
|
# Wait for pipeline to finish
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(collect_task, timeout=120.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Transcription timed out after 120s")
|
||||||
|
finally:
|
||||||
|
await processor.cleanup()
|
||||||
|
|
||||||
|
if final_result is None:
|
||||||
|
return JSONResponse({"text": ""})
|
||||||
|
|
||||||
|
result = _format_openai_response(final_result, response_format, language, duration)
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
return PlainTextResponse(result)
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def list_models():
|
||||||
|
"""OpenAI-compatible model listing endpoint."""
|
||||||
|
global transcription_engine
|
||||||
|
backend = getattr(transcription_engine.config, "backend", "whisper") if transcription_engine else "whisper"
|
||||||
|
model_size = getattr(transcription_engine.config, "model_size", "base") if transcription_engine else "base"
|
||||||
|
return JSONResponse({
|
||||||
|
"object": "list",
|
||||||
|
"data": [{
|
||||||
|
"id": f"{backend}/{model_size}" if backend != "whisper" else f"whisper-{model_size}",
|
||||||
|
"object": "model",
|
||||||
|
"owned_by": "whisperlivekit",
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Entry point for the CLI command."""
|
"""Entry point for the CLI command."""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from whisperlivekit.cli import print_banner
|
||||||
|
|
||||||
|
ssl = bool(config.ssl_certfile and config.ssl_keyfile)
|
||||||
|
print_banner(config, config.host, config.port, ssl=ssl)
|
||||||
|
|
||||||
uvicorn_kwargs = {
|
uvicorn_kwargs = {
|
||||||
"app": "whisperlivekit.basic_server:app",
|
"app": "whisperlivekit.basic_server:app",
|
||||||
"host": config.host,
|
"host": config.host,
|
||||||
|
|
|
||||||
310
whisperlivekit/deepgram_compat.py
Normal file
310
whisperlivekit/deepgram_compat.py
Normal file
|
|
@ -0,0 +1,310 @@
|
||||||
|
"""Deepgram-compatible WebSocket endpoint for WhisperLiveKit.
|
||||||
|
|
||||||
|
Provides a /v1/listen endpoint that speaks the Deepgram Live Transcription
|
||||||
|
protocol, enabling drop-in compatibility with Deepgram client SDKs.
|
||||||
|
|
||||||
|
Protocol mapping:
|
||||||
|
- Client sends binary audio frames → forwarded to AudioProcessor
|
||||||
|
- Client sends JSON control messages (KeepAlive, CloseStream, Finalize)
|
||||||
|
- Server sends Results, Metadata, UtteranceEnd messages
|
||||||
|
|
||||||
|
Differences from Deepgram:
|
||||||
|
- No authentication required (self-hosted)
|
||||||
|
- Word-level timestamps approximate (interpolated from segment boundaries)
|
||||||
|
- Confidence scores not available (set to 0.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_time_str(time_str: str) -> float:
|
||||||
|
"""Parse 'H:MM:SS.cc' to seconds."""
|
||||||
|
parts = time_str.split(":")
|
||||||
|
if len(parts) == 3:
|
||||||
|
return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
|
||||||
|
if len(parts) == 2:
|
||||||
|
return int(parts[0]) * 60 + float(parts[1])
|
||||||
|
return float(parts[0])
|
||||||
|
|
||||||
|
|
||||||
|
def _line_to_words(line: dict) -> list:
|
||||||
|
"""Convert a line dict to Deepgram-style word objects.
|
||||||
|
|
||||||
|
Distributes timestamps proportionally across words since
|
||||||
|
WhisperLiveKit provides segment-level timestamps.
|
||||||
|
"""
|
||||||
|
text = line.get("text", "")
|
||||||
|
if not text or not text.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
start = _parse_time_str(line.get("start", "0:00:00"))
|
||||||
|
end = _parse_time_str(line.get("end", "0:00:00"))
|
||||||
|
speaker = line.get("speaker", 0)
|
||||||
|
if speaker == -2:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
|
||||||
|
duration = end - start
|
||||||
|
step = duration / max(len(words), 1)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"word": w,
|
||||||
|
"start": round(start + i * step, 3),
|
||||||
|
"end": round(start + (i + 1) * step, 3),
|
||||||
|
"confidence": 0.0,
|
||||||
|
"punctuated_word": w,
|
||||||
|
"speaker": speaker if speaker > 0 else 0,
|
||||||
|
}
|
||||||
|
for i, w in enumerate(words)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _lines_to_result(lines: list, is_final: bool, speech_final: bool,
|
||||||
|
start_time: float = 0.0) -> dict:
|
||||||
|
"""Convert FrontData lines to a Deepgram Results message."""
|
||||||
|
all_words = []
|
||||||
|
full_text_parts = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.get("speaker") == -2:
|
||||||
|
continue
|
||||||
|
words = _line_to_words(line)
|
||||||
|
all_words.extend(words)
|
||||||
|
text = line.get("text", "")
|
||||||
|
if text and text.strip():
|
||||||
|
full_text_parts.append(text.strip())
|
||||||
|
|
||||||
|
transcript = " ".join(full_text_parts)
|
||||||
|
|
||||||
|
# Calculate duration from word boundaries
|
||||||
|
if all_words:
|
||||||
|
seg_start = all_words[0]["start"]
|
||||||
|
seg_end = all_words[-1]["end"]
|
||||||
|
duration = seg_end - seg_start
|
||||||
|
else:
|
||||||
|
seg_start = start_time
|
||||||
|
seg_end = start_time
|
||||||
|
duration = 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "Results",
|
||||||
|
"channel_index": [0, 1],
|
||||||
|
"duration": round(duration, 3),
|
||||||
|
"start": round(seg_start, 3),
|
||||||
|
"is_final": is_final,
|
||||||
|
"speech_final": speech_final,
|
||||||
|
"channel": {
|
||||||
|
"alternatives": [
|
||||||
|
{
|
||||||
|
"transcript": transcript,
|
||||||
|
"confidence": 0.0,
|
||||||
|
"words": all_words,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DeepgramAdapter:
|
||||||
|
"""Adapts WhisperLiveKit's FrontData stream to Deepgram's protocol."""
|
||||||
|
|
||||||
|
def __init__(self, websocket: WebSocket):
|
||||||
|
self.websocket = websocket
|
||||||
|
self.request_id = str(uuid.uuid4())
|
||||||
|
self._prev_n_lines = 0
|
||||||
|
self._sent_lines = 0
|
||||||
|
self._last_word_end = 0.0
|
||||||
|
self._speech_started_sent = False
|
||||||
|
self._vad_events = False
|
||||||
|
|
||||||
|
async def send_metadata(self, config):
|
||||||
|
"""Send initial Metadata message."""
|
||||||
|
backend = getattr(config, "backend", "whisper") if config else "whisper"
|
||||||
|
msg = {
|
||||||
|
"type": "Metadata",
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"sha256": "",
|
||||||
|
"created": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"duration": 0,
|
||||||
|
"channels": 1,
|
||||||
|
"models": [backend],
|
||||||
|
"model_info": {
|
||||||
|
backend: {
|
||||||
|
"name": backend,
|
||||||
|
"version": "whisperlivekit",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await self.websocket.send_json(msg)
|
||||||
|
|
||||||
|
async def process_update(self, front_data_dict: dict):
|
||||||
|
"""Convert a FrontData dict into Deepgram messages and send them."""
|
||||||
|
lines = front_data_dict.get("lines", [])
|
||||||
|
buffer = front_data_dict.get("buffer_transcription", "")
|
||||||
|
|
||||||
|
speech_lines = [l for l in lines if l.get("speaker", 0) != -2]
|
||||||
|
n_speech = len(speech_lines)
|
||||||
|
|
||||||
|
# Detect new committed lines → emit as is_final=true results
|
||||||
|
if n_speech > self._sent_lines:
|
||||||
|
new_lines = speech_lines[self._sent_lines:]
|
||||||
|
result = _lines_to_result(new_lines, is_final=True, speech_final=True)
|
||||||
|
await self.websocket.send_json(result)
|
||||||
|
|
||||||
|
# Track last word end for UtteranceEnd
|
||||||
|
if result["channel"]["alternatives"][0]["words"]:
|
||||||
|
self._last_word_end = result["channel"]["alternatives"][0]["words"][-1]["end"]
|
||||||
|
|
||||||
|
self._sent_lines = n_speech
|
||||||
|
|
||||||
|
# Emit buffer as interim result (is_final=false)
|
||||||
|
elif buffer and buffer.strip():
|
||||||
|
# SpeechStarted event
|
||||||
|
if self._vad_events and not self._speech_started_sent:
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "SpeechStarted",
|
||||||
|
"channel_index": [0],
|
||||||
|
"timestamp": 0.0,
|
||||||
|
})
|
||||||
|
self._speech_started_sent = True
|
||||||
|
|
||||||
|
# Create interim result from buffer
|
||||||
|
interim = {
|
||||||
|
"type": "Results",
|
||||||
|
"channel_index": [0, 1],
|
||||||
|
"duration": 0.0,
|
||||||
|
"start": self._last_word_end,
|
||||||
|
"is_final": False,
|
||||||
|
"speech_final": False,
|
||||||
|
"channel": {
|
||||||
|
"alternatives": [
|
||||||
|
{
|
||||||
|
"transcript": buffer.strip(),
|
||||||
|
"confidence": 0.0,
|
||||||
|
"words": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await self.websocket.send_json(interim)
|
||||||
|
|
||||||
|
# Detect silence → emit UtteranceEnd
|
||||||
|
silence_lines = [l for l in lines if l.get("speaker") == -2]
|
||||||
|
if silence_lines and n_speech > 0:
|
||||||
|
# Check if there's new silence after our last speech
|
||||||
|
for sil in silence_lines:
|
||||||
|
sil_start = _parse_time_str(sil.get("start", "0:00:00"))
|
||||||
|
if sil_start >= self._last_word_end:
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "UtteranceEnd",
|
||||||
|
"channel": [0, 1],
|
||||||
|
"last_word_end": round(self._last_word_end, 3),
|
||||||
|
})
|
||||||
|
self._speech_started_sent = False
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_deepgram_websocket(websocket: WebSocket, transcription_engine, config):
|
||||||
|
"""Handle a Deepgram-compatible WebSocket session."""
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
# Parse Deepgram query parameters
|
||||||
|
params = websocket.query_params
|
||||||
|
language = params.get("language", None)
|
||||||
|
vad_events = params.get("vad_events", "false").lower() == "true"
|
||||||
|
|
||||||
|
audio_processor = AudioProcessor(
|
||||||
|
transcription_engine=transcription_engine,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
logger.info("Deepgram-compat WebSocket opened")
|
||||||
|
|
||||||
|
adapter = DeepgramAdapter(websocket)
|
||||||
|
adapter._vad_events = vad_events
|
||||||
|
|
||||||
|
# Send metadata
|
||||||
|
await adapter.send_metadata(config)
|
||||||
|
|
||||||
|
results_generator = await audio_processor.create_tasks()
|
||||||
|
|
||||||
|
# Results consumer
|
||||||
|
async def handle_results():
|
||||||
|
try:
|
||||||
|
async for response in results_generator:
|
||||||
|
await adapter.process_update(response.to_dict())
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Deepgram compat results error: {e}")
|
||||||
|
|
||||||
|
results_task = asyncio.create_task(handle_results())
|
||||||
|
|
||||||
|
# Audio / control message consumer
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Try to receive as text first (for control messages)
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
websocket.receive(), timeout=30.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# No data for 30s — close
|
||||||
|
break
|
||||||
|
|
||||||
|
if "bytes" in message:
|
||||||
|
data = message["bytes"]
|
||||||
|
if data:
|
||||||
|
await audio_processor.process_audio(data)
|
||||||
|
else:
|
||||||
|
# Empty bytes = end of audio
|
||||||
|
await audio_processor.process_audio(b"")
|
||||||
|
break
|
||||||
|
elif "text" in message:
|
||||||
|
try:
|
||||||
|
ctrl = json.loads(message["text"])
|
||||||
|
msg_type = ctrl.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "CloseStream":
|
||||||
|
await audio_processor.process_audio(b"")
|
||||||
|
break
|
||||||
|
elif msg_type == "Finalize":
|
||||||
|
# Flush current audio — trigger end-of-utterance
|
||||||
|
await audio_processor.process_audio(b"")
|
||||||
|
results_generator = await audio_processor.create_tasks()
|
||||||
|
elif msg_type == "KeepAlive":
|
||||||
|
pass # Just keep the connection alive
|
||||||
|
else:
|
||||||
|
logger.debug("Unknown Deepgram control message: %s", msg_type)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Invalid JSON control message")
|
||||||
|
else:
|
||||||
|
# WebSocket close
|
||||||
|
break
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info("Deepgram-compat WebSocket disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Deepgram-compat error: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if not results_task.done():
|
||||||
|
results_task.cancel()
|
||||||
|
try:
|
||||||
|
await results_task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
await audio_processor.cleanup()
|
||||||
|
logger.info("Deepgram-compat WebSocket cleaned up")
|
||||||
Loading…
Reference in a new issue