Fix voxtral streaming drain and silence flush
This commit is contained in:
parent
d58365421f
commit
2fe34427ef
3 changed files with 95 additions and 17 deletions
|
|
@ -86,6 +86,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
self._first_chunk_samples = processor.num_samples_first_audio_chunk
|
||||
self._chunk_samples = processor.num_samples_per_audio_chunk
|
||||
self._chunk_step = processor.raw_audio_length_per_tok
|
||||
# num_right_pad_tokens is a method in some transformers versions, a property in others
|
||||
n_right_pad = processor.num_right_pad_tokens
|
||||
if callable(n_right_pad):
|
||||
n_right_pad = n_right_pad()
|
||||
|
|
@ -112,6 +113,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
# Text accumulation and word extraction
|
||||
self._accumulated_text = ""
|
||||
self._n_text_tokens_received = 0
|
||||
self._n_audio_tokens_fed = 0
|
||||
self._n_committed_words = 0
|
||||
self._global_time_offset = 0.0
|
||||
|
||||
|
|
@ -133,7 +135,12 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
return [], self.end
|
||||
|
||||
def get_buffer(self) -> Transcript:
|
||||
"""Return all uncommitted text as buffer."""
|
||||
"""Return all uncommitted text as buffer.
|
||||
|
||||
Drains the streamer first so late-arriving tokens (common on
|
||||
slower devices like MPS) are picked up even between audio chunks.
|
||||
"""
|
||||
self._drain_streamer()
|
||||
with self._text_lock:
|
||||
text = self._accumulated_text
|
||||
if not text:
|
||||
|
|
@ -146,11 +153,45 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
return Transcript(start=None, end=None, text="")
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Flush all uncommitted words when silence starts."""
|
||||
self._drain_streamer()
|
||||
words = self._flush_all_pending_words()
|
||||
logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words")
|
||||
return words, self.end
|
||||
"""Flush all uncommitted words when silence starts.
|
||||
|
||||
Feeds right-padding (silence) so the model has enough future context
|
||||
to emit the last few tokens, then drains repeatedly until the model
|
||||
has finished producing text. Without right-padding the model holds
|
||||
back the last few words because it hasn't seen enough audio yet.
|
||||
"""
|
||||
if not self._generate_started or self._generate_finished:
|
||||
self._drain_streamer()
|
||||
words = self._flush_all_pending_words()
|
||||
logger.info(f"[voxtral-hf] start_silence (no thread): flushed {len(words)} words")
|
||||
return words, self.end
|
||||
|
||||
# Feed any remaining real audio
|
||||
self._feed_pending_audio()
|
||||
|
||||
# Add right-padding so the model can decode trailing tokens.
|
||||
# Don't count these toward _n_audio_tokens_fed — they're not
|
||||
# real audio and shouldn't affect word timestamp calculations.
|
||||
if self._right_pad_samples > 0:
|
||||
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
||||
self._pending_audio = np.append(self._pending_audio, right_pad)
|
||||
saved_count = self._n_audio_tokens_fed
|
||||
self._feed_pending_audio()
|
||||
self._n_audio_tokens_fed = saved_count
|
||||
|
||||
# Drain in a loop: the model may still be processing right-padding
|
||||
# chunks after the first drain returns. Keep draining until no new
|
||||
# text appears for two consecutive rounds.
|
||||
all_words: List[ASRToken] = []
|
||||
for _ in range(5): # at most 5 drain+flush cycles
|
||||
self._drain_streamer_blocking(timeout=5.0)
|
||||
batch = self._flush_all_pending_words()
|
||||
all_words.extend(batch)
|
||||
if not batch:
|
||||
break # no new text — model has caught up
|
||||
|
||||
logger.info(f"[voxtral-hf] start_silence: flushed {len(all_words)} words")
|
||||
return all_words, self.end
|
||||
|
||||
def end_silence(self, silence_duration: float, offset: float):
|
||||
self._global_time_offset += silence_duration
|
||||
|
|
@ -203,6 +244,8 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
# Extract first chunk
|
||||
first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
|
||||
self._pending_audio = self._pending_audio[self._first_chunk_samples:]
|
||||
# First chunk covers multiple audio tokens
|
||||
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
|
||||
|
||||
first_inputs = processor(
|
||||
first_chunk_audio,
|
||||
|
|
@ -270,6 +313,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
chunk = self._pending_audio[:chunk_size]
|
||||
self._audio_queue.put(chunk)
|
||||
self._pending_audio = self._pending_audio[step_size:]
|
||||
self._n_audio_tokens_fed += 1
|
||||
|
||||
self.audio_buffer = self._pending_audio
|
||||
|
||||
|
|
@ -284,14 +328,49 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
text_fragment = text_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
# TextIteratorStreamer uses None as end-of-stream sentinel
|
||||
if text_fragment is None:
|
||||
self._generate_finished = True
|
||||
break
|
||||
if text_fragment:
|
||||
with self._text_lock:
|
||||
self._accumulated_text += text_fragment
|
||||
self._n_text_tokens_received += 1
|
||||
self._n_text_tokens_received += 1
|
||||
|
||||
def _drain_streamer_blocking(self, timeout=30.0):
|
||||
"""Blocking drain: wait for the generate thread to process all queued
|
||||
audio and produce the corresponding text.
|
||||
|
||||
Polls the text queue while the audio queue has items (model still
|
||||
processing). Once the audio queue is empty, waits for trailing
|
||||
tokens, then returns.
|
||||
|
||||
This is critical for start_silence(): without it, the non-blocking
|
||||
drain races with the generate thread and the last words get stuck.
|
||||
"""
|
||||
if not self._generate_started or self._generate_finished:
|
||||
self._drain_streamer()
|
||||
return
|
||||
|
||||
text_queue = self._streamer.text_queue
|
||||
deadline = time.time() + timeout
|
||||
|
||||
while time.time() < deadline:
|
||||
# Short poll while model is still processing queued audio;
|
||||
# longer wait once the audio queue is empty (trailing tokens).
|
||||
wait = 2.0 if self._audio_queue.empty() else 0.1
|
||||
try:
|
||||
text_fragment = text_queue.get(timeout=wait)
|
||||
except queue.Empty:
|
||||
if self._audio_queue.empty():
|
||||
break # Audio done + no text for 2s → fully caught up
|
||||
continue # Audio still queued, model still working
|
||||
if text_fragment is None:
|
||||
self._generate_finished = True
|
||||
break
|
||||
if text_fragment:
|
||||
with self._text_lock:
|
||||
self._accumulated_text += text_fragment
|
||||
self._n_text_tokens_received += 1
|
||||
|
||||
# ── Word extraction ──
|
||||
|
||||
|
|
@ -308,15 +387,15 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
|
||||
words = text.split()
|
||||
new_words: List[ASRToken] = []
|
||||
n_tokens = self._n_text_tokens_received
|
||||
n_words_total = len(words)
|
||||
n_audio_toks = max(self._n_audio_tokens_fed, 1)
|
||||
|
||||
while len(words) > self._n_committed_words + 1:
|
||||
word = words[self._n_committed_words]
|
||||
word_idx = self._n_committed_words
|
||||
|
||||
tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0
|
||||
tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0
|
||||
tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0
|
||||
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0
|
||||
|
||||
start_time = self._pos_to_time(tok_start)
|
||||
end_time = self._pos_to_time(tok_end)
|
||||
|
|
@ -336,15 +415,15 @@ class VoxtralHFStreamingOnlineProcessor:
|
|||
|
||||
words = text.split()
|
||||
new_words: List[ASRToken] = []
|
||||
n_tokens = max(self._n_text_tokens_received, 1)
|
||||
n_words_total = max(len(words), 1)
|
||||
n_audio_toks = max(self._n_audio_tokens_fed, 1)
|
||||
|
||||
while self._n_committed_words < len(words):
|
||||
word = words[self._n_committed_words]
|
||||
word_idx = self._n_committed_words
|
||||
|
||||
tok_start = int(word_idx / n_words_total * n_tokens)
|
||||
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
||||
tok_start = int(word_idx / n_words_total * n_audio_toks)
|
||||
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks)
|
||||
|
||||
start_time = self._pos_to_time(tok_start)
|
||||
end_time = self._pos_to_time(tok_end)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import math
|
|||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV Cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -20,12 +20,12 @@ import numpy as np
|
|||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
|
||||
from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model
|
||||
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
|
||||
from whisperlivekit.voxtral_mlx.spectrogram import (
|
||||
SAMPLES_PER_TOKEN,
|
||||
LEFT_PAD_TOKENS,
|
||||
RIGHT_PAD_TOKENS,
|
||||
SAMPLES_PER_TOKEN,
|
||||
compute_mel_streaming,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue