feat: add voxtral-mlx native backend for Apple Silicon

Pure-MLX implementation of Voxtral Mini 4B Realtime for low-latency
speech transcription on Apple Silicon. Avoids the transformers/torch
overhead and runs at 0.18-0.32x real-time factor.

- voxtral_mlx/model.py: MLX model with spectrogram, encoder, decoder
- voxtral_mlx/loader.py: model loading with 6-bit quantized weights
- voxtral_mlx/spectrogram.py: mel spectrogram computation in MLX
- voxtral_mlx_asr.py: VoxtralASR adapter for the AudioProcessor pipeline
This commit is contained in:
Quentin Fuxa 2026-02-22 23:28:10 +01:00
parent 9b2c3ee844
commit a4da246ea5
5 changed files with 1545 additions and 0 deletions

View file

@ -0,0 +1,6 @@
"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit."""
from .loader import load_voxtral_model
from .model import VoxtralMLXModel
__all__ = ["load_voxtral_model", "VoxtralMLXModel"]

View file

@ -0,0 +1,282 @@
"""
Model weight loading for the MLX Voxtral Realtime backend.
Supports two on-disk formats:
1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
with optional quantisation metadata.
2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
requires weight renaming and conv-weight transposition.
The public entry point is :func:`load_voxtral_model` which returns the
model, tokenizer, and raw config dict.
"""
import json
import logging
import re
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from .model import VoxtralMLXModel
logger = logging.getLogger(__name__)
DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
# ---------------------------------------------------------------------------
# Downloading
# ---------------------------------------------------------------------------
_ALLOWED_PATTERNS = [
"consolidated.safetensors",
"model*.safetensors",
"model.safetensors.index.json",
"params.json",
"config.json",
"tekken.json",
]
def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
"""Download model files from HuggingFace Hub and return the local path."""
return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
# ---------------------------------------------------------------------------
# Weight name remapping (Mistral → our naming)
# ---------------------------------------------------------------------------
_NAME_RULES: list[tuple[str, str]] = [
# Encoder convolutions
(r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
(r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
# Encoder transformer blocks
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
r"encoder.blocks.\1.self_attn.q_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
r"encoder.blocks.\1.self_attn.k_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
r"encoder.blocks.\1.self_attn.v_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
r"encoder.blocks.\1.self_attn.out_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
r"encoder.blocks.\1.pre_attn_norm.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
r"encoder.blocks.\1.ffn.gate.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
r"encoder.blocks.\1.ffn.down.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
r"encoder.blocks.\1.ffn.up.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
r"encoder.blocks.\1.pre_ffn_norm.\2"),
(r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
# Adapter
(r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
(r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
# Decoder embedding
(r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
# Decoder blocks
(r"layers\.(\d+)\.attention\.wq\.weight",
r"decoder.blocks.\1.self_attn.q_proj.weight"),
(r"layers\.(\d+)\.attention\.wk\.weight",
r"decoder.blocks.\1.self_attn.k_proj.weight"),
(r"layers\.(\d+)\.attention\.wv\.weight",
r"decoder.blocks.\1.self_attn.v_proj.weight"),
(r"layers\.(\d+)\.attention\.wo\.weight",
r"decoder.blocks.\1.self_attn.out_proj.weight"),
(r"layers\.(\d+)\.attention_norm\.weight",
r"decoder.blocks.\1.pre_attn_norm.weight"),
(r"layers\.(\d+)\.feed_forward\.w1\.weight",
r"decoder.blocks.\1.ffn.gate.weight"),
(r"layers\.(\d+)\.feed_forward\.w2\.weight",
r"decoder.blocks.\1.ffn.down.weight"),
(r"layers\.(\d+)\.feed_forward\.w3\.weight",
r"decoder.blocks.\1.ffn.up.weight"),
(r"layers\.(\d+)\.ffn_norm\.weight",
r"decoder.blocks.\1.pre_ffn_norm.weight"),
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
# Decoder final norm
(r"norm\.weight", r"decoder.final_norm.weight"),
]
_PREFIX_STRIP = re.compile(
r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
)
def _translate_weight_name(name: str) -> str | None:
name = _PREFIX_STRIP.sub("", name)
for pattern, replacement in _NAME_RULES:
result, n = re.subn(f"^{pattern}$", replacement, name)
if n:
return result
return None
def _is_conv_weight(name: str) -> bool:
return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
# ---------------------------------------------------------------------------
# Converted-format weight remapping (voxmlx names → our names)
# ---------------------------------------------------------------------------
_CONVERTED_RULES: list[tuple[str, str]] = [
# Adapter
(r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
(r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
# Encoder transformer blocks
(r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
(r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
(r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
(r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
# Decoder embedding
(r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
# Decoder blocks
(r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
(r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
(r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
(r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
]
# Also remap o_proj → out_proj in both encoder and decoder
_POST_RENAME = [
(r"\.o_proj\.", r".out_proj."),
]
def _remap_converted_name(name: str) -> str:
"""Translate a converted-format weight name to our naming convention."""
for pattern, replacement in _CONVERTED_RULES:
result, n = re.subn(f"^{pattern}$", replacement, name)
if n:
name = result
break
for pattern, replacement in _POST_RENAME:
name = re.sub(pattern, replacement, name)
return name
# ---------------------------------------------------------------------------
# Loading strategies
# ---------------------------------------------------------------------------
def _has_converted_layout(path: Path) -> bool:
return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
def _load_converted_weights(path: Path):
with open(path / "config.json") as f:
config = json.load(f)
model = VoxtralMLXModel(config)
quant = config.get("quantization")
if quant is not None:
gs = quant["group_size"]
nn.quantize(
model,
group_size=gs,
bits=quant["bits"],
class_predicate=lambda _p, m: (
hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
),
)
index_file = path / "model.safetensors.index.json"
if index_file.exists():
with open(index_file) as f:
shard_map = json.load(f)
shard_files = sorted(set(shard_map["weight_map"].values()))
weights = {}
for sf in shard_files:
weights.update(mx.load(str(path / sf)))
else:
weights = mx.load(str(path / "model.safetensors"))
remapped = {_remap_converted_name(k): v for k, v in weights.items()}
model.load_weights(list(remapped.items()))
mx.eval(model.parameters())
return model, config
def _load_original_weights(path: Path):
with open(path / "params.json") as f:
config = json.load(f)
model = VoxtralMLXModel(config)
raw = mx.load(str(path / "consolidated.safetensors"))
mapped: dict[str, mx.array] = {}
skipped: list[str] = []
for name, tensor in raw.items():
if name == "output.weight":
continue
new_name = _translate_weight_name(name)
if new_name is None:
skipped.append(name)
continue
# Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
if _is_conv_weight(new_name):
tensor = mx.swapaxes(tensor, 1, 2)
mapped[new_name] = tensor
if skipped:
logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
model.load_weights(list(mapped.items()))
mx.eval(model.parameters())
return model, config
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def _load_tokenizer(model_dir: Path):
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
return Tekkenizer.from_file(str(model_dir / "tekken.json"))
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
"""Load a Voxtral Realtime model and its tokenizer.
Args:
path_or_id: Local directory path **or** a HuggingFace model ID.
Returns:
``(model, tokenizer, config)``
"""
p = Path(path_or_id)
if not p.exists():
p = download_weights(path_or_id)
if _has_converted_layout(p):
model, config = _load_converted_weights(p)
else:
model, config = _load_original_weights(p)
tokenizer = _load_tokenizer(p)
logger.info("Voxtral MLX model loaded from %s", p)
return model, tokenizer, config

View file

@ -0,0 +1,534 @@
"""
Voxtral Realtime MLX model encoder, decoder, adapter, and top-level model.
Architecture:
audio StreamingEncoder EncoderToDecoderAdapter TextDecoder logits
with DelayEmbedding providing time-conditioning to the decoder.
The model supports both batch inference (full audio) and incremental streaming
(one chunk at a time with cached encoder/decoder state).
"""
import math
import mlx.core as mx
import mlx.nn as nn
# ---------------------------------------------------------------------------
# KV Cache
# ---------------------------------------------------------------------------
class SlidingKVCache:
"""Bounded key-value cache with rotating buffer for sliding-window attention.
Uses in-place writes for single-token autoregressive steps and
concatenation for multi-token prefills. Pre-allocates in blocks of
``alloc_step`` entries to reduce repeated allocation.
"""
alloc_step = 256
def __init__(self, capacity: int):
self.capacity = capacity
self.keys = None
self.values = None
self._offset = 0
self._write_idx = 0
@property
def offset(self) -> int:
return self._offset
# -- helpers --
def _reorder(self, buf):
"""Return *buf* in temporal order (unwrap the circular buffer)."""
if self._write_idx == buf.shape[2]:
return buf
if self._write_idx < self._offset:
return mx.concatenate(
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
axis=2,
)
return buf[..., : self._write_idx, :]
def _drop_oldest(self, buf, n_drop, tail=None):
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
if tail is not None:
parts.append(tail)
return mx.concatenate(parts, axis=2)
# -- update strategies --
def _append_concat(self, k, v):
"""Multi-token update via concatenation (used during prefill)."""
if self.keys is None:
self.keys, self.values = k, v
else:
self.keys = self._reorder(self.keys)
self.values = self._reorder(self.values)
self._write_idx = self.keys.shape[2]
overflow = self._write_idx - self.capacity + 1
self.keys = self._drop_oldest(self.keys, overflow, k)
self.values = self._drop_oldest(self.values, overflow, v)
self._offset += k.shape[2]
self._write_idx = self.keys.shape[2]
return self.keys, self.values
def _write_inplace(self, k, v):
"""Single-token update via in-place write (autoregressive step)."""
B, n_heads, S, dim_k = k.shape
dim_v = v.shape[3]
prev = self._offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
):
n_new = min(self.alloc_step, self.capacity - prev)
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
self.values = mx.concatenate([self.values, fresh_v], axis=2)
else:
self.keys, self.values = fresh_k, fresh_v
self._write_idx = prev
overflow = self.keys.shape[2] - self.capacity
if overflow > 0:
self.keys = self._drop_oldest(self.keys, overflow)
self.values = self._drop_oldest(self.values, overflow)
self._write_idx = self.capacity
if self._write_idx == self.capacity:
self._write_idx = 0
self.keys[..., self._write_idx : self._write_idx + S, :] = k
self.values[..., self._write_idx : self._write_idx + S, :] = v
self._offset += S
self._write_idx += S
if self._offset < self.capacity:
return (
self.keys[..., : self._offset, :],
self.values[..., : self._offset, :],
)
return self.keys, self.values
# -- public API --
def update_and_fetch(self, k, v):
if k.shape[2] == 1:
return self._write_inplace(k, v)
return self._append_concat(k, v)
# ---------------------------------------------------------------------------
# Encoder components
# ---------------------------------------------------------------------------
class CausalConv(nn.Module):
"""1-D causal convolution (left-padded so no future leakage)."""
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
super().__init__()
self.stride = stride
self.kernel = kernel
self.left_pad = kernel - stride
self.weight = mx.zeros((channels_out, kernel, channels_in))
self.bias = mx.zeros((channels_out,))
def __call__(self, x: mx.array) -> mx.array:
if self.left_pad > 0:
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
class _EncoderSelfAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope_theta = rope_theta
def __call__(self, x, mask, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _EncoderFFN(nn.Module):
"""SwiGLU feed-forward for encoder layers."""
def __init__(self, dim: int, hidden: int):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=True)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class _EncoderBlock(nn.Module):
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _EncoderFFN(dim, hidden)
def __call__(self, x, mask, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
x = x + self.ffn(self.pre_ffn_norm(x))
return x
class StreamingEncoder(nn.Module):
"""Causal Whisper-style encoder with two causal convolutions followed by
a stack of transformer blocks. Supports both full-sequence and
incremental (streaming) forward passes."""
def __init__(
self,
mel_channels: int = 128,
dim: int = 1280,
n_layers: int = 32,
n_heads: int = 32,
head_dim: int = 64,
hidden_dim: int = 5120,
rope_theta: float = 1e6,
sliding_window: int = 750,
):
super().__init__()
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
self.blocks = [
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
self.sliding_window = sliding_window
# -- full-sequence --
def _apply_convs(self, mel: mx.array) -> mx.array:
x = mel.T[None, :, :] # [1, T, mel_channels]
x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x))
return x
def forward(self, mel: mx.array) -> mx.array:
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
for blk in self.blocks:
x = blk(x, mask="causal")
return self.final_norm(x)
# -- incremental (streaming) --
def forward_conv_incremental(self, x_in, tail1, tail2):
"""Process new mel frames through the two causal convs using cached tails.
Args:
x_in: [1, N, mel_channels]
tail1: [1, pad1, mel_channels] or None (first call)
tail2: [1, pad2, dim] or None (first call)
Returns:
(out, new_tail1, new_tail2)
"""
# Conv1 (kernel=3, stride=1 → left_pad=2)
if tail1 is not None:
c1_in = mx.concatenate([tail1, x_in], axis=1)
else:
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
c1_out = nn.gelu(
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
)
# Conv2 (kernel=3, stride=2 → left_pad=1)
if tail2 is not None:
c2_in = mx.concatenate([tail2, c1_out], axis=1)
else:
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
c2_out = nn.gelu(
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
)
return c2_out, new_tail1, new_tail2
def forward_transformer_incremental(self, x, cache_list):
"""Run transformer blocks with per-layer KV caches."""
for i, blk in enumerate(self.blocks):
x = blk(x, mask="causal", cache=cache_list[i])
return self.final_norm(x)
# ---------------------------------------------------------------------------
# Decoder components
# ---------------------------------------------------------------------------
class _DecoderAttention(nn.Module):
"""Grouped-query attention for the text decoder."""
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope_theta = rope_theta
def __call__(self, x, mask=None, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _DecoderFFN(nn.Module):
"""SwiGLU feed-forward for decoder layers."""
def __init__(self, dim, hidden):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=False)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class AdaptiveScaling(nn.Module):
"""Small MLP that produces a multiplicative scale from the delay embedding,
used to condition the FFN on the streaming delay."""
def __init__(self, dim, bottleneck):
super().__init__()
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
def __call__(self, cond):
return self.proj_out(nn.gelu(self.proj_in(cond)))
class _DecoderBlock(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _DecoderFFN(dim, hidden)
def __call__(self, x, delay_cond, mask=None, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
x = x + self.ffn(scaled)
return x
class TextDecoder(nn.Module):
"""Mistral-style causal language model with adaptive time-conditioning."""
def __init__(
self,
dim: int = 3072,
n_layers: int = 26,
n_heads: int = 32,
n_kv_heads: int = 8,
head_dim: int = 128,
hidden_dim: int = 9216,
vocab_size: int = 131072,
rope_theta: float = 1e6,
cond_dim: int = 32,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, dim)
self.blocks = [
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
def embed(self, token_ids: mx.array) -> mx.array:
return self.token_embedding(token_ids)
def __call__(self, x, delay_cond, mask=None, cache=None):
delay_cond = delay_cond.astype(x.dtype)
for i, blk in enumerate(self.blocks):
blk_cache = cache[i] if cache is not None else None
x = blk(x, delay_cond, mask, blk_cache)
x = self.final_norm(x)
return self.token_embedding.as_linear(x)
# ---------------------------------------------------------------------------
# Adapter & embeddings
# ---------------------------------------------------------------------------
class EncoderToDecoderAdapter(nn.Module):
"""Two-layer projection from encoder space to decoder space."""
def __init__(self, enc_dim: int, dec_dim: int):
super().__init__()
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
def __call__(self, x):
return self.linear2(nn.gelu(self.linear1(x)))
class DelayEmbedding(nn.Module):
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
vector for the decoder's adaptive scaling."""
def __init__(self, dim: int = 3072, theta: float = 10000.0):
super().__init__()
self.dim = dim
half = dim // 2
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
self._freqs = freqs
def __call__(self, delay: mx.array) -> mx.array:
t = delay.reshape(-1, 1).astype(mx.float32)
angles = t * self._freqs
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class VoxtralMLXModel(nn.Module):
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
def __init__(self, config: dict):
super().__init__()
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
audio_cfg = enc_cfg["audio_encoding_args"]
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
self.encoder = StreamingEncoder(
mel_channels=audio_cfg["num_mel_bins"],
dim=enc_cfg["dim"],
n_layers=enc_cfg["n_layers"],
n_heads=enc_cfg["n_heads"],
head_dim=enc_cfg["head_dim"],
hidden_dim=enc_cfg["hidden_dim"],
rope_theta=enc_cfg["rope_theta"],
sliding_window=enc_cfg["sliding_window"],
)
adapter_input_dim = enc_cfg["dim"] * ds_factor
decoder_dim = config["dim"]
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
self.decoder = TextDecoder(
dim=decoder_dim,
n_layers=config["n_layers"],
n_heads=config["n_heads"],
n_kv_heads=config["n_kv_heads"],
head_dim=config["head_dim"],
hidden_dim=config["hidden_dim"],
vocab_size=config["vocab_size"],
rope_theta=config["rope_theta"],
cond_dim=cond_bottleneck,
)
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
self.ds_factor = ds_factor
# -- batch encode --
def encode(self, mel: mx.array) -> mx.array:
T = mel.shape[1]
if T % 2 != 0:
mel = mel[:, 1:]
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
h = h[0]
n = h.shape[0]
trim = n % self.ds_factor
if trim:
h = h[trim:]
n = h.shape[0]
h = h.reshape(n // self.ds_factor, -1)
return self.adapter(h)
# -- incremental encode --
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
"""Incrementally encode new mel frames.
Returns:
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
"""
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
if enc_cache is None:
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
x = self.encoder.forward_transformer_incremental(x, enc_cache)
x = x[0] # [N, enc_dim]
if ds_remainder is not None:
x = mx.concatenate([ds_remainder, x])
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
if n_full == 0:
return None, conv_tail1, conv_tail2, enc_cache, x
leftover = x[n_full:] if x.shape[0] > n_full else None
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
# -- decode --
def decode(self, embeddings, delay_cond, mask=None, cache=None):
return self.decoder(embeddings, delay_cond, mask, cache)

View file

@ -0,0 +1,202 @@
"""
Mel spectrogram computation for Voxtral Realtime.
Provides both a full-audio function and an incremental streaming variant
that maintains overlap state between calls. The DFT is computed via
matrix multiplication in MLX no external FFT dependency required.
"""
import math
import mlx.core as mx
import numpy as np
# Audio / mel constants matching the Voxtral Realtime model expectations.
SAMPLE_RATE = 16_000
WINDOW_SIZE = 400 # n_fft
HOP = 160
MEL_BANDS = 128
MEL_MAX = 1.5 # global log-mel normalisation ceiling
# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4)
SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms
# Padding tokens used by the model prompt structure.
LEFT_PAD_TOKENS = 32
RIGHT_PAD_TOKENS = 17
# ---------------------------------------------------------------------------
# Slaney mel filterbank
# ---------------------------------------------------------------------------
def _build_slaney_filterbank(
sr: int = SAMPLE_RATE,
n_fft: int = WINDOW_SIZE,
n_mels: int = MEL_BANDS,
lo_hz: float = 0.0,
hi_hz: float = 8000.0,
) -> np.ndarray:
"""Compute a Slaney-normalised triangular mel filterbank.
Returns an array of shape ``[n_mels, n_fft//2 + 1]``.
"""
def _hz2mel(f):
threshold = 1000.0
base_mel = 15.0
log_coeff = 27.0 / np.log(6.4)
mel = 3.0 * f / 200.0
if isinstance(f, np.ndarray):
above = f >= threshold
mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff
elif f >= threshold:
mel = base_mel + np.log(f / threshold) * log_coeff
return mel
def _mel2hz(m):
threshold = 1000.0
base_mel = 15.0
log_coeff = np.log(6.4) / 27.0
hz = 200.0 * m / 3.0
above = m >= base_mel
hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel))
return hz
n_bins = n_fft // 2 + 1
fft_hz = np.linspace(0, sr / 2, n_bins)
mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz)
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
hz_pts = _mel2hz(mel_pts)
diffs = np.diff(hz_pts)
slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1)
rising = -slopes[:, :-2] / diffs[:-1]
falling = slopes[:, 2:] / diffs[1:]
fb = np.maximum(0.0, np.minimum(rising, falling))
# Slaney area normalisation
widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels])
fb *= np.expand_dims(widths, 0)
return fb.T.astype(np.float32)
_CACHED_FILTERS: mx.array | None = None
def _mel_filters() -> mx.array:
global _CACHED_FILTERS
if _CACHED_FILTERS is None:
_CACHED_FILTERS = mx.array(_build_slaney_filterbank())
return _CACHED_FILTERS
# ---------------------------------------------------------------------------
# DFT helpers
# ---------------------------------------------------------------------------
def _hann_window() -> mx.array:
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
def _dft_matrices():
"""Pre-compute the real / imaginary DFT basis matrices."""
n_bins = WINDOW_SIZE // 2 + 1
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
return mx.cos(phase), mx.sin(phase)
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
"""Frame *audio* using the Hann window and compute power spectrogram."""
n_bins = WINDOW_SIZE // 2 + 1
n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP
if n_frames <= 0:
return mx.zeros((0, n_bins))
offsets = (mx.arange(n_frames) * HOP)[:, None]
indices = offsets + mx.arange(WINDOW_SIZE)[None, :]
windowed = audio[indices] * window[None, :]
dft_re, dft_im = _dft_matrices()
real_part = windowed @ dft_re.T
imag_part = windowed @ dft_im.T
return real_part ** 2 + imag_part ** 2
def _apply_mel_and_log(power: mx.array) -> mx.array:
"""Convert a power spectrogram to log-mel and normalise."""
mel = power @ _mel_filters().T
log_mel = mx.log10(mx.maximum(mel, 1e-10))
log_mel = mx.maximum(log_mel, MEL_MAX - 8.0)
return (log_mel + 4.0) / 4.0
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def compute_mel(audio: np.ndarray) -> mx.array:
"""Compute log-mel spectrogram for a complete audio signal.
Args:
audio: 1-D float32 numpy array at ``SAMPLE_RATE``.
Returns:
``[MEL_BANDS, T]`` MLX array.
"""
x = mx.array(audio)
pad = WINDOW_SIZE // 2
x = mx.pad(x, [(pad, pad)])
window = _hann_window()
power = _stft_frames(x, window)
# Drop last frame to match reference STFT behaviour
power = power[:-1]
return _apply_mel_and_log(power).T
def compute_mel_streaming(
chunk: np.ndarray,
overlap: np.ndarray | None,
) -> tuple[mx.array, np.ndarray]:
"""Incrementally compute log-mel for a new audio chunk.
Args:
chunk: New audio samples (float32 numpy).
overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the
previous call, or *None* on the first call (uses zero-padding).
Returns:
``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and
*new_overlap* is the 240-sample tail for the next call.
"""
tail_len = WINDOW_SIZE - HOP # 240
if overlap is not None:
combined = np.concatenate([overlap, chunk])
else:
combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk])
new_overlap = combined[-tail_len:].copy()
x = mx.array(combined)
window = _hann_window()
power = _stft_frames(x, window)
if power.shape[0] == 0:
return mx.zeros((MEL_BANDS, 0)), new_overlap
return _apply_mel_and_log(power).T, new_overlap
def pad_audio(
audio: np.ndarray,
n_left: int = LEFT_PAD_TOKENS,
n_right: int = RIGHT_PAD_TOKENS,
) -> np.ndarray:
"""Pad audio with silence for batch (non-streaming) inference."""
left = n_left * SAMPLES_PER_TOKEN
align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN
right = align + n_right * SAMPLES_PER_TOKEN
return np.pad(audio, (left, right))

View file

@ -0,0 +1,521 @@
"""
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
(streaming processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
Unlike the HuggingFace backend, this runs the full inference loop in-process
(no background thread / queue) MLX operations on Apple Silicon are fast
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
from whisperlivekit.voxtral_mlx.spectrogram import (
SAMPLES_PER_TOKEN,
LEFT_PAD_TOKENS,
RIGHT_PAD_TOKENS,
compute_mel_streaming,
)
logger = logging.getLogger(__name__)
# Decoder sliding-window size (matches the model's training configuration).
_DECODER_WINDOW = 8192
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
return ids, n_delay
# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------
class VoxtralMLXASR:
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
it alive for the lifetime of the server."""
sep = " "
SAMPLING_RATE = 16_000
def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL_ID
t0 = time.time()
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
self.backend_choice = "voxtral-mlx"
def transcribe(self, audio):
pass # all work happens in the online processor
# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------
class VoxtralMLXOnlineProcessor:
"""Streaming processor that incrementally encodes audio and decodes text
using the MLX Voxtral model.
Lifecycle (called by ``AudioProcessor.transcription_processor``):
insert_audio_chunk(pcm, time) process_iter() get_buffer()
... repeat ...
start_silence() / end_silence()
finish()
"""
SAMPLING_RATE = 16_000
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: list = []
self.audio_buffer = np.array([], dtype=np.float32)
self._model = asr.model
self._tokenizer = asr.tokenizer
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
self._prefix_len = len(self._prompt_ids)
self._delay_cond = self._model.delay_embedding(
mx.array([self._n_delay], dtype=mx.float32)
)
mx.eval(self._delay_cond)
self._prompt_embeds = self._model.decoder.embed(
mx.array([self._prompt_ids])
)[0] # [prefix_len, dim]
mx.eval(self._prompt_embeds)
self._eos_id = self._tokenizer.eos_id
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
# The streaming model has an inherent delay: text for audio at position P
# is generated at decoder position P + n_delay. Compensate timestamps.
self._delay_secs = self._n_delay * self._secs_per_token
self._reset_state()
# -- state management --
def _reset_state(self):
"""Reset all incremental state for a fresh utterance."""
# Audio accumulation
self._pending = np.zeros(0, dtype=np.float32)
# Mel overlap
self._mel_overlap: np.ndarray | None = None
# Encoder incremental state
self._conv_tail1 = None
self._conv_tail2 = None
self._enc_cache = None
self._ds_remainder = None
# Audio embeddings not yet decoded
self._audio_embeds: mx.array | None = None
# Decoder state
self._dec_cache: list[SlidingKVCache] | None = None
self._last_token: mx.array | None = None
# Bookkeeping
self._samples_encoded = 0
self._positions_decoded = 0
self._prefilled = False
self._first_chunk = True
# Text state
self._full_text = ""
self._n_text_tokens = 0
self._n_committed_words = 0
self._time_offset = 0.0
# Per-word audio position tracking: decoder position (relative to prefix)
# where each word in _full_text started and ended
self._word_audio_starts: list[int] = [] # audio pos where word i started
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending = np.append(self._pending, audio)
self.audio_buffer = self._pending
# -- core processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._step(is_last)
except Exception as e:
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
return [], self.end
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# 1. Encode any new audio
self._encode_pending()
if self._audio_embeds is None:
return [], self.end
# 2. Compute how many positions we can safely decode
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0:
return [], self.end
# 3. Prefill if needed
if not self._prefilled:
if self._positions_decoded + n_available < self._prefix_len:
return [], self.end
self._do_prefill()
# Re-check after consuming prefix embeddings
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0 or self._audio_embeds is None:
return [], self.end
# 4. Decode available positions
hit_eos = self._decode_positions(n_decodable)
if hit_eos:
# Flush words, reset for next utterance
words = self._flush_all_words()
logger.debug(
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
"samples_encoded=%d (%.2fs), text='%s'",
len(words), self._samples_encoded,
self._samples_encoded / self.SAMPLING_RATE,
self._full_text[-60:] if self._full_text else "",
)
saved_offset = self._time_offset
self._reset_state()
self._time_offset = saved_offset
return words, self.end
# 5. Extract committed words (all but the last, which may still grow)
return self._extract_committed_words(), self.end
def _encode_pending(self):
"""Feed pending audio through the incremental encoder."""
available = len(self._pending)
if available < SAMPLES_PER_TOKEN:
return
if self._first_chunk:
# First chunk: prepend silence for left-padding
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
chunk = np.concatenate([left_pad, self._pending[:n_take]])
self._pending = self._pending[n_take:]
self._samples_encoded += n_take
self._first_chunk = False
else:
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
chunk = self._pending[:n_take]
self._pending = self._pending[n_take:]
self._samples_encoded += n_take
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
self._model.encode_incremental(
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
)
)
if embeds is not None:
mx.eval(embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
else:
self._audio_embeds = embeds
self.audio_buffer = self._pending
def _do_prefill(self):
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
n_dec_layers = len(self._model.decoder.blocks)
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
self._last_token = self._sample(logits)
mx.async_eval(self._last_token)
# Remove consumed prefix embeddings
self._audio_embeds = self._audio_embeds[self._prefix_len :]
if self._audio_embeds.shape[0] == 0:
self._audio_embeds = None
self._positions_decoded = self._prefix_len
self._prefilled = True
def _decode_positions(self, n: int) -> bool:
"""Autoregressively decode *n* positions. Returns True on EOS."""
base_pos = self._positions_decoded # absolute position before this batch
for i in range(n):
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
next_tok = self._sample(logits)
mx.async_eval(next_tok)
token_id = self._last_token.item()
if token_id == self._eos_id:
# Close the current word if one is being built
if self._current_word_pos is not None:
self._word_audio_ends.append(base_pos + i - self._prefix_len)
self._current_word_pos = None
self._trim_embeds(i)
self._positions_decoded += i
return True
text = self._tokenizer.decode(
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
audio_pos = base_pos + i - self._prefix_len
# Detect word boundary: new word starts with space or is the very first text
if text.lstrip() != text or not self._full_text:
# Close previous word if exists
if self._current_word_pos is not None:
self._word_audio_ends.append(audio_pos)
# Start new word
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
elif self._current_word_pos is None:
# First token of first word (no leading space)
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
self._full_text += text
self._n_text_tokens += 1
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._last_token = next_tok
self._positions_decoded += n
self._trim_embeds(n)
return False
def _trim_embeds(self, n_consumed: int):
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
def _sample(self, logits: mx.array) -> mx.array:
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
# -- word extraction --
def _audio_pos_to_time(self, pos: int) -> float:
"""Convert an audio position (relative to prefix end) to seconds."""
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
"""Compute (start, end) time for a word using tracked word positions."""
starts = self._word_audio_starts
ends = self._word_audio_ends
if not starts:
return self._time_offset, self._time_offset
# Get start position for this word
if word_idx < len(starts):
t0 = self._audio_pos_to_time(starts[word_idx])
else:
# Fallback: estimate from last known position
last_pos = ends[-1] if ends else starts[-1]
t0 = self._audio_pos_to_time(last_pos + 1)
# Get end position: use the start of the next word, or the end of this word
if word_idx + 1 < len(starts):
t1 = self._audio_pos_to_time(starts[word_idx + 1])
elif word_idx < len(ends):
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
else:
# Last word, still being built: use last known position + 1 token
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
t1 = self._audio_pos_to_time(last_pos + 1)
return t0, t1
def _extract_committed_words(self) -> List[ASRToken]:
"""Return complete words (all except the last which may still grow)."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while len(words) > self._n_committed_words + 1:
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
def _flush_all_words(self) -> List[ASRToken]:
"""Flush every word including the last partial one."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while self._n_committed_words < len(words):
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
# -- interface methods --
def get_buffer(self) -> Transcript:
if not self._full_text:
return Transcript(start=None, end=None, text="")
words = self._full_text.split()
remaining = words[self._n_committed_words :]
if remaining:
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
words = self._flush_all_words()
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug(
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
len(self._pending),
self._audio_embeds.shape if self._audio_embeds is not None else None,
self._samples_encoded,
self._positions_decoded,
self._prefilled,
self._full_text[-80:] if self._full_text else "",
)
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
remainder = len(self._pending) % SAMPLES_PER_TOKEN
if remainder > 0:
align_pad = SAMPLES_PER_TOKEN - remainder
else:
align_pad = 0
# Add alignment + right-padding silence
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0:
self._pending = np.append(
self._pending, np.zeros(total_pad, dtype=np.float32)
)
# Encode remaining audio (including right-padding)
self._encode_pending()
logger.debug(
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
self._audio_embeds.shape if self._audio_embeds is not None else None,
len(self._pending),
)
hit_eos = False
# Decode everything that's left from right-padding
if self._audio_embeds is not None and self._prefilled:
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
logger.debug(
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
hit_eos, self._full_text[-80:] if self._full_text else "",
)
# Flush last token if it wasn't EOS
if self._last_token is not None:
tid = self._last_token.item()
if tid != self._eos_id:
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
last_pos = self._positions_decoded - self._prefix_len
# Check if this starts a new word
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
# Close the last word if still open
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
words = self._flush_all_words()
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
return words, self.end