diff --git a/whisperlivekit/voxtral_hf_streaming.py b/whisperlivekit/voxtral_hf_streaming.py index 89ffbd7..c530193 100644 --- a/whisperlivekit/voxtral_hf_streaming.py +++ b/whisperlivekit/voxtral_hf_streaming.py @@ -86,6 +86,7 @@ class VoxtralHFStreamingOnlineProcessor: self._first_chunk_samples = processor.num_samples_first_audio_chunk self._chunk_samples = processor.num_samples_per_audio_chunk self._chunk_step = processor.raw_audio_length_per_tok + # num_right_pad_tokens is a method in some transformers versions, a property in others n_right_pad = processor.num_right_pad_tokens if callable(n_right_pad): n_right_pad = n_right_pad() @@ -112,6 +113,7 @@ class VoxtralHFStreamingOnlineProcessor: # Text accumulation and word extraction self._accumulated_text = "" self._n_text_tokens_received = 0 + self._n_audio_tokens_fed = 0 self._n_committed_words = 0 self._global_time_offset = 0.0 @@ -133,7 +135,12 @@ class VoxtralHFStreamingOnlineProcessor: return [], self.end def get_buffer(self) -> Transcript: - """Return all uncommitted text as buffer.""" + """Return all uncommitted text as buffer. + + Drains the streamer first so late-arriving tokens (common on + slower devices like MPS) are picked up even between audio chunks. + """ + self._drain_streamer() with self._text_lock: text = self._accumulated_text if not text: @@ -146,11 +153,45 @@ class VoxtralHFStreamingOnlineProcessor: return Transcript(start=None, end=None, text="") def start_silence(self) -> Tuple[List[ASRToken], float]: - """Flush all uncommitted words when silence starts.""" - self._drain_streamer() - words = self._flush_all_pending_words() - logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words") - return words, self.end + """Flush all uncommitted words when silence starts. + + Feeds right-padding (silence) so the model has enough future context + to emit the last few tokens, then drains repeatedly until the model + has finished producing text. Without right-padding the model holds + back the last few words because it hasn't seen enough audio yet. + """ + if not self._generate_started or self._generate_finished: + self._drain_streamer() + words = self._flush_all_pending_words() + logger.info(f"[voxtral-hf] start_silence (no thread): flushed {len(words)} words") + return words, self.end + + # Feed any remaining real audio + self._feed_pending_audio() + + # Add right-padding so the model can decode trailing tokens. + # Don't count these toward _n_audio_tokens_fed — they're not + # real audio and shouldn't affect word timestamp calculations. + if self._right_pad_samples > 0: + right_pad = np.zeros(self._right_pad_samples, dtype=np.float32) + self._pending_audio = np.append(self._pending_audio, right_pad) + saved_count = self._n_audio_tokens_fed + self._feed_pending_audio() + self._n_audio_tokens_fed = saved_count + + # Drain in a loop: the model may still be processing right-padding + # chunks after the first drain returns. Keep draining until no new + # text appears for two consecutive rounds. + all_words: List[ASRToken] = [] + for _ in range(5): # at most 5 drain+flush cycles + self._drain_streamer_blocking(timeout=5.0) + batch = self._flush_all_pending_words() + all_words.extend(batch) + if not batch: + break # no new text — model has caught up + + logger.info(f"[voxtral-hf] start_silence: flushed {len(all_words)} words") + return all_words, self.end def end_silence(self, silence_duration: float, offset: float): self._global_time_offset += silence_duration @@ -203,6 +244,8 @@ class VoxtralHFStreamingOnlineProcessor: # Extract first chunk first_chunk_audio = self._pending_audio[:self._first_chunk_samples] self._pending_audio = self._pending_audio[self._first_chunk_samples:] + # First chunk covers multiple audio tokens + self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step) first_inputs = processor( first_chunk_audio, @@ -270,6 +313,7 @@ class VoxtralHFStreamingOnlineProcessor: chunk = self._pending_audio[:chunk_size] self._audio_queue.put(chunk) self._pending_audio = self._pending_audio[step_size:] + self._n_audio_tokens_fed += 1 self.audio_buffer = self._pending_audio @@ -284,14 +328,49 @@ class VoxtralHFStreamingOnlineProcessor: text_fragment = text_queue.get_nowait() except queue.Empty: break - # TextIteratorStreamer uses None as end-of-stream sentinel if text_fragment is None: self._generate_finished = True break if text_fragment: with self._text_lock: self._accumulated_text += text_fragment - self._n_text_tokens_received += 1 + self._n_text_tokens_received += 1 + + def _drain_streamer_blocking(self, timeout=30.0): + """Blocking drain: wait for the generate thread to process all queued + audio and produce the corresponding text. + + Polls the text queue while the audio queue has items (model still + processing). Once the audio queue is empty, waits for trailing + tokens, then returns. + + This is critical for start_silence(): without it, the non-blocking + drain races with the generate thread and the last words get stuck. + """ + if not self._generate_started or self._generate_finished: + self._drain_streamer() + return + + text_queue = self._streamer.text_queue + deadline = time.time() + timeout + + while time.time() < deadline: + # Short poll while model is still processing queued audio; + # longer wait once the audio queue is empty (trailing tokens). + wait = 2.0 if self._audio_queue.empty() else 0.1 + try: + text_fragment = text_queue.get(timeout=wait) + except queue.Empty: + if self._audio_queue.empty(): + break # Audio done + no text for 2s → fully caught up + continue # Audio still queued, model still working + if text_fragment is None: + self._generate_finished = True + break + if text_fragment: + with self._text_lock: + self._accumulated_text += text_fragment + self._n_text_tokens_received += 1 # ── Word extraction ── @@ -308,15 +387,15 @@ class VoxtralHFStreamingOnlineProcessor: words = text.split() new_words: List[ASRToken] = [] - n_tokens = self._n_text_tokens_received n_words_total = len(words) + n_audio_toks = max(self._n_audio_tokens_fed, 1) while len(words) > self._n_committed_words + 1: word = words[self._n_committed_words] word_idx = self._n_committed_words - tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0 - tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0 + tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0 + tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0 start_time = self._pos_to_time(tok_start) end_time = self._pos_to_time(tok_end) @@ -336,15 +415,15 @@ class VoxtralHFStreamingOnlineProcessor: words = text.split() new_words: List[ASRToken] = [] - n_tokens = max(self._n_text_tokens_received, 1) n_words_total = max(len(words), 1) + n_audio_toks = max(self._n_audio_tokens_fed, 1) while self._n_committed_words < len(words): word = words[self._n_committed_words] word_idx = self._n_committed_words - tok_start = int(word_idx / n_words_total * n_tokens) - tok_end = int((word_idx + 1) / n_words_total * n_tokens) + tok_start = int(word_idx / n_words_total * n_audio_toks) + tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) start_time = self._pos_to_time(tok_start) end_time = self._pos_to_time(tok_end) diff --git a/whisperlivekit/voxtral_mlx/model.py b/whisperlivekit/voxtral_mlx/model.py index 0a637f8..3bc999d 100644 --- a/whisperlivekit/voxtral_mlx/model.py +++ b/whisperlivekit/voxtral_mlx/model.py @@ -14,7 +14,6 @@ import math import mlx.core as mx import mlx.nn as nn - # --------------------------------------------------------------------------- # KV Cache # --------------------------------------------------------------------------- diff --git a/whisperlivekit/voxtral_mlx_asr.py b/whisperlivekit/voxtral_mlx_asr.py index 4c62f80..f666c0f 100644 --- a/whisperlivekit/voxtral_mlx_asr.py +++ b/whisperlivekit/voxtral_mlx_asr.py @@ -20,12 +20,12 @@ import numpy as np from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from whisperlivekit.timed_objects import ASRToken, Transcript -from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID +from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model from whisperlivekit.voxtral_mlx.model import SlidingKVCache from whisperlivekit.voxtral_mlx.spectrogram import ( - SAMPLES_PER_TOKEN, LEFT_PAD_TOKENS, RIGHT_PAD_TOKENS, + SAMPLES_PER_TOKEN, compute_mel_streaming, )