voxtral mlx : improved chunking

This commit is contained in:
Quentin Fuxa 2026-03-14 00:13:29 +01:00
parent 9d8db7ab38
commit dfd5bf417c
11 changed files with 1812 additions and 171 deletions

View file

@ -13,6 +13,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
logging.getLogger("whisperlivekit.qwen3_asr").setLevel(logging.DEBUG)
config = parse_args() config = parse_args()
transcription_engine = None transcription_engine = None

View file

@ -57,6 +57,8 @@ BACKENDS = [
"install": "pip install faster-whisper", "install": "pip install faster-whisper",
"description": "CTranslate2-based Whisper (fast, CPU/CUDA)", "description": "CTranslate2-based Whisper (fast, CPU/CUDA)",
"policy": "localagreement", "policy": "localagreement",
"streaming": "chunk", # batch inference with LocalAgreement/SimulStreaming
"devices": ["cpu", "cuda"],
}, },
{ {
"id": "whisper", "id": "whisper",
@ -65,6 +67,8 @@ BACKENDS = [
"install": "pip install openai-whisper", "install": "pip install openai-whisper",
"description": "Original OpenAI Whisper (PyTorch)", "description": "Original OpenAI Whisper (PyTorch)",
"policy": "simulstreaming", "policy": "simulstreaming",
"streaming": "chunk",
"devices": ["cpu", "cuda"],
}, },
{ {
"id": "mlx-whisper", "id": "mlx-whisper",
@ -74,21 +78,27 @@ BACKENDS = [
"description": "Apple Silicon native Whisper (MLX)", "description": "Apple Silicon native Whisper (MLX)",
"policy": "localagreement", "policy": "localagreement",
"platform": "darwin-arm64", "platform": "darwin-arm64",
"streaming": "chunk",
"devices": ["mlx"],
}, },
{ {
"id": "voxtral-mlx", "id": "voxtral-mlx",
"name": "Voxtral MLX", "name": "Voxtral MLX",
"module": "mlx", "module": "mlx",
"install": "pip install whisperlivekit[voxtral-mlx]", "install": "pip install whisperlivekit[voxtral-mlx]",
"description": "Mistral Voxtral Mini on Apple Silicon (MLX)", "description": "Mistral Voxtral Mini on Apple Silicon (MLX, native streaming)",
"platform": "darwin-arm64", "platform": "darwin-arm64",
"streaming": "native", # truly streaming (token-by-token)
"devices": ["mlx"],
}, },
{ {
"id": "voxtral", "id": "voxtral",
"name": "Voxtral HF", "name": "Voxtral HF",
"module": "transformers", "module": "transformers",
"install": "pip install whisperlivekit[voxtral-hf]", "install": "pip install whisperlivekit[voxtral-hf]",
"description": "Mistral Voxtral Mini (HF Transformers, CUDA/CPU/MPS)", "description": "Mistral Voxtral Mini (HF Transformers, native streaming)",
"streaming": "native",
"devices": ["cuda", "mps", "cpu"],
}, },
{ {
"id": "qwen3", "id": "qwen3",
@ -96,6 +106,8 @@ BACKENDS = [
"module": "qwen_asr", "module": "qwen_asr",
"install": "pip install qwen-asr", "install": "pip install qwen-asr",
"description": "Qwen3-ASR with ForcedAligner timestamps", "description": "Qwen3-ASR with ForcedAligner timestamps",
"streaming": "chunk",
"devices": ["cuda", "mps", "cpu"],
}, },
{ {
"id": "openai-api", "id": "openai-api",
@ -103,6 +115,8 @@ BACKENDS = [
"module": "openai", "module": "openai",
"install": "pip install openai", "install": "pip install openai",
"description": "Cloud-based transcription via OpenAI API", "description": "Cloud-based transcription via OpenAI API",
"streaming": "cloud",
"devices": ["cloud"],
}, },
] ]
@ -159,6 +173,28 @@ QWEN3_REPOS = {
} }
QWEN3_ALIGNER_REPO = "Qwen/Qwen3-ForcedAligner-0.6B" QWEN3_ALIGNER_REPO = "Qwen/Qwen3-ForcedAligner-0.6B"
# Model catalog: metadata for display in `wlk models`
# params = approximate parameter count, disk = approximate download size
MODEL_CATALOG = [
# Whisper family (available across faster-whisper, mlx-whisper, whisper backends)
{"name": "tiny", "family": "whisper", "params": "39M", "disk": "75 MB", "languages": 99, "quality": "low", "speed": "fastest"},
{"name": "tiny.en", "family": "whisper", "params": "39M", "disk": "75 MB", "languages": 1, "quality": "low", "speed": "fastest"},
{"name": "base", "family": "whisper", "params": "74M", "disk": "142 MB", "languages": 99, "quality": "fair", "speed": "fast"},
{"name": "base.en", "family": "whisper", "params": "74M", "disk": "142 MB", "languages": 1, "quality": "fair", "speed": "fast"},
{"name": "small", "family": "whisper", "params": "244M", "disk": "466 MB", "languages": 99, "quality": "good", "speed": "medium"},
{"name": "small.en", "family": "whisper", "params": "244M", "disk": "466 MB", "languages": 1, "quality": "good", "speed": "medium"},
{"name": "medium", "family": "whisper", "params": "769M", "disk": "1.5 GB", "languages": 99, "quality": "great", "speed": "slow"},
{"name": "medium.en", "family": "whisper", "params": "769M", "disk": "1.5 GB", "languages": 1, "quality": "great", "speed": "slow"},
{"name": "large-v3", "family": "whisper", "params": "1.5B", "disk": "3.1 GB", "languages": 99, "quality": "best", "speed": "slowest"},
{"name": "large-v3-turbo", "family": "whisper", "params": "809M", "disk": "1.6 GB", "languages": 99, "quality": "great", "speed": "medium"},
# Voxtral (native streaming, single model)
{"name": "voxtral", "family": "voxtral", "params": "4B", "disk": "8.2 GB", "languages": 15, "quality": "great", "speed": "medium"},
{"name": "voxtral-mlx", "family": "voxtral", "params": "4B", "disk": "2.7 GB", "languages": 15, "quality": "great", "speed": "medium"},
# Qwen3 ASR
{"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"},
{"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
]
def _check_platform(backend: dict) -> bool: def _check_platform(backend: dict) -> bool:
"""Check if backend is compatible with current platform.""" """Check if backend is compatible with current platform."""
@ -254,93 +290,124 @@ def print_banner(config, host: str, port: int, ssl: bool = False):
# `wlk models` subcommand # `wlk models` subcommand
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def cmd_models(): def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool:
"""List available backends and their installation status.""" """Check if a model catalog entry has been downloaded."""
is_apple_silicon = platform.system() == "Darwin" and platform.machine() == "arm64" name = model_entry["name"]
family = model_entry["family"]
print("\nAvailable backends:\n") if family == "whisper":
# Check all whisper backends
repos = [
FASTER_WHISPER_REPOS.get(name),
MLX_WHISPER_REPOS.get(name),
f"openai/whisper-{name}",
]
return any(r in downloaded for r in repos if r)
elif name == "voxtral":
return VOXTRAL_HF_REPO in downloaded
elif name == "voxtral-mlx":
return VOXTRAL_MLX_REPO in downloaded
elif family == "qwen3":
size = name.split(":")[1] if ":" in name else "1.7b"
return QWEN3_REPOS.get(size, "") in downloaded
return False
def _best_backend_for_model(model_entry: dict) -> str:
"""Suggest the best available backend for a model."""
family = model_entry["family"]
is_apple = platform.system() == "Darwin" and platform.machine() == "arm64"
if family == "voxtral":
if "mlx" in model_entry["name"]:
return "voxtral-mlx"
return "voxtral"
elif family == "qwen3":
return "qwen3"
elif family == "whisper":
if is_apple and _module_available("mlx_whisper"):
return "mlx-whisper"
if _module_available("faster_whisper"):
return "faster-whisper"
if _module_available("whisper"):
return "whisper"
# Suggest best installable
return "mlx-whisper" if is_apple else "faster-whisper"
return "auto"
def cmd_models():
"""List available models and backends (ollama-style)."""
is_apple_silicon = platform.system() == "Darwin" and platform.machine() == "arm64"
downloaded = _scan_downloaded_models()
# --- Installed backends ---
print("\n Backends:\n")
max_name = max(len(b["name"]) for b in BACKENDS) max_name = max(len(b["name"]) for b in BACKENDS)
for b in BACKENDS: for b in BACKENDS:
compatible = _check_platform(b) compatible = _check_platform(b)
installed = _is_installed(b) installed = _is_installed(b)
streaming = b.get("streaming", "chunk")
stream_label = {"native": "streaming", "chunk": "chunked", "cloud": "cloud"}.get(streaming, streaming)
if installed: if installed:
status = "\033[32m installed\033[0m" status = "\033[32m+\033[0m"
elif not compatible: elif not compatible:
status = "\033[90m n/a (wrong platform)\033[0m" status = "\033[90m-\033[0m"
else: else:
status = "\033[33m not installed\033[0m" status = "\033[33m-\033[0m"
name_pad = b["name"].ljust(max_name) name_pad = b["name"].ljust(max_name)
print(f" {name_pad} [{status}] {b['description']}") desc_short = b["description"]
print(f" {status} {name_pad} {desc_short} [{stream_label}]")
if not installed and compatible: if not installed and compatible:
print(f" {''.ljust(max_name)} └─ {b['install']}") print(f" {''.ljust(max_name)} \033[90m{b['install']}\033[0m")
# System info # --- System info ---
print(f"\n Platform: {platform.system()} {platform.machine()}") print(f"\n Platform: {platform.system()} {platform.machine()}")
print(f" Python: {platform.python_version()}")
print(f" Accelerator: {_gpu_info()}") print(f" Accelerator: {_gpu_info()}")
print(f" ffmpeg: {'found' if _check_ffmpeg() else 'NOT FOUND (required)'}") print(f" ffmpeg: {'found' if _check_ffmpeg() else '\033[31mNOT FOUND\033[0m (required)'}")
# --- Model catalog ---
print("\n Models:\n")
# Table header
hdr = f" {'NAME':<20} {'PARAMS':>7} {'SIZE':>8} {'QUALITY':<8} {'SPEED':<8} {'LANGS':>5} {'STATUS':<10}"
print(hdr)
print(f" {'' * 20} {'' * 7} {'' * 8} {'' * 8} {'' * 8} {'' * 5} {'' * 10}")
for m in MODEL_CATALOG:
name = m["name"]
# Skip platform-incompatible models
if name == "voxtral-mlx" and not is_apple_silicon:
continue
is_dl = _model_is_downloaded(m, downloaded)
if is_dl:
status = "\033[32mpulled\033[0m "
else:
status = "\033[90mavailable\033[0m "
langs = str(m["languages"]) if m["languages"] < 99 else "99+"
print(
f" {name:<20} {m['params']:>7} {m['disk']:>8} "
f"{m['quality']:<8} {m['speed']:<8} {langs:>5} {status}"
)
# --- Quick start ---
print(f"\n Quick start:\n")
if is_apple_silicon: if is_apple_silicon:
print("\n Tip: On Apple Silicon, mlx-whisper and voxtral-mlx offer the best performance.") print(" wlk run voxtral-mlx # Best streaming on Apple Silicon")
print(" wlk run large-v3-turbo # Best quality/speed balance")
# Scan for downloaded models else:
downloaded = _scan_downloaded_models() print(" wlk run large-v3-turbo # Best quality/speed balance")
print(" wlk run voxtral # Native streaming (CUDA/CPU)")
print("\n Downloaded models:\n") print(" wlk pull base # Download smallest multilingual model")
found_any = False print(" wlk transcribe audio.mp3 # Offline transcription")
# Check Whisper-family models
all_repos = {
"faster-whisper": FASTER_WHISPER_REPOS,
"mlx-whisper": MLX_WHISPER_REPOS,
}
for backend_name, repos in all_repos.items():
for size, repo_id in repos.items():
if repo_id in downloaded:
found_any = True
print(f" \033[32m*\033[0m {backend_name}:{size} ({repo_id})")
# Check native whisper
for size in WHISPER_SIZES:
key = f"openai/whisper-{size}"
if key in downloaded:
found_any = True
print(f" \033[32m*\033[0m whisper:{size}")
# Check voxtral / qwen3
if VOXTRAL_HF_REPO in downloaded:
found_any = True
print(f" \033[32m*\033[0m voxtral ({VOXTRAL_HF_REPO})")
if VOXTRAL_MLX_REPO in downloaded:
found_any = True
print(f" \033[32m*\033[0m voxtral-mlx ({VOXTRAL_MLX_REPO})")
for qsize, repo_id in QWEN3_REPOS.items():
if repo_id in downloaded:
found_any = True
print(f" \033[32m*\033[0m qwen3:{qsize} ({repo_id})")
if QWEN3_ALIGNER_REPO in downloaded:
found_any = True
print(f" \033[32m*\033[0m qwen3-aligner ({QWEN3_ALIGNER_REPO})")
if not found_any:
print(" (none — models download automatically on first use, or use 'wlk pull')")
# Show pullable models
print("\n Available models (use 'wlk pull <name>'):\n")
print(" Whisper sizes: " + ", ".join(WHISPER_SIZES))
print(" Voxtral: voxtral, voxtral-mlx")
print(" Qwen3: qwen3:1.7b, qwen3:0.6b")
print()
print(" Examples:")
print(" wlk pull base # Download for best available backend")
print(" wlk pull faster-whisper:large-v3 # Specific backend + model")
print(" wlk pull voxtral # Voxtral HF model")
print(" wlk pull qwen3:1.7b # Qwen3-ASR 1.7B")
print() print()
@ -1010,6 +1077,23 @@ def cmd_run(args: list):
if parsed.model: if parsed.model:
backend_flag, model_flag = _resolve_run_spec(parsed.model) backend_flag, model_flag = _resolve_run_spec(parsed.model)
# Show what we resolved
catalog_match = next(
(m for m in MODEL_CATALOG if m["name"] == parsed.model),
None,
)
if catalog_match:
print(
f"\n Model: {catalog_match['name']} "
f"({catalog_match['params']} params, {catalog_match['disk']})",
file=sys.stderr,
)
if backend_flag:
print(f" Backend: {backend_flag}", file=sys.stderr)
else:
best = _best_backend_for_model(catalog_match)
print(f" Backend: {best} (auto-detected)", file=sys.stderr)
# Auto-pull if needed # Auto-pull if needed
downloaded = _scan_downloaded_models() downloaded = _scan_downloaded_models()
targets = _resolve_pull_target(parsed.model) targets = _resolve_pull_target(parsed.model)
@ -1198,9 +1282,9 @@ def _probe_backend_state(processor) -> dict:
info["n_audio_tokens_fed"] = transcription._n_audio_tokens_fed info["n_audio_tokens_fed"] = transcription._n_audio_tokens_fed
info["n_text_tokens_received"] = transcription._n_text_tokens_received info["n_text_tokens_received"] = transcription._n_text_tokens_received
info["n_committed_words"] = transcription._n_committed_words info["n_committed_words"] = transcription._n_committed_words
info["pending_audio_samples"] = len(transcription._pending_audio) info["pending_audio_samples"] = transcription._pending_len
with transcription._text_lock: with transcription._text_lock:
info["accumulated_text"] = transcription._accumulated_text info["accumulated_text"] = transcription._get_accumulated_text()
if transcription._generate_error: if transcription._generate_error:
info["generate_error"] = str(transcription._generate_error) info["generate_error"] = str(transcription._generate_error)
# Audio queue depth # Audio queue depth

View file

@ -72,6 +72,10 @@ class WhisperLiveKitConfig:
nllb_backend: str = "transformers" nllb_backend: str = "transformers"
nllb_size: str = "600M" nllb_size: str = "600M"
# vLLM Realtime backend
vllm_url: str = "ws://localhost:8000/v1/realtime"
vllm_model: str = ""
def __post_init__(self): def __post_init__(self):
# .en model suffix forces English # .en model suffix forces English
if self.model_size and self.model_size.endswith(".en"): if self.model_size and self.model_size.endswith(".en"):

View file

@ -102,7 +102,16 @@ class TranscriptionEngine:
} }
if config.transcription: if config.transcription:
if config.backend == "voxtral-mlx": if config.backend == "vllm-realtime":
from whisperlivekit.vllm_realtime import VLLMRealtimeASR
self.tokenizer = None
self.asr = VLLMRealtimeASR(
vllm_url=config.vllm_url,
model_name=config.vllm_model or "Qwen/Qwen3-ASR-1.7B",
lan=config.lan,
)
logger.info("Using vLLM Realtime streaming backend at %s", config.vllm_url)
elif config.backend == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
self.tokenizer = None self.tokenizer = None
self.asr = VoxtralMLXASR(**transcription_common_params) self.asr = VoxtralMLXASR(**transcription_common_params)
@ -112,6 +121,14 @@ class TranscriptionEngine:
self.tokenizer = None self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params) self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend") logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR
self.tokenizer = None
self.asr = Qwen3SimulStreamingASR(
**transcription_common_params,
alignment_heads_path=config.custom_alignment_heads,
)
logger.info("Using Qwen3-ASR backend with SimulStreaming policy")
elif config.backend == "qwen3": elif config.backend == "qwen3":
from whisperlivekit.qwen3_asr import Qwen3ASR from whisperlivekit.qwen3_asr import Qwen3ASR
self.asr = Qwen3ASR(**transcription_common_params) self.asr = Qwen3ASR(**transcription_common_params)
@ -210,6 +227,12 @@ def online_factory(args, asr, language=None):
asr = SessionASRProxy(asr, language) asr = SessionASRProxy(asr, language)
backend = getattr(args, 'backend', None) backend = getattr(args, 'backend', None)
if backend == "vllm-realtime":
from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor
return VLLMRealtimeOnlineProcessor(asr)
if backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor
return Qwen3SimulStreamingOnlineProcessor(asr)
if backend == "voxtral-mlx": if backend == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
return VoxtralMLXOnlineProcessor(asr) return VoxtralMLXOnlineProcessor(asr)

View file

@ -147,8 +147,8 @@ def parse_args():
"--backend", "--backend",
type=str, type=str,
default="auto", default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3"], choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon. Use 'qwen3' for Qwen3-ASR.", help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
) )
parser.add_argument( parser.add_argument(
"--no-vac", "--no-vac",
@ -196,6 +196,22 @@ def parse_args():
default=False, default=False,
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder." help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
) )
# vLLM Realtime backend arguments
parser.add_argument(
"--vllm-url",
type=str,
default="ws://localhost:8000/v1/realtime",
dest="vllm_url",
help="URL of the vLLM realtime WebSocket endpoint.",
)
parser.add_argument(
"--vllm-model",
type=str,
default="",
dest="vllm_model",
help="Model name to use with vLLM (e.g. Qwen/Qwen3-ASR-1.7B).",
)
# SimulStreaming-specific arguments # SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)') simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')

View file

@ -1,4 +1,5 @@
import logging import logging
import re
import sys import sys
from typing import List, Optional from typing import List, Optional
@ -11,12 +12,10 @@ logger = logging.getLogger(__name__)
def _patch_transformers_compat(): def _patch_transformers_compat():
"""Patch transformers for qwen_asr compatibility. """Patch transformers for qwen_asr 0.0.6 + transformers >= 5.3 compatibility."""
import torch
qwen_asr imports ``check_model_inputs`` from ``transformers.utils.generic``, # 1. check_model_inputs was removed
but this decorator hasn't been released yet in any public transformers
version. We inject a no-op stub so the import succeeds.
"""
try: try:
import transformers.utils.generic as _g import transformers.utils.generic as _g
if not hasattr(_g, "check_model_inputs"): if not hasattr(_g, "check_model_inputs"):
@ -28,6 +27,63 @@ def _patch_transformers_compat():
except ImportError: except ImportError:
pass pass
# 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS
try:
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
if "default" not in ROPE_INIT_FUNCTIONS:
def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
partial = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial)
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, 1.0
ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters
except ImportError:
pass
# 3. pad_token_id missing on thinker config
try:
from qwen_asr.core.transformers_backend.configuration_qwen3_asr import (
Qwen3ASRThinkerConfig,
)
if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"):
Qwen3ASRThinkerConfig.pad_token_id = None
except ImportError:
pass
# 4. fix_mistral_regex kwarg not accepted by newer transformers
try:
from transformers.models.auto import processing_auto
_orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__
@classmethod
def _patched_ap_from_pretrained(cls, *args, **kwargs):
kwargs.pop("fix_mistral_regex", None)
return _orig_ap_from_pretrained(cls, *args, **kwargs)
processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained
except Exception:
pass
# 5. compute_default_rope_parameters missing on RotaryEmbedding
try:
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
Qwen3ASRThinkerTextRotaryEmbedding,
)
if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"):
@staticmethod
def _rope_params(config=None, device=None, seq_len=None, **kwargs):
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
partial = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial)
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, 1.0
Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _rope_params
except ImportError:
pass
_patch_transformers_compat() _patch_transformers_compat()
@ -62,6 +118,9 @@ QWEN3_MODEL_MAPPING = {
} }
_PUNCTUATION_ENDS = set(".!?。!?;;") _PUNCTUATION_ENDS = set(".!?。!?;;")
# Qwen3 raw output starts with "language <Name>" metadata before <asr_text> tag.
# When the tag is missing (silence/noise), this metadata leaks as transcription text.
_GARBAGE_RE = re.compile(r"^language\s+\S+$", re.IGNORECASE)
class Qwen3ASR(ASRBase): class Qwen3ASR(ASRBase):
@ -88,8 +147,12 @@ class Qwen3ASR(ASRBase):
else: else:
model_id = "Qwen/Qwen3-ASR-1.7B" model_id = "Qwen/Qwen3-ASR-1.7B"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 if torch.cuda.is_available():
device = "cuda:0" if torch.cuda.is_available() else "cpu" dtype, device = torch.bfloat16, "cuda:0"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
dtype, device = torch.float32, "mps"
else:
dtype, device = torch.float32, "cpu"
logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})") logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})")
model = Qwen3ASRModel.from_pretrained( model = Qwen3ASRModel.from_pretrained(
@ -126,17 +189,32 @@ class Qwen3ASR(ASRBase):
result = results[0] result = results[0]
# Stash audio length for timestamp estimation fallback # Stash audio length for timestamp estimation fallback
result._audio_duration = len(audio) / 16000 result._audio_duration = len(audio) / 16000
logger.info(
"Qwen3 result: language=%r text=%r ts=%s",
result.language, result.text[:80] if result.text else "",
bool(result.time_stamps),
)
return result return result
@staticmethod @staticmethod
def _detected_language(result) -> Optional[str]: def _detected_language(result) -> Optional[str]:
"""Extract Whisper-style language code from Qwen3 result.""" """Extract Whisper-style language code from Qwen3 result."""
lang = getattr(result, 'language', None) lang = getattr(result, 'language', None)
if lang: if not lang or lang.lower() == "none":
return QWEN3_TO_WHISPER_LANGUAGE.get(lang, lang.lower()) return None
return None # merge_languages may return comma-separated; take the first
first = lang.split(",")[0].strip()
if not first or first.lower() == "none":
return None
return QWEN3_TO_WHISPER_LANGUAGE.get(first, first.lower())
def ts_words(self, result) -> List[ASRToken]: def ts_words(self, result) -> List[ASRToken]:
# Filter garbage model output (e.g. "language None" for silence/noise)
text = (result.text or "").strip()
if not text or _GARBAGE_RE.match(text):
if text:
logger.info("Filtered garbage Qwen3 output: %r", text)
return []
detected = self._detected_language(result) detected = self._detected_language(result)
if result.time_stamps: if result.time_stamps:
tokens = [] tokens = []

View file

@ -0,0 +1,837 @@
"""
SimulStreaming-style online processor for Qwen3-ASR.
Architecture overview
---------------------
Qwen3-ASR is a decoder-only multimodal model. Audio is encoded by an audio
encoder (Whisper-style) into a sequence of embeddings that replace <|audio_pad|>
placeholder tokens in the input sequence. The text decoder then uses causal
self-attention over the combined audio + text tokens.
Unlike Whisper (which has explicit cross-attention between decoder and encoder),
Qwen3-ASR uses self-attention where generated text tokens attend to earlier
audio tokens and previously generated text. This means "alignment heads" here
are self-attention heads whose attention over the *audio-token region* tracks
the monotonic audio-to-text alignment.
The border-distance policy works as follows:
- After each generated token, extract the attention weights from the
selected alignment heads, restricted to the audio-token region
- Find which audio frame each head attends to most strongly (argmax)
- If the most-attended audio frame is approaching the end of the available
audio, pause generation and wait for more audio
- If the most-attended frame jumps backward (rewind), discard recent tokens
This module loads the Qwen3-ASR model *directly* via transformers (not through
the qwen_asr package's Qwen3ASRModel wrapper), giving us full control over
forward passes, KV caches, and attention extraction.
Requires:
- A pre-computed alignment heads JSON file (from detect_alignment_heads_qwen3.py)
- OR will fall back to all heads in a configurable set of layers
"""
import json
import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
@dataclass
class Qwen3SimulConfig:
"""Configuration for Qwen3 SimulStreaming."""
model_id: str = "Qwen/Qwen3-ASR-1.7B"
alignment_heads_path: Optional[str] = None
language: str = "auto"
# Border/rewind thresholds as fraction of audio tokens (not absolute frames).
# Qwen3 has ~13 audio tokens/sec vs Whisper's ~50, so absolute thresholds
# don't transfer. 0.15 = pause when attention is within last 15% of audio.
border_fraction: float = 0.15 # Fraction of audio tokens from end to trigger pause
rewind_fraction: float = 0.12 # Max backward jump as fraction of audio tokens
audio_min_len: float = 0.5 # Minimum audio length before starting decode
audio_max_len: float = 15.0 # Maximum audio buffer length in seconds
max_context_tokens: int = 30 # Max committed tokens to include as context
init_prompt: Optional[str] = None
max_alignment_heads: int = 20 # Use only top N alignment heads
@dataclass
class Qwen3SimulState:
"""Per-session mutable state for Qwen3 SimulStreaming."""
# Audio
audio_buffer: np.ndarray = field(
default_factory=lambda: np.array([], dtype=np.float32)
)
cumulative_time_offset: float = 0.0
global_time_offset: float = 0.0
speaker: int = -1
# Decode state
last_attend_frame: int = -15
generated_tokens: List[int] = field(default_factory=list)
committed_text: str = ""
committed_word_count: int = 0 # How many words already emitted
committed_token_ids: List[int] = field(default_factory=list) # token IDs for prompt context
# Tracking
first_timestamp: Optional[float] = None
detected_language: Optional[str] = None
last_infer_samples: int = 0 # audio_buffer length at last inference
class Qwen3SimulStreamingASR:
"""
Shared backend for Qwen3-ASR SimulStreaming.
Loads the model once and is shared across sessions. Each session gets
its own Qwen3SimulStreamingOnlineProcessor with independent state.
"""
sep = ""
def __init__(
self,
model_size: str = None,
model_dir: str = None,
lan: str = "auto",
alignment_heads_path: Optional[str] = None,
border_fraction: float = 0.15,
min_chunk_size: float = 0.1,
warmup_file: Optional[str] = None,
model_cache_dir: Optional[str] = None,
model_path: Optional[str] = None,
lora_path: Optional[str] = None,
direct_english_translation: bool = False,
**kwargs,
):
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.warmup_file = warmup_file
self.cfg = Qwen3SimulConfig(
language=lan,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
# Load model directly via transformers
self._load_model(model_size, model_dir, model_cache_dir, model_path)
# Load alignment heads
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
# Warmup
if warmup_file:
from whisperlivekit.warmup import load_file
audio = load_file(warmup_file)
if audio is not None:
logger.info("Warming up Qwen3 SimulStreaming model")
# Simple warmup: just encode a short audio
self._warmup(audio)
def _load_model(self, model_size, model_dir, model_cache_dir, model_path):
"""Load Qwen3-ASR via transformers (SDPA attention for speed)."""
from whisperlivekit.qwen3_asr import (
QWEN3_MODEL_MAPPING,
_patch_transformers_compat,
)
_patch_transformers_compat()
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
if model_dir:
model_id = model_dir
elif model_path:
model_id = model_path
elif model_size:
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
else:
model_id = "Qwen/Qwen3-ASR-1.7B"
if torch.cuda.is_available():
dtype, device = torch.bfloat16, "cuda:0"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
dtype, device = torch.float32, "mps"
else:
dtype, device = torch.float32, "cpu"
logger.info("Loading Qwen3-ASR for SimulStreaming: %s (sdpa attention)", model_id)
self.model = AutoModel.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True)
# Cache model properties
thinker = self.model.thinker
text_config = thinker.config.text_config
self.num_layers = text_config.num_hidden_layers
self.num_heads = text_config.num_attention_heads
self.num_kv_heads = text_config.num_key_value_heads
self.audio_token_id = thinker.config.audio_token_id
self.device = next(self.model.parameters()).device
self.dtype = next(self.model.parameters()).dtype
# Cache special token IDs for metadata stripping
self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("<asr_text>")
logger.info(
"Qwen3-ASR loaded: %d layers x %d heads, device=%s, <asr_text> id=%d",
self.num_layers, self.num_heads, self.device, self.asr_text_token_id,
)
def _load_alignment_heads(
self, path: Optional[str],
) -> List[Tuple[int, int]]:
"""Load alignment heads from JSON or use defaults.
Only loads the top N heads (sorted by TS score) for efficiency.
The Qwen3-ASR model has alignment info spread across most heads
(decoder-only, no cross-attention), so we pick the strongest ones.
"""
max_heads = self.cfg.max_alignment_heads
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
# alignment_heads_compact is pre-sorted by TS score (descending)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
heads = all_heads[:max_heads]
logger.info(
"Loaded top %d alignment heads from %s (of %d total)",
len(heads), path, len(all_heads),
)
return heads
# Default: use heads from the last quarter of layers
default_heads = []
start_layer = self.num_layers * 3 // 4
for layer in range(start_layer, self.num_layers):
for head in range(self.num_heads):
default_heads.append((layer, head))
logger.warning(
"No alignment heads file found. Using default heuristic: "
"%d heads from layers %d-%d. Run detect_alignment_heads_qwen3.py "
"to find optimal heads.",
len(default_heads), start_layer, self.num_layers - 1,
)
return default_heads[:max_heads]
def _warmup(self, audio: np.ndarray):
"""Run a short inference to warmup the model."""
try:
audio = audio[:SAMPLE_RATE * 2] # Max 2 seconds
msgs = [
{"role": "system", "content": ""},
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
]
text_prompt = self.processor.apply_chat_template(
msgs, add_generation_prompt=True, tokenize=False,
)
inputs = self.processor(
text=[text_prompt],
audio=[audio],
return_tensors="pt",
padding=True,
)
inputs = inputs.to(self.device).to(self.dtype)
with torch.inference_mode():
self.model.thinker.generate(
**inputs, max_new_tokens=5, do_sample=False,
)
logger.info("Qwen3 SimulStreaming warmup complete")
except Exception as e:
logger.warning("Warmup failed: %s", e)
def transcribe(self, audio):
"""No-op -- SimulStreaming uses the online processor directly."""
pass
class Qwen3SimulStreamingOnlineProcessor:
"""
Per-session online processor for Qwen3-ASR SimulStreaming.
Implements the same interface as SimulStreamingOnlineProcessor:
- insert_audio_chunk(audio, time)
- process_iter(is_last=False) -> (List[ASRToken], float)
- get_buffer() -> Transcript
- start_silence() -> (List[ASRToken], float)
- end_silence(duration, offset)
- finish() -> (List[ASRToken], float)
"""
SAMPLING_RATE = 16000
MIN_DURATION_REAL_SILENCE = 5
def __init__(self, asr: Qwen3SimulStreamingASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: List[ASRToken] = []
# Per-session state
self.state = Qwen3SimulState()
# Build the prompt template once
self._build_prompt_template()
def _build_prompt_template(self):
"""Build the base text prompt for Qwen3-ASR."""
from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE
msgs = [
{"role": "system", "content": ""},
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
]
self._base_prompt = self.asr.processor.apply_chat_template(
msgs, add_generation_prompt=True, tokenize=False,
)
# Add language forcing if configured
lan = self.asr.cfg.language
if lan and lan != "auto":
lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
self._base_prompt += f"language {lang_name}<asr_text>"
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
"""Append an audio chunk to be processed."""
self.end = audio_stream_end_time
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
# Trim audio if too long
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
if len(self.state.audio_buffer) > max_samples:
trim = len(self.state.audio_buffer) - max_samples
self.state.audio_buffer = self.state.audio_buffer[trim:]
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
# Adjust throttle counter so it tracks position within the trimmed buffer
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Handle start of silence -- flush all pending tokens.
Loops inference until the model produces no new tokens, since a
single is_last call may not exhaust all text for the buffered audio.
"""
all_tokens = []
for _ in range(5): # safety limit
tokens, processed_upto = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end
def end_silence(self, silence_duration: float, offset: float):
"""Handle silence period."""
self.end += silence_duration
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(self.SAMPLING_RATE * silence_duration)
if gap_len > 0:
gap_silence = np.zeros(gap_len, dtype=np.float32)
self.state.audio_buffer = np.append(
self.state.audio_buffer, gap_silence,
)
else:
# Long silence: reset
self.state = Qwen3SimulState()
self.state.global_time_offset = silence_duration + offset
def new_speaker(self, change_speaker: ChangeSpeaker):
"""Handle speaker change event."""
self.process_iter(is_last=True)
self.state = Qwen3SimulState()
self.state.speaker = change_speaker.speaker
self.state.global_time_offset = change_speaker.start
def get_buffer(self) -> Transcript:
"""Get the current unvalidated buffer."""
return Transcript.from_tokens(tokens=self.buffer, sep='')
@torch.inference_mode()
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio using SimulStreaming with alignment heads.
This performs a full forward pass (encode audio + greedy decode with
attention extraction), applying the border-distance policy to decide
when to stop generating.
Returns:
Tuple of (committed ASRToken list, audio processed up to time).
"""
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
if audio_duration < self.asr.cfg.audio_min_len:
return [], self.end
# Throttle: skip inference if less than 1s of new audio since last run.
# Each inference re-encodes the full buffer, so calling too often wastes
# GPU/CPU time and causes lag to spiral.
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
min_new_seconds = 1.0
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
return [], self.end
logger.info("Running SimulStreaming inference on %.2fs of audio (%.2fs new)", audio_duration, new_samples / self.SAMPLING_RATE)
self.state.last_infer_samples = len(self.state.audio_buffer)
try:
timestamped_words = self._infer(is_last)
except Exception as e:
logger.exception("Qwen3 SimulStreaming inference error: %s", e)
return [], self.end
logger.info("SimulStreaming produced %d words", len(timestamped_words))
if not timestamped_words:
return [], self.end
self.buffer = []
return timestamped_words, self.end
def _infer(self, is_last: bool) -> List[ASRToken]:
"""Run one inference cycle with alignment-head-based stopping.
Uses forward hooks on self_attn modules to capture attention weights
during generation. The Qwen3-ASR decoder layer discards attention
weights (hidden_states, _ = self.self_attn(...)), so output_attentions
via generate() would return None. Hooks capture them before discard.
"""
asr = self.asr
state = self.state
# Prepare inputs
inputs = asr.processor(
text=[self._base_prompt],
audio=[state.audio_buffer],
return_tensors="pt",
padding=True,
)
inputs = inputs.to(asr.device).to(asr.dtype)
# Append committed token IDs as context so generate() continues from
# where it left off. Cap at max_context_tokens to prevent prompt growth.
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
ctx_ids = torch.tensor(
[ctx], dtype=inputs.input_ids.dtype,
device=inputs.input_ids.device,
)
inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1)
if "attention_mask" in inputs:
ctx_mask = torch.ones_like(ctx_ids)
inputs["attention_mask"] = torch.cat(
[inputs.attention_mask, ctx_mask], dim=1,
)
prompt_len = inputs.input_ids.shape[1]
# Find audio token range
input_ids = inputs.input_ids[0]
audio_mask = (input_ids == asr.audio_token_id)
audio_positions = audio_mask.nonzero(as_tuple=True)[0]
if len(audio_positions) == 0:
return []
audio_start = audio_positions[0].item()
audio_end = audio_positions[-1].item() + 1
n_audio_tokens = audio_end - audio_start
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
# Install forward hooks to capture alignment attention from Q and K.
# With SDPA attention (fast), attn_weights are not returned. Instead,
# we hook self_attn to compute Q*K^T attention ONLY for alignment heads
# during autoregressive steps (q_len == 1). This is cheap because we
# only compute dot products for ~20 heads, not full attention for all.
#
# Key detail: self_attn is called with ALL keyword arguments from the
# decoder layer, so hidden_states/position_embeddings/past_key_values
# are all in kwargs, not args.
per_step_frames: List[List[int]] = []
current_step_frames: List[int] = []
heads_by_layer: dict = {}
for layer_idx, head_idx in asr.alignment_heads:
heads_by_layer.setdefault(layer_idx, []).append(head_idx)
decoder_layers = asr.model.thinker.model.layers
num_kv_heads = asr.num_kv_heads
num_heads = asr.num_heads
gqa_ratio = num_heads // num_kv_heads # GQA group size
# Import RoPE function used by this model's attention
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
apply_rotary_pos_emb,
)
hooks = []
def _make_attn_hook(layer_idx):
"""Forward hook on self_attn that computes Q*K^T for alignment heads.
After the forward pass, we recompute Q (with RoPE) for the current
token and dot it against the cached K (which already has RoPE) in
the audio region. This gives us per-head alignment frames.
"""
head_indices = heads_by_layer[layer_idx]
def hook_fn(module, args, kwargs, output):
# All arguments are keyword-passed from the decoder layer
hidden_states = kwargs.get('hidden_states')
if hidden_states is None:
hidden_states = args[0] if args else None
if hidden_states is None or hidden_states.shape[1] != 1:
return # Skip prefill (seq_len > 1)
position_embeddings = kwargs.get('position_embeddings')
if position_embeddings is None and len(args) > 1:
position_embeddings = args[1]
past_kv = kwargs.get('past_key_values')
if position_embeddings is None or past_kv is None:
return
# Recompute Q with RoPE (cheap: single token through q_proj + RoPE)
hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim)
q = module.q_norm(
module.q_proj(hidden_states).view(hidden_shape)
).transpose(1, 2)
cos, sin = position_embeddings
q, _ = apply_rotary_pos_emb(q, q, cos, sin)
# K from cache already has RoPE applied
cache_layer = past_kv.layers[module.layer_idx]
k = cache_layer.keys # (batch, n_kv_heads, kv_len, head_dim)
if k is None or audio_end > k.shape[2]:
return
# Compute attention scores for alignment heads only
for h_idx in head_indices:
if h_idx >= q.shape[1]:
continue
kv_h_idx = h_idx // gqa_ratio
q_h = q[0, h_idx, 0] # (head_dim,)
k_audio = k[0, kv_h_idx, audio_start:audio_end] # (n_audio, head_dim)
scores = torch.matmul(k_audio, q_h) # (n_audio,)
frame = scores.argmax().item()
current_step_frames.append(frame)
return hook_fn
for layer_idx in heads_by_layer:
if layer_idx < len(decoder_layers):
h = decoder_layers[layer_idx].self_attn.register_forward_hook(
_make_attn_hook(layer_idx),
with_kwargs=True,
)
hooks.append(h)
# Step boundary hook on lm_head to separate per-step frames
# and check border-distance stopping criteria in real-time.
# This is CRITICAL for performance: instead of generating 200 tokens
# then truncating, we stop as soon as attention hits the audio border.
# On MPS, each token costs ~50ms, so stopping at 10 tokens vs 200
# means ~0.5s vs ~10s inference.
last_attend_frame = state.last_attend_frame
border_stop_step: Optional[int] = None
# Compute absolute thresholds from fractional config
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
def _step_boundary_hook(module, args, output):
nonlocal current_step_frames, last_attend_frame, border_stop_step
if current_step_frames:
per_step_frames.append(current_step_frames)
current_step_frames = []
# Check border distance on each step.
# Allow at least 3 steps before checking, so short buffers
# can still produce some tokens during streaming.
if not is_last and border_stop_step is None and len(per_step_frames) >= 3:
latest = per_step_frames[-1]
if latest:
frames_sorted = sorted(latest)
attended = frames_sorted[len(frames_sorted) // 2]
# Rewind check
if last_attend_frame - attended > rewind_threshold:
border_stop_step = max(0, len(per_step_frames) - 2)
return
last_attend_frame = attended
# Border check
if (n_audio_tokens - attended) <= border_threshold:
border_stop_step = len(per_step_frames) - 1
return
lm_head = asr.model.thinker.lm_head
step_hook = lm_head.register_forward_hook(_step_boundary_hook)
hooks.append(step_hook)
# StoppingCriteria that stops generation when border distance is hit
from transformers import StoppingCriteria, StoppingCriteriaList
class BorderStop(StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs):
return border_stop_step is not None
stopping = StoppingCriteriaList([BorderStop()])
# Limit max tokens to what's reasonable for the audio duration.
# On MPS, each token costs ~50-100ms, so tight limits are critical.
# Speech produces ~4-6 tokens/sec; +5 for metadata prefix tokens.
# With is_last, allow slightly more for flushing remaining text.
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
tokens_per_sec = 6
if is_last:
max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120)
else:
max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40)
try:
outputs = asr.model.thinker.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
stopping_criteria=stopping,
)
finally:
for h in hooks:
h.remove()
# Flush any remaining frames
if current_step_frames:
per_step_frames.append(current_step_frames)
state.last_attend_frame = last_attend_frame
# Extract generated tokens
all_generated = outputs[0, prompt_len:]
eos_ids = {151645, 151643}
if asr.processor.tokenizer.eos_token_id is not None:
eos_ids.add(asr.processor.tokenizer.eos_token_id)
num_gen = len(all_generated)
for i, tid in enumerate(all_generated):
if tid.item() in eos_ids:
num_gen = i
break
raw_text = asr.processor.tokenizer.decode(all_generated[:num_gen], skip_special_tokens=True)
logger.info(
"SimulStreaming raw output: %d tokens (stopped at step %s), text=%r",
num_gen, border_stop_step, raw_text[:100],
)
if num_gen == 0:
return []
# Strip metadata prefix: when language is "auto", the model generates
# "language <Name><asr_text>..." before actual transcription text.
# Find <asr_text> token and skip everything before it (including itself).
asr_text_id = asr.asr_text_token_id
metadata_offset = 0
for i in range(min(num_gen, 10)): # metadata is at most ~3-4 tokens
if all_generated[i].item() == asr_text_id:
# Detect language from the metadata prefix before stripping
if state.detected_language is None and i > 0:
from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE
prefix_text = asr.processor.tokenizer.decode(
all_generated[:i].tolist(), skip_special_tokens=True,
).strip()
parts = prefix_text.split()
if len(parts) >= 2:
lang_name = parts[-1]
if lang_name.lower() != "none":
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
lang_name, lang_name.lower(),
)
logger.info("Auto-detected language: %s", state.detected_language)
metadata_offset = i + 1
break
if metadata_offset > 0:
logger.info(
"Stripping %d metadata prefix tokens (before <asr_text>)",
metadata_offset,
)
all_generated = all_generated[metadata_offset:]
num_gen -= metadata_offset
per_step_frames = per_step_frames[metadata_offset:]
if num_gen <= 0:
return []
# Determine how many tokens to emit based on border stopping
step_frames = [f for f in per_step_frames if f]
if border_stop_step is not None:
emit_up_to = min(border_stop_step, num_gen)
else:
emit_up_to = num_gen
# Build timestamped words from the emitted tokens
generated_ids = all_generated[:emit_up_to]
if len(generated_ids) == 0:
return []
all_words = self._build_timestamped_words(
generated_ids, step_frames, emit_up_to,
n_audio_tokens, audio_duration,
)
new_words = all_words
# Update committed word count for space-prefix logic in next batch
state.committed_word_count += len(new_words)
# Append newly emitted token IDs to committed context for next call
new_emitted = outputs[0, prompt_len:prompt_len + emit_up_to + metadata_offset]
state.committed_token_ids.extend(new_emitted.tolist())
return new_words
def _build_timestamped_words(
self,
generated_ids: torch.Tensor,
step_frames: List[List[int]],
emit_up_to: int,
n_audio_tokens: int,
audio_duration: float,
) -> List[ASRToken]:
"""Build timestamped ASRToken list from generated tokens and hook-captured frames."""
asr = self.asr
state = self.state
# Get per-token attended audio frame (median of alignment head votes)
per_token_frame: List[Optional[int]] = []
for step in range(emit_up_to):
if step < len(step_frames) and step_frames[step]:
frames = sorted(step_frames[step])
per_token_frame.append(frames[len(frames) // 2])
else:
per_token_frame.append(None)
# Decode the full generated sequence at once, then split into words.
# This is more robust than per-token Ġ detection, which can fail when
# committed context causes the model to generate sub-word continuations.
tokenizer = asr.processor.tokenizer
full_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
text_words = full_text.split()
# Map each text word to an approximate frame using token-level alignment.
# Distribute frames evenly across words (since exact token→word mapping
# is imprecise with BPE sub-words anyway).
all_frames = [f for f in per_token_frame if f is not None]
words = []
for wi, word in enumerate(text_words):
if all_frames:
# Proportionally assign frames to words
frac = wi / max(len(text_words), 1)
frame_idx = int(frac * len(all_frames))
frame_idx = min(frame_idx, len(all_frames) - 1)
frame = all_frames[frame_idx]
else:
frame = None
words.append((word, frame))
# Convert to ASRToken with timestamps
tokens = []
for i, (text, frame) in enumerate(words):
text = text.strip()
if not text:
continue
if frame is not None and n_audio_tokens > 0:
timestamp = (
frame / n_audio_tokens * audio_duration
+ state.cumulative_time_offset
)
else:
timestamp = (
(i / max(len(words), 1)) * audio_duration
+ state.cumulative_time_offset
)
# Prefix space: first word of the very first batch has no space;
# all subsequent words (same batch or later batches) get a space.
is_very_first_word = (i == 0 and state.committed_word_count == 0)
display_text = text if is_very_first_word else " " + text
token = ASRToken(
start=round(timestamp, 2),
end=round(timestamp + 0.1, 2),
text=display_text,
speaker=state.speaker,
detected_language=state.detected_language,
).with_offset(state.global_time_offset)
tokens.append(token)
return tokens
@staticmethod
def _median_frame(frames: List[int]) -> Optional[int]:
"""Return median of frame list, or None if empty."""
if not frames:
return None
frames_sorted = sorted(frames)
return frames_sorted[len(frames_sorted) // 2]
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
"""Warmup the model with a short audio clip."""
try:
self.state.audio_buffer = audio[:SAMPLE_RATE]
self.process_iter(is_last=True)
self.state = Qwen3SimulState()
logger.info("Qwen3 SimulStreaming online processor warmed up")
except Exception as e:
logger.warning("Warmup failed: %s", e)
self.state = Qwen3SimulState()
def finish(self) -> Tuple[List[ASRToken], float]:
"""Flush remaining audio at end of stream."""
all_tokens = []
for _ in range(5): # safety limit
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end

View file

@ -0,0 +1,416 @@
"""
vLLM Realtime WebSocket streaming backend for WhisperLiveKit.
Connects to a vLLM server's ``/v1/realtime`` WebSocket endpoint to stream
audio and receive transcription deltas. Uses ``websockets.sync.client``
for simplicity since ``process_iter`` runs inside ``asyncio.to_thread``.
Provides ``VLLMRealtimeASR`` (lightweight model holder) and
``VLLMRealtimeOnlineProcessor`` (streaming processor) that plug into
WhisperLiveKit's audio processing pipeline.
"""
import base64
import json
import logging
import threading
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
class VLLMRealtimeASR:
"""Lightweight model holder — stores connection info for the vLLM server."""
sep = " "
SAMPLING_RATE = 16000
backend_choice = "vllm-realtime"
def __init__(self, vllm_url="ws://localhost:8000/v1/realtime",
model_name="Qwen/Qwen3-ASR-1.7B", lan="auto", **kwargs):
self.vllm_url = vllm_url
self.model_name = model_name
self.original_language = None if lan == "auto" else lan
self.tokenizer = None
def transcribe(self, audio):
pass
class VLLMRealtimeOnlineProcessor:
"""
Online processor that streams audio to a vLLM Realtime WebSocket.
Uses a background thread for WebSocket receiving and
``websockets.sync.client`` for the sync WebSocket connection.
"""
SAMPLING_RATE = 16000
# Minimum audio samples before connecting (0.5s of audio)
_MIN_CONNECT_SAMPLES = SAMPLING_RATE // 2
def __init__(self, asr: VLLMRealtimeASR):
self.asr = asr
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32)
self._reset_state()
logger.info(
"[vllm-realtime] Initialized. url=%s model=%s",
asr.vllm_url, asr.model_name,
)
def _reset_state(self):
self._pending_audio = np.zeros(0, dtype=np.float32)
self._ws = None
self._recv_thread: Optional[threading.Thread] = None
self._connected = False
self._done = False
self._recv_error: Optional[Exception] = None
# Text accumulation and word extraction
self._accumulated_text = ""
self._n_committed_words = 0
self._total_audio_duration = 0.0
self._global_time_offset = 0.0
# Lock for text state accessed from both recv thread and main thread
self._text_lock = threading.Lock()
# ── Interface methods ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio)
self.audio_buffer = self._pending_audio
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning("[vllm-realtime] process_iter exception: %s", e, exc_info=True)
return [], self.end
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer."""
self._drain_deltas()
with self._text_lock:
text = self._accumulated_text
if not text:
return Transcript(start=None, end=None, text="")
words = text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all pending words when silence starts.
Sends commit(final=true) to signal end of utterance, waits for
transcription.done, flushes all words, then prepares for reconnection
on the next utterance.
"""
if not self._connected or self._done:
words = self._flush_all_pending_words()
logger.info("[vllm-realtime] start_silence (not connected): flushed %d words", len(words))
return words, self.end
# Send any remaining buffered audio
self._send_pending_audio()
# Signal end of stream
self._send_commit(final=True)
# Wait for transcription.done
self._wait_for_done(timeout=10.0)
# Flush all remaining words
words = self._flush_all_pending_words()
# Close and reset for next utterance
self._close_ws()
old_offset = self._global_time_offset + self._total_audio_duration
self._reset_state()
self._global_time_offset = old_offset
logger.info("[vllm-realtime] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Close connection and flush all remaining words."""
if self._connected and not self._done:
# Send remaining audio
self._send_pending_audio()
# Signal final commit
self._send_commit(final=True)
# Wait for transcription.done
self._wait_for_done(timeout=30.0)
# Flush all words
words = self._flush_all_pending_words()
# Close WebSocket
self._close_ws()
logger.info("[vllm-realtime] finish: flushed %d words", len(words))
return words, self.end
# ── WebSocket connection management ──
def _connect(self):
"""Connect to the vLLM realtime WebSocket and start the receive thread."""
from websockets.sync.client import connect
url = self.asr.vllm_url
logger.info("[vllm-realtime] Connecting to %s", url)
self._ws = connect(url)
# Send session.update to select model
self._ws.send(json.dumps({
"type": "session.update",
"model": self.asr.model_name,
}))
# Send initial commit(final=false) to start generation
self._send_commit(final=False)
# Start receive thread
self._recv_thread = threading.Thread(target=self._recv_loop, daemon=True)
self._recv_thread.start()
self._connected = True
logger.info("[vllm-realtime] Connected and started receive thread")
def _close_ws(self):
"""Close the WebSocket connection and join the receive thread."""
if self._ws is not None:
try:
self._ws.close()
except Exception:
pass
self._ws = None
if self._recv_thread is not None:
self._recv_thread.join(timeout=5.0)
self._recv_thread = None
def _recv_loop(self):
"""Background thread: receive messages from the vLLM WebSocket."""
try:
while not self._done and self._ws is not None:
try:
raw = self._ws.recv(timeout=0.1)
except TimeoutError:
continue
except Exception:
break
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError):
continue
msg_type = msg.get("type", "")
if msg_type == "transcription.delta":
delta = msg.get("delta", "")
if delta:
with self._text_lock:
self._accumulated_text += delta
elif msg_type == "transcription.done":
done_text = msg.get("text", "")
if done_text:
with self._text_lock:
# Replace accumulated text with final text
self._accumulated_text = done_text
self._done = True
break
except Exception as e:
logger.error("[vllm-realtime] recv_loop error: %s", e, exc_info=True)
self._recv_error = e
self._done = True
# ── Protocol messages ──
def _send_commit(self, final: bool):
"""Send input_audio_buffer.commit message."""
if self._ws is None:
return
try:
self._ws.send(json.dumps({
"type": "input_audio_buffer.commit",
"final": final,
}))
except Exception as e:
logger.warning("[vllm-realtime] Failed to send commit: %s", e)
def _send_audio(self, audio: np.ndarray):
"""Send audio as a base64-encoded PCM16 append message."""
if self._ws is None:
return
# Convert float32 [-1, 1] to int16 PCM
pcm16 = (audio * 32767).astype(np.int16)
audio_bytes = pcm16.tobytes()
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
try:
self._ws.send(json.dumps({
"type": "input_audio_buffer.append",
"audio": audio_b64,
}))
except Exception as e:
logger.warning("[vllm-realtime] Failed to send audio: %s", e)
def _send_pending_audio(self):
"""Send all pending audio to the vLLM server."""
if len(self._pending_audio) == 0:
return
# Track total audio duration for timestamp estimation
self._total_audio_duration += len(self._pending_audio) / self.SAMPLING_RATE
# Send in chunks of 0.5s to avoid overwhelming the WebSocket
chunk_samples = self.SAMPLING_RATE // 2
while len(self._pending_audio) >= chunk_samples:
chunk = self._pending_audio[:chunk_samples]
self._send_audio(chunk)
self._pending_audio = self._pending_audio[chunk_samples:]
# Send remaining audio if any
if len(self._pending_audio) > 0:
self._send_audio(self._pending_audio)
self._pending_audio = np.zeros(0, dtype=np.float32)
self.audio_buffer = self._pending_audio
# ── Receive helpers ──
def _drain_deltas(self):
"""No-op since the recv thread accumulates text directly."""
pass
def _wait_for_done(self, timeout: float = 10.0):
"""Wait for transcription.done message from the server."""
deadline = time.time() + timeout
while not self._done and time.time() < deadline:
time.sleep(0.05)
if not self._done:
logger.warning("[vllm-realtime] Timed out waiting for transcription.done")
# ── Word extraction (same approach as VoxtralHF) ──
def _time_for_word(self, word_idx: int, n_words_total: int) -> Tuple[float, float]:
"""Estimate timestamps by linearly distributing words across audio duration."""
duration = max(self._total_audio_duration, 0.001)
n_total = max(n_words_total, 1)
start_time = (word_idx / n_total) * duration + self._global_time_offset
end_time = ((word_idx + 1) / n_total) * duration + self._global_time_offset
return start_time, end_time
def _extract_new_words(self) -> List[ASRToken]:
"""Extract complete words (all but the last, which may still grow)."""
with self._text_lock:
text = self._accumulated_text
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_words_total = len(words)
while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words]
start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total)
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
with self._text_lock:
text = self._accumulated_text
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_words_total = max(len(words), 1)
while self._n_committed_words < len(words):
word = words[self._n_committed_words]
start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total)
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
# ── Core processing ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# Connect when we have enough audio buffered
if not self._connected:
if len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES:
self._connect()
self._send_pending_audio()
else:
return [], self.end
# Send any new pending audio
if self._connected and not self._done:
self._send_pending_audio()
# If connection closed unexpectedly but new audio arrived, reconnect
if self._done and len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES:
flush_words = self._flush_all_pending_words()
old_offset = self._global_time_offset + self._total_audio_duration
self._close_ws()
self._reset_state()
self._global_time_offset = old_offset
self._connect()
self._send_pending_audio()
return flush_words, self.end
# Extract complete words
new_words = self._extract_new_words()
if new_words:
logger.info(
"[vllm-realtime] returning %d words: %s",
len(new_words), [w.text for w in new_words],
)
self.buffer = []
return new_words, self.end

View file

@ -102,7 +102,8 @@ class VoxtralHFStreamingOnlineProcessor:
) )
def _reset_state(self): def _reset_state(self):
self._pending_audio = np.zeros(0, dtype=np.float32) self._pending_chunks: List[np.ndarray] = []
self._pending_len = 0
self._audio_queue: queue.Queue = queue.Queue() self._audio_queue: queue.Queue = queue.Queue()
self._streamer_texts: List[str] = [] self._streamer_texts: List[str] = []
self._generate_thread: Optional[threading.Thread] = None self._generate_thread: Optional[threading.Thread] = None
@ -110,22 +111,63 @@ class VoxtralHFStreamingOnlineProcessor:
self._generate_finished = False self._generate_finished = False
self._generate_error: Optional[Exception] = None self._generate_error: Optional[Exception] = None
# Text accumulation and word extraction # Text accumulation (list of fragments, joined on demand)
self._accumulated_text = "" self._text_fragments: List[str] = []
self._text_len = 0
# Fragment position tracking for accurate word timestamps:
# each entry is (char_offset_in_full_text, audio_tok_pos_consumed)
self._fragment_positions: List[Tuple[int, int]] = []
self._n_text_tokens_received = 0 self._n_text_tokens_received = 0
self._n_audio_tokens_fed = 0 self._n_audio_tokens_fed = 0
# Audio tokens actually consumed by the model (tracked inside generator)
self._n_audio_tokens_consumed = 0
self._n_committed_words = 0 self._n_committed_words = 0
self._global_time_offset = 0.0 self._global_time_offset = 0.0
# Event signalled by the generate thread when it finishes
self._generate_done = threading.Event()
# Lock for text state accessed from both generate thread and main thread # Lock for text state accessed from both generate thread and main thread
self._text_lock = threading.Lock() self._text_lock = threading.Lock()
# ── Audio / text helpers ──
def _get_pending_audio(self) -> np.ndarray:
"""Flatten pending audio chunks into a single array."""
if not self._pending_chunks:
return np.zeros(0, dtype=np.float32)
if len(self._pending_chunks) == 1:
return self._pending_chunks[0]
flat = np.concatenate(self._pending_chunks)
self._pending_chunks = [flat]
return flat
def _set_pending_audio(self, arr: np.ndarray):
"""Replace pending audio with a single array."""
if len(arr) == 0:
self._pending_chunks = []
self._pending_len = 0
else:
self._pending_chunks = [arr]
self._pending_len = len(arr)
def _get_accumulated_text(self) -> str:
"""Get the full accumulated text (joins fragments if needed)."""
if not self._text_fragments:
return ""
if len(self._text_fragments) == 1:
return self._text_fragments[0]
joined = "".join(self._text_fragments)
self._text_fragments = [joined]
return joined
# ── Interface methods ── # ── Interface methods ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio) self._pending_chunks.append(audio)
self.audio_buffer = self._pending_audio self._pending_len += len(audio)
self.audio_buffer = audio # diagnostic only
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try: try:
@ -142,7 +184,7 @@ class VoxtralHFStreamingOnlineProcessor:
""" """
self._drain_streamer() self._drain_streamer()
with self._text_lock: with self._text_lock:
text = self._accumulated_text text = self._get_accumulated_text()
if not text: if not text:
return Transcript(start=None, end=None, text="") return Transcript(start=None, end=None, text="")
@ -174,16 +216,17 @@ class VoxtralHFStreamingOnlineProcessor:
# real audio and shouldn't affect word timestamp calculations. # real audio and shouldn't affect word timestamp calculations.
if self._right_pad_samples > 0: if self._right_pad_samples > 0:
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32) right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
self._pending_audio = np.append(self._pending_audio, right_pad) self._pending_chunks.append(right_pad)
self._pending_len += len(right_pad)
saved_count = self._n_audio_tokens_fed saved_count = self._n_audio_tokens_fed
self._feed_pending_audio() self._feed_pending_audio()
self._n_audio_tokens_fed = saved_count self._n_audio_tokens_fed = saved_count
# Drain in a loop: the model may still be processing right-padding # Drain in a loop: the model may continue producing text tokens after
# chunks after the first drain returns. Keep draining until no new # the audio queue is empty (autoregressive generation). Each iteration
# text appears for two consecutive rounds. # uses an event-driven blocking drain with short timeouts.
all_words: List[ASRToken] = [] all_words: List[ASRToken] = []
for _ in range(5): # at most 5 drain+flush cycles for _ in range(5):
self._drain_streamer_blocking(timeout=5.0) self._drain_streamer_blocking(timeout=5.0)
batch = self._flush_all_pending_words() batch = self._flush_all_pending_words()
all_words.extend(batch) all_words.extend(batch)
@ -208,7 +251,8 @@ class VoxtralHFStreamingOnlineProcessor:
# Add right-padding so the model can finish decoding # Add right-padding so the model can finish decoding
if self._right_pad_samples > 0: if self._right_pad_samples > 0:
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32) right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
self._pending_audio = np.append(self._pending_audio, right_pad) self._pending_chunks.append(right_pad)
self._pending_len += len(right_pad)
# Feed remaining audio # Feed remaining audio
if self._generate_started and not self._generate_finished: if self._generate_started and not self._generate_finished:
@ -218,7 +262,7 @@ class VoxtralHFStreamingOnlineProcessor:
# Wait for generate to finish # Wait for generate to finish
if self._generate_thread is not None: if self._generate_thread is not None:
self._generate_thread.join(timeout=30.0) self._generate_thread.join(timeout=30.0)
elif not self._generate_started and len(self._pending_audio) >= self._first_chunk_samples: elif not self._generate_started and self._pending_len >= self._first_chunk_samples:
# Never started but have enough audio — start and immediately finish # Never started but have enough audio — start and immediately finish
self._start_generate_thread() self._start_generate_thread()
self._feed_pending_audio() self._feed_pending_audio()
@ -242,8 +286,9 @@ class VoxtralHFStreamingOnlineProcessor:
model = self.asr.model model = self.asr.model
# Extract first chunk # Extract first chunk
first_chunk_audio = self._pending_audio[:self._first_chunk_samples] pending = self._get_pending_audio()
self._pending_audio = self._pending_audio[self._first_chunk_samples:] first_chunk_audio = pending[:self._first_chunk_samples]
self._set_pending_audio(pending[self._first_chunk_samples:])
# First chunk covers multiple audio tokens # First chunk covers multiple audio tokens
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step) self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
@ -265,11 +310,14 @@ class VoxtralHFStreamingOnlineProcessor:
audio_queue = self._audio_queue audio_queue = self._audio_queue
def input_features_gen(): def input_features_gen():
# Track audio consumption inside the generator (runs in generate thread)
self._n_audio_tokens_consumed = max(1, self._first_chunk_samples // self._chunk_step)
yield first_inputs.input_features yield first_inputs.input_features
while True: while True:
chunk_audio = audio_queue.get() chunk_audio = audio_queue.get()
if chunk_audio is None: if chunk_audio is None:
break break
self._n_audio_tokens_consumed += 1
inputs = processor( inputs = processor(
chunk_audio, chunk_audio,
is_streaming=True, is_streaming=True,
@ -298,6 +346,7 @@ class VoxtralHFStreamingOnlineProcessor:
self._generate_error = e self._generate_error = e
finally: finally:
self._generate_finished = True self._generate_finished = True
self._generate_done.set()
self._generate_thread = threading.Thread(target=run_generate, daemon=True) self._generate_thread = threading.Thread(target=run_generate, daemon=True)
self._generate_thread.start() self._generate_thread.start()
@ -309,13 +358,22 @@ class VoxtralHFStreamingOnlineProcessor:
chunk_size = self._chunk_samples chunk_size = self._chunk_samples
step_size = self._chunk_step step_size = self._chunk_step
while len(self._pending_audio) >= chunk_size: pending = self._get_pending_audio()
chunk = self._pending_audio[:chunk_size] while len(pending) >= chunk_size:
chunk = pending[:chunk_size]
self._audio_queue.put(chunk) self._audio_queue.put(chunk)
self._pending_audio = self._pending_audio[step_size:] pending = pending[step_size:]
self._n_audio_tokens_fed += 1 self._n_audio_tokens_fed += 1
self.audio_buffer = self._pending_audio self._set_pending_audio(pending)
self.audio_buffer = pending
def _append_text_fragment(self, text_fragment: str):
"""Append a text fragment with its audio position (must hold _text_lock)."""
self._fragment_positions.append((self._text_len, self._n_audio_tokens_consumed))
self._text_fragments.append(text_fragment)
self._text_len += len(text_fragment)
self._n_text_tokens_received += 1
def _drain_streamer(self): def _drain_streamer(self):
"""Non-blocking drain of all available text from the streamer.""" """Non-blocking drain of all available text from the streamer."""
@ -333,19 +391,13 @@ class VoxtralHFStreamingOnlineProcessor:
break break
if text_fragment: if text_fragment:
with self._text_lock: with self._text_lock:
self._accumulated_text += text_fragment self._append_text_fragment(text_fragment)
self._n_text_tokens_received += 1
def _drain_streamer_blocking(self, timeout=30.0): def _drain_streamer_blocking(self, timeout=30.0):
"""Blocking drain: wait for the generate thread to process all queued """Blocking drain: wait for the generate thread to finish producing text.
audio and produce the corresponding text.
Polls the text queue while the audio queue has items (model still Uses the _generate_done event to know when the model is truly finished.
processing). Once the audio queue is empty, waits for trailing Falls back to text-queue polling with adaptive timeouts.
tokens, then returns.
This is critical for start_silence(): without it, the non-blocking
drain races with the generate thread and the last words get stuck.
""" """
if not self._generate_started or self._generate_finished: if not self._generate_started or self._generate_finished:
self._drain_streamer() self._drain_streamer()
@ -353,52 +405,101 @@ class VoxtralHFStreamingOnlineProcessor:
text_queue = self._streamer.text_queue text_queue = self._streamer.text_queue
deadline = time.time() + timeout deadline = time.time() + timeout
# Count consecutive empty polls to detect when model has caught up
empty_streak = 0
while time.time() < deadline: while time.time() < deadline:
# Short poll while model is still processing queued audio; remaining = max(deadline - time.time(), 0.01)
# longer wait once the audio queue is empty (trailing tokens).
wait = 2.0 if self._audio_queue.empty() else 0.1 # If generate thread is done, do a final flush and exit
if self._generate_done.is_set() or self._generate_finished:
self._drain_streamer()
return
# Adaptive wait: short while audio is queued, longer once queue is empty
if self._audio_queue.empty():
wait = min(remaining, 0.5)
else:
wait = min(remaining, 0.1)
try: try:
text_fragment = text_queue.get(timeout=wait) text_fragment = text_queue.get(timeout=wait)
except queue.Empty: except queue.Empty:
if self._audio_queue.empty(): empty_streak += 1
break # Audio done + no text for 2s → fully caught up # Only exit if audio queue is empty AND we've had enough empty polls
continue # Audio still queued, model still working # This prevents premature exit when the model is slow
if self._audio_queue.empty() and empty_streak >= 4:
break
continue
empty_streak = 0
if text_fragment is None: if text_fragment is None:
self._generate_finished = True self._generate_finished = True
break break
if text_fragment: if text_fragment:
with self._text_lock: with self._text_lock:
self._accumulated_text += text_fragment self._append_text_fragment(text_fragment)
self._n_text_tokens_received += 1
# ── Word extraction ── # ── Word extraction ──
def _pos_to_time(self, token_position: int) -> float: def _pos_to_time(self, token_position: int) -> float:
"""Convert token position to seconds.""" """Convert audio token position to seconds."""
return token_position * self._seconds_per_token + self._global_time_offset return token_position * self._seconds_per_token + self._global_time_offset
def _audio_pos_for_char(self, char_idx: int) -> int:
"""Look up the audio token position for a character index in the text.
Uses the fragment position index recorded when text arrives from the
generate thread. Returns the audio position of the fragment that
contains ``char_idx``, giving much better word timestamps than the
old uniform-distribution heuristic.
"""
if not self._fragment_positions:
return 0
# _fragment_positions is sorted by char_offset — find the last entry
# whose char_offset <= char_idx (the fragment containing this char).
pos = 0
for offset, audio_tok in self._fragment_positions:
if offset > char_idx:
break
pos = audio_tok
return pos
def _word_timestamps(self, text: str, words: List[str], start_idx: int, end_idx: int) -> List[Tuple[int, int]]:
"""Compute (tok_start, tok_end) for words[start_idx:end_idx] using fragment positions."""
# Build char offsets for each word
result = []
char_pos = 0
for i, word in enumerate(words):
if i > 0:
char_pos += 1 # space separator
if start_idx <= i < end_idx:
tok_start = self._audio_pos_for_char(char_pos)
tok_end = self._audio_pos_for_char(char_pos + len(word))
result.append((tok_start, tok_end))
char_pos += len(word)
return result
def _extract_new_words(self) -> List[ASRToken]: def _extract_new_words(self) -> List[ASRToken]:
"""Extract complete words (all but the last, which may still be growing).""" """Extract complete words (all but the last, which may still be growing)."""
with self._text_lock: with self._text_lock:
text = self._accumulated_text text = self._get_accumulated_text()
if not text: if not text:
return [] return []
words = text.split() words = text.split()
new_words: List[ASRToken] = [] new_words: List[ASRToken] = []
n_words_total = len(words) n_to_commit = len(words) - 1 # keep last word (may still grow)
n_audio_toks = max(self._n_audio_tokens_fed, 1)
while len(words) > self._n_committed_words + 1: if n_to_commit <= self._n_committed_words:
return []
timestamps = self._word_timestamps(text, words, self._n_committed_words, n_to_commit)
for tok_start, tok_end in timestamps:
word = words[self._n_committed_words] word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0
start_time = self._pos_to_time(tok_start) start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(tok_end) end_time = self._pos_to_time(max(tok_end, tok_start + 1))
text_out = word if self._n_committed_words == 0 else " " + word text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
@ -409,24 +510,22 @@ class VoxtralHFStreamingOnlineProcessor:
def _flush_all_pending_words(self) -> List[ASRToken]: def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one.""" """Flush ALL words including the last partial one."""
with self._text_lock: with self._text_lock:
text = self._accumulated_text text = self._get_accumulated_text()
if not text: if not text:
return [] return []
words = text.split() words = text.split()
new_words: List[ASRToken] = [] new_words: List[ASRToken] = []
n_words_total = max(len(words), 1)
n_audio_toks = max(self._n_audio_tokens_fed, 1)
while self._n_committed_words < len(words): if self._n_committed_words >= len(words):
return []
timestamps = self._word_timestamps(text, words, self._n_committed_words, len(words))
for tok_start, tok_end in timestamps:
word = words[self._n_committed_words] word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_audio_toks)
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks)
start_time = self._pos_to_time(tok_start) start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(tok_end) end_time = self._pos_to_time(max(tok_end, tok_start + 1))
text_out = word if self._n_committed_words == 0 else " " + word text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
@ -439,7 +538,7 @@ class VoxtralHFStreamingOnlineProcessor:
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]: def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# Start generate thread when enough audio is buffered # Start generate thread when enough audio is buffered
if not self._generate_started: if not self._generate_started:
if len(self._pending_audio) >= self._first_chunk_samples: if self._pending_len >= self._first_chunk_samples:
self._start_generate_thread() self._start_generate_thread()
self._feed_pending_audio() self._feed_pending_audio()
else: else:
@ -450,7 +549,7 @@ class VoxtralHFStreamingOnlineProcessor:
self._feed_pending_audio() self._feed_pending_audio()
# If generate finished unexpectedly (EOS) but new audio arrived, restart # If generate finished unexpectedly (EOS) but new audio arrived, restart
if self._generate_finished and len(self._pending_audio) >= self._first_chunk_samples: if self._generate_finished and self._pending_len >= self._first_chunk_samples:
self._drain_streamer() self._drain_streamer()
flush_words = self._flush_all_pending_words() flush_words = self._flush_all_pending_words()
# Reset for new utterance # Reset for new utterance

View file

@ -91,20 +91,33 @@ def _mel_filters() -> mx.array:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# DFT helpers # DFT helpers (cached — these are constant for a given WINDOW_SIZE)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_CACHED_WINDOW: mx.array | None = None
_CACHED_DFT_RE: mx.array | None = None
_CACHED_DFT_IM: mx.array | None = None
def _hann_window() -> mx.array: def _hann_window() -> mx.array:
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32)) global _CACHED_WINDOW
if _CACHED_WINDOW is None:
_CACHED_WINDOW = mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
return _CACHED_WINDOW
def _dft_matrices(): def _dft_matrices():
"""Pre-compute the real / imaginary DFT basis matrices.""" """Return cached real / imaginary DFT basis matrices."""
n_bins = WINDOW_SIZE // 2 + 1 global _CACHED_DFT_RE, _CACHED_DFT_IM
k = mx.arange(n_bins, dtype=mx.float32)[:, None] if _CACHED_DFT_RE is None:
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :] n_bins = WINDOW_SIZE // 2 + 1
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE k = mx.arange(n_bins, dtype=mx.float32)[:, None]
return mx.cos(phase), mx.sin(phase) n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
_CACHED_DFT_RE = mx.cos(phase)
_CACHED_DFT_IM = mx.sin(phase)
mx.eval(_CACHED_DFT_RE, _CACHED_DFT_IM)
return _CACHED_DFT_RE, _CACHED_DFT_IM
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array: def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:

View file

@ -135,8 +135,9 @@ class VoxtralMLXOnlineProcessor:
def _reset_state(self): def _reset_state(self):
"""Reset all incremental state for a fresh utterance.""" """Reset all incremental state for a fresh utterance."""
# Audio accumulation # Audio accumulation (list of chunks, concatenated on demand)
self._pending = np.zeros(0, dtype=np.float32) self._pending_chunks: list[np.ndarray] = []
self._pending_len = 0
# Mel overlap # Mel overlap
self._mel_overlap: np.ndarray | None = None self._mel_overlap: np.ndarray | None = None
# Encoder incremental state # Encoder incremental state
@ -167,10 +168,30 @@ class VoxtralMLXOnlineProcessor:
# -- audio ingestion -- # -- audio ingestion --
def _get_pending(self) -> np.ndarray:
"""Flatten pending chunks into a single array."""
if not self._pending_chunks:
return np.zeros(0, dtype=np.float32)
if len(self._pending_chunks) == 1:
return self._pending_chunks[0]
flat = np.concatenate(self._pending_chunks)
self._pending_chunks = [flat]
return flat
def _set_pending(self, arr: np.ndarray):
"""Replace pending audio with a single array."""
if len(arr) == 0:
self._pending_chunks = []
self._pending_len = 0
else:
self._pending_chunks = [arr]
self._pending_len = len(arr)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time self.end = audio_stream_end_time
self._pending = np.append(self._pending, audio) self._pending_chunks.append(audio)
self.audio_buffer = self._pending self._pending_len += len(audio)
self.audio_buffer = audio # diagnostic only
# -- core processing -- # -- core processing --
@ -231,22 +252,24 @@ class VoxtralMLXOnlineProcessor:
def _encode_pending(self): def _encode_pending(self):
"""Feed pending audio through the incremental encoder.""" """Feed pending audio through the incremental encoder."""
available = len(self._pending) if self._pending_len < SAMPLES_PER_TOKEN:
if available < SAMPLES_PER_TOKEN:
return return
pending = self._get_pending()
available = len(pending)
if self._first_chunk: if self._first_chunk:
# First chunk: prepend silence for left-padding # First chunk: prepend silence for left-padding
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32) left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
chunk = np.concatenate([left_pad, self._pending[:n_take]]) chunk = np.concatenate([left_pad, pending[:n_take]])
self._pending = self._pending[n_take:] self._set_pending(pending[n_take:])
self._samples_encoded += n_take self._samples_encoded += n_take
self._first_chunk = False self._first_chunk = False
else: else:
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
chunk = self._pending[:n_take] chunk = pending[:n_take]
self._pending = self._pending[n_take:] self._set_pending(pending[n_take:])
self._samples_encoded += n_take self._samples_encoded += n_take
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap) mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
@ -261,11 +284,10 @@ class VoxtralMLXOnlineProcessor:
mx.eval(embeds) mx.eval(embeds)
if self._audio_embeds is not None: if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds]) self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
mx.eval(self._audio_embeds)
else: else:
self._audio_embeds = embeds self._audio_embeds = embeds
self.audio_buffer = self._pending
def _do_prefill(self): def _do_prefill(self):
"""Run the decoder prefill pass over the prompt + first audio embeddings.""" """Run the decoder prefill pass over the prompt + first audio embeddings."""
n_dec_layers = len(self._model.decoder.blocks) n_dec_layers = len(self._model.decoder.blocks)
@ -430,6 +452,55 @@ class VoxtralMLXOnlineProcessor:
return Transcript(start=None, end=None, text="") return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]: def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all pending words when silence starts.
Adds right-padding silence and forces a full decode pass so the
decoder emits tokens for the last words of speech. Without this,
the model holds back the final tokens waiting for future context.
"""
# Align pending audio to SAMPLES_PER_TOKEN boundary
remainder = self._pending_len % SAMPLES_PER_TOKEN
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
# Add alignment + right-padding silence to provide future context
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0:
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
self._pending_len += total_pad
# Encode remaining audio (including right-padding)
self._encode_pending()
# Decode everything that's left
if self._audio_embeds is not None and self._prefilled:
self._decode_positions(self._audio_embeds.shape[0])
# Flush last token if it wasn't EOS
if self._last_token is not None:
tid = self._last_token.item()
if tid != self._eos_id:
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
last_pos = self._positions_decoded - self._prefix_len
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
# Close the last word if still open
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
words = self._flush_all_words() words = self._flush_all_words()
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words)) logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
return words, self.end return words, self.end
@ -448,7 +519,7 @@ class VoxtralMLXOnlineProcessor:
logger.debug( logger.debug(
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, " "[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'", "samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
len(self._pending), self._pending_len,
self._audio_embeds.shape if self._audio_embeds is not None else None, self._audio_embeds.shape if self._audio_embeds is not None else None,
self._samples_encoded, self._samples_encoded,
self._positions_decoded, self._positions_decoded,
@ -457,7 +528,7 @@ class VoxtralMLXOnlineProcessor:
) )
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost # Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
remainder = len(self._pending) % SAMPLES_PER_TOKEN remainder = self._pending_len % SAMPLES_PER_TOKEN
if remainder > 0: if remainder > 0:
align_pad = SAMPLES_PER_TOKEN - remainder align_pad = SAMPLES_PER_TOKEN - remainder
else: else:
@ -466,9 +537,8 @@ class VoxtralMLXOnlineProcessor:
# Add alignment + right-padding silence # Add alignment + right-padding silence
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0: if total_pad > 0:
self._pending = np.append( self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
self._pending, np.zeros(total_pad, dtype=np.float32) self._pending_len += total_pad
)
# Encode remaining audio (including right-padding) # Encode remaining audio (including right-padding)
self._encode_pending() self._encode_pending()
@ -476,7 +546,7 @@ class VoxtralMLXOnlineProcessor:
logger.debug( logger.debug(
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d", "[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
self._audio_embeds.shape if self._audio_embeds is not None else None, self._audio_embeds.shape if self._audio_embeds is not None else None,
len(self._pending), self._pending_len,
) )
hit_eos = False hit_eos = False