Fix voxtral streaming drain and silence flush

This commit is contained in:
Quentin Fuxa 2026-01-31 11:12:00 +01:00
parent d58365421f
commit 2fe34427ef
3 changed files with 95 additions and 17 deletions

View file

@ -86,6 +86,7 @@ class VoxtralHFStreamingOnlineProcessor:
self._first_chunk_samples = processor.num_samples_first_audio_chunk self._first_chunk_samples = processor.num_samples_first_audio_chunk
self._chunk_samples = processor.num_samples_per_audio_chunk self._chunk_samples = processor.num_samples_per_audio_chunk
self._chunk_step = processor.raw_audio_length_per_tok self._chunk_step = processor.raw_audio_length_per_tok
# num_right_pad_tokens is a method in some transformers versions, a property in others
n_right_pad = processor.num_right_pad_tokens n_right_pad = processor.num_right_pad_tokens
if callable(n_right_pad): if callable(n_right_pad):
n_right_pad = n_right_pad() n_right_pad = n_right_pad()
@ -112,6 +113,7 @@ class VoxtralHFStreamingOnlineProcessor:
# Text accumulation and word extraction # Text accumulation and word extraction
self._accumulated_text = "" self._accumulated_text = ""
self._n_text_tokens_received = 0 self._n_text_tokens_received = 0
self._n_audio_tokens_fed = 0
self._n_committed_words = 0 self._n_committed_words = 0
self._global_time_offset = 0.0 self._global_time_offset = 0.0
@ -133,7 +135,12 @@ class VoxtralHFStreamingOnlineProcessor:
return [], self.end return [], self.end
def get_buffer(self) -> Transcript: def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer.""" """Return all uncommitted text as buffer.
Drains the streamer first so late-arriving tokens (common on
slower devices like MPS) are picked up even between audio chunks.
"""
self._drain_streamer()
with self._text_lock: with self._text_lock:
text = self._accumulated_text text = self._accumulated_text
if not text: if not text:
@ -146,11 +153,45 @@ class VoxtralHFStreamingOnlineProcessor:
return Transcript(start=None, end=None, text="") return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]: def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all uncommitted words when silence starts.""" """Flush all uncommitted words when silence starts.
self._drain_streamer()
words = self._flush_all_pending_words() Feeds right-padding (silence) so the model has enough future context
logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words") to emit the last few tokens, then drains repeatedly until the model
return words, self.end has finished producing text. Without right-padding the model holds
back the last few words because it hasn't seen enough audio yet.
"""
if not self._generate_started or self._generate_finished:
self._drain_streamer()
words = self._flush_all_pending_words()
logger.info(f"[voxtral-hf] start_silence (no thread): flushed {len(words)} words")
return words, self.end
# Feed any remaining real audio
self._feed_pending_audio()
# Add right-padding so the model can decode trailing tokens.
# Don't count these toward _n_audio_tokens_fed — they're not
# real audio and shouldn't affect word timestamp calculations.
if self._right_pad_samples > 0:
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
self._pending_audio = np.append(self._pending_audio, right_pad)
saved_count = self._n_audio_tokens_fed
self._feed_pending_audio()
self._n_audio_tokens_fed = saved_count
# Drain in a loop: the model may still be processing right-padding
# chunks after the first drain returns. Keep draining until no new
# text appears for two consecutive rounds.
all_words: List[ASRToken] = []
for _ in range(5): # at most 5 drain+flush cycles
self._drain_streamer_blocking(timeout=5.0)
batch = self._flush_all_pending_words()
all_words.extend(batch)
if not batch:
break # no new text — model has caught up
logger.info(f"[voxtral-hf] start_silence: flushed {len(all_words)} words")
return all_words, self.end
def end_silence(self, silence_duration: float, offset: float): def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration self._global_time_offset += silence_duration
@ -203,6 +244,8 @@ class VoxtralHFStreamingOnlineProcessor:
# Extract first chunk # Extract first chunk
first_chunk_audio = self._pending_audio[:self._first_chunk_samples] first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
self._pending_audio = self._pending_audio[self._first_chunk_samples:] self._pending_audio = self._pending_audio[self._first_chunk_samples:]
# First chunk covers multiple audio tokens
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
first_inputs = processor( first_inputs = processor(
first_chunk_audio, first_chunk_audio,
@ -270,6 +313,7 @@ class VoxtralHFStreamingOnlineProcessor:
chunk = self._pending_audio[:chunk_size] chunk = self._pending_audio[:chunk_size]
self._audio_queue.put(chunk) self._audio_queue.put(chunk)
self._pending_audio = self._pending_audio[step_size:] self._pending_audio = self._pending_audio[step_size:]
self._n_audio_tokens_fed += 1
self.audio_buffer = self._pending_audio self.audio_buffer = self._pending_audio
@ -284,14 +328,49 @@ class VoxtralHFStreamingOnlineProcessor:
text_fragment = text_queue.get_nowait() text_fragment = text_queue.get_nowait()
except queue.Empty: except queue.Empty:
break break
# TextIteratorStreamer uses None as end-of-stream sentinel
if text_fragment is None: if text_fragment is None:
self._generate_finished = True self._generate_finished = True
break break
if text_fragment: if text_fragment:
with self._text_lock: with self._text_lock:
self._accumulated_text += text_fragment self._accumulated_text += text_fragment
self._n_text_tokens_received += 1 self._n_text_tokens_received += 1
def _drain_streamer_blocking(self, timeout=30.0):
"""Blocking drain: wait for the generate thread to process all queued
audio and produce the corresponding text.
Polls the text queue while the audio queue has items (model still
processing). Once the audio queue is empty, waits for trailing
tokens, then returns.
This is critical for start_silence(): without it, the non-blocking
drain races with the generate thread and the last words get stuck.
"""
if not self._generate_started or self._generate_finished:
self._drain_streamer()
return
text_queue = self._streamer.text_queue
deadline = time.time() + timeout
while time.time() < deadline:
# Short poll while model is still processing queued audio;
# longer wait once the audio queue is empty (trailing tokens).
wait = 2.0 if self._audio_queue.empty() else 0.1
try:
text_fragment = text_queue.get(timeout=wait)
except queue.Empty:
if self._audio_queue.empty():
break # Audio done + no text for 2s → fully caught up
continue # Audio still queued, model still working
if text_fragment is None:
self._generate_finished = True
break
if text_fragment:
with self._text_lock:
self._accumulated_text += text_fragment
self._n_text_tokens_received += 1
# ── Word extraction ── # ── Word extraction ──
@ -308,15 +387,15 @@ class VoxtralHFStreamingOnlineProcessor:
words = text.split() words = text.split()
new_words: List[ASRToken] = [] new_words: List[ASRToken] = []
n_tokens = self._n_text_tokens_received
n_words_total = len(words) n_words_total = len(words)
n_audio_toks = max(self._n_audio_tokens_fed, 1)
while len(words) > self._n_committed_words + 1: while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words] word = words[self._n_committed_words]
word_idx = self._n_committed_words word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0 tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0
tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0 tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0
start_time = self._pos_to_time(tok_start) start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(tok_end) end_time = self._pos_to_time(tok_end)
@ -336,15 +415,15 @@ class VoxtralHFStreamingOnlineProcessor:
words = text.split() words = text.split()
new_words: List[ASRToken] = [] new_words: List[ASRToken] = []
n_tokens = max(self._n_text_tokens_received, 1)
n_words_total = max(len(words), 1) n_words_total = max(len(words), 1)
n_audio_toks = max(self._n_audio_tokens_fed, 1)
while self._n_committed_words < len(words): while self._n_committed_words < len(words):
word = words[self._n_committed_words] word = words[self._n_committed_words]
word_idx = self._n_committed_words word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens) tok_start = int(word_idx / n_words_total * n_audio_toks)
tok_end = int((word_idx + 1) / n_words_total * n_tokens) tok_end = int((word_idx + 1) / n_words_total * n_audio_toks)
start_time = self._pos_to_time(tok_start) start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(tok_end) end_time = self._pos_to_time(tok_end)

View file

@ -14,7 +14,6 @@ import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# KV Cache # KV Cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -20,12 +20,12 @@ import numpy as np
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from whisperlivekit.timed_objects import ASRToken, Transcript from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model
from whisperlivekit.voxtral_mlx.model import SlidingKVCache from whisperlivekit.voxtral_mlx.model import SlidingKVCache
from whisperlivekit.voxtral_mlx.spectrogram import ( from whisperlivekit.voxtral_mlx.spectrogram import (
SAMPLES_PER_TOKEN,
LEFT_PAD_TOKENS, LEFT_PAD_TOKENS,
RIGHT_PAD_TOKENS, RIGHT_PAD_TOKENS,
SAMPLES_PER_TOKEN,
compute_mel_streaming, compute_mel_streaming,
) )