Fix voxtral streaming drain and silence flush
This commit is contained in:
parent
d58365421f
commit
2fe34427ef
3 changed files with 95 additions and 17 deletions
|
|
@ -86,6 +86,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
self._first_chunk_samples = processor.num_samples_first_audio_chunk
|
self._first_chunk_samples = processor.num_samples_first_audio_chunk
|
||||||
self._chunk_samples = processor.num_samples_per_audio_chunk
|
self._chunk_samples = processor.num_samples_per_audio_chunk
|
||||||
self._chunk_step = processor.raw_audio_length_per_tok
|
self._chunk_step = processor.raw_audio_length_per_tok
|
||||||
|
# num_right_pad_tokens is a method in some transformers versions, a property in others
|
||||||
n_right_pad = processor.num_right_pad_tokens
|
n_right_pad = processor.num_right_pad_tokens
|
||||||
if callable(n_right_pad):
|
if callable(n_right_pad):
|
||||||
n_right_pad = n_right_pad()
|
n_right_pad = n_right_pad()
|
||||||
|
|
@ -112,6 +113,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
# Text accumulation and word extraction
|
# Text accumulation and word extraction
|
||||||
self._accumulated_text = ""
|
self._accumulated_text = ""
|
||||||
self._n_text_tokens_received = 0
|
self._n_text_tokens_received = 0
|
||||||
|
self._n_audio_tokens_fed = 0
|
||||||
self._n_committed_words = 0
|
self._n_committed_words = 0
|
||||||
self._global_time_offset = 0.0
|
self._global_time_offset = 0.0
|
||||||
|
|
||||||
|
|
@ -133,7 +135,12 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
return [], self.end
|
return [], self.end
|
||||||
|
|
||||||
def get_buffer(self) -> Transcript:
|
def get_buffer(self) -> Transcript:
|
||||||
"""Return all uncommitted text as buffer."""
|
"""Return all uncommitted text as buffer.
|
||||||
|
|
||||||
|
Drains the streamer first so late-arriving tokens (common on
|
||||||
|
slower devices like MPS) are picked up even between audio chunks.
|
||||||
|
"""
|
||||||
|
self._drain_streamer()
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
text = self._accumulated_text
|
text = self._accumulated_text
|
||||||
if not text:
|
if not text:
|
||||||
|
|
@ -146,11 +153,45 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
return Transcript(start=None, end=None, text="")
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""Flush all uncommitted words when silence starts."""
|
"""Flush all uncommitted words when silence starts.
|
||||||
self._drain_streamer()
|
|
||||||
words = self._flush_all_pending_words()
|
Feeds right-padding (silence) so the model has enough future context
|
||||||
logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words")
|
to emit the last few tokens, then drains repeatedly until the model
|
||||||
return words, self.end
|
has finished producing text. Without right-padding the model holds
|
||||||
|
back the last few words because it hasn't seen enough audio yet.
|
||||||
|
"""
|
||||||
|
if not self._generate_started or self._generate_finished:
|
||||||
|
self._drain_streamer()
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
logger.info(f"[voxtral-hf] start_silence (no thread): flushed {len(words)} words")
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
# Feed any remaining real audio
|
||||||
|
self._feed_pending_audio()
|
||||||
|
|
||||||
|
# Add right-padding so the model can decode trailing tokens.
|
||||||
|
# Don't count these toward _n_audio_tokens_fed — they're not
|
||||||
|
# real audio and shouldn't affect word timestamp calculations.
|
||||||
|
if self._right_pad_samples > 0:
|
||||||
|
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
||||||
|
self._pending_audio = np.append(self._pending_audio, right_pad)
|
||||||
|
saved_count = self._n_audio_tokens_fed
|
||||||
|
self._feed_pending_audio()
|
||||||
|
self._n_audio_tokens_fed = saved_count
|
||||||
|
|
||||||
|
# Drain in a loop: the model may still be processing right-padding
|
||||||
|
# chunks after the first drain returns. Keep draining until no new
|
||||||
|
# text appears for two consecutive rounds.
|
||||||
|
all_words: List[ASRToken] = []
|
||||||
|
for _ in range(5): # at most 5 drain+flush cycles
|
||||||
|
self._drain_streamer_blocking(timeout=5.0)
|
||||||
|
batch = self._flush_all_pending_words()
|
||||||
|
all_words.extend(batch)
|
||||||
|
if not batch:
|
||||||
|
break # no new text — model has caught up
|
||||||
|
|
||||||
|
logger.info(f"[voxtral-hf] start_silence: flushed {len(all_words)} words")
|
||||||
|
return all_words, self.end
|
||||||
|
|
||||||
def end_silence(self, silence_duration: float, offset: float):
|
def end_silence(self, silence_duration: float, offset: float):
|
||||||
self._global_time_offset += silence_duration
|
self._global_time_offset += silence_duration
|
||||||
|
|
@ -203,6 +244,8 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
# Extract first chunk
|
# Extract first chunk
|
||||||
first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
|
first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
|
||||||
self._pending_audio = self._pending_audio[self._first_chunk_samples:]
|
self._pending_audio = self._pending_audio[self._first_chunk_samples:]
|
||||||
|
# First chunk covers multiple audio tokens
|
||||||
|
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
|
||||||
|
|
||||||
first_inputs = processor(
|
first_inputs = processor(
|
||||||
first_chunk_audio,
|
first_chunk_audio,
|
||||||
|
|
@ -270,6 +313,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
chunk = self._pending_audio[:chunk_size]
|
chunk = self._pending_audio[:chunk_size]
|
||||||
self._audio_queue.put(chunk)
|
self._audio_queue.put(chunk)
|
||||||
self._pending_audio = self._pending_audio[step_size:]
|
self._pending_audio = self._pending_audio[step_size:]
|
||||||
|
self._n_audio_tokens_fed += 1
|
||||||
|
|
||||||
self.audio_buffer = self._pending_audio
|
self.audio_buffer = self._pending_audio
|
||||||
|
|
||||||
|
|
@ -284,14 +328,49 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
text_fragment = text_queue.get_nowait()
|
text_fragment = text_queue.get_nowait()
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
break
|
break
|
||||||
# TextIteratorStreamer uses None as end-of-stream sentinel
|
|
||||||
if text_fragment is None:
|
if text_fragment is None:
|
||||||
self._generate_finished = True
|
self._generate_finished = True
|
||||||
break
|
break
|
||||||
if text_fragment:
|
if text_fragment:
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
self._accumulated_text += text_fragment
|
self._accumulated_text += text_fragment
|
||||||
self._n_text_tokens_received += 1
|
self._n_text_tokens_received += 1
|
||||||
|
|
||||||
|
def _drain_streamer_blocking(self, timeout=30.0):
|
||||||
|
"""Blocking drain: wait for the generate thread to process all queued
|
||||||
|
audio and produce the corresponding text.
|
||||||
|
|
||||||
|
Polls the text queue while the audio queue has items (model still
|
||||||
|
processing). Once the audio queue is empty, waits for trailing
|
||||||
|
tokens, then returns.
|
||||||
|
|
||||||
|
This is critical for start_silence(): without it, the non-blocking
|
||||||
|
drain races with the generate thread and the last words get stuck.
|
||||||
|
"""
|
||||||
|
if not self._generate_started or self._generate_finished:
|
||||||
|
self._drain_streamer()
|
||||||
|
return
|
||||||
|
|
||||||
|
text_queue = self._streamer.text_queue
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
|
||||||
|
while time.time() < deadline:
|
||||||
|
# Short poll while model is still processing queued audio;
|
||||||
|
# longer wait once the audio queue is empty (trailing tokens).
|
||||||
|
wait = 2.0 if self._audio_queue.empty() else 0.1
|
||||||
|
try:
|
||||||
|
text_fragment = text_queue.get(timeout=wait)
|
||||||
|
except queue.Empty:
|
||||||
|
if self._audio_queue.empty():
|
||||||
|
break # Audio done + no text for 2s → fully caught up
|
||||||
|
continue # Audio still queued, model still working
|
||||||
|
if text_fragment is None:
|
||||||
|
self._generate_finished = True
|
||||||
|
break
|
||||||
|
if text_fragment:
|
||||||
|
with self._text_lock:
|
||||||
|
self._accumulated_text += text_fragment
|
||||||
|
self._n_text_tokens_received += 1
|
||||||
|
|
||||||
# ── Word extraction ──
|
# ── Word extraction ──
|
||||||
|
|
||||||
|
|
@ -308,15 +387,15 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
|
|
||||||
words = text.split()
|
words = text.split()
|
||||||
new_words: List[ASRToken] = []
|
new_words: List[ASRToken] = []
|
||||||
n_tokens = self._n_text_tokens_received
|
|
||||||
n_words_total = len(words)
|
n_words_total = len(words)
|
||||||
|
n_audio_toks = max(self._n_audio_tokens_fed, 1)
|
||||||
|
|
||||||
while len(words) > self._n_committed_words + 1:
|
while len(words) > self._n_committed_words + 1:
|
||||||
word = words[self._n_committed_words]
|
word = words[self._n_committed_words]
|
||||||
word_idx = self._n_committed_words
|
word_idx = self._n_committed_words
|
||||||
|
|
||||||
tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0
|
tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0
|
||||||
tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0
|
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0
|
||||||
|
|
||||||
start_time = self._pos_to_time(tok_start)
|
start_time = self._pos_to_time(tok_start)
|
||||||
end_time = self._pos_to_time(tok_end)
|
end_time = self._pos_to_time(tok_end)
|
||||||
|
|
@ -336,15 +415,15 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
|
|
||||||
words = text.split()
|
words = text.split()
|
||||||
new_words: List[ASRToken] = []
|
new_words: List[ASRToken] = []
|
||||||
n_tokens = max(self._n_text_tokens_received, 1)
|
|
||||||
n_words_total = max(len(words), 1)
|
n_words_total = max(len(words), 1)
|
||||||
|
n_audio_toks = max(self._n_audio_tokens_fed, 1)
|
||||||
|
|
||||||
while self._n_committed_words < len(words):
|
while self._n_committed_words < len(words):
|
||||||
word = words[self._n_committed_words]
|
word = words[self._n_committed_words]
|
||||||
word_idx = self._n_committed_words
|
word_idx = self._n_committed_words
|
||||||
|
|
||||||
tok_start = int(word_idx / n_words_total * n_tokens)
|
tok_start = int(word_idx / n_words_total * n_audio_toks)
|
||||||
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks)
|
||||||
|
|
||||||
start_time = self._pos_to_time(tok_start)
|
start_time = self._pos_to_time(tok_start)
|
||||||
end_time = self._pos_to_time(tok_end)
|
end_time = self._pos_to_time(tok_end)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ import math
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# KV Cache
|
# KV Cache
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,12 @@ import numpy as np
|
||||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
|
from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model
|
||||||
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
|
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
|
||||||
from whisperlivekit.voxtral_mlx.spectrogram import (
|
from whisperlivekit.voxtral_mlx.spectrogram import (
|
||||||
SAMPLES_PER_TOKEN,
|
|
||||||
LEFT_PAD_TOKENS,
|
LEFT_PAD_TOKENS,
|
||||||
RIGHT_PAD_TOKENS,
|
RIGHT_PAD_TOKENS,
|
||||||
|
SAMPLES_PER_TOKEN,
|
||||||
compute_mel_streaming,
|
compute_mel_streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue