WhisperLiveKit/whisperlivekit/voxtral_hf_streaming.py
2026-03-14 00:13:29 +01:00

573 lines
22 KiB
Python

"""
Voxtral Mini Realtime streaming backend using HuggingFace Transformers.
Uses VoxtralRealtimeForConditionalGeneration with a background generate thread
and queue-based audio feeding for real-time streaming transcription.
Supports CUDA, CPU, and MPS devices.
"""
import logging
import queue
import sys
import threading
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
class VoxtralHFStreamingASR:
"""Voxtral model holder using HuggingFace Transformers."""
sep = " "
def __init__(self, logfile=sys.stderr, **kwargs):
import torch
from transformers import (
AutoProcessor,
VoxtralRealtimeForConditionalGeneration,
)
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
DEFAULT_MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL
t = time.time()
logger.info(f"Loading Voxtral model '{model_path}' via HF Transformers...")
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
logger.info(f"Voxtral HF model loaded in {time.time() - t:.2f}s on {self.model.device}")
self.backend_choice = "voxtral"
self.tokenizer = None # sentence tokenizer — not needed for streaming
def transcribe(self, audio):
pass
class VoxtralHFStreamingOnlineProcessor:
"""
Online processor for Voxtral streaming ASR via HuggingFace Transformers.
Uses a background thread running model.generate() with a queue-based
input_features_generator and TextIteratorStreamer for real-time output.
Each decoded token corresponds to ~80ms of audio.
"""
SAMPLING_RATE = 16000
def __init__(self, asr: VoxtralHFStreamingASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32)
processor = asr.processor
self._first_chunk_samples = processor.num_samples_first_audio_chunk
self._chunk_samples = processor.num_samples_per_audio_chunk
self._chunk_step = processor.raw_audio_length_per_tok
# num_right_pad_tokens is a method in some transformers versions, a property in others
n_right_pad = processor.num_right_pad_tokens
if callable(n_right_pad):
n_right_pad = n_right_pad()
self._right_pad_samples = int(n_right_pad * processor.raw_audio_length_per_tok)
self._seconds_per_token = processor.raw_audio_length_per_tok / self.SAMPLING_RATE
self._reset_state()
logger.info(
f"[voxtral-hf] Initialized. first_chunk={self._first_chunk_samples} samples, "
f"chunk={self._chunk_samples}, step={self._chunk_step}, "
f"right_pad={self._right_pad_samples}"
)
def _reset_state(self):
self._pending_chunks: List[np.ndarray] = []
self._pending_len = 0
self._audio_queue: queue.Queue = queue.Queue()
self._streamer_texts: List[str] = []
self._generate_thread: Optional[threading.Thread] = None
self._generate_started = False
self._generate_finished = False
self._generate_error: Optional[Exception] = None
# Text accumulation (list of fragments, joined on demand)
self._text_fragments: List[str] = []
self._text_len = 0
# Fragment position tracking for accurate word timestamps:
# each entry is (char_offset_in_full_text, audio_tok_pos_consumed)
self._fragment_positions: List[Tuple[int, int]] = []
self._n_text_tokens_received = 0
self._n_audio_tokens_fed = 0
# Audio tokens actually consumed by the model (tracked inside generator)
self._n_audio_tokens_consumed = 0
self._n_committed_words = 0
self._global_time_offset = 0.0
# Event signalled by the generate thread when it finishes
self._generate_done = threading.Event()
# Lock for text state accessed from both generate thread and main thread
self._text_lock = threading.Lock()
# ── Audio / text helpers ──
def _get_pending_audio(self) -> np.ndarray:
"""Flatten pending audio chunks into a single array."""
if not self._pending_chunks:
return np.zeros(0, dtype=np.float32)
if len(self._pending_chunks) == 1:
return self._pending_chunks[0]
flat = np.concatenate(self._pending_chunks)
self._pending_chunks = [flat]
return flat
def _set_pending_audio(self, arr: np.ndarray):
"""Replace pending audio with a single array."""
if len(arr) == 0:
self._pending_chunks = []
self._pending_len = 0
else:
self._pending_chunks = [arr]
self._pending_len = len(arr)
def _get_accumulated_text(self) -> str:
"""Get the full accumulated text (joins fragments if needed)."""
if not self._text_fragments:
return ""
if len(self._text_fragments) == 1:
return self._text_fragments[0]
joined = "".join(self._text_fragments)
self._text_fragments = [joined]
return joined
# ── Interface methods ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_chunks.append(audio)
self._pending_len += len(audio)
self.audio_buffer = audio # diagnostic only
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning(f"[voxtral-hf] process_iter exception: {e}", exc_info=True)
return [], self.end
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer.
Drains the streamer first so late-arriving tokens (common on
slower devices like MPS) are picked up even between audio chunks.
"""
self._drain_streamer()
with self._text_lock:
text = self._get_accumulated_text()
if not text:
return Transcript(start=None, end=None, text="")
words = text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all uncommitted words when silence starts.
Feeds right-padding (silence) so the model has enough future context
to emit the last few tokens, then drains repeatedly until the model
has finished producing text. Without right-padding the model holds
back the last few words because it hasn't seen enough audio yet.
"""
if not self._generate_started or self._generate_finished:
self._drain_streamer()
words = self._flush_all_pending_words()
logger.info(f"[voxtral-hf] start_silence (no thread): flushed {len(words)} words")
return words, self.end
# Feed any remaining real audio
self._feed_pending_audio()
# Add right-padding so the model can decode trailing tokens.
# Don't count these toward _n_audio_tokens_fed — they're not
# real audio and shouldn't affect word timestamp calculations.
if self._right_pad_samples > 0:
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
self._pending_chunks.append(right_pad)
self._pending_len += len(right_pad)
saved_count = self._n_audio_tokens_fed
self._feed_pending_audio()
self._n_audio_tokens_fed = saved_count
# Drain in a loop: the model may continue producing text tokens after
# the audio queue is empty (autoregressive generation). Each iteration
# uses an event-driven blocking drain with short timeouts.
all_words: List[ASRToken] = []
for _ in range(5):
self._drain_streamer_blocking(timeout=5.0)
batch = self._flush_all_pending_words()
all_words.extend(batch)
if not batch:
break # no new text — model has caught up
logger.info(f"[voxtral-hf] start_silence: flushed {len(all_words)} words")
return all_words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Flush remaining audio with right-padding and stop the generate thread."""
# Add right-padding so the model can finish decoding
if self._right_pad_samples > 0:
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
self._pending_chunks.append(right_pad)
self._pending_len += len(right_pad)
# Feed remaining audio
if self._generate_started and not self._generate_finished:
self._feed_pending_audio()
# Signal end of audio
self._audio_queue.put(None)
# Wait for generate to finish
if self._generate_thread is not None:
self._generate_thread.join(timeout=30.0)
elif not self._generate_started and self._pending_len >= self._first_chunk_samples:
# Never started but have enough audio — start and immediately finish
self._start_generate_thread()
self._feed_pending_audio()
self._audio_queue.put(None)
if self._generate_thread is not None:
self._generate_thread.join(timeout=30.0)
self._drain_streamer()
words = self._flush_all_pending_words()
logger.info(f"[voxtral-hf] finish: flushed {len(words)} words")
return words, self.end
# ── Generate thread management ──
def _start_generate_thread(self):
"""Start model.generate() in a background thread with streaming."""
import torch
from transformers import TextIteratorStreamer
processor = self.asr.processor
model = self.asr.model
# Extract first chunk
pending = self._get_pending_audio()
first_chunk_audio = pending[:self._first_chunk_samples]
self._set_pending_audio(pending[self._first_chunk_samples:])
# First chunk covers multiple audio tokens
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
first_inputs = processor(
first_chunk_audio,
is_streaming=True,
is_first_audio_chunk=True,
return_tensors="pt",
)
first_inputs = first_inputs.to(model.device, dtype=model.dtype)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
self._streamer = streamer
audio_queue = self._audio_queue
def input_features_gen():
# Track audio consumption inside the generator (runs in generate thread)
self._n_audio_tokens_consumed = max(1, self._first_chunk_samples // self._chunk_step)
yield first_inputs.input_features
while True:
chunk_audio = audio_queue.get()
if chunk_audio is None:
break
self._n_audio_tokens_consumed += 1
inputs = processor(
chunk_audio,
is_streaming=True,
is_first_audio_chunk=False,
return_tensors="pt",
)
inputs = inputs.to(model.device, dtype=model.dtype)
yield inputs.input_features
def run_generate():
try:
with torch.no_grad():
# Pass generator as input_features — the model detects GeneratorType
# and internally converts it to input_features_generator
generate_kwargs = {
k: v for k, v in first_inputs.items()
if k != "input_features"
}
model.generate(
input_features=input_features_gen(),
streamer=streamer,
**generate_kwargs,
)
except Exception as e:
logger.error(f"[voxtral-hf] generate error: {e}", exc_info=True)
self._generate_error = e
finally:
self._generate_finished = True
self._generate_done.set()
self._generate_thread = threading.Thread(target=run_generate, daemon=True)
self._generate_thread.start()
self._generate_started = True
logger.info("[voxtral-hf] generate thread started")
def _feed_pending_audio(self):
"""Convert pending audio into properly-sized chunks for the generator."""
chunk_size = self._chunk_samples
step_size = self._chunk_step
pending = self._get_pending_audio()
while len(pending) >= chunk_size:
chunk = pending[:chunk_size]
self._audio_queue.put(chunk)
pending = pending[step_size:]
self._n_audio_tokens_fed += 1
self._set_pending_audio(pending)
self.audio_buffer = pending
def _append_text_fragment(self, text_fragment: str):
"""Append a text fragment with its audio position (must hold _text_lock)."""
self._fragment_positions.append((self._text_len, self._n_audio_tokens_consumed))
self._text_fragments.append(text_fragment)
self._text_len += len(text_fragment)
self._n_text_tokens_received += 1
def _drain_streamer(self):
"""Non-blocking drain of all available text from the streamer."""
if not self._generate_started:
return
text_queue = self._streamer.text_queue
while True:
try:
text_fragment = text_queue.get_nowait()
except queue.Empty:
break
if text_fragment is None:
self._generate_finished = True
break
if text_fragment:
with self._text_lock:
self._append_text_fragment(text_fragment)
def _drain_streamer_blocking(self, timeout=30.0):
"""Blocking drain: wait for the generate thread to finish producing text.
Uses the _generate_done event to know when the model is truly finished.
Falls back to text-queue polling with adaptive timeouts.
"""
if not self._generate_started or self._generate_finished:
self._drain_streamer()
return
text_queue = self._streamer.text_queue
deadline = time.time() + timeout
# Count consecutive empty polls to detect when model has caught up
empty_streak = 0
while time.time() < deadline:
remaining = max(deadline - time.time(), 0.01)
# If generate thread is done, do a final flush and exit
if self._generate_done.is_set() or self._generate_finished:
self._drain_streamer()
return
# Adaptive wait: short while audio is queued, longer once queue is empty
if self._audio_queue.empty():
wait = min(remaining, 0.5)
else:
wait = min(remaining, 0.1)
try:
text_fragment = text_queue.get(timeout=wait)
except queue.Empty:
empty_streak += 1
# Only exit if audio queue is empty AND we've had enough empty polls
# This prevents premature exit when the model is slow
if self._audio_queue.empty() and empty_streak >= 4:
break
continue
empty_streak = 0
if text_fragment is None:
self._generate_finished = True
break
if text_fragment:
with self._text_lock:
self._append_text_fragment(text_fragment)
# ── Word extraction ──
def _pos_to_time(self, token_position: int) -> float:
"""Convert audio token position to seconds."""
return token_position * self._seconds_per_token + self._global_time_offset
def _audio_pos_for_char(self, char_idx: int) -> int:
"""Look up the audio token position for a character index in the text.
Uses the fragment position index recorded when text arrives from the
generate thread. Returns the audio position of the fragment that
contains ``char_idx``, giving much better word timestamps than the
old uniform-distribution heuristic.
"""
if not self._fragment_positions:
return 0
# _fragment_positions is sorted by char_offset — find the last entry
# whose char_offset <= char_idx (the fragment containing this char).
pos = 0
for offset, audio_tok in self._fragment_positions:
if offset > char_idx:
break
pos = audio_tok
return pos
def _word_timestamps(self, text: str, words: List[str], start_idx: int, end_idx: int) -> List[Tuple[int, int]]:
"""Compute (tok_start, tok_end) for words[start_idx:end_idx] using fragment positions."""
# Build char offsets for each word
result = []
char_pos = 0
for i, word in enumerate(words):
if i > 0:
char_pos += 1 # space separator
if start_idx <= i < end_idx:
tok_start = self._audio_pos_for_char(char_pos)
tok_end = self._audio_pos_for_char(char_pos + len(word))
result.append((tok_start, tok_end))
char_pos += len(word)
return result
def _extract_new_words(self) -> List[ASRToken]:
"""Extract complete words (all but the last, which may still be growing)."""
with self._text_lock:
text = self._get_accumulated_text()
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_to_commit = len(words) - 1 # keep last word (may still grow)
if n_to_commit <= self._n_committed_words:
return []
timestamps = self._word_timestamps(text, words, self._n_committed_words, n_to_commit)
for tok_start, tok_end in timestamps:
word = words[self._n_committed_words]
start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(max(tok_end, tok_start + 1))
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
with self._text_lock:
text = self._get_accumulated_text()
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
if self._n_committed_words >= len(words):
return []
timestamps = self._word_timestamps(text, words, self._n_committed_words, len(words))
for tok_start, tok_end in timestamps:
word = words[self._n_committed_words]
start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(max(tok_end, tok_start + 1))
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
# ── Core processing ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# Start generate thread when enough audio is buffered
if not self._generate_started:
if self._pending_len >= self._first_chunk_samples:
self._start_generate_thread()
self._feed_pending_audio()
else:
return [], self.end
# Feed any new pending audio
if self._generate_started and not self._generate_finished:
self._feed_pending_audio()
# If generate finished unexpectedly (EOS) but new audio arrived, restart
if self._generate_finished and self._pending_len >= self._first_chunk_samples:
self._drain_streamer()
flush_words = self._flush_all_pending_words()
# Reset for new utterance
old_offset = self._global_time_offset
self._reset_state()
self._global_time_offset = old_offset
self._start_generate_thread()
self._feed_pending_audio()
return flush_words, self.end
# Drain available text from streamer
self._drain_streamer()
# Extract complete words
new_words = self._extract_new_words()
if new_words:
logger.info(f"[voxtral-hf] returning {len(new_words)} words: {[w.text for w in new_words]}")
self.buffer = []
return new_words, self.end