WhisperLiveKit/whisperlivekit/vllm_realtime.py
2026-03-14 00:13:29 +01:00

416 lines
14 KiB
Python

"""
vLLM Realtime WebSocket streaming backend for WhisperLiveKit.
Connects to a vLLM server's ``/v1/realtime`` WebSocket endpoint to stream
audio and receive transcription deltas. Uses ``websockets.sync.client``
for simplicity since ``process_iter`` runs inside ``asyncio.to_thread``.
Provides ``VLLMRealtimeASR`` (lightweight model holder) and
``VLLMRealtimeOnlineProcessor`` (streaming processor) that plug into
WhisperLiveKit's audio processing pipeline.
"""
import base64
import json
import logging
import threading
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
class VLLMRealtimeASR:
"""Lightweight model holder — stores connection info for the vLLM server."""
sep = " "
SAMPLING_RATE = 16000
backend_choice = "vllm-realtime"
def __init__(self, vllm_url="ws://localhost:8000/v1/realtime",
model_name="Qwen/Qwen3-ASR-1.7B", lan="auto", **kwargs):
self.vllm_url = vllm_url
self.model_name = model_name
self.original_language = None if lan == "auto" else lan
self.tokenizer = None
def transcribe(self, audio):
pass
class VLLMRealtimeOnlineProcessor:
"""
Online processor that streams audio to a vLLM Realtime WebSocket.
Uses a background thread for WebSocket receiving and
``websockets.sync.client`` for the sync WebSocket connection.
"""
SAMPLING_RATE = 16000
# Minimum audio samples before connecting (0.5s of audio)
_MIN_CONNECT_SAMPLES = SAMPLING_RATE // 2
def __init__(self, asr: VLLMRealtimeASR):
self.asr = asr
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32)
self._reset_state()
logger.info(
"[vllm-realtime] Initialized. url=%s model=%s",
asr.vllm_url, asr.model_name,
)
def _reset_state(self):
self._pending_audio = np.zeros(0, dtype=np.float32)
self._ws = None
self._recv_thread: Optional[threading.Thread] = None
self._connected = False
self._done = False
self._recv_error: Optional[Exception] = None
# Text accumulation and word extraction
self._accumulated_text = ""
self._n_committed_words = 0
self._total_audio_duration = 0.0
self._global_time_offset = 0.0
# Lock for text state accessed from both recv thread and main thread
self._text_lock = threading.Lock()
# ── Interface methods ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio)
self.audio_buffer = self._pending_audio
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning("[vllm-realtime] process_iter exception: %s", e, exc_info=True)
return [], self.end
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer."""
self._drain_deltas()
with self._text_lock:
text = self._accumulated_text
if not text:
return Transcript(start=None, end=None, text="")
words = text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all pending words when silence starts.
Sends commit(final=true) to signal end of utterance, waits for
transcription.done, flushes all words, then prepares for reconnection
on the next utterance.
"""
if not self._connected or self._done:
words = self._flush_all_pending_words()
logger.info("[vllm-realtime] start_silence (not connected): flushed %d words", len(words))
return words, self.end
# Send any remaining buffered audio
self._send_pending_audio()
# Signal end of stream
self._send_commit(final=True)
# Wait for transcription.done
self._wait_for_done(timeout=10.0)
# Flush all remaining words
words = self._flush_all_pending_words()
# Close and reset for next utterance
self._close_ws()
old_offset = self._global_time_offset + self._total_audio_duration
self._reset_state()
self._global_time_offset = old_offset
logger.info("[vllm-realtime] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Close connection and flush all remaining words."""
if self._connected and not self._done:
# Send remaining audio
self._send_pending_audio()
# Signal final commit
self._send_commit(final=True)
# Wait for transcription.done
self._wait_for_done(timeout=30.0)
# Flush all words
words = self._flush_all_pending_words()
# Close WebSocket
self._close_ws()
logger.info("[vllm-realtime] finish: flushed %d words", len(words))
return words, self.end
# ── WebSocket connection management ──
def _connect(self):
"""Connect to the vLLM realtime WebSocket and start the receive thread."""
from websockets.sync.client import connect
url = self.asr.vllm_url
logger.info("[vllm-realtime] Connecting to %s", url)
self._ws = connect(url)
# Send session.update to select model
self._ws.send(json.dumps({
"type": "session.update",
"model": self.asr.model_name,
}))
# Send initial commit(final=false) to start generation
self._send_commit(final=False)
# Start receive thread
self._recv_thread = threading.Thread(target=self._recv_loop, daemon=True)
self._recv_thread.start()
self._connected = True
logger.info("[vllm-realtime] Connected and started receive thread")
def _close_ws(self):
"""Close the WebSocket connection and join the receive thread."""
if self._ws is not None:
try:
self._ws.close()
except Exception:
pass
self._ws = None
if self._recv_thread is not None:
self._recv_thread.join(timeout=5.0)
self._recv_thread = None
def _recv_loop(self):
"""Background thread: receive messages from the vLLM WebSocket."""
try:
while not self._done and self._ws is not None:
try:
raw = self._ws.recv(timeout=0.1)
except TimeoutError:
continue
except Exception:
break
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError):
continue
msg_type = msg.get("type", "")
if msg_type == "transcription.delta":
delta = msg.get("delta", "")
if delta:
with self._text_lock:
self._accumulated_text += delta
elif msg_type == "transcription.done":
done_text = msg.get("text", "")
if done_text:
with self._text_lock:
# Replace accumulated text with final text
self._accumulated_text = done_text
self._done = True
break
except Exception as e:
logger.error("[vllm-realtime] recv_loop error: %s", e, exc_info=True)
self._recv_error = e
self._done = True
# ── Protocol messages ──
def _send_commit(self, final: bool):
"""Send input_audio_buffer.commit message."""
if self._ws is None:
return
try:
self._ws.send(json.dumps({
"type": "input_audio_buffer.commit",
"final": final,
}))
except Exception as e:
logger.warning("[vllm-realtime] Failed to send commit: %s", e)
def _send_audio(self, audio: np.ndarray):
"""Send audio as a base64-encoded PCM16 append message."""
if self._ws is None:
return
# Convert float32 [-1, 1] to int16 PCM
pcm16 = (audio * 32767).astype(np.int16)
audio_bytes = pcm16.tobytes()
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
try:
self._ws.send(json.dumps({
"type": "input_audio_buffer.append",
"audio": audio_b64,
}))
except Exception as e:
logger.warning("[vllm-realtime] Failed to send audio: %s", e)
def _send_pending_audio(self):
"""Send all pending audio to the vLLM server."""
if len(self._pending_audio) == 0:
return
# Track total audio duration for timestamp estimation
self._total_audio_duration += len(self._pending_audio) / self.SAMPLING_RATE
# Send in chunks of 0.5s to avoid overwhelming the WebSocket
chunk_samples = self.SAMPLING_RATE // 2
while len(self._pending_audio) >= chunk_samples:
chunk = self._pending_audio[:chunk_samples]
self._send_audio(chunk)
self._pending_audio = self._pending_audio[chunk_samples:]
# Send remaining audio if any
if len(self._pending_audio) > 0:
self._send_audio(self._pending_audio)
self._pending_audio = np.zeros(0, dtype=np.float32)
self.audio_buffer = self._pending_audio
# ── Receive helpers ──
def _drain_deltas(self):
"""No-op since the recv thread accumulates text directly."""
pass
def _wait_for_done(self, timeout: float = 10.0):
"""Wait for transcription.done message from the server."""
deadline = time.time() + timeout
while not self._done and time.time() < deadline:
time.sleep(0.05)
if not self._done:
logger.warning("[vllm-realtime] Timed out waiting for transcription.done")
# ── Word extraction (same approach as VoxtralHF) ──
def _time_for_word(self, word_idx: int, n_words_total: int) -> Tuple[float, float]:
"""Estimate timestamps by linearly distributing words across audio duration."""
duration = max(self._total_audio_duration, 0.001)
n_total = max(n_words_total, 1)
start_time = (word_idx / n_total) * duration + self._global_time_offset
end_time = ((word_idx + 1) / n_total) * duration + self._global_time_offset
return start_time, end_time
def _extract_new_words(self) -> List[ASRToken]:
"""Extract complete words (all but the last, which may still grow)."""
with self._text_lock:
text = self._accumulated_text
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_words_total = len(words)
while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words]
start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total)
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
with self._text_lock:
text = self._accumulated_text
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_words_total = max(len(words), 1)
while self._n_committed_words < len(words):
word = words[self._n_committed_words]
start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total)
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
# ── Core processing ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# Connect when we have enough audio buffered
if not self._connected:
if len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES:
self._connect()
self._send_pending_audio()
else:
return [], self.end
# Send any new pending audio
if self._connected and not self._done:
self._send_pending_audio()
# If connection closed unexpectedly but new audio arrived, reconnect
if self._done and len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES:
flush_words = self._flush_all_pending_words()
old_offset = self._global_time_offset + self._total_audio_duration
self._close_ws()
self._reset_state()
self._global_time_offset = old_offset
self._connect()
self._send_pending_audio()
return flush_words, self.end
# Extract complete words
new_words = self._extract_new_words()
if new_words:
logger.info(
"[vllm-realtime] returning %d words: %s",
len(new_words), [w.text for w in new_words],
)
self.buffer = []
return new_words, self.end