""" MLX-accelerated Qwen3-ASR backend for WhisperLiveKit. Provides ``Qwen3MLXASR`` (model holder) and ``Qwen3MLXOnlineProcessor`` (batch-based processor) that plug into WhisperLiveKit's audio processing pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc. Uses the ``mlx-qwen3-asr`` package for fast Qwen3 inference on Apple Silicon. The batch ``session.transcribe()`` API is called on the full accumulated audio buffer, and LocalAgreement-style diffing (HypothesisBuffer) commits stable words across consecutive inferences. """ import logging import sys import time from typing import List, Tuple import numpy as np from whisperlivekit.timed_objects import ASRToken, Transcript logger = logging.getLogger(__name__) # Whisper language codes -> Qwen3 canonical language names # (duplicated from qwen3_asr.py to avoid importing torch at module level) WHISPER_TO_QWEN3_LANGUAGE = { "zh": "Chinese", "en": "English", "yue": "Cantonese", "ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish", "pt": "Portuguese", "id": "Indonesian", "it": "Italian", "ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese", "ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay", "nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish", "pl": "Polish", "cs": "Czech", "fa": "Persian", "el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian", } # Model size aliases -> HuggingFace model IDs QWEN3_MLX_MODEL_MAPPING = { "base": "Qwen/Qwen3-ASR-0.6B", "tiny": "Qwen/Qwen3-ASR-0.6B", "small": "Qwen/Qwen3-ASR-0.6B", "large": "Qwen/Qwen3-ASR-1.7B", "medium": "Qwen/Qwen3-ASR-1.7B", "large-v3": "Qwen/Qwen3-ASR-1.7B", "qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B", "qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B", "qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B", "qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B", "1.7b": "Qwen/Qwen3-ASR-1.7B", "0.6b": "Qwen/Qwen3-ASR-0.6B", } # --------------------------------------------------------------------------- # Model holder # --------------------------------------------------------------------------- class Qwen3MLXASR: """Lightweight model holder -- loads the mlx-qwen3-asr model once and keeps it alive for the lifetime of the server.""" sep = "" SAMPLING_RATE = 16_000 def __init__(self, logfile=sys.stderr, **kwargs): import mlx.core as mx import mlx_qwen3_asr self.logfile = logfile self.transcribe_kargs = {} lan = kwargs.get("lan", "auto") self.original_language = None if lan == "auto" else lan # Resolve model ID from size aliases or explicit path model_path = kwargs.get("model_dir") or kwargs.get("model_path") if not model_path: model_size = kwargs.get("model_size", "") if model_size and ("/" in model_size or model_size.startswith(".")): model_path = model_size else: model_path = QWEN3_MLX_MODEL_MAPPING.get( (model_size or "base").lower(), "Qwen/Qwen3-ASR-0.6B" ) t0 = time.time() logger.info("Loading Qwen3 MLX model '%s' ...", model_path) self.session = mlx_qwen3_asr.Session(model_path, dtype=mx.float16) logger.info("Qwen3 MLX model loaded in %.2fs", time.time() - t0) self.backend_choice = "qwen3-mlx" self.tokenizer = None def transcribe(self, audio): pass # all work happens in the online processor # --------------------------------------------------------------------------- # Online processor # --------------------------------------------------------------------------- class Qwen3MLXOnlineProcessor: """Batch-based processor that accumulates audio and periodically calls ``session.transcribe()`` on the full buffer. Uses LocalAgreement-style diffing (HypothesisBuffer) to commit stable words across consecutive inferences, exactly like the PyTorch Qwen3 backend with ``OnlineASRProcessor``. Lifecycle (called by ``AudioProcessor.transcription_processor``): insert_audio_chunk(pcm, time) -> process_iter() -> get_buffer() ... repeat ... start_silence() / end_silence() finish() """ SAMPLING_RATE = 16_000 def __init__(self, asr: Qwen3MLXASR, logfile=sys.stderr): self.asr = asr self.logfile = logfile self.end = 0.0 self._session = asr.session lan = asr.original_language self._language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, "English") if lan else None # Audio accumulation self.audio_buffer = np.array([], dtype=np.float32) self._buffer_time_offset: float = 0.0 # absolute time of audio_buffer[0] # Throttle: minimum new audio (in samples) before re-running inference self._min_new_samples: int = int(1.0 * self.SAMPLING_RATE) # 1 second self._samples_since_last_inference: int = 0 # Buffer trimming — keep buffer short for fast re-transcription. # The model produces ~0.2x RTF, so 15s buffer = ~3s per call. self._max_buffer_sec: float = 15.0 self._trim_sec: float = 10.0 # keep this many seconds after trimming # HypothesisBuffer for LocalAgreement diffing self._committed: List[ASRToken] = [] self._prev_tokens: List[ASRToken] = [] # previous hypothesis (buffer role) self._last_committed_time: float = 0.0 # Global time tracking self._global_time_offset: float = 0.0 # extra offset from silences # -- audio ingestion -- def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): self.end = audio_stream_end_time self.audio_buffer = np.append(self.audio_buffer, audio) self._samples_since_last_inference += len(audio) # -- batch transcription -- def _transcribe_buffer(self) -> List[ASRToken]: """Run batch transcription on the full audio buffer and return tokens.""" if len(self.audio_buffer) < 400: # too short for meaningful transcription return [] t0 = time.time() try: result = self._session.transcribe( self.audio_buffer, language=self._language, return_timestamps=True, ) except Exception as e: logger.warning("[qwen3-mlx] transcribe error: %s", e, exc_info=True) return [] dur = time.time() - t0 audio_dur = len(self.audio_buffer) / self.SAMPLING_RATE logger.debug( "[qwen3-mlx] transcribed %.1fs audio in %.2fs (%.2fx RTF)", audio_dur, dur, dur / max(audio_dur, 0.01), ) text = (result.text or "").strip() if not text: return [] # Build tokens from segments (word-level timestamps) tokens: List[ASRToken] = [] if result.segments: for i, seg in enumerate(result.segments): word = seg["text"] start = self._buffer_time_offset + seg["start"] end = self._buffer_time_offset + seg["end"] label = word if i == 0 else " " + word tokens.append(ASRToken(start=start, end=end, text=label)) else: # Fallback: estimate timestamps from word count words = text.split() step = audio_dur / max(len(words), 1) for i, w in enumerate(words): t_start = self._buffer_time_offset + i * step t_end = self._buffer_time_offset + (i + 1) * step label = w if i == 0 else " " + w tokens.append(ASRToken(start=t_start, end=t_end, text=label)) return tokens def _local_agreement(self, new_tokens: List[ASRToken]) -> List[ASRToken]: """LocalAgreement diffing: commit the longest common prefix between the previous hypothesis (``self._prev_tokens``) and the new tokens. Before comparing, strips tokens that correspond to already-committed audio (i.e., tokens whose start time is before ``_last_committed_time``). Also deduplicates boundary tokens (ngram matching) to avoid re-committing the tail of the previous committed output. Returns the newly committed tokens. """ # Step 1: Only keep tokens that are roughly "new" (after last committed time) fresh_tokens = [ t for t in new_tokens if t.start > self._last_committed_time - 0.1 ] # Step 2: Remove duplicates at the boundary with committed tokens # (like HypothesisBuffer.insert's ngram dedup) if fresh_tokens and self._committed: max_ngram = min(len(self._committed), len(fresh_tokens), 5) for n in range(1, max_ngram + 1): committed_ngram = " ".join( t.text.strip() for t in self._committed[-n:] ) fresh_ngram = " ".join( t.text.strip() for t in fresh_tokens[:n] ) if committed_ngram == fresh_ngram: fresh_tokens = fresh_tokens[n:] break # Step 3: LocalAgreement -- longest common prefix between prev and fresh committed: List[ASRToken] = [] prev = self._prev_tokens i = 0 j = 0 while i < len(fresh_tokens) and j < len(prev): if fresh_tokens[i].text.strip() == prev[j].text.strip(): # Agreement: commit this token (use the new token's timestamps) committed.append(fresh_tokens[i]) i += 1 j += 1 else: break # The remaining fresh tokens become the new "previous hypothesis" self._prev_tokens = fresh_tokens[i:] if i < len(fresh_tokens) else [] return committed def _trim_buffer_if_needed(self): """Trim the audio buffer if it exceeds max_buffer_sec. Keeps the last ``_trim_sec`` seconds of audio. Also adjusts committed token tracking and buffer_time_offset. """ buffer_dur = len(self.audio_buffer) / self.SAMPLING_RATE if buffer_dur <= self._max_buffer_sec: return keep_sec = self._trim_sec keep_samples = int(keep_sec * self.SAMPLING_RATE) cut_samples = len(self.audio_buffer) - keep_samples if cut_samples <= 0: return cut_sec = cut_samples / self.SAMPLING_RATE self.audio_buffer = self.audio_buffer[cut_samples:] self._buffer_time_offset += cut_sec # Remove committed tokens that are before the new buffer start self._committed = [ t for t in self._committed if t.end > self._buffer_time_offset ] logger.debug( "[qwen3-mlx] trimmed buffer: cut %.1fs, new offset %.1f, buffer %.1fs", cut_sec, self._buffer_time_offset, len(self.audio_buffer) / self.SAMPLING_RATE, ) # -- interface methods -- def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: """Process the current audio buffer. Throttles inference to at least 1s of new audio between calls. Returns (newly_committed_tokens, audio_processed_upto_time). """ try: # Throttle: skip if not enough new audio since last inference if (not is_last and self._samples_since_last_inference < self._min_new_samples): return [], self.end self._samples_since_last_inference = 0 # Trim buffer if too long self._trim_buffer_if_needed() # Run batch transcription new_tokens = self._transcribe_buffer() # LocalAgreement diffing committed = self._local_agreement(new_tokens) if committed: self._committed.extend(committed) self._last_committed_time = committed[-1].end return committed, self.end except Exception as e: logger.warning("[qwen3-mlx] process_iter error: %s", e, exc_info=True) return [], self.end def get_buffer(self) -> Transcript: """Return the unconfirmed text (the tail of the last hypothesis that was not committed by LocalAgreement).""" if not self._prev_tokens: return Transcript(start=None, end=None, text="") text = "".join(t.text for t in self._prev_tokens) start = self._prev_tokens[0].start end = self._prev_tokens[-1].end return Transcript(start=start, end=end, text=text) def _flush_all(self) -> List[ASRToken]: """Force a final transcription and commit all remaining words.""" # Run one last transcription on the full buffer self._samples_since_last_inference = self._min_new_samples # bypass throttle new_tokens = self._transcribe_buffer() # Commit everything: first the agreed prefix, then the remainder committed = self._local_agreement(new_tokens) # Also commit any remaining buffer tokens remaining = self._prev_tokens self._prev_tokens = [] all_new = committed + remaining if all_new: self._committed.extend(all_new) self._last_committed_time = all_new[-1].end return all_new def _reset_for_new_utterance(self): """Reset buffers for a new utterance, preserving time continuity.""" new_offset = self._buffer_time_offset + len(self.audio_buffer) / self.SAMPLING_RATE saved_end = self.end self.audio_buffer = np.array([], dtype=np.float32) self._buffer_time_offset = new_offset self._samples_since_last_inference = 0 self._committed = [] self._prev_tokens = [] self.end = saved_end def start_silence(self) -> Tuple[List[ASRToken], float]: """Flush pending words when silence starts. Unlike other backends, does NOT reset the audio buffer — the model produces better results re-transcribing the full accumulated audio. Buffer trimming at 30s handles memory naturally. """ words = self._flush_all() logger.info("[qwen3-mlx] start_silence: flushed %d words", len(words)) return words, self.end def end_silence(self, silence_duration: float, offset: float): self._global_time_offset += silence_duration self.end += silence_duration def new_speaker(self, change_speaker): self.start_silence() def warmup(self, audio, init_prompt=""): pass def finish(self) -> Tuple[List[ASRToken], float]: words = self._flush_all() logger.info("[qwen3-mlx] finish: flushed %d words", len(words)) return words, self.end