diff --git a/benchmark_scatter_en_aware.png b/benchmark_scatter_en_aware.png index fdcca2a..c9e8599 100644 Binary files a/benchmark_scatter_en_aware.png and b/benchmark_scatter_en_aware.png differ diff --git a/benchmark_scatter_fr_aware.png b/benchmark_scatter_fr_aware.png index 67106ca..bf3608c 100644 Binary files a/benchmark_scatter_fr_aware.png and b/benchmark_scatter_fr_aware.png differ diff --git a/scripts/run_scatter_benchmark.py b/scripts/run_scatter_benchmark.py index e461f32..4c4bc7c 100644 --- a/scripts/run_scatter_benchmark.py +++ b/scripts/run_scatter_benchmark.py @@ -266,6 +266,8 @@ def generate_scatter(results, system_info, output_path, n_samples, lang="en", "mlx SS small": (-55, -5), "voxtral mlx": (10, -14), "qwen3 0.6B": (10, 8), + "qwen3-mlx 0.6B": (10, -14), + "qwen3-mlx 1.7B": (10, 8), "fw LA large-v3": (8, -5), "fw SS large-v3": (8, 5), } diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d248871..e2851bf 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -51,6 +51,12 @@ try: except (ImportError, Exception): pass +try: + import mlx_qwen3_asr # noqa: F401 + AVAILABLE_BACKENDS.append("qwen3-mlx") +except ImportError: + pass + BACKEND_CONFIG = { "whisper": {"model_size": "tiny", "lan": "en"}, "voxtral-mlx": {"backend": "voxtral-mlx", "lan": "en"}, @@ -61,6 +67,7 @@ BACKEND_CONFIG = { "lan": "en", "custom_alignment_heads": "scripts/alignment_heads_qwen3_asr_1.7B.json", }, + "qwen3-mlx": {"backend": "qwen3-mlx", "lan": "en"}, } # Voxtral backends flush all words at once with proportionally-distributed @@ -70,7 +77,7 @@ BACKEND_CONFIG = { VOXTRAL_BACKENDS = {"voxtral-mlx", "voxtral-hf"} # Backends that use batch-flush and may have non-monotonic timestamps -BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul"} +BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul", "qwen3-mlx"} def backend_kwargs(backend: str) -> dict: diff --git a/whisperlivekit/cli.py b/whisperlivekit/cli.py index 9feb2a4..f878709 100644 --- a/whisperlivekit/cli.py +++ b/whisperlivekit/cli.py @@ -109,6 +109,16 @@ BACKENDS = [ "streaming": "chunk", "devices": ["cuda", "mps", "cpu"], }, + { + "id": "qwen3-mlx", + "name": "Qwen3 MLX", + "module": "mlx_qwen3_asr", + "install": "pip install mlx-qwen3-asr", + "description": "Qwen3-ASR on Apple Silicon (MLX, native streaming)", + "platform": "darwin-arm64", + "streaming": "native", + "devices": ["mlx"], + }, { "id": "openai-api", "name": "OpenAI API", @@ -193,6 +203,9 @@ MODEL_CATALOG = [ # Qwen3 ASR {"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"}, {"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"}, + # Qwen3 MLX (native streaming on Apple Silicon) + {"name": "qwen3-mlx:1.7b", "family": "qwen3-mlx", "params": "1.7B", "disk": "1.8 GB", "languages": 12, "quality": "good", "speed": "fast"}, + {"name": "qwen3-mlx:0.6b", "family": "qwen3-mlx", "params": "0.6B", "disk": "0.7 GB", "languages": 12, "quality": "fair", "speed": "fastest"}, ] @@ -310,6 +323,9 @@ def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool: elif family == "qwen3": size = name.split(":")[1] if ":" in name else "1.7b" return QWEN3_REPOS.get(size, "") in downloaded + elif family == "qwen3-mlx": + size = name.split(":")[1] if ":" in name else "1.7b" + return QWEN3_REPOS.get(size, "") in downloaded return False @@ -324,6 +340,8 @@ def _best_backend_for_model(model_entry: dict) -> str: return "voxtral" elif family == "qwen3": return "qwen3" + elif family == "qwen3-mlx": + return "qwen3-mlx" elif family == "whisper": if is_apple and _module_available("mlx_whisper"): return "mlx-whisper" @@ -383,6 +401,8 @@ def cmd_models(): # Skip platform-incompatible models if name == "voxtral-mlx" and not is_apple_silicon: continue + if m["family"] == "qwen3-mlx" and not is_apple_silicon: + continue is_dl = _model_is_downloaded(m, downloaded) @@ -447,6 +467,18 @@ def _resolve_pull_target(spec: str): targets.append(("voxtral-mlx", VOXTRAL_MLX_REPO, "Voxtral Mini (MLX)")) return targets + # Handle qwen3-mlx (must check before generic qwen3) + if backend_part == "qwen3-mlx" or size_part.startswith("qwen3-mlx"): + qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b" + if qwen_size.startswith("qwen3"): + qwen_size = "1.7b" # default + repo = QWEN3_REPOS.get(qwen_size) + if not repo: + print(f" Unknown Qwen3 size: {qwen_size}. Available: {', '.join(QWEN3_REPOS.keys())}") + return [] + targets.append(("qwen3-mlx", repo, f"Qwen3-ASR MLX {qwen_size}")) + return targets + # Handle qwen3 if backend_part == "qwen3" or size_part.startswith("qwen3"): qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b" @@ -503,7 +535,7 @@ def _resolve_pull_target(spec: str): else: print(f" Unknown model: {spec}") print(f" Available sizes: {', '.join(WHISPER_SIZES)}") - print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b") + print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b, qwen3-mlx:1.7b, qwen3-mlx:0.6b") return [] return targets @@ -986,6 +1018,9 @@ def _resolve_run_spec(spec: str): if spec == "voxtral-mlx": return "voxtral-mlx", None + if spec == "qwen3-mlx": + return "qwen3-mlx", None + if spec in WHISPER_SIZES: return None, spec @@ -1231,6 +1266,12 @@ def _probe_backend_state(processor) -> dict: elif hasattr(transcription, "_mlx_processor"): info["backend_type"] = "voxtral-mlx" + # Qwen3 MLX specifics + elif hasattr(transcription, "_session") and hasattr(transcription, "_state"): + info["backend_type"] = "qwen3-mlx" + info["samples_fed"] = getattr(transcription, "_samples_fed", 0) + info["committed_words"] = getattr(transcription, "_n_committed_words", 0) + # SimulStreaming specifics elif hasattr(transcription, "prev_output"): info["backend_type"] = "simulstreaming" diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index d789e69..6a67c55 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -121,6 +121,11 @@ class TranscriptionEngine: self.tokenizer = None self.asr = VoxtralHFStreamingASR(**transcription_common_params) logger.info("Using Voxtral HF Transformers streaming backend") + elif config.backend == "qwen3-mlx": + from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR + self.tokenizer = None + self.asr = Qwen3MLXASR(**transcription_common_params) + logger.info("Using Qwen3 MLX native backend") elif config.backend == "qwen3-simul": from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR self.tokenizer = None @@ -230,6 +235,9 @@ def online_factory(args, asr, language=None): if backend == "vllm-realtime": from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor return VLLMRealtimeOnlineProcessor(asr) + if backend == "qwen3-mlx": + from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor + return Qwen3MLXOnlineProcessor(asr) if backend == "qwen3-simul": from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor return Qwen3SimulStreamingOnlineProcessor(asr) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index a6f23f6..726ee89 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -147,8 +147,8 @@ def parse_args(): "--backend", type=str, default="auto", - choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"], - help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.", + choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"], + help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.", ) parser.add_argument( "--no-vac", diff --git a/whisperlivekit/qwen3_mlx_asr.py b/whisperlivekit/qwen3_mlx_asr.py new file mode 100644 index 0000000..71f4d11 --- /dev/null +++ b/whisperlivekit/qwen3_mlx_asr.py @@ -0,0 +1,392 @@ +""" +MLX-accelerated Qwen3-ASR backend for WhisperLiveKit. + +Provides ``Qwen3MLXASR`` (model holder) and ``Qwen3MLXOnlineProcessor`` +(batch-based processor) that plug into WhisperLiveKit's audio processing +pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc. + +Uses the ``mlx-qwen3-asr`` package for fast Qwen3 inference on Apple Silicon. +The batch ``session.transcribe()`` API is called on the full accumulated audio +buffer, and LocalAgreement-style diffing (HypothesisBuffer) commits stable +words across consecutive inferences. +""" + +import logging +import sys +import time +from typing import List, Tuple + +import numpy as np + +from whisperlivekit.timed_objects import ASRToken, Transcript + +logger = logging.getLogger(__name__) + +# Whisper language codes -> Qwen3 canonical language names +# (duplicated from qwen3_asr.py to avoid importing torch at module level) +WHISPER_TO_QWEN3_LANGUAGE = { + "zh": "Chinese", "en": "English", "yue": "Cantonese", + "ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish", + "pt": "Portuguese", "id": "Indonesian", "it": "Italian", + "ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese", + "ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay", + "nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish", + "pl": "Polish", "cs": "Czech", "fa": "Persian", + "el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian", +} + +# Model size aliases -> HuggingFace model IDs +QWEN3_MLX_MODEL_MAPPING = { + "base": "Qwen/Qwen3-ASR-0.6B", + "tiny": "Qwen/Qwen3-ASR-0.6B", + "small": "Qwen/Qwen3-ASR-0.6B", + "large": "Qwen/Qwen3-ASR-1.7B", + "medium": "Qwen/Qwen3-ASR-1.7B", + "large-v3": "Qwen/Qwen3-ASR-1.7B", + "qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B", + "qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B", + "qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B", + "qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B", + "1.7b": "Qwen/Qwen3-ASR-1.7B", + "0.6b": "Qwen/Qwen3-ASR-0.6B", +} + + +# --------------------------------------------------------------------------- +# Model holder +# --------------------------------------------------------------------------- + + +class Qwen3MLXASR: + """Lightweight model holder -- loads the mlx-qwen3-asr model once and + keeps it alive for the lifetime of the server.""" + + sep = "" + SAMPLING_RATE = 16_000 + + def __init__(self, logfile=sys.stderr, **kwargs): + import mlx.core as mx + import mlx_qwen3_asr + + self.logfile = logfile + self.transcribe_kargs = {} + + lan = kwargs.get("lan", "auto") + self.original_language = None if lan == "auto" else lan + + # Resolve model ID from size aliases or explicit path + model_path = kwargs.get("model_dir") or kwargs.get("model_path") + if not model_path: + model_size = kwargs.get("model_size", "") + if model_size and ("/" in model_size or model_size.startswith(".")): + model_path = model_size + else: + model_path = QWEN3_MLX_MODEL_MAPPING.get( + (model_size or "base").lower(), "Qwen/Qwen3-ASR-0.6B" + ) + + t0 = time.time() + logger.info("Loading Qwen3 MLX model '%s' ...", model_path) + self.session = mlx_qwen3_asr.Session(model_path, dtype=mx.float16) + logger.info("Qwen3 MLX model loaded in %.2fs", time.time() - t0) + + self.backend_choice = "qwen3-mlx" + self.tokenizer = None + + def transcribe(self, audio): + pass # all work happens in the online processor + + +# --------------------------------------------------------------------------- +# Online processor +# --------------------------------------------------------------------------- + + +class Qwen3MLXOnlineProcessor: + """Batch-based processor that accumulates audio and periodically calls + ``session.transcribe()`` on the full buffer. + + Uses LocalAgreement-style diffing (HypothesisBuffer) to commit stable + words across consecutive inferences, exactly like the PyTorch Qwen3 + backend with ``OnlineASRProcessor``. + + Lifecycle (called by ``AudioProcessor.transcription_processor``): + + insert_audio_chunk(pcm, time) -> process_iter() -> get_buffer() + ... repeat ... + start_silence() / end_silence() + finish() + """ + + SAMPLING_RATE = 16_000 + + def __init__(self, asr: Qwen3MLXASR, logfile=sys.stderr): + self.asr = asr + self.logfile = logfile + self.end = 0.0 + + self._session = asr.session + lan = asr.original_language + self._language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, "English") if lan else None + + # Audio accumulation + self.audio_buffer = np.array([], dtype=np.float32) + self._buffer_time_offset: float = 0.0 # absolute time of audio_buffer[0] + + # Throttle: minimum new audio (in samples) before re-running inference + self._min_new_samples: int = int(1.0 * self.SAMPLING_RATE) # 1 second + self._samples_since_last_inference: int = 0 + + # Buffer trimming — keep buffer short for fast re-transcription. + # The model produces ~0.2x RTF, so 15s buffer = ~3s per call. + self._max_buffer_sec: float = 15.0 + self._trim_sec: float = 10.0 # keep this many seconds after trimming + + # HypothesisBuffer for LocalAgreement diffing + self._committed: List[ASRToken] = [] + self._prev_tokens: List[ASRToken] = [] # previous hypothesis (buffer role) + self._last_committed_time: float = 0.0 + + # Global time tracking + self._global_time_offset: float = 0.0 # extra offset from silences + + # -- audio ingestion -- + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + self.end = audio_stream_end_time + self.audio_buffer = np.append(self.audio_buffer, audio) + self._samples_since_last_inference += len(audio) + + # -- batch transcription -- + + def _transcribe_buffer(self) -> List[ASRToken]: + """Run batch transcription on the full audio buffer and return tokens.""" + if len(self.audio_buffer) < 400: # too short for meaningful transcription + return [] + + t0 = time.time() + try: + result = self._session.transcribe( + self.audio_buffer, + language=self._language, + return_timestamps=True, + ) + except Exception as e: + logger.warning("[qwen3-mlx] transcribe error: %s", e, exc_info=True) + return [] + dur = time.time() - t0 + audio_dur = len(self.audio_buffer) / self.SAMPLING_RATE + logger.debug( + "[qwen3-mlx] transcribed %.1fs audio in %.2fs (%.2fx RTF)", + audio_dur, dur, dur / max(audio_dur, 0.01), + ) + + text = (result.text or "").strip() + if not text: + return [] + + # Build tokens from segments (word-level timestamps) + tokens: List[ASRToken] = [] + if result.segments: + for i, seg in enumerate(result.segments): + word = seg["text"] + start = self._buffer_time_offset + seg["start"] + end = self._buffer_time_offset + seg["end"] + label = word if i == 0 else " " + word + tokens.append(ASRToken(start=start, end=end, text=label)) + else: + # Fallback: estimate timestamps from word count + words = text.split() + step = audio_dur / max(len(words), 1) + for i, w in enumerate(words): + t_start = self._buffer_time_offset + i * step + t_end = self._buffer_time_offset + (i + 1) * step + label = w if i == 0 else " " + w + tokens.append(ASRToken(start=t_start, end=t_end, text=label)) + + return tokens + + def _local_agreement(self, new_tokens: List[ASRToken]) -> List[ASRToken]: + """LocalAgreement diffing: commit the longest common prefix between + the previous hypothesis (``self._prev_tokens``) and the new tokens. + + Before comparing, strips tokens that correspond to already-committed + audio (i.e., tokens whose start time is before ``_last_committed_time``). + Also deduplicates boundary tokens (ngram matching) to avoid re-committing + the tail of the previous committed output. + + Returns the newly committed tokens. + """ + # Step 1: Only keep tokens that are roughly "new" (after last committed time) + fresh_tokens = [ + t for t in new_tokens + if t.start > self._last_committed_time - 0.1 + ] + + # Step 2: Remove duplicates at the boundary with committed tokens + # (like HypothesisBuffer.insert's ngram dedup) + if fresh_tokens and self._committed: + max_ngram = min(len(self._committed), len(fresh_tokens), 5) + for n in range(1, max_ngram + 1): + committed_ngram = " ".join( + t.text.strip() for t in self._committed[-n:] + ) + fresh_ngram = " ".join( + t.text.strip() for t in fresh_tokens[:n] + ) + if committed_ngram == fresh_ngram: + fresh_tokens = fresh_tokens[n:] + break + + # Step 3: LocalAgreement -- longest common prefix between prev and fresh + committed: List[ASRToken] = [] + prev = self._prev_tokens + i = 0 + j = 0 + + while i < len(fresh_tokens) and j < len(prev): + if fresh_tokens[i].text.strip() == prev[j].text.strip(): + # Agreement: commit this token (use the new token's timestamps) + committed.append(fresh_tokens[i]) + i += 1 + j += 1 + else: + break + + # The remaining fresh tokens become the new "previous hypothesis" + self._prev_tokens = fresh_tokens[i:] if i < len(fresh_tokens) else [] + return committed + + def _trim_buffer_if_needed(self): + """Trim the audio buffer if it exceeds max_buffer_sec. + + Keeps the last ``_trim_sec`` seconds of audio. Also adjusts + committed token tracking and buffer_time_offset. + """ + buffer_dur = len(self.audio_buffer) / self.SAMPLING_RATE + if buffer_dur <= self._max_buffer_sec: + return + + keep_sec = self._trim_sec + keep_samples = int(keep_sec * self.SAMPLING_RATE) + cut_samples = len(self.audio_buffer) - keep_samples + if cut_samples <= 0: + return + + cut_sec = cut_samples / self.SAMPLING_RATE + self.audio_buffer = self.audio_buffer[cut_samples:] + self._buffer_time_offset += cut_sec + + # Remove committed tokens that are before the new buffer start + self._committed = [ + t for t in self._committed if t.end > self._buffer_time_offset + ] + + logger.debug( + "[qwen3-mlx] trimmed buffer: cut %.1fs, new offset %.1f, buffer %.1fs", + cut_sec, self._buffer_time_offset, len(self.audio_buffer) / self.SAMPLING_RATE, + ) + + # -- interface methods -- + + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + """Process the current audio buffer. + + Throttles inference to at least 1s of new audio between calls. + Returns (newly_committed_tokens, audio_processed_upto_time). + """ + try: + # Throttle: skip if not enough new audio since last inference + if (not is_last + and self._samples_since_last_inference < self._min_new_samples): + return [], self.end + + self._samples_since_last_inference = 0 + + # Trim buffer if too long + self._trim_buffer_if_needed() + + # Run batch transcription + new_tokens = self._transcribe_buffer() + + # LocalAgreement diffing + committed = self._local_agreement(new_tokens) + + if committed: + self._committed.extend(committed) + self._last_committed_time = committed[-1].end + + return committed, self.end + except Exception as e: + logger.warning("[qwen3-mlx] process_iter error: %s", e, exc_info=True) + return [], self.end + + def get_buffer(self) -> Transcript: + """Return the unconfirmed text (the tail of the last hypothesis + that was not committed by LocalAgreement).""" + if not self._prev_tokens: + return Transcript(start=None, end=None, text="") + + text = "".join(t.text for t in self._prev_tokens) + start = self._prev_tokens[0].start + end = self._prev_tokens[-1].end + return Transcript(start=start, end=end, text=text) + + def _flush_all(self) -> List[ASRToken]: + """Force a final transcription and commit all remaining words.""" + # Run one last transcription on the full buffer + self._samples_since_last_inference = self._min_new_samples # bypass throttle + new_tokens = self._transcribe_buffer() + + # Commit everything: first the agreed prefix, then the remainder + committed = self._local_agreement(new_tokens) + + # Also commit any remaining buffer tokens + remaining = self._prev_tokens + self._prev_tokens = [] + + all_new = committed + remaining + if all_new: + self._committed.extend(all_new) + self._last_committed_time = all_new[-1].end + + return all_new + + def _reset_for_new_utterance(self): + """Reset buffers for a new utterance, preserving time continuity.""" + new_offset = self._buffer_time_offset + len(self.audio_buffer) / self.SAMPLING_RATE + saved_end = self.end + + self.audio_buffer = np.array([], dtype=np.float32) + self._buffer_time_offset = new_offset + self._samples_since_last_inference = 0 + self._committed = [] + self._prev_tokens = [] + + self.end = saved_end + + def start_silence(self) -> Tuple[List[ASRToken], float]: + """Flush pending words when silence starts. + + Unlike other backends, does NOT reset the audio buffer — the model + produces better results re-transcribing the full accumulated audio. + Buffer trimming at 30s handles memory naturally. + """ + words = self._flush_all() + logger.info("[qwen3-mlx] start_silence: flushed %d words", len(words)) + return words, self.end + + def end_silence(self, silence_duration: float, offset: float): + self._global_time_offset += silence_duration + self.end += silence_duration + + def new_speaker(self, change_speaker): + self.start_silence() + + def warmup(self, audio, init_prompt=""): + pass + + def finish(self) -> Tuple[List[ASRToken], float]: + words = self._flush_all() + logger.info("[qwen3-mlx] finish: flushed %d words", len(words)) + return words, self.end