diff --git a/whisperlivekit/basic_server.py b/whisperlivekit/basic_server.py index a8d9773..9621e76 100644 --- a/whisperlivekit/basic_server.py +++ b/whisperlivekit/basic_server.py @@ -13,6 +13,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %( logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +logging.getLogger("whisperlivekit.qwen3_asr").setLevel(logging.DEBUG) config = parse_args() transcription_engine = None diff --git a/whisperlivekit/cli.py b/whisperlivekit/cli.py index c1e6699..7c8dc98 100644 --- a/whisperlivekit/cli.py +++ b/whisperlivekit/cli.py @@ -57,6 +57,8 @@ BACKENDS = [ "install": "pip install faster-whisper", "description": "CTranslate2-based Whisper (fast, CPU/CUDA)", "policy": "localagreement", + "streaming": "chunk", # batch inference with LocalAgreement/SimulStreaming + "devices": ["cpu", "cuda"], }, { "id": "whisper", @@ -65,6 +67,8 @@ BACKENDS = [ "install": "pip install openai-whisper", "description": "Original OpenAI Whisper (PyTorch)", "policy": "simulstreaming", + "streaming": "chunk", + "devices": ["cpu", "cuda"], }, { "id": "mlx-whisper", @@ -74,21 +78,27 @@ BACKENDS = [ "description": "Apple Silicon native Whisper (MLX)", "policy": "localagreement", "platform": "darwin-arm64", + "streaming": "chunk", + "devices": ["mlx"], }, { "id": "voxtral-mlx", "name": "Voxtral MLX", "module": "mlx", "install": "pip install whisperlivekit[voxtral-mlx]", - "description": "Mistral Voxtral Mini on Apple Silicon (MLX)", + "description": "Mistral Voxtral Mini on Apple Silicon (MLX, native streaming)", "platform": "darwin-arm64", + "streaming": "native", # truly streaming (token-by-token) + "devices": ["mlx"], }, { "id": "voxtral", "name": "Voxtral HF", "module": "transformers", "install": "pip install whisperlivekit[voxtral-hf]", - "description": "Mistral Voxtral Mini (HF Transformers, CUDA/CPU/MPS)", + "description": "Mistral Voxtral Mini (HF Transformers, native streaming)", + "streaming": "native", + "devices": ["cuda", "mps", "cpu"], }, { "id": "qwen3", @@ -96,6 +106,8 @@ BACKENDS = [ "module": "qwen_asr", "install": "pip install qwen-asr", "description": "Qwen3-ASR with ForcedAligner timestamps", + "streaming": "chunk", + "devices": ["cuda", "mps", "cpu"], }, { "id": "openai-api", @@ -103,6 +115,8 @@ BACKENDS = [ "module": "openai", "install": "pip install openai", "description": "Cloud-based transcription via OpenAI API", + "streaming": "cloud", + "devices": ["cloud"], }, ] @@ -159,6 +173,28 @@ QWEN3_REPOS = { } QWEN3_ALIGNER_REPO = "Qwen/Qwen3-ForcedAligner-0.6B" +# Model catalog: metadata for display in `wlk models` +# params = approximate parameter count, disk = approximate download size +MODEL_CATALOG = [ + # Whisper family (available across faster-whisper, mlx-whisper, whisper backends) + {"name": "tiny", "family": "whisper", "params": "39M", "disk": "75 MB", "languages": 99, "quality": "low", "speed": "fastest"}, + {"name": "tiny.en", "family": "whisper", "params": "39M", "disk": "75 MB", "languages": 1, "quality": "low", "speed": "fastest"}, + {"name": "base", "family": "whisper", "params": "74M", "disk": "142 MB", "languages": 99, "quality": "fair", "speed": "fast"}, + {"name": "base.en", "family": "whisper", "params": "74M", "disk": "142 MB", "languages": 1, "quality": "fair", "speed": "fast"}, + {"name": "small", "family": "whisper", "params": "244M", "disk": "466 MB", "languages": 99, "quality": "good", "speed": "medium"}, + {"name": "small.en", "family": "whisper", "params": "244M", "disk": "466 MB", "languages": 1, "quality": "good", "speed": "medium"}, + {"name": "medium", "family": "whisper", "params": "769M", "disk": "1.5 GB", "languages": 99, "quality": "great", "speed": "slow"}, + {"name": "medium.en", "family": "whisper", "params": "769M", "disk": "1.5 GB", "languages": 1, "quality": "great", "speed": "slow"}, + {"name": "large-v3", "family": "whisper", "params": "1.5B", "disk": "3.1 GB", "languages": 99, "quality": "best", "speed": "slowest"}, + {"name": "large-v3-turbo", "family": "whisper", "params": "809M", "disk": "1.6 GB", "languages": 99, "quality": "great", "speed": "medium"}, + # Voxtral (native streaming, single model) + {"name": "voxtral", "family": "voxtral", "params": "4B", "disk": "8.2 GB", "languages": 15, "quality": "great", "speed": "medium"}, + {"name": "voxtral-mlx", "family": "voxtral", "params": "4B", "disk": "2.7 GB", "languages": 15, "quality": "great", "speed": "medium"}, + # Qwen3 ASR + {"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"}, + {"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"}, +] + def _check_platform(backend: dict) -> bool: """Check if backend is compatible with current platform.""" @@ -254,93 +290,124 @@ def print_banner(config, host: str, port: int, ssl: bool = False): # `wlk models` subcommand # --------------------------------------------------------------------------- -def cmd_models(): - """List available backends and their installation status.""" - is_apple_silicon = platform.system() == "Darwin" and platform.machine() == "arm64" +def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool: + """Check if a model catalog entry has been downloaded.""" + name = model_entry["name"] + family = model_entry["family"] - print("\nAvailable backends:\n") + if family == "whisper": + # Check all whisper backends + repos = [ + FASTER_WHISPER_REPOS.get(name), + MLX_WHISPER_REPOS.get(name), + f"openai/whisper-{name}", + ] + return any(r in downloaded for r in repos if r) + elif name == "voxtral": + return VOXTRAL_HF_REPO in downloaded + elif name == "voxtral-mlx": + return VOXTRAL_MLX_REPO in downloaded + elif family == "qwen3": + size = name.split(":")[1] if ":" in name else "1.7b" + return QWEN3_REPOS.get(size, "") in downloaded + return False + + +def _best_backend_for_model(model_entry: dict) -> str: + """Suggest the best available backend for a model.""" + family = model_entry["family"] + is_apple = platform.system() == "Darwin" and platform.machine() == "arm64" + + if family == "voxtral": + if "mlx" in model_entry["name"]: + return "voxtral-mlx" + return "voxtral" + elif family == "qwen3": + return "qwen3" + elif family == "whisper": + if is_apple and _module_available("mlx_whisper"): + return "mlx-whisper" + if _module_available("faster_whisper"): + return "faster-whisper" + if _module_available("whisper"): + return "whisper" + # Suggest best installable + return "mlx-whisper" if is_apple else "faster-whisper" + return "auto" + + +def cmd_models(): + """List available models and backends (ollama-style).""" + is_apple_silicon = platform.system() == "Darwin" and platform.machine() == "arm64" + downloaded = _scan_downloaded_models() + + # --- Installed backends --- + print("\n Backends:\n") max_name = max(len(b["name"]) for b in BACKENDS) - for b in BACKENDS: compatible = _check_platform(b) installed = _is_installed(b) + streaming = b.get("streaming", "chunk") + stream_label = {"native": "streaming", "chunk": "chunked", "cloud": "cloud"}.get(streaming, streaming) if installed: - status = "\033[32m installed\033[0m" + status = "\033[32m+\033[0m" elif not compatible: - status = "\033[90m n/a (wrong platform)\033[0m" + status = "\033[90m-\033[0m" else: - status = "\033[33m not installed\033[0m" + status = "\033[33m-\033[0m" name_pad = b["name"].ljust(max_name) - print(f" {name_pad} [{status}] {b['description']}") + desc_short = b["description"] + print(f" {status} {name_pad} {desc_short} [{stream_label}]") if not installed and compatible: - print(f" {''.ljust(max_name)} └─ {b['install']}") + print(f" {''.ljust(max_name)} \033[90m{b['install']}\033[0m") - # System info + # --- System info --- print(f"\n Platform: {platform.system()} {platform.machine()}") - print(f" Python: {platform.python_version()}") print(f" Accelerator: {_gpu_info()}") - print(f" ffmpeg: {'found' if _check_ffmpeg() else 'NOT FOUND (required)'}") + print(f" ffmpeg: {'found' if _check_ffmpeg() else '\033[31mNOT FOUND\033[0m (required)'}") + # --- Model catalog --- + print("\n Models:\n") + + # Table header + hdr = f" {'NAME':<20} {'PARAMS':>7} {'SIZE':>8} {'QUALITY':<8} {'SPEED':<8} {'LANGS':>5} {'STATUS':<10}" + print(hdr) + print(f" {'─' * 20} {'─' * 7} {'─' * 8} {'─' * 8} {'─' * 8} {'─' * 5} {'─' * 10}") + + for m in MODEL_CATALOG: + name = m["name"] + # Skip platform-incompatible models + if name == "voxtral-mlx" and not is_apple_silicon: + continue + + is_dl = _model_is_downloaded(m, downloaded) + + if is_dl: + status = "\033[32mpulled\033[0m " + else: + status = "\033[90mavailable\033[0m " + + langs = str(m["languages"]) if m["languages"] < 99 else "99+" + + print( + f" {name:<20} {m['params']:>7} {m['disk']:>8} " + f"{m['quality']:<8} {m['speed']:<8} {langs:>5} {status}" + ) + + # --- Quick start --- + print(f"\n Quick start:\n") if is_apple_silicon: - print("\n Tip: On Apple Silicon, mlx-whisper and voxtral-mlx offer the best performance.") - - # Scan for downloaded models - downloaded = _scan_downloaded_models() - - print("\n Downloaded models:\n") - found_any = False - - # Check Whisper-family models - all_repos = { - "faster-whisper": FASTER_WHISPER_REPOS, - "mlx-whisper": MLX_WHISPER_REPOS, - } - for backend_name, repos in all_repos.items(): - for size, repo_id in repos.items(): - if repo_id in downloaded: - found_any = True - print(f" \033[32m*\033[0m {backend_name}:{size} ({repo_id})") - - # Check native whisper - for size in WHISPER_SIZES: - key = f"openai/whisper-{size}" - if key in downloaded: - found_any = True - print(f" \033[32m*\033[0m whisper:{size}") - - # Check voxtral / qwen3 - if VOXTRAL_HF_REPO in downloaded: - found_any = True - print(f" \033[32m*\033[0m voxtral ({VOXTRAL_HF_REPO})") - if VOXTRAL_MLX_REPO in downloaded: - found_any = True - print(f" \033[32m*\033[0m voxtral-mlx ({VOXTRAL_MLX_REPO})") - for qsize, repo_id in QWEN3_REPOS.items(): - if repo_id in downloaded: - found_any = True - print(f" \033[32m*\033[0m qwen3:{qsize} ({repo_id})") - if QWEN3_ALIGNER_REPO in downloaded: - found_any = True - print(f" \033[32m*\033[0m qwen3-aligner ({QWEN3_ALIGNER_REPO})") - - if not found_any: - print(" (none — models download automatically on first use, or use 'wlk pull')") - - # Show pullable models - print("\n Available models (use 'wlk pull '):\n") - print(" Whisper sizes: " + ", ".join(WHISPER_SIZES)) - print(" Voxtral: voxtral, voxtral-mlx") - print(" Qwen3: qwen3:1.7b, qwen3:0.6b") - print() - print(" Examples:") - print(" wlk pull base # Download for best available backend") - print(" wlk pull faster-whisper:large-v3 # Specific backend + model") - print(" wlk pull voxtral # Voxtral HF model") - print(" wlk pull qwen3:1.7b # Qwen3-ASR 1.7B") + print(" wlk run voxtral-mlx # Best streaming on Apple Silicon") + print(" wlk run large-v3-turbo # Best quality/speed balance") + else: + print(" wlk run large-v3-turbo # Best quality/speed balance") + print(" wlk run voxtral # Native streaming (CUDA/CPU)") + print(" wlk pull base # Download smallest multilingual model") + print(" wlk transcribe audio.mp3 # Offline transcription") print() @@ -1010,6 +1077,23 @@ def cmd_run(args: list): if parsed.model: backend_flag, model_flag = _resolve_run_spec(parsed.model) + # Show what we resolved + catalog_match = next( + (m for m in MODEL_CATALOG if m["name"] == parsed.model), + None, + ) + if catalog_match: + print( + f"\n Model: {catalog_match['name']} " + f"({catalog_match['params']} params, {catalog_match['disk']})", + file=sys.stderr, + ) + if backend_flag: + print(f" Backend: {backend_flag}", file=sys.stderr) + else: + best = _best_backend_for_model(catalog_match) + print(f" Backend: {best} (auto-detected)", file=sys.stderr) + # Auto-pull if needed downloaded = _scan_downloaded_models() targets = _resolve_pull_target(parsed.model) @@ -1198,9 +1282,9 @@ def _probe_backend_state(processor) -> dict: info["n_audio_tokens_fed"] = transcription._n_audio_tokens_fed info["n_text_tokens_received"] = transcription._n_text_tokens_received info["n_committed_words"] = transcription._n_committed_words - info["pending_audio_samples"] = len(transcription._pending_audio) + info["pending_audio_samples"] = transcription._pending_len with transcription._text_lock: - info["accumulated_text"] = transcription._accumulated_text + info["accumulated_text"] = transcription._get_accumulated_text() if transcription._generate_error: info["generate_error"] = str(transcription._generate_error) # Audio queue depth diff --git a/whisperlivekit/config.py b/whisperlivekit/config.py index b08edeb..af262fb 100644 --- a/whisperlivekit/config.py +++ b/whisperlivekit/config.py @@ -72,6 +72,10 @@ class WhisperLiveKitConfig: nllb_backend: str = "transformers" nllb_size: str = "600M" + # vLLM Realtime backend + vllm_url: str = "ws://localhost:8000/v1/realtime" + vllm_model: str = "" + def __post_init__(self): # .en model suffix forces English if self.model_size and self.model_size.endswith(".en"): diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 30a8da7..d789e69 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -102,7 +102,16 @@ class TranscriptionEngine: } if config.transcription: - if config.backend == "voxtral-mlx": + if config.backend == "vllm-realtime": + from whisperlivekit.vllm_realtime import VLLMRealtimeASR + self.tokenizer = None + self.asr = VLLMRealtimeASR( + vllm_url=config.vllm_url, + model_name=config.vllm_model or "Qwen/Qwen3-ASR-1.7B", + lan=config.lan, + ) + logger.info("Using vLLM Realtime streaming backend at %s", config.vllm_url) + elif config.backend == "voxtral-mlx": from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR self.tokenizer = None self.asr = VoxtralMLXASR(**transcription_common_params) @@ -112,6 +121,14 @@ class TranscriptionEngine: self.tokenizer = None self.asr = VoxtralHFStreamingASR(**transcription_common_params) logger.info("Using Voxtral HF Transformers streaming backend") + elif config.backend == "qwen3-simul": + from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR + self.tokenizer = None + self.asr = Qwen3SimulStreamingASR( + **transcription_common_params, + alignment_heads_path=config.custom_alignment_heads, + ) + logger.info("Using Qwen3-ASR backend with SimulStreaming policy") elif config.backend == "qwen3": from whisperlivekit.qwen3_asr import Qwen3ASR self.asr = Qwen3ASR(**transcription_common_params) @@ -210,6 +227,12 @@ def online_factory(args, asr, language=None): asr = SessionASRProxy(asr, language) backend = getattr(args, 'backend', None) + if backend == "vllm-realtime": + from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor + return VLLMRealtimeOnlineProcessor(asr) + if backend == "qwen3-simul": + from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor + return Qwen3SimulStreamingOnlineProcessor(asr) if backend == "voxtral-mlx": from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor return VoxtralMLXOnlineProcessor(asr) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 2eaeb16..a6f23f6 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -147,8 +147,8 @@ def parse_args(): "--backend", type=str, default="auto", - choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3"], - help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon. Use 'qwen3' for Qwen3-ASR.", + choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"], + help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.", ) parser.add_argument( "--no-vac", @@ -196,6 +196,22 @@ def parse_args(): default=False, help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder." ) + # vLLM Realtime backend arguments + parser.add_argument( + "--vllm-url", + type=str, + default="ws://localhost:8000/v1/realtime", + dest="vllm_url", + help="URL of the vLLM realtime WebSocket endpoint.", + ) + parser.add_argument( + "--vllm-model", + type=str, + default="", + dest="vllm_model", + help="Model name to use with vLLM (e.g. Qwen/Qwen3-ASR-1.7B).", + ) + # SimulStreaming-specific arguments simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)') diff --git a/whisperlivekit/qwen3_asr.py b/whisperlivekit/qwen3_asr.py index 967a430..ec6b0ef 100644 --- a/whisperlivekit/qwen3_asr.py +++ b/whisperlivekit/qwen3_asr.py @@ -1,4 +1,5 @@ import logging +import re import sys from typing import List, Optional @@ -11,12 +12,10 @@ logger = logging.getLogger(__name__) def _patch_transformers_compat(): - """Patch transformers for qwen_asr compatibility. + """Patch transformers for qwen_asr 0.0.6 + transformers >= 5.3 compatibility.""" + import torch - qwen_asr imports ``check_model_inputs`` from ``transformers.utils.generic``, - but this decorator hasn't been released yet in any public transformers - version. We inject a no-op stub so the import succeeds. - """ + # 1. check_model_inputs was removed try: import transformers.utils.generic as _g if not hasattr(_g, "check_model_inputs"): @@ -28,6 +27,63 @@ def _patch_transformers_compat(): except ImportError: pass + # 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS + try: + from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + if "default" not in ROPE_INIT_FUNCTIONS: + def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs): + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + partial = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial) + base = config.rope_theta + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, 1.0 + ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters + except ImportError: + pass + + # 3. pad_token_id missing on thinker config + try: + from qwen_asr.core.transformers_backend.configuration_qwen3_asr import ( + Qwen3ASRThinkerConfig, + ) + if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"): + Qwen3ASRThinkerConfig.pad_token_id = None + except ImportError: + pass + + # 4. fix_mistral_regex kwarg not accepted by newer transformers + try: + from transformers.models.auto import processing_auto + _orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__ + + @classmethod + def _patched_ap_from_pretrained(cls, *args, **kwargs): + kwargs.pop("fix_mistral_regex", None) + return _orig_ap_from_pretrained(cls, *args, **kwargs) + + processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained + except Exception: + pass + + # 5. compute_default_rope_parameters missing on RotaryEmbedding + try: + from qwen_asr.core.transformers_backend.modeling_qwen3_asr import ( + Qwen3ASRThinkerTextRotaryEmbedding, + ) + if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"): + @staticmethod + def _rope_params(config=None, device=None, seq_len=None, **kwargs): + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + partial = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial) + base = config.rope_theta + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, 1.0 + Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _rope_params + except ImportError: + pass + _patch_transformers_compat() @@ -62,6 +118,9 @@ QWEN3_MODEL_MAPPING = { } _PUNCTUATION_ENDS = set(".!?。!?;;") +# Qwen3 raw output starts with "language " metadata before tag. +# When the tag is missing (silence/noise), this metadata leaks as transcription text. +_GARBAGE_RE = re.compile(r"^language\s+\S+$", re.IGNORECASE) class Qwen3ASR(ASRBase): @@ -88,8 +147,12 @@ class Qwen3ASR(ASRBase): else: model_id = "Qwen/Qwen3-ASR-1.7B" - dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 - device = "cuda:0" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + dtype, device = torch.bfloat16, "cuda:0" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + dtype, device = torch.float32, "mps" + else: + dtype, device = torch.float32, "cpu" logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})") model = Qwen3ASRModel.from_pretrained( @@ -126,17 +189,32 @@ class Qwen3ASR(ASRBase): result = results[0] # Stash audio length for timestamp estimation fallback result._audio_duration = len(audio) / 16000 + logger.info( + "Qwen3 result: language=%r text=%r ts=%s", + result.language, result.text[:80] if result.text else "", + bool(result.time_stamps), + ) return result @staticmethod def _detected_language(result) -> Optional[str]: """Extract Whisper-style language code from Qwen3 result.""" lang = getattr(result, 'language', None) - if lang: - return QWEN3_TO_WHISPER_LANGUAGE.get(lang, lang.lower()) - return None + if not lang or lang.lower() == "none": + return None + # merge_languages may return comma-separated; take the first + first = lang.split(",")[0].strip() + if not first or first.lower() == "none": + return None + return QWEN3_TO_WHISPER_LANGUAGE.get(first, first.lower()) def ts_words(self, result) -> List[ASRToken]: + # Filter garbage model output (e.g. "language None" for silence/noise) + text = (result.text or "").strip() + if not text or _GARBAGE_RE.match(text): + if text: + logger.info("Filtered garbage Qwen3 output: %r", text) + return [] detected = self._detected_language(result) if result.time_stamps: tokens = [] diff --git a/whisperlivekit/qwen3_simul.py b/whisperlivekit/qwen3_simul.py new file mode 100644 index 0000000..eaac07e --- /dev/null +++ b/whisperlivekit/qwen3_simul.py @@ -0,0 +1,837 @@ +""" +SimulStreaming-style online processor for Qwen3-ASR. + +Architecture overview +--------------------- +Qwen3-ASR is a decoder-only multimodal model. Audio is encoded by an audio +encoder (Whisper-style) into a sequence of embeddings that replace <|audio_pad|> +placeholder tokens in the input sequence. The text decoder then uses causal +self-attention over the combined audio + text tokens. + +Unlike Whisper (which has explicit cross-attention between decoder and encoder), +Qwen3-ASR uses self-attention where generated text tokens attend to earlier +audio tokens and previously generated text. This means "alignment heads" here +are self-attention heads whose attention over the *audio-token region* tracks +the monotonic audio-to-text alignment. + +The border-distance policy works as follows: + - After each generated token, extract the attention weights from the + selected alignment heads, restricted to the audio-token region + - Find which audio frame each head attends to most strongly (argmax) + - If the most-attended audio frame is approaching the end of the available + audio, pause generation and wait for more audio + - If the most-attended frame jumps backward (rewind), discard recent tokens + +This module loads the Qwen3-ASR model *directly* via transformers (not through +the qwen_asr package's Qwen3ASRModel wrapper), giving us full control over +forward passes, KV caches, and attention extraction. + +Requires: + - A pre-computed alignment heads JSON file (from detect_alignment_heads_qwen3.py) + - OR will fall back to all heads in a configurable set of layers +""" + +import json +import logging +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 16000 + + +@dataclass +class Qwen3SimulConfig: + """Configuration for Qwen3 SimulStreaming.""" + model_id: str = "Qwen/Qwen3-ASR-1.7B" + alignment_heads_path: Optional[str] = None + language: str = "auto" + # Border/rewind thresholds as fraction of audio tokens (not absolute frames). + # Qwen3 has ~13 audio tokens/sec vs Whisper's ~50, so absolute thresholds + # don't transfer. 0.15 = pause when attention is within last 15% of audio. + border_fraction: float = 0.15 # Fraction of audio tokens from end to trigger pause + rewind_fraction: float = 0.12 # Max backward jump as fraction of audio tokens + audio_min_len: float = 0.5 # Minimum audio length before starting decode + audio_max_len: float = 15.0 # Maximum audio buffer length in seconds + max_context_tokens: int = 30 # Max committed tokens to include as context + init_prompt: Optional[str] = None + max_alignment_heads: int = 20 # Use only top N alignment heads + + +@dataclass +class Qwen3SimulState: + """Per-session mutable state for Qwen3 SimulStreaming.""" + # Audio + audio_buffer: np.ndarray = field( + default_factory=lambda: np.array([], dtype=np.float32) + ) + cumulative_time_offset: float = 0.0 + global_time_offset: float = 0.0 + speaker: int = -1 + + # Decode state + last_attend_frame: int = -15 + generated_tokens: List[int] = field(default_factory=list) + committed_text: str = "" + committed_word_count: int = 0 # How many words already emitted + committed_token_ids: List[int] = field(default_factory=list) # token IDs for prompt context + + # Tracking + first_timestamp: Optional[float] = None + detected_language: Optional[str] = None + last_infer_samples: int = 0 # audio_buffer length at last inference + + +class Qwen3SimulStreamingASR: + """ + Shared backend for Qwen3-ASR SimulStreaming. + + Loads the model once and is shared across sessions. Each session gets + its own Qwen3SimulStreamingOnlineProcessor with independent state. + """ + + sep = "" + + def __init__( + self, + model_size: str = None, + model_dir: str = None, + lan: str = "auto", + alignment_heads_path: Optional[str] = None, + border_fraction: float = 0.15, + min_chunk_size: float = 0.1, + warmup_file: Optional[str] = None, + model_cache_dir: Optional[str] = None, + model_path: Optional[str] = None, + lora_path: Optional[str] = None, + direct_english_translation: bool = False, + **kwargs, + ): + self.transcribe_kargs = {} + self.original_language = None if lan == "auto" else lan + self.warmup_file = warmup_file + + self.cfg = Qwen3SimulConfig( + language=lan, + alignment_heads_path=alignment_heads_path, + border_fraction=border_fraction, + ) + + # Load model directly via transformers + self._load_model(model_size, model_dir, model_cache_dir, model_path) + + # Load alignment heads + self.alignment_heads = self._load_alignment_heads(alignment_heads_path) + + # Warmup + if warmup_file: + from whisperlivekit.warmup import load_file + audio = load_file(warmup_file) + if audio is not None: + logger.info("Warming up Qwen3 SimulStreaming model") + # Simple warmup: just encode a short audio + self._warmup(audio) + + def _load_model(self, model_size, model_dir, model_cache_dir, model_path): + """Load Qwen3-ASR via transformers (SDPA attention for speed).""" + from whisperlivekit.qwen3_asr import ( + QWEN3_MODEL_MAPPING, + _patch_transformers_compat, + ) + _patch_transformers_compat() + + from qwen_asr.core.transformers_backend import ( + Qwen3ASRConfig, + Qwen3ASRForConditionalGeneration, + Qwen3ASRProcessor, + ) + from transformers import AutoConfig, AutoModel, AutoProcessor + + AutoConfig.register("qwen3_asr", Qwen3ASRConfig) + AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration) + AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor) + + if model_dir: + model_id = model_dir + elif model_path: + model_id = model_path + elif model_size: + model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size) + else: + model_id = "Qwen/Qwen3-ASR-1.7B" + + if torch.cuda.is_available(): + dtype, device = torch.bfloat16, "cuda:0" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + dtype, device = torch.float32, "mps" + else: + dtype, device = torch.float32, "cpu" + + logger.info("Loading Qwen3-ASR for SimulStreaming: %s (sdpa attention)", model_id) + self.model = AutoModel.from_pretrained( + model_id, + torch_dtype=dtype, + device_map=device, + ) + self.model.eval() + self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True) + + # Cache model properties + thinker = self.model.thinker + text_config = thinker.config.text_config + self.num_layers = text_config.num_hidden_layers + self.num_heads = text_config.num_attention_heads + self.num_kv_heads = text_config.num_key_value_heads + self.audio_token_id = thinker.config.audio_token_id + self.device = next(self.model.parameters()).device + self.dtype = next(self.model.parameters()).dtype + + # Cache special token IDs for metadata stripping + self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("") + + logger.info( + "Qwen3-ASR loaded: %d layers x %d heads, device=%s, id=%d", + self.num_layers, self.num_heads, self.device, self.asr_text_token_id, + ) + + def _load_alignment_heads( + self, path: Optional[str], + ) -> List[Tuple[int, int]]: + """Load alignment heads from JSON or use defaults. + + Only loads the top N heads (sorted by TS score) for efficiency. + The Qwen3-ASR model has alignment info spread across most heads + (decoder-only, no cross-attention), so we pick the strongest ones. + """ + max_heads = self.cfg.max_alignment_heads + + if path and Path(path).exists(): + with open(path) as f: + data = json.load(f) + # alignment_heads_compact is pre-sorted by TS score (descending) + all_heads = [tuple(h) for h in data["alignment_heads_compact"]] + heads = all_heads[:max_heads] + logger.info( + "Loaded top %d alignment heads from %s (of %d total)", + len(heads), path, len(all_heads), + ) + return heads + + # Default: use heads from the last quarter of layers + default_heads = [] + start_layer = self.num_layers * 3 // 4 + for layer in range(start_layer, self.num_layers): + for head in range(self.num_heads): + default_heads.append((layer, head)) + logger.warning( + "No alignment heads file found. Using default heuristic: " + "%d heads from layers %d-%d. Run detect_alignment_heads_qwen3.py " + "to find optimal heads.", + len(default_heads), start_layer, self.num_layers - 1, + ) + return default_heads[:max_heads] + + def _warmup(self, audio: np.ndarray): + """Run a short inference to warmup the model.""" + try: + audio = audio[:SAMPLE_RATE * 2] # Max 2 seconds + msgs = [ + {"role": "system", "content": ""}, + {"role": "user", "content": [{"type": "audio", "audio": ""}]}, + ] + text_prompt = self.processor.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=False, + ) + inputs = self.processor( + text=[text_prompt], + audio=[audio], + return_tensors="pt", + padding=True, + ) + inputs = inputs.to(self.device).to(self.dtype) + + with torch.inference_mode(): + self.model.thinker.generate( + **inputs, max_new_tokens=5, do_sample=False, + ) + logger.info("Qwen3 SimulStreaming warmup complete") + except Exception as e: + logger.warning("Warmup failed: %s", e) + + def transcribe(self, audio): + """No-op -- SimulStreaming uses the online processor directly.""" + pass + + +class Qwen3SimulStreamingOnlineProcessor: + """ + Per-session online processor for Qwen3-ASR SimulStreaming. + + Implements the same interface as SimulStreamingOnlineProcessor: + - insert_audio_chunk(audio, time) + - process_iter(is_last=False) -> (List[ASRToken], float) + - get_buffer() -> Transcript + - start_silence() -> (List[ASRToken], float) + - end_silence(duration, offset) + - finish() -> (List[ASRToken], float) + """ + + SAMPLING_RATE = 16000 + MIN_DURATION_REAL_SILENCE = 5 + + def __init__(self, asr: Qwen3SimulStreamingASR, logfile=sys.stderr): + self.asr = asr + self.logfile = logfile + self.end = 0.0 + self.buffer: List[ASRToken] = [] + + # Per-session state + self.state = Qwen3SimulState() + + # Build the prompt template once + self._build_prompt_template() + + def _build_prompt_template(self): + """Build the base text prompt for Qwen3-ASR.""" + from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE + + msgs = [ + {"role": "system", "content": ""}, + {"role": "user", "content": [{"type": "audio", "audio": ""}]}, + ] + self._base_prompt = self.asr.processor.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=False, + ) + + # Add language forcing if configured + lan = self.asr.cfg.language + if lan and lan != "auto": + lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan) + self._base_prompt += f"language {lang_name}" + + @property + def speaker(self): + return self.state.speaker + + @speaker.setter + def speaker(self, value): + self.state.speaker = value + + @property + def global_time_offset(self): + return self.state.global_time_offset + + @global_time_offset.setter + def global_time_offset(self, value): + self.state.global_time_offset = value + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + """Append an audio chunk to be processed.""" + self.end = audio_stream_end_time + self.state.audio_buffer = np.append(self.state.audio_buffer, audio) + + # Trim audio if too long + max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE) + if len(self.state.audio_buffer) > max_samples: + trim = len(self.state.audio_buffer) - max_samples + self.state.audio_buffer = self.state.audio_buffer[trim:] + self.state.cumulative_time_offset += trim / self.SAMPLING_RATE + # Adjust throttle counter so it tracks position within the trimmed buffer + self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim) + + def start_silence(self) -> Tuple[List[ASRToken], float]: + """Handle start of silence -- flush all pending tokens. + + Loops inference until the model produces no new tokens, since a + single is_last call may not exhaust all text for the buffered audio. + """ + all_tokens = [] + for _ in range(5): # safety limit + tokens, processed_upto = self.process_iter(is_last=True) + if not tokens: + break + all_tokens.extend(tokens) + return all_tokens, self.end + + def end_silence(self, silence_duration: float, offset: float): + """Handle silence period.""" + self.end += silence_duration + long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE + if not long_silence: + gap_len = int(self.SAMPLING_RATE * silence_duration) + if gap_len > 0: + gap_silence = np.zeros(gap_len, dtype=np.float32) + self.state.audio_buffer = np.append( + self.state.audio_buffer, gap_silence, + ) + else: + # Long silence: reset + self.state = Qwen3SimulState() + self.state.global_time_offset = silence_duration + offset + + def new_speaker(self, change_speaker: ChangeSpeaker): + """Handle speaker change event.""" + self.process_iter(is_last=True) + self.state = Qwen3SimulState() + self.state.speaker = change_speaker.speaker + self.state.global_time_offset = change_speaker.start + + def get_buffer(self) -> Transcript: + """Get the current unvalidated buffer.""" + return Transcript.from_tokens(tokens=self.buffer, sep='') + + @torch.inference_mode() + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + """ + Process accumulated audio using SimulStreaming with alignment heads. + + This performs a full forward pass (encode audio + greedy decode with + attention extraction), applying the border-distance policy to decide + when to stop generating. + + Returns: + Tuple of (committed ASRToken list, audio processed up to time). + """ + audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE + if audio_duration < self.asr.cfg.audio_min_len: + return [], self.end + + # Throttle: skip inference if less than 1s of new audio since last run. + # Each inference re-encodes the full buffer, so calling too often wastes + # GPU/CPU time and causes lag to spiral. + new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples + min_new_seconds = 1.0 + if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE): + return [], self.end + + logger.info("Running SimulStreaming inference on %.2fs of audio (%.2fs new)", audio_duration, new_samples / self.SAMPLING_RATE) + self.state.last_infer_samples = len(self.state.audio_buffer) + + try: + timestamped_words = self._infer(is_last) + except Exception as e: + logger.exception("Qwen3 SimulStreaming inference error: %s", e) + return [], self.end + + logger.info("SimulStreaming produced %d words", len(timestamped_words)) + if not timestamped_words: + return [], self.end + + self.buffer = [] + return timestamped_words, self.end + + def _infer(self, is_last: bool) -> List[ASRToken]: + """Run one inference cycle with alignment-head-based stopping. + + Uses forward hooks on self_attn modules to capture attention weights + during generation. The Qwen3-ASR decoder layer discards attention + weights (hidden_states, _ = self.self_attn(...)), so output_attentions + via generate() would return None. Hooks capture them before discard. + """ + asr = self.asr + state = self.state + + # Prepare inputs + inputs = asr.processor( + text=[self._base_prompt], + audio=[state.audio_buffer], + return_tensors="pt", + padding=True, + ) + inputs = inputs.to(asr.device).to(asr.dtype) + + # Append committed token IDs as context so generate() continues from + # where it left off. Cap at max_context_tokens to prevent prompt growth. + if state.committed_token_ids: + ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:] + ctx_ids = torch.tensor( + [ctx], dtype=inputs.input_ids.dtype, + device=inputs.input_ids.device, + ) + inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1) + if "attention_mask" in inputs: + ctx_mask = torch.ones_like(ctx_ids) + inputs["attention_mask"] = torch.cat( + [inputs.attention_mask, ctx_mask], dim=1, + ) + + prompt_len = inputs.input_ids.shape[1] + + # Find audio token range + input_ids = inputs.input_ids[0] + audio_mask = (input_ids == asr.audio_token_id) + audio_positions = audio_mask.nonzero(as_tuple=True)[0] + if len(audio_positions) == 0: + return [] + audio_start = audio_positions[0].item() + audio_end = audio_positions[-1].item() + 1 + n_audio_tokens = audio_end - audio_start + + audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE + + # Install forward hooks to capture alignment attention from Q and K. + # With SDPA attention (fast), attn_weights are not returned. Instead, + # we hook self_attn to compute Q*K^T attention ONLY for alignment heads + # during autoregressive steps (q_len == 1). This is cheap because we + # only compute dot products for ~20 heads, not full attention for all. + # + # Key detail: self_attn is called with ALL keyword arguments from the + # decoder layer, so hidden_states/position_embeddings/past_key_values + # are all in kwargs, not args. + per_step_frames: List[List[int]] = [] + current_step_frames: List[int] = [] + + heads_by_layer: dict = {} + for layer_idx, head_idx in asr.alignment_heads: + heads_by_layer.setdefault(layer_idx, []).append(head_idx) + + decoder_layers = asr.model.thinker.model.layers + num_kv_heads = asr.num_kv_heads + num_heads = asr.num_heads + gqa_ratio = num_heads // num_kv_heads # GQA group size + + # Import RoPE function used by this model's attention + from qwen_asr.core.transformers_backend.modeling_qwen3_asr import ( + apply_rotary_pos_emb, + ) + + hooks = [] + + def _make_attn_hook(layer_idx): + """Forward hook on self_attn that computes Q*K^T for alignment heads. + + After the forward pass, we recompute Q (with RoPE) for the current + token and dot it against the cached K (which already has RoPE) in + the audio region. This gives us per-head alignment frames. + """ + head_indices = heads_by_layer[layer_idx] + + def hook_fn(module, args, kwargs, output): + # All arguments are keyword-passed from the decoder layer + hidden_states = kwargs.get('hidden_states') + if hidden_states is None: + hidden_states = args[0] if args else None + if hidden_states is None or hidden_states.shape[1] != 1: + return # Skip prefill (seq_len > 1) + + position_embeddings = kwargs.get('position_embeddings') + if position_embeddings is None and len(args) > 1: + position_embeddings = args[1] + past_kv = kwargs.get('past_key_values') + if position_embeddings is None or past_kv is None: + return + + # Recompute Q with RoPE (cheap: single token through q_proj + RoPE) + hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim) + q = module.q_norm( + module.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + cos, sin = position_embeddings + q, _ = apply_rotary_pos_emb(q, q, cos, sin) + + # K from cache already has RoPE applied + cache_layer = past_kv.layers[module.layer_idx] + k = cache_layer.keys # (batch, n_kv_heads, kv_len, head_dim) + if k is None or audio_end > k.shape[2]: + return + + # Compute attention scores for alignment heads only + for h_idx in head_indices: + if h_idx >= q.shape[1]: + continue + kv_h_idx = h_idx // gqa_ratio + q_h = q[0, h_idx, 0] # (head_dim,) + k_audio = k[0, kv_h_idx, audio_start:audio_end] # (n_audio, head_dim) + scores = torch.matmul(k_audio, q_h) # (n_audio,) + frame = scores.argmax().item() + current_step_frames.append(frame) + + return hook_fn + + for layer_idx in heads_by_layer: + if layer_idx < len(decoder_layers): + h = decoder_layers[layer_idx].self_attn.register_forward_hook( + _make_attn_hook(layer_idx), + with_kwargs=True, + ) + hooks.append(h) + + # Step boundary hook on lm_head to separate per-step frames + # and check border-distance stopping criteria in real-time. + # This is CRITICAL for performance: instead of generating 200 tokens + # then truncating, we stop as soon as attention hits the audio border. + # On MPS, each token costs ~50ms, so stopping at 10 tokens vs 200 + # means ~0.5s vs ~10s inference. + last_attend_frame = state.last_attend_frame + border_stop_step: Optional[int] = None + + # Compute absolute thresholds from fractional config + border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction)) + rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction)) + + def _step_boundary_hook(module, args, output): + nonlocal current_step_frames, last_attend_frame, border_stop_step + if current_step_frames: + per_step_frames.append(current_step_frames) + current_step_frames = [] + + # Check border distance on each step. + # Allow at least 3 steps before checking, so short buffers + # can still produce some tokens during streaming. + if not is_last and border_stop_step is None and len(per_step_frames) >= 3: + latest = per_step_frames[-1] + if latest: + frames_sorted = sorted(latest) + attended = frames_sorted[len(frames_sorted) // 2] + + # Rewind check + if last_attend_frame - attended > rewind_threshold: + border_stop_step = max(0, len(per_step_frames) - 2) + return + + last_attend_frame = attended + + # Border check + if (n_audio_tokens - attended) <= border_threshold: + border_stop_step = len(per_step_frames) - 1 + return + + lm_head = asr.model.thinker.lm_head + step_hook = lm_head.register_forward_hook(_step_boundary_hook) + hooks.append(step_hook) + + # StoppingCriteria that stops generation when border distance is hit + from transformers import StoppingCriteria, StoppingCriteriaList + + class BorderStop(StoppingCriteria): + def __call__(self, input_ids, scores, **kwargs): + return border_stop_step is not None + + stopping = StoppingCriteriaList([BorderStop()]) + + # Limit max tokens to what's reasonable for the audio duration. + # On MPS, each token costs ~50-100ms, so tight limits are critical. + # Speech produces ~4-6 tokens/sec; +5 for metadata prefix tokens. + # With is_last, allow slightly more for flushing remaining text. + new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE + tokens_per_sec = 6 + if is_last: + max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120) + else: + max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40) + + try: + outputs = asr.model.thinker.generate( + **inputs, + max_new_tokens=max_tokens, + do_sample=False, + stopping_criteria=stopping, + ) + finally: + for h in hooks: + h.remove() + # Flush any remaining frames + if current_step_frames: + per_step_frames.append(current_step_frames) + + state.last_attend_frame = last_attend_frame + + # Extract generated tokens + all_generated = outputs[0, prompt_len:] + eos_ids = {151645, 151643} + if asr.processor.tokenizer.eos_token_id is not None: + eos_ids.add(asr.processor.tokenizer.eos_token_id) + + num_gen = len(all_generated) + for i, tid in enumerate(all_generated): + if tid.item() in eos_ids: + num_gen = i + break + + raw_text = asr.processor.tokenizer.decode(all_generated[:num_gen], skip_special_tokens=True) + logger.info( + "SimulStreaming raw output: %d tokens (stopped at step %s), text=%r", + num_gen, border_stop_step, raw_text[:100], + ) + + if num_gen == 0: + return [] + + # Strip metadata prefix: when language is "auto", the model generates + # "language ..." before actual transcription text. + # Find token and skip everything before it (including itself). + asr_text_id = asr.asr_text_token_id + metadata_offset = 0 + for i in range(min(num_gen, 10)): # metadata is at most ~3-4 tokens + if all_generated[i].item() == asr_text_id: + # Detect language from the metadata prefix before stripping + if state.detected_language is None and i > 0: + from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE + prefix_text = asr.processor.tokenizer.decode( + all_generated[:i].tolist(), skip_special_tokens=True, + ).strip() + parts = prefix_text.split() + if len(parts) >= 2: + lang_name = parts[-1] + if lang_name.lower() != "none": + state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get( + lang_name, lang_name.lower(), + ) + logger.info("Auto-detected language: %s", state.detected_language) + metadata_offset = i + 1 + break + + if metadata_offset > 0: + logger.info( + "Stripping %d metadata prefix tokens (before )", + metadata_offset, + ) + all_generated = all_generated[metadata_offset:] + num_gen -= metadata_offset + per_step_frames = per_step_frames[metadata_offset:] + + if num_gen <= 0: + return [] + + # Determine how many tokens to emit based on border stopping + step_frames = [f for f in per_step_frames if f] + if border_stop_step is not None: + emit_up_to = min(border_stop_step, num_gen) + else: + emit_up_to = num_gen + + # Build timestamped words from the emitted tokens + generated_ids = all_generated[:emit_up_to] + if len(generated_ids) == 0: + return [] + + all_words = self._build_timestamped_words( + generated_ids, step_frames, emit_up_to, + n_audio_tokens, audio_duration, + ) + + new_words = all_words + + # Update committed word count for space-prefix logic in next batch + state.committed_word_count += len(new_words) + + # Append newly emitted token IDs to committed context for next call + new_emitted = outputs[0, prompt_len:prompt_len + emit_up_to + metadata_offset] + state.committed_token_ids.extend(new_emitted.tolist()) + + return new_words + + def _build_timestamped_words( + self, + generated_ids: torch.Tensor, + step_frames: List[List[int]], + emit_up_to: int, + n_audio_tokens: int, + audio_duration: float, + ) -> List[ASRToken]: + """Build timestamped ASRToken list from generated tokens and hook-captured frames.""" + asr = self.asr + state = self.state + + # Get per-token attended audio frame (median of alignment head votes) + per_token_frame: List[Optional[int]] = [] + for step in range(emit_up_to): + if step < len(step_frames) and step_frames[step]: + frames = sorted(step_frames[step]) + per_token_frame.append(frames[len(frames) // 2]) + else: + per_token_frame.append(None) + + # Decode the full generated sequence at once, then split into words. + # This is more robust than per-token Ġ detection, which can fail when + # committed context causes the model to generate sub-word continuations. + tokenizer = asr.processor.tokenizer + full_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) + text_words = full_text.split() + + # Map each text word to an approximate frame using token-level alignment. + # Distribute frames evenly across words (since exact token→word mapping + # is imprecise with BPE sub-words anyway). + all_frames = [f for f in per_token_frame if f is not None] + words = [] + for wi, word in enumerate(text_words): + if all_frames: + # Proportionally assign frames to words + frac = wi / max(len(text_words), 1) + frame_idx = int(frac * len(all_frames)) + frame_idx = min(frame_idx, len(all_frames) - 1) + frame = all_frames[frame_idx] + else: + frame = None + words.append((word, frame)) + + # Convert to ASRToken with timestamps + tokens = [] + for i, (text, frame) in enumerate(words): + text = text.strip() + if not text: + continue + + if frame is not None and n_audio_tokens > 0: + timestamp = ( + frame / n_audio_tokens * audio_duration + + state.cumulative_time_offset + ) + else: + timestamp = ( + (i / max(len(words), 1)) * audio_duration + + state.cumulative_time_offset + ) + + # Prefix space: first word of the very first batch has no space; + # all subsequent words (same batch or later batches) get a space. + is_very_first_word = (i == 0 and state.committed_word_count == 0) + display_text = text if is_very_first_word else " " + text + + token = ASRToken( + start=round(timestamp, 2), + end=round(timestamp + 0.1, 2), + text=display_text, + speaker=state.speaker, + detected_language=state.detected_language, + ).with_offset(state.global_time_offset) + tokens.append(token) + + return tokens + + @staticmethod + def _median_frame(frames: List[int]) -> Optional[int]: + """Return median of frame list, or None if empty.""" + if not frames: + return None + frames_sorted = sorted(frames) + return frames_sorted[len(frames_sorted) // 2] + + def warmup(self, audio: np.ndarray, init_prompt: str = ""): + """Warmup the model with a short audio clip.""" + try: + self.state.audio_buffer = audio[:SAMPLE_RATE] + self.process_iter(is_last=True) + self.state = Qwen3SimulState() + logger.info("Qwen3 SimulStreaming online processor warmed up") + except Exception as e: + logger.warning("Warmup failed: %s", e) + self.state = Qwen3SimulState() + + def finish(self) -> Tuple[List[ASRToken], float]: + """Flush remaining audio at end of stream.""" + all_tokens = [] + for _ in range(5): # safety limit + tokens, _ = self.process_iter(is_last=True) + if not tokens: + break + all_tokens.extend(tokens) + return all_tokens, self.end diff --git a/whisperlivekit/vllm_realtime.py b/whisperlivekit/vllm_realtime.py new file mode 100644 index 0000000..53c022b --- /dev/null +++ b/whisperlivekit/vllm_realtime.py @@ -0,0 +1,416 @@ +""" +vLLM Realtime WebSocket streaming backend for WhisperLiveKit. + +Connects to a vLLM server's ``/v1/realtime`` WebSocket endpoint to stream +audio and receive transcription deltas. Uses ``websockets.sync.client`` +for simplicity since ``process_iter`` runs inside ``asyncio.to_thread``. + +Provides ``VLLMRealtimeASR`` (lightweight model holder) and +``VLLMRealtimeOnlineProcessor`` (streaming processor) that plug into +WhisperLiveKit's audio processing pipeline. +""" + +import base64 +import json +import logging +import threading +import time +from typing import List, Optional, Tuple + +import numpy as np + +from whisperlivekit.timed_objects import ASRToken, Transcript + +logger = logging.getLogger(__name__) + + +class VLLMRealtimeASR: + """Lightweight model holder — stores connection info for the vLLM server.""" + + sep = " " + SAMPLING_RATE = 16000 + backend_choice = "vllm-realtime" + + def __init__(self, vllm_url="ws://localhost:8000/v1/realtime", + model_name="Qwen/Qwen3-ASR-1.7B", lan="auto", **kwargs): + self.vllm_url = vllm_url + self.model_name = model_name + self.original_language = None if lan == "auto" else lan + self.tokenizer = None + + def transcribe(self, audio): + pass + + +class VLLMRealtimeOnlineProcessor: + """ + Online processor that streams audio to a vLLM Realtime WebSocket. + + Uses a background thread for WebSocket receiving and + ``websockets.sync.client`` for the sync WebSocket connection. + """ + + SAMPLING_RATE = 16000 + # Minimum audio samples before connecting (0.5s of audio) + _MIN_CONNECT_SAMPLES = SAMPLING_RATE // 2 + + def __init__(self, asr: VLLMRealtimeASR): + self.asr = asr + self.end = 0.0 + self.buffer = [] + self.audio_buffer = np.array([], dtype=np.float32) + + self._reset_state() + + logger.info( + "[vllm-realtime] Initialized. url=%s model=%s", + asr.vllm_url, asr.model_name, + ) + + def _reset_state(self): + self._pending_audio = np.zeros(0, dtype=np.float32) + self._ws = None + self._recv_thread: Optional[threading.Thread] = None + self._connected = False + self._done = False + self._recv_error: Optional[Exception] = None + + # Text accumulation and word extraction + self._accumulated_text = "" + self._n_committed_words = 0 + self._total_audio_duration = 0.0 + self._global_time_offset = 0.0 + + # Lock for text state accessed from both recv thread and main thread + self._text_lock = threading.Lock() + + # ── Interface methods ── + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + self.end = audio_stream_end_time + self._pending_audio = np.append(self._pending_audio, audio) + self.audio_buffer = self._pending_audio + + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + try: + return self._process_iter_inner(is_last) + except Exception as e: + logger.warning("[vllm-realtime] process_iter exception: %s", e, exc_info=True) + return [], self.end + + def get_buffer(self) -> Transcript: + """Return all uncommitted text as buffer.""" + self._drain_deltas() + with self._text_lock: + text = self._accumulated_text + if not text: + return Transcript(start=None, end=None, text="") + + words = text.split() + uncommitted = words[self._n_committed_words:] + if uncommitted: + return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted)) + return Transcript(start=None, end=None, text="") + + def start_silence(self) -> Tuple[List[ASRToken], float]: + """Flush all pending words when silence starts. + + Sends commit(final=true) to signal end of utterance, waits for + transcription.done, flushes all words, then prepares for reconnection + on the next utterance. + """ + if not self._connected or self._done: + words = self._flush_all_pending_words() + logger.info("[vllm-realtime] start_silence (not connected): flushed %d words", len(words)) + return words, self.end + + # Send any remaining buffered audio + self._send_pending_audio() + + # Signal end of stream + self._send_commit(final=True) + + # Wait for transcription.done + self._wait_for_done(timeout=10.0) + + # Flush all remaining words + words = self._flush_all_pending_words() + + # Close and reset for next utterance + self._close_ws() + old_offset = self._global_time_offset + self._total_audio_duration + self._reset_state() + self._global_time_offset = old_offset + + logger.info("[vllm-realtime] start_silence: flushed %d words", len(words)) + return words, self.end + + def end_silence(self, silence_duration: float, offset: float): + self._global_time_offset += silence_duration + self.end += silence_duration + + def new_speaker(self, change_speaker): + self.start_silence() + + def warmup(self, audio, init_prompt=""): + pass + + def finish(self) -> Tuple[List[ASRToken], float]: + """Close connection and flush all remaining words.""" + if self._connected and not self._done: + # Send remaining audio + self._send_pending_audio() + + # Signal final commit + self._send_commit(final=True) + + # Wait for transcription.done + self._wait_for_done(timeout=30.0) + + # Flush all words + words = self._flush_all_pending_words() + + # Close WebSocket + self._close_ws() + + logger.info("[vllm-realtime] finish: flushed %d words", len(words)) + return words, self.end + + # ── WebSocket connection management ── + + def _connect(self): + """Connect to the vLLM realtime WebSocket and start the receive thread.""" + from websockets.sync.client import connect + + url = self.asr.vllm_url + logger.info("[vllm-realtime] Connecting to %s", url) + + self._ws = connect(url) + + # Send session.update to select model + self._ws.send(json.dumps({ + "type": "session.update", + "model": self.asr.model_name, + })) + + # Send initial commit(final=false) to start generation + self._send_commit(final=False) + + # Start receive thread + self._recv_thread = threading.Thread(target=self._recv_loop, daemon=True) + self._recv_thread.start() + + self._connected = True + logger.info("[vllm-realtime] Connected and started receive thread") + + def _close_ws(self): + """Close the WebSocket connection and join the receive thread.""" + if self._ws is not None: + try: + self._ws.close() + except Exception: + pass + self._ws = None + + if self._recv_thread is not None: + self._recv_thread.join(timeout=5.0) + self._recv_thread = None + + def _recv_loop(self): + """Background thread: receive messages from the vLLM WebSocket.""" + try: + while not self._done and self._ws is not None: + try: + raw = self._ws.recv(timeout=0.1) + except TimeoutError: + continue + except Exception: + break + + try: + msg = json.loads(raw) + except (json.JSONDecodeError, TypeError): + continue + + msg_type = msg.get("type", "") + + if msg_type == "transcription.delta": + delta = msg.get("delta", "") + if delta: + with self._text_lock: + self._accumulated_text += delta + + elif msg_type == "transcription.done": + done_text = msg.get("text", "") + if done_text: + with self._text_lock: + # Replace accumulated text with final text + self._accumulated_text = done_text + self._done = True + break + + except Exception as e: + logger.error("[vllm-realtime] recv_loop error: %s", e, exc_info=True) + self._recv_error = e + self._done = True + + # ── Protocol messages ── + + def _send_commit(self, final: bool): + """Send input_audio_buffer.commit message.""" + if self._ws is None: + return + try: + self._ws.send(json.dumps({ + "type": "input_audio_buffer.commit", + "final": final, + })) + except Exception as e: + logger.warning("[vllm-realtime] Failed to send commit: %s", e) + + def _send_audio(self, audio: np.ndarray): + """Send audio as a base64-encoded PCM16 append message.""" + if self._ws is None: + return + + # Convert float32 [-1, 1] to int16 PCM + pcm16 = (audio * 32767).astype(np.int16) + audio_bytes = pcm16.tobytes() + audio_b64 = base64.b64encode(audio_bytes).decode("ascii") + + try: + self._ws.send(json.dumps({ + "type": "input_audio_buffer.append", + "audio": audio_b64, + })) + except Exception as e: + logger.warning("[vllm-realtime] Failed to send audio: %s", e) + + def _send_pending_audio(self): + """Send all pending audio to the vLLM server.""" + if len(self._pending_audio) == 0: + return + + # Track total audio duration for timestamp estimation + self._total_audio_duration += len(self._pending_audio) / self.SAMPLING_RATE + + # Send in chunks of 0.5s to avoid overwhelming the WebSocket + chunk_samples = self.SAMPLING_RATE // 2 + while len(self._pending_audio) >= chunk_samples: + chunk = self._pending_audio[:chunk_samples] + self._send_audio(chunk) + self._pending_audio = self._pending_audio[chunk_samples:] + + # Send remaining audio if any + if len(self._pending_audio) > 0: + self._send_audio(self._pending_audio) + self._pending_audio = np.zeros(0, dtype=np.float32) + + self.audio_buffer = self._pending_audio + + # ── Receive helpers ── + + def _drain_deltas(self): + """No-op since the recv thread accumulates text directly.""" + pass + + def _wait_for_done(self, timeout: float = 10.0): + """Wait for transcription.done message from the server.""" + deadline = time.time() + timeout + while not self._done and time.time() < deadline: + time.sleep(0.05) + + if not self._done: + logger.warning("[vllm-realtime] Timed out waiting for transcription.done") + + # ── Word extraction (same approach as VoxtralHF) ── + + def _time_for_word(self, word_idx: int, n_words_total: int) -> Tuple[float, float]: + """Estimate timestamps by linearly distributing words across audio duration.""" + duration = max(self._total_audio_duration, 0.001) + n_total = max(n_words_total, 1) + + start_time = (word_idx / n_total) * duration + self._global_time_offset + end_time = ((word_idx + 1) / n_total) * duration + self._global_time_offset + + return start_time, end_time + + def _extract_new_words(self) -> List[ASRToken]: + """Extract complete words (all but the last, which may still grow).""" + with self._text_lock: + text = self._accumulated_text + if not text: + return [] + + words = text.split() + new_words: List[ASRToken] = [] + n_words_total = len(words) + + while len(words) > self._n_committed_words + 1: + word = words[self._n_committed_words] + start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total) + + text_out = word if self._n_committed_words == 0 else " " + word + new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) + self._n_committed_words += 1 + + return new_words + + def _flush_all_pending_words(self) -> List[ASRToken]: + """Flush ALL words including the last partial one.""" + with self._text_lock: + text = self._accumulated_text + if not text: + return [] + + words = text.split() + new_words: List[ASRToken] = [] + n_words_total = max(len(words), 1) + + while self._n_committed_words < len(words): + word = words[self._n_committed_words] + start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total) + + text_out = word if self._n_committed_words == 0 else " " + word + new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) + self._n_committed_words += 1 + + return new_words + + # ── Core processing ── + + def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]: + # Connect when we have enough audio buffered + if not self._connected: + if len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES: + self._connect() + self._send_pending_audio() + else: + return [], self.end + + # Send any new pending audio + if self._connected and not self._done: + self._send_pending_audio() + + # If connection closed unexpectedly but new audio arrived, reconnect + if self._done and len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES: + flush_words = self._flush_all_pending_words() + old_offset = self._global_time_offset + self._total_audio_duration + self._close_ws() + self._reset_state() + self._global_time_offset = old_offset + self._connect() + self._send_pending_audio() + return flush_words, self.end + + # Extract complete words + new_words = self._extract_new_words() + + if new_words: + logger.info( + "[vllm-realtime] returning %d words: %s", + len(new_words), [w.text for w in new_words], + ) + + self.buffer = [] + return new_words, self.end diff --git a/whisperlivekit/voxtral_hf_streaming.py b/whisperlivekit/voxtral_hf_streaming.py index c530193..a1bd088 100644 --- a/whisperlivekit/voxtral_hf_streaming.py +++ b/whisperlivekit/voxtral_hf_streaming.py @@ -102,7 +102,8 @@ class VoxtralHFStreamingOnlineProcessor: ) def _reset_state(self): - self._pending_audio = np.zeros(0, dtype=np.float32) + self._pending_chunks: List[np.ndarray] = [] + self._pending_len = 0 self._audio_queue: queue.Queue = queue.Queue() self._streamer_texts: List[str] = [] self._generate_thread: Optional[threading.Thread] = None @@ -110,22 +111,63 @@ class VoxtralHFStreamingOnlineProcessor: self._generate_finished = False self._generate_error: Optional[Exception] = None - # Text accumulation and word extraction - self._accumulated_text = "" + # Text accumulation (list of fragments, joined on demand) + self._text_fragments: List[str] = [] + self._text_len = 0 + # Fragment position tracking for accurate word timestamps: + # each entry is (char_offset_in_full_text, audio_tok_pos_consumed) + self._fragment_positions: List[Tuple[int, int]] = [] self._n_text_tokens_received = 0 self._n_audio_tokens_fed = 0 + # Audio tokens actually consumed by the model (tracked inside generator) + self._n_audio_tokens_consumed = 0 self._n_committed_words = 0 self._global_time_offset = 0.0 + # Event signalled by the generate thread when it finishes + self._generate_done = threading.Event() + # Lock for text state accessed from both generate thread and main thread self._text_lock = threading.Lock() + # ── Audio / text helpers ── + + def _get_pending_audio(self) -> np.ndarray: + """Flatten pending audio chunks into a single array.""" + if not self._pending_chunks: + return np.zeros(0, dtype=np.float32) + if len(self._pending_chunks) == 1: + return self._pending_chunks[0] + flat = np.concatenate(self._pending_chunks) + self._pending_chunks = [flat] + return flat + + def _set_pending_audio(self, arr: np.ndarray): + """Replace pending audio with a single array.""" + if len(arr) == 0: + self._pending_chunks = [] + self._pending_len = 0 + else: + self._pending_chunks = [arr] + self._pending_len = len(arr) + + def _get_accumulated_text(self) -> str: + """Get the full accumulated text (joins fragments if needed).""" + if not self._text_fragments: + return "" + if len(self._text_fragments) == 1: + return self._text_fragments[0] + joined = "".join(self._text_fragments) + self._text_fragments = [joined] + return joined + # ── Interface methods ── def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): self.end = audio_stream_end_time - self._pending_audio = np.append(self._pending_audio, audio) - self.audio_buffer = self._pending_audio + self._pending_chunks.append(audio) + self._pending_len += len(audio) + self.audio_buffer = audio # diagnostic only def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: try: @@ -142,7 +184,7 @@ class VoxtralHFStreamingOnlineProcessor: """ self._drain_streamer() with self._text_lock: - text = self._accumulated_text + text = self._get_accumulated_text() if not text: return Transcript(start=None, end=None, text="") @@ -174,16 +216,17 @@ class VoxtralHFStreamingOnlineProcessor: # real audio and shouldn't affect word timestamp calculations. if self._right_pad_samples > 0: right_pad = np.zeros(self._right_pad_samples, dtype=np.float32) - self._pending_audio = np.append(self._pending_audio, right_pad) + self._pending_chunks.append(right_pad) + self._pending_len += len(right_pad) saved_count = self._n_audio_tokens_fed self._feed_pending_audio() self._n_audio_tokens_fed = saved_count - # Drain in a loop: the model may still be processing right-padding - # chunks after the first drain returns. Keep draining until no new - # text appears for two consecutive rounds. + # Drain in a loop: the model may continue producing text tokens after + # the audio queue is empty (autoregressive generation). Each iteration + # uses an event-driven blocking drain with short timeouts. all_words: List[ASRToken] = [] - for _ in range(5): # at most 5 drain+flush cycles + for _ in range(5): self._drain_streamer_blocking(timeout=5.0) batch = self._flush_all_pending_words() all_words.extend(batch) @@ -208,7 +251,8 @@ class VoxtralHFStreamingOnlineProcessor: # Add right-padding so the model can finish decoding if self._right_pad_samples > 0: right_pad = np.zeros(self._right_pad_samples, dtype=np.float32) - self._pending_audio = np.append(self._pending_audio, right_pad) + self._pending_chunks.append(right_pad) + self._pending_len += len(right_pad) # Feed remaining audio if self._generate_started and not self._generate_finished: @@ -218,7 +262,7 @@ class VoxtralHFStreamingOnlineProcessor: # Wait for generate to finish if self._generate_thread is not None: self._generate_thread.join(timeout=30.0) - elif not self._generate_started and len(self._pending_audio) >= self._first_chunk_samples: + elif not self._generate_started and self._pending_len >= self._first_chunk_samples: # Never started but have enough audio — start and immediately finish self._start_generate_thread() self._feed_pending_audio() @@ -242,8 +286,9 @@ class VoxtralHFStreamingOnlineProcessor: model = self.asr.model # Extract first chunk - first_chunk_audio = self._pending_audio[:self._first_chunk_samples] - self._pending_audio = self._pending_audio[self._first_chunk_samples:] + pending = self._get_pending_audio() + first_chunk_audio = pending[:self._first_chunk_samples] + self._set_pending_audio(pending[self._first_chunk_samples:]) # First chunk covers multiple audio tokens self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step) @@ -265,11 +310,14 @@ class VoxtralHFStreamingOnlineProcessor: audio_queue = self._audio_queue def input_features_gen(): + # Track audio consumption inside the generator (runs in generate thread) + self._n_audio_tokens_consumed = max(1, self._first_chunk_samples // self._chunk_step) yield first_inputs.input_features while True: chunk_audio = audio_queue.get() if chunk_audio is None: break + self._n_audio_tokens_consumed += 1 inputs = processor( chunk_audio, is_streaming=True, @@ -298,6 +346,7 @@ class VoxtralHFStreamingOnlineProcessor: self._generate_error = e finally: self._generate_finished = True + self._generate_done.set() self._generate_thread = threading.Thread(target=run_generate, daemon=True) self._generate_thread.start() @@ -309,13 +358,22 @@ class VoxtralHFStreamingOnlineProcessor: chunk_size = self._chunk_samples step_size = self._chunk_step - while len(self._pending_audio) >= chunk_size: - chunk = self._pending_audio[:chunk_size] + pending = self._get_pending_audio() + while len(pending) >= chunk_size: + chunk = pending[:chunk_size] self._audio_queue.put(chunk) - self._pending_audio = self._pending_audio[step_size:] + pending = pending[step_size:] self._n_audio_tokens_fed += 1 - self.audio_buffer = self._pending_audio + self._set_pending_audio(pending) + self.audio_buffer = pending + + def _append_text_fragment(self, text_fragment: str): + """Append a text fragment with its audio position (must hold _text_lock).""" + self._fragment_positions.append((self._text_len, self._n_audio_tokens_consumed)) + self._text_fragments.append(text_fragment) + self._text_len += len(text_fragment) + self._n_text_tokens_received += 1 def _drain_streamer(self): """Non-blocking drain of all available text from the streamer.""" @@ -333,19 +391,13 @@ class VoxtralHFStreamingOnlineProcessor: break if text_fragment: with self._text_lock: - self._accumulated_text += text_fragment - self._n_text_tokens_received += 1 + self._append_text_fragment(text_fragment) def _drain_streamer_blocking(self, timeout=30.0): - """Blocking drain: wait for the generate thread to process all queued - audio and produce the corresponding text. + """Blocking drain: wait for the generate thread to finish producing text. - Polls the text queue while the audio queue has items (model still - processing). Once the audio queue is empty, waits for trailing - tokens, then returns. - - This is critical for start_silence(): without it, the non-blocking - drain races with the generate thread and the last words get stuck. + Uses the _generate_done event to know when the model is truly finished. + Falls back to text-queue polling with adaptive timeouts. """ if not self._generate_started or self._generate_finished: self._drain_streamer() @@ -353,52 +405,101 @@ class VoxtralHFStreamingOnlineProcessor: text_queue = self._streamer.text_queue deadline = time.time() + timeout + # Count consecutive empty polls to detect when model has caught up + empty_streak = 0 while time.time() < deadline: - # Short poll while model is still processing queued audio; - # longer wait once the audio queue is empty (trailing tokens). - wait = 2.0 if self._audio_queue.empty() else 0.1 + remaining = max(deadline - time.time(), 0.01) + + # If generate thread is done, do a final flush and exit + if self._generate_done.is_set() or self._generate_finished: + self._drain_streamer() + return + + # Adaptive wait: short while audio is queued, longer once queue is empty + if self._audio_queue.empty(): + wait = min(remaining, 0.5) + else: + wait = min(remaining, 0.1) + try: text_fragment = text_queue.get(timeout=wait) except queue.Empty: - if self._audio_queue.empty(): - break # Audio done + no text for 2s → fully caught up - continue # Audio still queued, model still working + empty_streak += 1 + # Only exit if audio queue is empty AND we've had enough empty polls + # This prevents premature exit when the model is slow + if self._audio_queue.empty() and empty_streak >= 4: + break + continue + + empty_streak = 0 if text_fragment is None: self._generate_finished = True break if text_fragment: with self._text_lock: - self._accumulated_text += text_fragment - self._n_text_tokens_received += 1 + self._append_text_fragment(text_fragment) # ── Word extraction ── def _pos_to_time(self, token_position: int) -> float: - """Convert token position to seconds.""" + """Convert audio token position to seconds.""" return token_position * self._seconds_per_token + self._global_time_offset + def _audio_pos_for_char(self, char_idx: int) -> int: + """Look up the audio token position for a character index in the text. + + Uses the fragment position index recorded when text arrives from the + generate thread. Returns the audio position of the fragment that + contains ``char_idx``, giving much better word timestamps than the + old uniform-distribution heuristic. + """ + if not self._fragment_positions: + return 0 + # _fragment_positions is sorted by char_offset — find the last entry + # whose char_offset <= char_idx (the fragment containing this char). + pos = 0 + for offset, audio_tok in self._fragment_positions: + if offset > char_idx: + break + pos = audio_tok + return pos + + def _word_timestamps(self, text: str, words: List[str], start_idx: int, end_idx: int) -> List[Tuple[int, int]]: + """Compute (tok_start, tok_end) for words[start_idx:end_idx] using fragment positions.""" + # Build char offsets for each word + result = [] + char_pos = 0 + for i, word in enumerate(words): + if i > 0: + char_pos += 1 # space separator + if start_idx <= i < end_idx: + tok_start = self._audio_pos_for_char(char_pos) + tok_end = self._audio_pos_for_char(char_pos + len(word)) + result.append((tok_start, tok_end)) + char_pos += len(word) + return result + def _extract_new_words(self) -> List[ASRToken]: """Extract complete words (all but the last, which may still be growing).""" with self._text_lock: - text = self._accumulated_text + text = self._get_accumulated_text() if not text: return [] words = text.split() new_words: List[ASRToken] = [] - n_words_total = len(words) - n_audio_toks = max(self._n_audio_tokens_fed, 1) + n_to_commit = len(words) - 1 # keep last word (may still grow) - while len(words) > self._n_committed_words + 1: + if n_to_commit <= self._n_committed_words: + return [] + + timestamps = self._word_timestamps(text, words, self._n_committed_words, n_to_commit) + + for tok_start, tok_end in timestamps: word = words[self._n_committed_words] - word_idx = self._n_committed_words - - tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0 - tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0 - start_time = self._pos_to_time(tok_start) - end_time = self._pos_to_time(tok_end) + end_time = self._pos_to_time(max(tok_end, tok_start + 1)) text_out = word if self._n_committed_words == 0 else " " + word new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) @@ -409,24 +510,22 @@ class VoxtralHFStreamingOnlineProcessor: def _flush_all_pending_words(self) -> List[ASRToken]: """Flush ALL words including the last partial one.""" with self._text_lock: - text = self._accumulated_text + text = self._get_accumulated_text() if not text: return [] words = text.split() new_words: List[ASRToken] = [] - n_words_total = max(len(words), 1) - n_audio_toks = max(self._n_audio_tokens_fed, 1) - while self._n_committed_words < len(words): + if self._n_committed_words >= len(words): + return [] + + timestamps = self._word_timestamps(text, words, self._n_committed_words, len(words)) + + for tok_start, tok_end in timestamps: word = words[self._n_committed_words] - word_idx = self._n_committed_words - - tok_start = int(word_idx / n_words_total * n_audio_toks) - tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) - start_time = self._pos_to_time(tok_start) - end_time = self._pos_to_time(tok_end) + end_time = self._pos_to_time(max(tok_end, tok_start + 1)) text_out = word if self._n_committed_words == 0 else " " + word new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) @@ -439,7 +538,7 @@ class VoxtralHFStreamingOnlineProcessor: def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]: # Start generate thread when enough audio is buffered if not self._generate_started: - if len(self._pending_audio) >= self._first_chunk_samples: + if self._pending_len >= self._first_chunk_samples: self._start_generate_thread() self._feed_pending_audio() else: @@ -450,7 +549,7 @@ class VoxtralHFStreamingOnlineProcessor: self._feed_pending_audio() # If generate finished unexpectedly (EOS) but new audio arrived, restart - if self._generate_finished and len(self._pending_audio) >= self._first_chunk_samples: + if self._generate_finished and self._pending_len >= self._first_chunk_samples: self._drain_streamer() flush_words = self._flush_all_pending_words() # Reset for new utterance diff --git a/whisperlivekit/voxtral_mlx/spectrogram.py b/whisperlivekit/voxtral_mlx/spectrogram.py index 0fdf463..0647aef 100644 --- a/whisperlivekit/voxtral_mlx/spectrogram.py +++ b/whisperlivekit/voxtral_mlx/spectrogram.py @@ -91,20 +91,33 @@ def _mel_filters() -> mx.array: # --------------------------------------------------------------------------- -# DFT helpers +# DFT helpers (cached — these are constant for a given WINDOW_SIZE) # --------------------------------------------------------------------------- +_CACHED_WINDOW: mx.array | None = None +_CACHED_DFT_RE: mx.array | None = None +_CACHED_DFT_IM: mx.array | None = None + + def _hann_window() -> mx.array: - return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32)) + global _CACHED_WINDOW + if _CACHED_WINDOW is None: + _CACHED_WINDOW = mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32)) + return _CACHED_WINDOW def _dft_matrices(): - """Pre-compute the real / imaginary DFT basis matrices.""" - n_bins = WINDOW_SIZE // 2 + 1 - k = mx.arange(n_bins, dtype=mx.float32)[:, None] - n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :] - phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE - return mx.cos(phase), mx.sin(phase) + """Return cached real / imaginary DFT basis matrices.""" + global _CACHED_DFT_RE, _CACHED_DFT_IM + if _CACHED_DFT_RE is None: + n_bins = WINDOW_SIZE // 2 + 1 + k = mx.arange(n_bins, dtype=mx.float32)[:, None] + n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :] + phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE + _CACHED_DFT_RE = mx.cos(phase) + _CACHED_DFT_IM = mx.sin(phase) + mx.eval(_CACHED_DFT_RE, _CACHED_DFT_IM) + return _CACHED_DFT_RE, _CACHED_DFT_IM def _stft_frames(audio: mx.array, window: mx.array) -> mx.array: diff --git a/whisperlivekit/voxtral_mlx_asr.py b/whisperlivekit/voxtral_mlx_asr.py index f666c0f..5c35bea 100644 --- a/whisperlivekit/voxtral_mlx_asr.py +++ b/whisperlivekit/voxtral_mlx_asr.py @@ -135,8 +135,9 @@ class VoxtralMLXOnlineProcessor: def _reset_state(self): """Reset all incremental state for a fresh utterance.""" - # Audio accumulation - self._pending = np.zeros(0, dtype=np.float32) + # Audio accumulation (list of chunks, concatenated on demand) + self._pending_chunks: list[np.ndarray] = [] + self._pending_len = 0 # Mel overlap self._mel_overlap: np.ndarray | None = None # Encoder incremental state @@ -167,10 +168,30 @@ class VoxtralMLXOnlineProcessor: # -- audio ingestion -- + def _get_pending(self) -> np.ndarray: + """Flatten pending chunks into a single array.""" + if not self._pending_chunks: + return np.zeros(0, dtype=np.float32) + if len(self._pending_chunks) == 1: + return self._pending_chunks[0] + flat = np.concatenate(self._pending_chunks) + self._pending_chunks = [flat] + return flat + + def _set_pending(self, arr: np.ndarray): + """Replace pending audio with a single array.""" + if len(arr) == 0: + self._pending_chunks = [] + self._pending_len = 0 + else: + self._pending_chunks = [arr] + self._pending_len = len(arr) + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): self.end = audio_stream_end_time - self._pending = np.append(self._pending, audio) - self.audio_buffer = self._pending + self._pending_chunks.append(audio) + self._pending_len += len(audio) + self.audio_buffer = audio # diagnostic only # -- core processing -- @@ -231,22 +252,24 @@ class VoxtralMLXOnlineProcessor: def _encode_pending(self): """Feed pending audio through the incremental encoder.""" - available = len(self._pending) - if available < SAMPLES_PER_TOKEN: + if self._pending_len < SAMPLES_PER_TOKEN: return + pending = self._get_pending() + available = len(pending) + if self._first_chunk: # First chunk: prepend silence for left-padding n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32) - chunk = np.concatenate([left_pad, self._pending[:n_take]]) - self._pending = self._pending[n_take:] + chunk = np.concatenate([left_pad, pending[:n_take]]) + self._set_pending(pending[n_take:]) self._samples_encoded += n_take self._first_chunk = False else: n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN - chunk = self._pending[:n_take] - self._pending = self._pending[n_take:] + chunk = pending[:n_take] + self._set_pending(pending[n_take:]) self._samples_encoded += n_take mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap) @@ -261,11 +284,10 @@ class VoxtralMLXOnlineProcessor: mx.eval(embeds) if self._audio_embeds is not None: self._audio_embeds = mx.concatenate([self._audio_embeds, embeds]) + mx.eval(self._audio_embeds) else: self._audio_embeds = embeds - self.audio_buffer = self._pending - def _do_prefill(self): """Run the decoder prefill pass over the prompt + first audio embeddings.""" n_dec_layers = len(self._model.decoder.blocks) @@ -430,6 +452,55 @@ class VoxtralMLXOnlineProcessor: return Transcript(start=None, end=None, text="") def start_silence(self) -> Tuple[List[ASRToken], float]: + """Flush all pending words when silence starts. + + Adds right-padding silence and forces a full decode pass so the + decoder emits tokens for the last words of speech. Without this, + the model holds back the final tokens waiting for future context. + """ + # Align pending audio to SAMPLES_PER_TOKEN boundary + remainder = self._pending_len % SAMPLES_PER_TOKEN + align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0 + + # Add alignment + right-padding silence to provide future context + total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN + if total_pad > 0: + self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32)) + self._pending_len += total_pad + + # Encode remaining audio (including right-padding) + self._encode_pending() + + # Decode everything that's left + if self._audio_embeds is not None and self._prefilled: + self._decode_positions(self._audio_embeds.shape[0]) + + # Flush last token if it wasn't EOS + if self._last_token is not None: + tid = self._last_token.item() + if tid != self._eos_id: + text = self._tokenizer.decode( + [tid], special_token_policy=SpecialTokenPolicy.IGNORE + ) + if text: + last_pos = self._positions_decoded - self._prefix_len + if text.lstrip() != text or not self._full_text: + if self._current_word_pos is not None: + self._word_audio_ends.append(last_pos) + self._word_audio_starts.append(last_pos) + self._current_word_pos = last_pos + elif self._current_word_pos is None: + self._word_audio_starts.append(last_pos) + self._current_word_pos = last_pos + self._full_text += text + self._n_text_tokens += 1 + + # Close the last word if still open + if self._current_word_pos is not None: + last_pos = self._positions_decoded - self._prefix_len + self._word_audio_ends.append(last_pos) + self._current_word_pos = None + words = self._flush_all_words() logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words)) return words, self.end @@ -448,7 +519,7 @@ class VoxtralMLXOnlineProcessor: logger.debug( "[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, " "samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'", - len(self._pending), + self._pending_len, self._audio_embeds.shape if self._audio_embeds is not None else None, self._samples_encoded, self._positions_decoded, @@ -457,7 +528,7 @@ class VoxtralMLXOnlineProcessor: ) # Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost - remainder = len(self._pending) % SAMPLES_PER_TOKEN + remainder = self._pending_len % SAMPLES_PER_TOKEN if remainder > 0: align_pad = SAMPLES_PER_TOKEN - remainder else: @@ -466,9 +537,8 @@ class VoxtralMLXOnlineProcessor: # Add alignment + right-padding silence total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN if total_pad > 0: - self._pending = np.append( - self._pending, np.zeros(total_pad, dtype=np.float32) - ) + self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32)) + self._pending_len += total_pad # Encode remaining audio (including right-padding) self._encode_pending() @@ -476,7 +546,7 @@ class VoxtralMLXOnlineProcessor: logger.debug( "[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d", self._audio_embeds.shape if self._audio_embeds is not None else None, - len(self._pending), + self._pending_len, ) hit_eos = False