WhisperLiveKit/whisperlivekit/voxtral_mlx_asr.py
2026-03-15 22:32:01 +01:00

640 lines
25 KiB
Python

"""
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
(streaming processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
Unlike the HuggingFace backend, this runs the full inference loop in-process
(no background thread / queue) — MLX operations on Apple Silicon are fast
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
from whisperlivekit.voxtral_mlx.spectrogram import (
LEFT_PAD_TOKENS,
RIGHT_PAD_TOKENS,
SAMPLES_PER_TOKEN,
compute_mel_streaming,
)
logger = logging.getLogger(__name__)
# Decoder sliding-window size (matches the model's training configuration).
_DECODER_WINDOW = 8192
# Maximum continuous decoding positions before forcing a reset.
# Beyond ~20s of continuous audio the autoregressive context drifts and
# produces hallucination. 20s / 80ms per token = 250 tokens.
_MAX_CONTINUOUS_POSITIONS = 250
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
return ids, n_delay
# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------
class VoxtralMLXASR:
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
it alive for the lifetime of the server."""
sep = " "
SAMPLING_RATE = 16_000
def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL_ID
t0 = time.time()
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
self.backend_choice = "voxtral-mlx"
def transcribe(self, audio):
pass # all work happens in the online processor
# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------
class VoxtralMLXOnlineProcessor:
"""Streaming processor that incrementally encodes audio and decodes text
using the MLX Voxtral model.
Lifecycle (called by ``AudioProcessor.transcription_processor``):
insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
... repeat ...
start_silence() / end_silence()
finish()
"""
SAMPLING_RATE = 16_000
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: list = []
self.audio_buffer = np.array([], dtype=np.float32)
self._model = asr.model
self._tokenizer = asr.tokenizer
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
self._prefix_len = len(self._prompt_ids)
self._delay_cond = self._model.delay_embedding(
mx.array([self._n_delay], dtype=mx.float32)
)
mx.eval(self._delay_cond)
self._prompt_embeds = self._model.decoder.embed(
mx.array([self._prompt_ids])
)[0] # [prefix_len, dim]
mx.eval(self._prompt_embeds)
self._eos_id = self._tokenizer.eos_id
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
# The streaming model has an inherent delay: text for audio at position P
# is generated at decoder position P + n_delay. Compensate timestamps.
self._delay_secs = self._n_delay * self._secs_per_token
self._reset_state()
# -- state management --
def _reset_state(self):
"""Reset all incremental state for a fresh utterance."""
# Audio accumulation (list of chunks, concatenated on demand)
self._pending_chunks: list[np.ndarray] = []
self._pending_len = 0
# Mel overlap
self._mel_overlap: np.ndarray | None = None
# Encoder incremental state
self._conv_tail1 = None
self._conv_tail2 = None
self._enc_cache = None
self._ds_remainder = None
# Audio embeddings not yet decoded
self._audio_embeds: mx.array | None = None
# Decoder state
self._dec_cache: list[SlidingKVCache] | None = None
self._last_token: mx.array | None = None
# Bookkeeping
self._samples_encoded = 0
self._real_samples_encoded = 0 # only real audio, excludes silence padding
self._positions_decoded = 0
self._prefilled = False
self._first_chunk = True
# Text state
self._full_text = ""
self._n_text_tokens = 0
self._n_committed_words = 0
self._time_offset = 0.0
# Per-word audio position tracking: decoder position (relative to prefix)
# where each word in _full_text started and ended
self._word_audio_starts: list[int] = [] # audio pos where word i started
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
# -- audio ingestion --
def _get_pending(self) -> np.ndarray:
"""Flatten pending chunks into a single array."""
if not self._pending_chunks:
return np.zeros(0, dtype=np.float32)
if len(self._pending_chunks) == 1:
return self._pending_chunks[0]
flat = np.concatenate(self._pending_chunks)
self._pending_chunks = [flat]
return flat
def _set_pending(self, arr: np.ndarray):
"""Replace pending audio with a single array."""
if len(arr) == 0:
self._pending_chunks = []
self._pending_len = 0
else:
self._pending_chunks = [arr]
self._pending_len = len(arr)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_chunks.append(audio)
self._pending_len += len(audio)
self._real_samples_encoded += len(audio)
self.audio_buffer = audio # diagnostic only
# -- core processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._step(is_last)
except Exception as e:
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
return [], self.end
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# 0. Safety cap: if continuous decoding exceeds the limit, force a
# flush+reset to prevent hallucination even without VAD silence.
if self._prefilled and self._positions_decoded >= _MAX_CONTINUOUS_POSITIONS + self._prefix_len:
logger.info(
"[voxtral-mlx] continuous decoding cap hit at %d positions — "
"forcing flush+reset",
self._positions_decoded,
)
words = self._flush_and_reset()
return words, self.end
# 1. Encode any new audio
self._encode_pending()
if self._audio_embeds is None:
return [], self.end
# 2. Compute how many positions we can safely decode.
# The safe boundary prevents the decoder from running ahead of the
# audio encoder. _samples_encoded tracks only real audio (not
# silence padding), so positions beyond this produce hallucination.
total_safe = LEFT_PAD_TOKENS + self._real_samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0:
return [], self.end
# 3. Prefill if needed
if not self._prefilled:
if self._positions_decoded + n_available < self._prefix_len:
return [], self.end
self._do_prefill()
# Re-check after consuming prefix embeddings
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0 or self._audio_embeds is None:
return [], self.end
# Clamp to the continuous decoding cap so we don't overshoot
max_left = _MAX_CONTINUOUS_POSITIONS + self._prefix_len - self._positions_decoded
if max_left > 0:
n_decodable = min(n_decodable, max_left)
else:
# Will be caught by the cap check on the next call
return self._extract_committed_words(), self.end
# 4. Decode available positions
hit_eos = self._decode_positions(n_decodable)
if hit_eos:
# Flush words, then full reset for next utterance
words = self._flush_all_words()
logger.debug(
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
"samples_encoded=%d (%.2fs), text='%s'",
len(words), self._samples_encoded,
self._samples_encoded / self.SAMPLING_RATE,
self._full_text[-60:] if self._full_text else "",
)
new_offset = self._time_offset + self._real_samples_encoded / self.SAMPLING_RATE
saved_end = self.end
self._reset_state()
self._time_offset = new_offset
self.end = saved_end
mx.clear_cache()
return words, self.end
# 5. Extract committed words (all but the last, which may still grow)
return self._extract_committed_words(), self.end
def _encode_pending(self):
"""Feed pending audio through the incremental encoder."""
if self._pending_len < SAMPLES_PER_TOKEN:
return
pending = self._get_pending()
available = len(pending)
if self._first_chunk:
# First chunk: prepend silence for left-padding
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
chunk = np.concatenate([left_pad, pending[:n_take]])
self._set_pending(pending[n_take:])
self._samples_encoded += n_take
self._first_chunk = False
else:
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
chunk = pending[:n_take]
self._set_pending(pending[n_take:])
self._samples_encoded += n_take
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
self._model.encode_incremental(
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
)
)
if embeds is not None:
mx.eval(embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
mx.eval(self._audio_embeds)
else:
self._audio_embeds = embeds
def _do_prefill(self):
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
n_dec_layers = len(self._model.decoder.blocks)
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
self._last_token = self._sample(logits)
mx.async_eval(self._last_token)
# Remove consumed prefix embeddings
self._audio_embeds = self._audio_embeds[self._prefix_len :]
if self._audio_embeds.shape[0] == 0:
self._audio_embeds = None
self._positions_decoded = self._prefix_len
self._prefilled = True
def _decode_positions(self, n: int) -> bool:
"""Autoregressively decode *n* positions. Returns True on EOS."""
base_pos = self._positions_decoded # absolute position before this batch
for i in range(n):
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
next_tok = self._sample(logits)
mx.async_eval(next_tok)
token_id = self._last_token.item()
if token_id == self._eos_id:
# Close the current word if one is being built
if self._current_word_pos is not None:
self._word_audio_ends.append(base_pos + i - self._prefix_len)
self._current_word_pos = None
self._trim_embeds(i)
self._positions_decoded += i
return True
text = self._tokenizer.decode(
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
audio_pos = base_pos + i - self._prefix_len
# Detect word boundary: new word starts with space or is the very first text
if text.lstrip() != text or not self._full_text:
# Close previous word if exists
if self._current_word_pos is not None:
self._word_audio_ends.append(audio_pos)
# Start new word
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
elif self._current_word_pos is None:
# First token of first word (no leading space)
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
self._full_text += text
self._n_text_tokens += 1
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._last_token = next_tok
self._positions_decoded += n
self._trim_embeds(n)
return False
def _trim_embeds(self, n_consumed: int):
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
def _sample(self, logits: mx.array) -> mx.array:
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
# -- word extraction --
def _audio_pos_to_time(self, pos: int) -> float:
"""Convert an audio position (relative to prefix end) to seconds."""
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
"""Compute (start, end) time for a word using tracked word positions."""
starts = self._word_audio_starts
ends = self._word_audio_ends
if not starts:
return self._time_offset, self._time_offset
# Get start position for this word
if word_idx < len(starts):
t0 = self._audio_pos_to_time(starts[word_idx])
else:
# Fallback: estimate from last known position
last_pos = ends[-1] if ends else starts[-1]
t0 = self._audio_pos_to_time(last_pos + 1)
# Get end position: use the start of the next word, or the end of this word
if word_idx + 1 < len(starts):
t1 = self._audio_pos_to_time(starts[word_idx + 1])
elif word_idx < len(ends):
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
else:
# Last word, still being built: use last known position + 1 token
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
t1 = self._audio_pos_to_time(last_pos + 1)
return t0, t1
def _extract_committed_words(self) -> List[ASRToken]:
"""Return complete words (all except the last which may still grow)."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while len(words) > self._n_committed_words + 1:
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
def _flush_all_words(self) -> List[ASRToken]:
"""Flush every word including the last partial one."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while self._n_committed_words < len(words):
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
# -- interface methods --
def get_buffer(self) -> Transcript:
if not self._full_text:
return Transcript(start=None, end=None, text="")
words = self._full_text.split()
remaining = words[self._n_committed_words :]
if remaining:
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
return Transcript(start=None, end=None, text="")
def _safe_decode_remaining(self):
"""Decode remaining audio embeddings, respecting the safe boundary.
Uses the same guard as ``_step`` to avoid decoding positions that
are beyond the real audio frontier, which causes hallucination.
"""
if self._audio_embeds is None or not self._prefilled:
return
# Use the same formula as _step() — this excludes padding positions
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, max(0, total_safe - self._positions_decoded))
# Cap at RIGHT_PAD_TOKENS to only decode the padding needed for
# the model to emit final tokens, not all accumulated padding
n_decodable = min(n_decodable, RIGHT_PAD_TOKENS)
if n_decodable > 0:
self._decode_positions(n_decodable)
def _flush_last_token_text(self):
"""Add the last pending token's text (if not EOS) to _full_text."""
if self._last_token is None:
return
tid = self._last_token.item()
if tid == self._eos_id:
return
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if not text:
return
last_pos = self._positions_decoded - self._prefix_len
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
def _close_current_word(self):
"""Close the last word if one is being built."""
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
def _flush_and_reset(self) -> List[ASRToken]:
"""Flush pending audio, decode remaining, extract all words, then
fully reset both encoder and decoder state.
Used at silence boundaries and when the continuous decoding cap is
hit. A full reset (encoder + decoder) is necessary because the
encoder's incremental state (conv tails, KV caches) contains history
that would produce embeddings incompatible with a freshly-initialised
decoder. After reset ``_first_chunk=True``, so the next audio chunk
receives proper left-padding and both encoder and decoder start in
sync.
"""
# Align pending audio to SAMPLES_PER_TOKEN boundary
remainder = self._pending_len % SAMPLES_PER_TOKEN
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
# Add alignment + right-padding silence to provide future context
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0:
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
self._pending_len += total_pad
# Encode remaining audio (including right-padding)
self._encode_pending()
# Decode only positions backed by real audio
self._safe_decode_remaining()
self._flush_last_token_text()
self._close_current_word()
words = self._flush_all_words()
# Compute time offset: the decoded audio covers up to this point
new_offset = self._time_offset + self._real_samples_encoded / self.SAMPLING_RATE
saved_end = self.end
# Full reset — encoder AND decoder. The encoder's incremental
# state (conv tails, transformer KV caches) carries history from
# the previous segment; keeping it would make the next set of
# embeddings incompatible with a fresh decoder prefill.
self._reset_state()
self._time_offset = new_offset
self.end = saved_end
# Free MLX caches eagerly
mx.clear_cache()
return words
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all pending words when silence starts, then fully reset.
Adds right-padding silence and forces a decode pass so the
decoder emits tokens for the last words of speech. After flushing,
resets both encoder and decoder state to prevent hallucination from
accumulated autoregressive context drift on long audio.
"""
words = self._flush_and_reset()
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug(
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
self._pending_len,
self._audio_embeds.shape if self._audio_embeds is not None else None,
self._samples_encoded,
self._positions_decoded,
self._prefilled,
self._full_text[-80:] if self._full_text else "",
)
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
remainder = self._pending_len % SAMPLES_PER_TOKEN
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
# Add alignment + right-padding silence
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0:
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
self._pending_len += total_pad
# Encode remaining audio (including right-padding)
self._encode_pending()
# Decode only positions backed by real audio
self._safe_decode_remaining()
self._flush_last_token_text()
self._close_current_word()
words = self._flush_all_words()
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
return words, self.end