573 lines
22 KiB
Python
573 lines
22 KiB
Python
"""
|
|
Voxtral Mini Realtime streaming backend using HuggingFace Transformers.
|
|
|
|
Uses VoxtralRealtimeForConditionalGeneration with a background generate thread
|
|
and queue-based audio feeding for real-time streaming transcription.
|
|
Supports CUDA, CPU, and MPS devices.
|
|
"""
|
|
|
|
import logging
|
|
import queue
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VoxtralHFStreamingASR:
|
|
"""Voxtral model holder using HuggingFace Transformers."""
|
|
|
|
sep = " "
|
|
|
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
|
import torch
|
|
from transformers import (
|
|
AutoProcessor,
|
|
VoxtralRealtimeForConditionalGeneration,
|
|
)
|
|
|
|
self.logfile = logfile
|
|
self.transcribe_kargs = {}
|
|
|
|
lan = kwargs.get("lan", "auto")
|
|
self.original_language = None if lan == "auto" else lan
|
|
|
|
DEFAULT_MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
|
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
|
if not model_path:
|
|
model_size = kwargs.get("model_size", "")
|
|
if model_size and ("/" in model_size or model_size.startswith(".")):
|
|
model_path = model_size
|
|
else:
|
|
model_path = DEFAULT_MODEL
|
|
|
|
t = time.time()
|
|
logger.info(f"Loading Voxtral model '{model_path}' via HF Transformers...")
|
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
|
self.model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
logger.info(f"Voxtral HF model loaded in {time.time() - t:.2f}s on {self.model.device}")
|
|
|
|
self.backend_choice = "voxtral"
|
|
self.tokenizer = None # sentence tokenizer — not needed for streaming
|
|
|
|
def transcribe(self, audio):
|
|
pass
|
|
|
|
|
|
class VoxtralHFStreamingOnlineProcessor:
|
|
"""
|
|
Online processor for Voxtral streaming ASR via HuggingFace Transformers.
|
|
|
|
Uses a background thread running model.generate() with a queue-based
|
|
input_features_generator and TextIteratorStreamer for real-time output.
|
|
Each decoded token corresponds to ~80ms of audio.
|
|
"""
|
|
|
|
SAMPLING_RATE = 16000
|
|
|
|
def __init__(self, asr: VoxtralHFStreamingASR, logfile=sys.stderr):
|
|
self.asr = asr
|
|
self.logfile = logfile
|
|
self.end = 0.0
|
|
self.buffer = []
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
|
|
processor = asr.processor
|
|
self._first_chunk_samples = processor.num_samples_first_audio_chunk
|
|
self._chunk_samples = processor.num_samples_per_audio_chunk
|
|
self._chunk_step = processor.raw_audio_length_per_tok
|
|
# num_right_pad_tokens is a method in some transformers versions, a property in others
|
|
n_right_pad = processor.num_right_pad_tokens
|
|
if callable(n_right_pad):
|
|
n_right_pad = n_right_pad()
|
|
self._right_pad_samples = int(n_right_pad * processor.raw_audio_length_per_tok)
|
|
self._seconds_per_token = processor.raw_audio_length_per_tok / self.SAMPLING_RATE
|
|
|
|
self._reset_state()
|
|
|
|
logger.info(
|
|
f"[voxtral-hf] Initialized. first_chunk={self._first_chunk_samples} samples, "
|
|
f"chunk={self._chunk_samples}, step={self._chunk_step}, "
|
|
f"right_pad={self._right_pad_samples}"
|
|
)
|
|
|
|
def _reset_state(self):
|
|
self._pending_chunks: List[np.ndarray] = []
|
|
self._pending_len = 0
|
|
self._audio_queue: queue.Queue = queue.Queue()
|
|
self._streamer_texts: List[str] = []
|
|
self._generate_thread: Optional[threading.Thread] = None
|
|
self._generate_started = False
|
|
self._generate_finished = False
|
|
self._generate_error: Optional[Exception] = None
|
|
|
|
# Text accumulation (list of fragments, joined on demand)
|
|
self._text_fragments: List[str] = []
|
|
self._text_len = 0
|
|
# Fragment position tracking for accurate word timestamps:
|
|
# each entry is (char_offset_in_full_text, audio_tok_pos_consumed)
|
|
self._fragment_positions: List[Tuple[int, int]] = []
|
|
self._n_text_tokens_received = 0
|
|
self._n_audio_tokens_fed = 0
|
|
# Audio tokens actually consumed by the model (tracked inside generator)
|
|
self._n_audio_tokens_consumed = 0
|
|
self._n_committed_words = 0
|
|
self._global_time_offset = 0.0
|
|
|
|
# Event signalled by the generate thread when it finishes
|
|
self._generate_done = threading.Event()
|
|
|
|
# Lock for text state accessed from both generate thread and main thread
|
|
self._text_lock = threading.Lock()
|
|
|
|
# ── Audio / text helpers ──
|
|
|
|
def _get_pending_audio(self) -> np.ndarray:
|
|
"""Flatten pending audio chunks into a single array."""
|
|
if not self._pending_chunks:
|
|
return np.zeros(0, dtype=np.float32)
|
|
if len(self._pending_chunks) == 1:
|
|
return self._pending_chunks[0]
|
|
flat = np.concatenate(self._pending_chunks)
|
|
self._pending_chunks = [flat]
|
|
return flat
|
|
|
|
def _set_pending_audio(self, arr: np.ndarray):
|
|
"""Replace pending audio with a single array."""
|
|
if len(arr) == 0:
|
|
self._pending_chunks = []
|
|
self._pending_len = 0
|
|
else:
|
|
self._pending_chunks = [arr]
|
|
self._pending_len = len(arr)
|
|
|
|
def _get_accumulated_text(self) -> str:
|
|
"""Get the full accumulated text (joins fragments if needed)."""
|
|
if not self._text_fragments:
|
|
return ""
|
|
if len(self._text_fragments) == 1:
|
|
return self._text_fragments[0]
|
|
joined = "".join(self._text_fragments)
|
|
self._text_fragments = [joined]
|
|
return joined
|
|
|
|
# ── Interface methods ──
|
|
|
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
|
self.end = audio_stream_end_time
|
|
self._pending_chunks.append(audio)
|
|
self._pending_len += len(audio)
|
|
self.audio_buffer = audio # diagnostic only
|
|
|
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
|
try:
|
|
return self._process_iter_inner(is_last)
|
|
except Exception as e:
|
|
logger.warning(f"[voxtral-hf] process_iter exception: {e}", exc_info=True)
|
|
return [], self.end
|
|
|
|
def get_buffer(self) -> Transcript:
|
|
"""Return all uncommitted text as buffer.
|
|
|
|
Drains the streamer first so late-arriving tokens (common on
|
|
slower devices like MPS) are picked up even between audio chunks.
|
|
"""
|
|
self._drain_streamer()
|
|
with self._text_lock:
|
|
text = self._get_accumulated_text()
|
|
if not text:
|
|
return Transcript(start=None, end=None, text="")
|
|
|
|
words = text.split()
|
|
uncommitted = words[self._n_committed_words:]
|
|
if uncommitted:
|
|
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
|
|
return Transcript(start=None, end=None, text="")
|
|
|
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
|
"""Flush all uncommitted words when silence starts.
|
|
|
|
Feeds right-padding (silence) so the model has enough future context
|
|
to emit the last few tokens, then drains repeatedly until the model
|
|
has finished producing text. Without right-padding the model holds
|
|
back the last few words because it hasn't seen enough audio yet.
|
|
"""
|
|
if not self._generate_started or self._generate_finished:
|
|
self._drain_streamer()
|
|
words = self._flush_all_pending_words()
|
|
logger.info(f"[voxtral-hf] start_silence (no thread): flushed {len(words)} words")
|
|
return words, self.end
|
|
|
|
# Feed any remaining real audio
|
|
self._feed_pending_audio()
|
|
|
|
# Add right-padding so the model can decode trailing tokens.
|
|
# Don't count these toward _n_audio_tokens_fed — they're not
|
|
# real audio and shouldn't affect word timestamp calculations.
|
|
if self._right_pad_samples > 0:
|
|
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
|
self._pending_chunks.append(right_pad)
|
|
self._pending_len += len(right_pad)
|
|
saved_count = self._n_audio_tokens_fed
|
|
self._feed_pending_audio()
|
|
self._n_audio_tokens_fed = saved_count
|
|
|
|
# Drain in a loop: the model may continue producing text tokens after
|
|
# the audio queue is empty (autoregressive generation). Each iteration
|
|
# uses an event-driven blocking drain with short timeouts.
|
|
all_words: List[ASRToken] = []
|
|
for _ in range(5):
|
|
self._drain_streamer_blocking(timeout=5.0)
|
|
batch = self._flush_all_pending_words()
|
|
all_words.extend(batch)
|
|
if not batch:
|
|
break # no new text — model has caught up
|
|
|
|
logger.info(f"[voxtral-hf] start_silence: flushed {len(all_words)} words")
|
|
return all_words, self.end
|
|
|
|
def end_silence(self, silence_duration: float, offset: float):
|
|
self._global_time_offset += silence_duration
|
|
self.end += silence_duration
|
|
|
|
def new_speaker(self, change_speaker):
|
|
self.start_silence()
|
|
|
|
def warmup(self, audio, init_prompt=""):
|
|
pass
|
|
|
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
"""Flush remaining audio with right-padding and stop the generate thread."""
|
|
# Add right-padding so the model can finish decoding
|
|
if self._right_pad_samples > 0:
|
|
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
|
self._pending_chunks.append(right_pad)
|
|
self._pending_len += len(right_pad)
|
|
|
|
# Feed remaining audio
|
|
if self._generate_started and not self._generate_finished:
|
|
self._feed_pending_audio()
|
|
# Signal end of audio
|
|
self._audio_queue.put(None)
|
|
# Wait for generate to finish
|
|
if self._generate_thread is not None:
|
|
self._generate_thread.join(timeout=30.0)
|
|
elif not self._generate_started and self._pending_len >= self._first_chunk_samples:
|
|
# Never started but have enough audio — start and immediately finish
|
|
self._start_generate_thread()
|
|
self._feed_pending_audio()
|
|
self._audio_queue.put(None)
|
|
if self._generate_thread is not None:
|
|
self._generate_thread.join(timeout=30.0)
|
|
|
|
self._drain_streamer()
|
|
words = self._flush_all_pending_words()
|
|
logger.info(f"[voxtral-hf] finish: flushed {len(words)} words")
|
|
return words, self.end
|
|
|
|
# ── Generate thread management ──
|
|
|
|
def _start_generate_thread(self):
|
|
"""Start model.generate() in a background thread with streaming."""
|
|
import torch
|
|
from transformers import TextIteratorStreamer
|
|
|
|
processor = self.asr.processor
|
|
model = self.asr.model
|
|
|
|
# Extract first chunk
|
|
pending = self._get_pending_audio()
|
|
first_chunk_audio = pending[:self._first_chunk_samples]
|
|
self._set_pending_audio(pending[self._first_chunk_samples:])
|
|
# First chunk covers multiple audio tokens
|
|
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
|
|
|
|
first_inputs = processor(
|
|
first_chunk_audio,
|
|
is_streaming=True,
|
|
is_first_audio_chunk=True,
|
|
return_tensors="pt",
|
|
)
|
|
first_inputs = first_inputs.to(model.device, dtype=model.dtype)
|
|
|
|
streamer = TextIteratorStreamer(
|
|
processor.tokenizer,
|
|
skip_prompt=True,
|
|
skip_special_tokens=True,
|
|
)
|
|
self._streamer = streamer
|
|
|
|
audio_queue = self._audio_queue
|
|
|
|
def input_features_gen():
|
|
# Track audio consumption inside the generator (runs in generate thread)
|
|
self._n_audio_tokens_consumed = max(1, self._first_chunk_samples // self._chunk_step)
|
|
yield first_inputs.input_features
|
|
while True:
|
|
chunk_audio = audio_queue.get()
|
|
if chunk_audio is None:
|
|
break
|
|
self._n_audio_tokens_consumed += 1
|
|
inputs = processor(
|
|
chunk_audio,
|
|
is_streaming=True,
|
|
is_first_audio_chunk=False,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = inputs.to(model.device, dtype=model.dtype)
|
|
yield inputs.input_features
|
|
|
|
def run_generate():
|
|
try:
|
|
with torch.no_grad():
|
|
# Pass generator as input_features — the model detects GeneratorType
|
|
# and internally converts it to input_features_generator
|
|
generate_kwargs = {
|
|
k: v for k, v in first_inputs.items()
|
|
if k != "input_features"
|
|
}
|
|
model.generate(
|
|
input_features=input_features_gen(),
|
|
streamer=streamer,
|
|
**generate_kwargs,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"[voxtral-hf] generate error: {e}", exc_info=True)
|
|
self._generate_error = e
|
|
finally:
|
|
self._generate_finished = True
|
|
self._generate_done.set()
|
|
|
|
self._generate_thread = threading.Thread(target=run_generate, daemon=True)
|
|
self._generate_thread.start()
|
|
self._generate_started = True
|
|
logger.info("[voxtral-hf] generate thread started")
|
|
|
|
def _feed_pending_audio(self):
|
|
"""Convert pending audio into properly-sized chunks for the generator."""
|
|
chunk_size = self._chunk_samples
|
|
step_size = self._chunk_step
|
|
|
|
pending = self._get_pending_audio()
|
|
while len(pending) >= chunk_size:
|
|
chunk = pending[:chunk_size]
|
|
self._audio_queue.put(chunk)
|
|
pending = pending[step_size:]
|
|
self._n_audio_tokens_fed += 1
|
|
|
|
self._set_pending_audio(pending)
|
|
self.audio_buffer = pending
|
|
|
|
def _append_text_fragment(self, text_fragment: str):
|
|
"""Append a text fragment with its audio position (must hold _text_lock)."""
|
|
self._fragment_positions.append((self._text_len, self._n_audio_tokens_consumed))
|
|
self._text_fragments.append(text_fragment)
|
|
self._text_len += len(text_fragment)
|
|
self._n_text_tokens_received += 1
|
|
|
|
def _drain_streamer(self):
|
|
"""Non-blocking drain of all available text from the streamer."""
|
|
if not self._generate_started:
|
|
return
|
|
|
|
text_queue = self._streamer.text_queue
|
|
while True:
|
|
try:
|
|
text_fragment = text_queue.get_nowait()
|
|
except queue.Empty:
|
|
break
|
|
if text_fragment is None:
|
|
self._generate_finished = True
|
|
break
|
|
if text_fragment:
|
|
with self._text_lock:
|
|
self._append_text_fragment(text_fragment)
|
|
|
|
def _drain_streamer_blocking(self, timeout=30.0):
|
|
"""Blocking drain: wait for the generate thread to finish producing text.
|
|
|
|
Uses the _generate_done event to know when the model is truly finished.
|
|
Falls back to text-queue polling with adaptive timeouts.
|
|
"""
|
|
if not self._generate_started or self._generate_finished:
|
|
self._drain_streamer()
|
|
return
|
|
|
|
text_queue = self._streamer.text_queue
|
|
deadline = time.time() + timeout
|
|
# Count consecutive empty polls to detect when model has caught up
|
|
empty_streak = 0
|
|
|
|
while time.time() < deadline:
|
|
remaining = max(deadline - time.time(), 0.01)
|
|
|
|
# If generate thread is done, do a final flush and exit
|
|
if self._generate_done.is_set() or self._generate_finished:
|
|
self._drain_streamer()
|
|
return
|
|
|
|
# Adaptive wait: short while audio is queued, longer once queue is empty
|
|
if self._audio_queue.empty():
|
|
wait = min(remaining, 0.5)
|
|
else:
|
|
wait = min(remaining, 0.1)
|
|
|
|
try:
|
|
text_fragment = text_queue.get(timeout=wait)
|
|
except queue.Empty:
|
|
empty_streak += 1
|
|
# Only exit if audio queue is empty AND we've had enough empty polls
|
|
# This prevents premature exit when the model is slow
|
|
if self._audio_queue.empty() and empty_streak >= 4:
|
|
break
|
|
continue
|
|
|
|
empty_streak = 0
|
|
if text_fragment is None:
|
|
self._generate_finished = True
|
|
break
|
|
if text_fragment:
|
|
with self._text_lock:
|
|
self._append_text_fragment(text_fragment)
|
|
|
|
# ── Word extraction ──
|
|
|
|
def _pos_to_time(self, token_position: int) -> float:
|
|
"""Convert audio token position to seconds."""
|
|
return token_position * self._seconds_per_token + self._global_time_offset
|
|
|
|
def _audio_pos_for_char(self, char_idx: int) -> int:
|
|
"""Look up the audio token position for a character index in the text.
|
|
|
|
Uses the fragment position index recorded when text arrives from the
|
|
generate thread. Returns the audio position of the fragment that
|
|
contains ``char_idx``, giving much better word timestamps than the
|
|
old uniform-distribution heuristic.
|
|
"""
|
|
if not self._fragment_positions:
|
|
return 0
|
|
# _fragment_positions is sorted by char_offset — find the last entry
|
|
# whose char_offset <= char_idx (the fragment containing this char).
|
|
pos = 0
|
|
for offset, audio_tok in self._fragment_positions:
|
|
if offset > char_idx:
|
|
break
|
|
pos = audio_tok
|
|
return pos
|
|
|
|
def _word_timestamps(self, text: str, words: List[str], start_idx: int, end_idx: int) -> List[Tuple[int, int]]:
|
|
"""Compute (tok_start, tok_end) for words[start_idx:end_idx] using fragment positions."""
|
|
# Build char offsets for each word
|
|
result = []
|
|
char_pos = 0
|
|
for i, word in enumerate(words):
|
|
if i > 0:
|
|
char_pos += 1 # space separator
|
|
if start_idx <= i < end_idx:
|
|
tok_start = self._audio_pos_for_char(char_pos)
|
|
tok_end = self._audio_pos_for_char(char_pos + len(word))
|
|
result.append((tok_start, tok_end))
|
|
char_pos += len(word)
|
|
return result
|
|
|
|
def _extract_new_words(self) -> List[ASRToken]:
|
|
"""Extract complete words (all but the last, which may still be growing)."""
|
|
with self._text_lock:
|
|
text = self._get_accumulated_text()
|
|
if not text:
|
|
return []
|
|
|
|
words = text.split()
|
|
new_words: List[ASRToken] = []
|
|
n_to_commit = len(words) - 1 # keep last word (may still grow)
|
|
|
|
if n_to_commit <= self._n_committed_words:
|
|
return []
|
|
|
|
timestamps = self._word_timestamps(text, words, self._n_committed_words, n_to_commit)
|
|
|
|
for tok_start, tok_end in timestamps:
|
|
word = words[self._n_committed_words]
|
|
start_time = self._pos_to_time(tok_start)
|
|
end_time = self._pos_to_time(max(tok_end, tok_start + 1))
|
|
|
|
text_out = word if self._n_committed_words == 0 else " " + word
|
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
|
self._n_committed_words += 1
|
|
|
|
return new_words
|
|
|
|
def _flush_all_pending_words(self) -> List[ASRToken]:
|
|
"""Flush ALL words including the last partial one."""
|
|
with self._text_lock:
|
|
text = self._get_accumulated_text()
|
|
if not text:
|
|
return []
|
|
|
|
words = text.split()
|
|
new_words: List[ASRToken] = []
|
|
|
|
if self._n_committed_words >= len(words):
|
|
return []
|
|
|
|
timestamps = self._word_timestamps(text, words, self._n_committed_words, len(words))
|
|
|
|
for tok_start, tok_end in timestamps:
|
|
word = words[self._n_committed_words]
|
|
start_time = self._pos_to_time(tok_start)
|
|
end_time = self._pos_to_time(max(tok_end, tok_start + 1))
|
|
|
|
text_out = word if self._n_committed_words == 0 else " " + word
|
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
|
self._n_committed_words += 1
|
|
|
|
return new_words
|
|
|
|
# ── Core processing ──
|
|
|
|
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
|
# Start generate thread when enough audio is buffered
|
|
if not self._generate_started:
|
|
if self._pending_len >= self._first_chunk_samples:
|
|
self._start_generate_thread()
|
|
self._feed_pending_audio()
|
|
else:
|
|
return [], self.end
|
|
|
|
# Feed any new pending audio
|
|
if self._generate_started and not self._generate_finished:
|
|
self._feed_pending_audio()
|
|
|
|
# If generate finished unexpectedly (EOS) but new audio arrived, restart
|
|
if self._generate_finished and self._pending_len >= self._first_chunk_samples:
|
|
self._drain_streamer()
|
|
flush_words = self._flush_all_pending_words()
|
|
# Reset for new utterance
|
|
old_offset = self._global_time_offset
|
|
self._reset_state()
|
|
self._global_time_offset = old_offset
|
|
self._start_generate_thread()
|
|
self._feed_pending_audio()
|
|
return flush_words, self.end
|
|
|
|
# Drain available text from streamer
|
|
self._drain_streamer()
|
|
|
|
# Extract complete words
|
|
new_words = self._extract_new_words()
|
|
|
|
if new_words:
|
|
logger.info(f"[voxtral-hf] returning {len(new_words)} words: {[w.text for w in new_words]}")
|
|
|
|
self.buffer = []
|
|
return new_words, self.end
|