diff --git a/whisperlivekit/basic_server.py b/whisperlivekit/basic_server.py index 1694e0c..7adeb59 100644 --- a/whisperlivekit/basic_server.py +++ b/whisperlivekit/basic_server.py @@ -14,15 +14,13 @@ logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -args = parse_args() +config = parse_args() transcription_engine = None @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI): global transcription_engine - transcription_engine = TranscriptionEngine( - **vars(args), - ) + transcription_engine = TranscriptionEngine(config=config) yield app = FastAPI(lifespan=lifespan) @@ -63,7 +61,7 @@ async def websocket_endpoint(websocket: WebSocket): logger.info("WebSocket connection opened.") try: - await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)}) + await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input)}) except Exception as e: logger.warning(f"Failed to send config to client: {e}") @@ -103,26 +101,26 @@ def main(): uvicorn_kwargs = { "app": "whisperlivekit.basic_server:app", - "host":args.host, - "port":args.port, + "host": config.host, + "port": config.port, "reload": False, "log_level": "info", "lifespan": "on", } - + ssl_kwargs = {} - if args.ssl_certfile or args.ssl_keyfile: - if not (args.ssl_certfile and args.ssl_keyfile): + if config.ssl_certfile or config.ssl_keyfile: + if not (config.ssl_certfile and config.ssl_keyfile): raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.") ssl_kwargs = { - "ssl_certfile": args.ssl_certfile, - "ssl_keyfile": args.ssl_keyfile + "ssl_certfile": config.ssl_certfile, + "ssl_keyfile": config.ssl_keyfile, } if ssl_kwargs: uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs} - if args.forwarded_allow_ips: - uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips } + if config.forwarded_allow_ips: + uvicorn_kwargs = {**uvicorn_kwargs, "forwarded_allow_ips": config.forwarded_allow_ips} uvicorn.run(**uvicorn_kwargs) diff --git a/whisperlivekit/config.py b/whisperlivekit/config.py new file mode 100644 index 0000000..da42d17 --- /dev/null +++ b/whisperlivekit/config.py @@ -0,0 +1,102 @@ +"""Typed configuration for the WhisperLiveKit pipeline.""" +import logging +from dataclasses import dataclass, field, fields +from typing import Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class WhisperLiveKitConfig: + """Single source of truth for all WhisperLiveKit configuration. + + Replaces the previous dict-based parameter system in TranscriptionEngine. + All fields have defaults matching the prior behaviour. + """ + + # Server / global + host: str = "localhost" + port: int = 8000 + diarization: bool = False + punctuation_split: bool = False + target_language: str = "" + vac: bool = True + vac_chunk_size: float = 0.04 + log_level: str = "DEBUG" + ssl_certfile: Optional[str] = None + ssl_keyfile: Optional[str] = None + forwarded_allow_ips: Optional[str] = None + transcription: bool = True + vad: bool = True + pcm_input: bool = False + disable_punctuation_split: bool = False + diarization_backend: str = "sortformer" + backend_policy: str = "simulstreaming" + backend: str = "auto" + + # Transcription common + warmup_file: Optional[str] = None + min_chunk_size: float = 0.1 + model_size: str = "base" + model_cache_dir: Optional[str] = None + model_dir: Optional[str] = None + model_path: Optional[str] = None + lora_path: Optional[str] = None + lan: str = "auto" + direct_english_translation: bool = False + + # LocalAgreement-specific + buffer_trimming: str = "segment" + confidence_validation: bool = False + buffer_trimming_sec: float = 15.0 + + # SimulStreaming-specific + disable_fast_encoder: bool = False + custom_alignment_heads: Optional[str] = None + frame_threshold: int = 25 + beams: int = 1 + decoder_type: Optional[str] = None + audio_max_len: float = 20.0 + audio_min_len: float = 0.0 + cif_ckpt_path: Optional[str] = None + never_fire: bool = False + init_prompt: Optional[str] = None + static_init_prompt: Optional[str] = None + max_context_tokens: Optional[int] = None + + # Diarization (diart) + segmentation_model: str = "pyannote/segmentation-3.0" + embedding_model: str = "pyannote/embedding" + + # Translation + nllb_backend: str = "transformers" + nllb_size: str = "600M" + + def __post_init__(self): + # .en model suffix forces English + if self.model_size and self.model_size.endswith(".en"): + self.lan = "en" + # Normalize backend_policy aliases + if self.backend_policy == "1": + self.backend_policy = "simulstreaming" + elif self.backend_policy == "2": + self.backend_policy = "localagreement" + + # ------------------------------------------------------------------ + # Factory helpers + # ------------------------------------------------------------------ + + @classmethod + def from_namespace(cls, ns) -> "WhisperLiveKitConfig": + """Create config from an argparse Namespace, ignoring unknown keys.""" + known = {f.name for f in fields(cls)} + return cls(**{k: v for k, v in vars(ns).items() if k in known}) + + @classmethod + def from_kwargs(cls, **kwargs) -> "WhisperLiveKitConfig": + """Create config from keyword arguments; warns on unknown keys.""" + known = {f.name for f in fields(cls)} + unknown = set(kwargs.keys()) - known + if unknown: + logger.warning("Unknown config keys ignored: %s", unknown) + return cls(**{k: v for k, v in kwargs.items() if k in known}) diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 133ccbf..f96fa3c 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -2,19 +2,13 @@ import logging import sys import threading from argparse import Namespace +from dataclasses import asdict +from whisperlivekit.config import WhisperLiveKitConfig from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.whisper_online import backend_factory from whisperlivekit.simul_whisper import SimulStreamingASR - -def update_with_kwargs(_dict, kwargs): - _dict.update({ - k: v for k, v in kwargs.items() if k in _dict - }) - return _dict - - logger = logging.getLogger(__name__) class TranscriptionEngine: @@ -31,69 +25,51 @@ class TranscriptionEngine: cls._instance = super().__new__(cls) return cls._instance - def __init__(self, **kwargs): + def __init__(self, config=None, **kwargs): # Thread-safe initialization check with TranscriptionEngine._lock: if TranscriptionEngine._initialized: return - # Set flag immediately to prevent re-initialization + + try: + self._do_init(config, **kwargs) + except Exception: + # Reset singleton so a retry is possible + with TranscriptionEngine._lock: + TranscriptionEngine._instance = None + TranscriptionEngine._initialized = False + raise + + with TranscriptionEngine._lock: TranscriptionEngine._initialized = True - # Perform initialization outside lock to avoid holding lock during slow operations - global_params = { - "host": "localhost", - "port": 8000, - "diarization": False, - "punctuation_split": False, - "target_language": "", - "vac": True, - "vac_chunk_size": 0.04, - "log_level": "DEBUG", - "ssl_certfile": None, - "ssl_keyfile": None, - "forwarded_allow_ips": None, - "transcription": True, - "vad": True, - "pcm_input": False, - "disable_punctuation_split" : False, - "diarization_backend": "sortformer", - "backend_policy": "simulstreaming", - "backend": "auto", - } - global_params = update_with_kwargs(global_params, kwargs) - - transcription_common_params = { - "warmup_file": None, - "min_chunk_size": 0.1, - "model_size": "base", - "model_cache_dir": None, - "model_dir": None, - "model_path": None, - "lora_path": None, - "lan": "auto", - "direct_english_translation": False, - } - transcription_common_params = update_with_kwargs(transcription_common_params, kwargs) - - if transcription_common_params['model_size'].endswith(".en"): - transcription_common_params["lan"] = "en" + def _do_init(self, config=None, **kwargs): + # Handle negated kwargs from programmatic API if 'no_transcription' in kwargs: - global_params['transcription'] = not global_params['no_transcription'] + kwargs['transcription'] = not kwargs.pop('no_transcription') if 'no_vad' in kwargs: - global_params['vad'] = not kwargs['no_vad'] + kwargs['vad'] = not kwargs.pop('no_vad') if 'no_vac' in kwargs: - global_params['vac'] = not kwargs['no_vac'] + kwargs['vac'] = not kwargs.pop('no_vac') + + if config is None: + if isinstance(kwargs.get('config'), WhisperLiveKitConfig): + config = kwargs.pop('config') + else: + config = WhisperLiveKitConfig.from_kwargs(**kwargs) + self.config = config + + # Backward compat: expose as self.args (Namespace-like) for AudioProcessor etc. + self.args = Namespace(**asdict(config)) - self.args = Namespace(**{**global_params, **transcription_common_params}) - self.asr = None self.tokenizer = None self.diarization = None self.vac_session = None - - if self.args.vac: + + if config.vac: from whisperlivekit.silero_vad_iterator import is_onnx_available - + if is_onnx_available(): from whisperlivekit.silero_vad_iterator import load_onnx_session self.vac_session = load_onnx_session() @@ -102,46 +78,55 @@ class TranscriptionEngine: "onnxruntime not installed. VAC will use JIT model which is loaded per-session. " "For multi-user scenarios, install onnxruntime: pip install onnxruntime" ) - backend_policy = self.args.backend_policy - if self.args.transcription: - if backend_policy == "simulstreaming": + + transcription_common_params = { + "warmup_file": config.warmup_file, + "min_chunk_size": config.min_chunk_size, + "model_size": config.model_size, + "model_cache_dir": config.model_cache_dir, + "model_dir": config.model_dir, + "model_path": config.model_path, + "lora_path": config.lora_path, + "lan": config.lan, + "direct_english_translation": config.direct_english_translation, + } + + if config.transcription: + if config.backend_policy == "simulstreaming": simulstreaming_params = { - "disable_fast_encoder": False, - "custom_alignment_heads": None, - "frame_threshold": 25, - "beams": 1, - "decoder_type": None, - "audio_max_len": 20.0, - "audio_min_len": 0.0, - "cif_ckpt_path": None, - "never_fire": False, - "init_prompt": None, - "static_init_prompt": None, - "max_context_tokens": None, + "disable_fast_encoder": config.disable_fast_encoder, + "custom_alignment_heads": config.custom_alignment_heads, + "frame_threshold": config.frame_threshold, + "beams": config.beams, + "decoder_type": config.decoder_type, + "audio_max_len": config.audio_max_len, + "audio_min_len": config.audio_min_len, + "cif_ckpt_path": config.cif_ckpt_path, + "never_fire": config.never_fire, + "init_prompt": config.init_prompt, + "static_init_prompt": config.static_init_prompt, + "max_context_tokens": config.max_context_tokens, } - simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs) - - self.tokenizer = None + + self.tokenizer = None self.asr = SimulStreamingASR( **transcription_common_params, **simulstreaming_params, - backend=self.args.backend, + backend=config.backend, ) logger.info( "Using SimulStreaming policy with %s backend", getattr(self.asr, "encoder_backend", "whisper"), ) else: - whisperstreaming_params = { - "buffer_trimming": "segment", - "confidence_validation": False, - "buffer_trimming_sec": 15, + "buffer_trimming": config.buffer_trimming, + "confidence_validation": config.confidence_validation, + "buffer_trimming_sec": config.buffer_trimming_sec, } - whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs) - + self.asr = backend_factory( - backend=self.args.backend, + backend=config.backend, **transcription_common_params, **whisperstreaming_params, ) @@ -150,39 +135,32 @@ class TranscriptionEngine: getattr(self.asr, "backend_choice", self.asr.__class__.__name__), ) - if self.args.diarization: - if self.args.diarization_backend == "diart": - from whisperlivekit.diarization.diart_backend import \ - DiartDiarization - diart_params = { - "segmentation_model": "pyannote/segmentation-3.0", - "embedding_model": "pyannote/embedding", - } - diart_params = update_with_kwargs(diart_params, kwargs) + if config.diarization: + if config.diarization_backend == "diart": + from whisperlivekit.diarization.diart_backend import DiartDiarization self.diarization_model = DiartDiarization( - block_duration=self.args.min_chunk_size, - **diart_params + block_duration=config.min_chunk_size, + segmentation_model=config.segmentation_model, + embedding_model=config.embedding_model, ) - elif self.args.diarization_backend == "sortformer": - from whisperlivekit.diarization.sortformer_backend import \ - SortformerDiarization + elif config.diarization_backend == "sortformer": + from whisperlivekit.diarization.sortformer_backend import SortformerDiarization self.diarization_model = SortformerDiarization() - + self.translation_model = None - if self.args.target_language: - if self.args.lan == 'auto' and backend_policy != "simulstreaming": - raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming') + if config.target_language: + if config.lan == 'auto' and config.backend_policy != "simulstreaming": + raise ValueError('Translation cannot be set with language auto when transcription backend is not simulstreaming') else: try: from nllw import load_model - except: - raise Exception('To use translation, you must install nllw: `pip install nllw`') - translation_params = { - "nllb_backend": "transformers", - "nllb_size": "600M" - } - translation_params = update_with_kwargs(translation_params, kwargs) - self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers + except ImportError: + raise ImportError('To use translation, you must install nllw: `pip install nllw`') + self.translation_model = load_model( + [config.lan], + nllb_backend=config.nllb_backend, + nllb_size=config.nllb_size, + ) def online_factory(args, asr): @@ -196,11 +174,12 @@ def online_diarization_factory(args, diarization_backend): if args.diarization_backend == "diart": online = diarization_backend # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended - - if args.diarization_backend == "sortformer": + elif args.diarization_backend == "sortformer": from whisperlivekit.diarization.sortformer_backend import \ SortformerDiarizationOnline online = SortformerDiarizationOnline(shared_model=diarization_backend) + else: + raise ValueError(f"Unknown diarization backend: {args.diarization_backend}") return online diff --git a/whisperlivekit/diarization/diart_backend.py b/whisperlivekit/diarization/diart_backend.py index 9e65afe..ccfdbfc 100644 --- a/whisperlivekit/diarization/diart_backend.py +++ b/whisperlivekit/diarization/diart_backend.py @@ -1,6 +1,5 @@ import asyncio import logging -import re import threading import time from queue import Empty, SimpleQueue @@ -14,14 +13,11 @@ from diart.sources import AudioSource, MicrophoneAudioSource from pyannote.core import Annotation from rx.core import Observer +from whisperlivekit.diarization.utils import extract_number from whisperlivekit.timed_objects import SpeakerSegment logger = logging.getLogger(__name__) -def extract_number(s: str) -> int: - m = re.search(r'\d+', s) - return int(m.group()) if m else None - class DiarizationObserver(Observer): """Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index 474b6dc..2f60a46 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -287,11 +287,7 @@ class SortformerDiarizationOnline: logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav") -def extract_number(s: str) -> int: - """Extract number from speaker string (compatibility function).""" - import re - m = re.search(r'\d+', s) - return int(m.group()) if m else 0 +from whisperlivekit.diarization.utils import extract_number if __name__ == '__main__': diff --git a/whisperlivekit/diarization/utils.py b/whisperlivekit/diarization/utils.py new file mode 100644 index 0000000..1f929b1 --- /dev/null +++ b/whisperlivekit/diarization/utils.py @@ -0,0 +1,7 @@ +import re + + +def extract_number(s: str) -> int: + """Extract the first integer from a string, e.g. 'speaker_2' -> 2.""" + m = re.search(r'\d+', s) + return int(m.group()) if m else 0 diff --git a/whisperlivekit/local_agreement/backends.py b/whisperlivekit/local_agreement/backends.py index 95ab0e8..b669ef2 100644 --- a/whisperlivekit/local_agreement/backends.py +++ b/whisperlivekit/local_agreement/backends.py @@ -26,13 +26,6 @@ class ASRBase: self.original_language = lan self.model = self.load_model(model_size, cache_dir, model_dir) - def with_offset(self, offset: float) -> ASRToken: - # This method is kept for compatibility (typically you will use ASRToken.with_offset) - return ASRToken(self.start + offset, self.end + offset, self.text) - - def __repr__(self): - return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" - def load_model(self, model_size, cache_dir, model_dir): raise NotImplementedError("must be implemented in the child class") @@ -187,22 +180,8 @@ class MLXWhisper(ASRBase): return transcribe def translate_model_name(self, model_name): - model_mapping = { - "tiny.en": "mlx-community/whisper-tiny.en-mlx", - "tiny": "mlx-community/whisper-tiny-mlx", - "base.en": "mlx-community/whisper-base.en-mlx", - "base": "mlx-community/whisper-base-mlx", - "small.en": "mlx-community/whisper-small.en-mlx", - "small": "mlx-community/whisper-small-mlx", - "medium.en": "mlx-community/whisper-medium.en-mlx", - "medium": "mlx-community/whisper-medium-mlx", - "large-v1": "mlx-community/whisper-large-v1-mlx", - "large-v2": "mlx-community/whisper-large-v2-mlx", - "large-v3": "mlx-community/whisper-large-v3-mlx", - "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", - "large": "mlx-community/whisper-large-mlx", - } - mlx_model_path = model_mapping.get(model_name) + from whisperlivekit.model_mapping import MLX_MODEL_MAPPING + mlx_model_path = MLX_MODEL_MAPPING.get(model_name) if mlx_model_path: return mlx_model_path else: @@ -227,7 +206,6 @@ class MLXWhisper(ASRBase): if segment.get("no_speech_prob", 0) > 0.9: continue for word in segment.get("words", []): - probability=word["probability"] token = ASRToken(word["start"], word["end"], word["word"]) tokens.append(token) return tokens @@ -238,6 +216,7 @@ class MLXWhisper(ASRBase): def use_vad(self): self.transcribe_kargs["vad_filter"] = True + class OpenaiApiASR(ASRBase): """Uses OpenAI's Whisper API for transcription.""" def __init__(self, lan=None, temperature=0, logfile=sys.stderr): diff --git a/whisperlivekit/local_agreement/online_asr.py b/whisperlivekit/local_agreement/online_asr.py index e5b8632..28869df 100644 --- a/whisperlivekit/local_agreement/online_asr.py +++ b/whisperlivekit/local_agreement/online_asr.py @@ -136,6 +136,11 @@ class OnlineASRProcessor: f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM." ) + def new_speaker(self, change_speaker): + """Handle speaker change event.""" + self.process_iter() + self.init(offset=change_speaker.start) + def init(self, offset: Optional[float] = None): """Initialize or reset the processing buffers.""" self.audio_buffer = np.array([], dtype=np.float32) diff --git a/whisperlivekit/local_agreement/whisper_online.py b/whisperlivekit/local_agreement/whisper_online.py index 5093196..4ab3062 100644 --- a/whisperlivekit/local_agreement/whisper_online.py +++ b/whisperlivekit/local_agreement/whisper_online.py @@ -1,12 +1,7 @@ #!/usr/bin/env python3 import logging import platform -import sys import time -from functools import lru_cache - -import librosa -import numpy as np from whisperlivekit.backend_support import (faster_backend_available, mlx_backend_available) diff --git a/whisperlivekit/model_mapping.py b/whisperlivekit/model_mapping.py new file mode 100644 index 0000000..fbeaadb --- /dev/null +++ b/whisperlivekit/model_mapping.py @@ -0,0 +1,17 @@ +"""Shared MLX model name mapping used by both SimulStreaming and LocalAgreement backends.""" + +MLX_MODEL_MAPPING = { + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "tiny": "mlx-community/whisper-tiny-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v1": "mlx-community/whisper-large-v1-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "large": "mlx-community/whisper-large-mlx", +} diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 9b5da4d..0f5f394 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -148,7 +148,7 @@ def parse_args(): type=str, default="auto", choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"], - help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.", + help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper).", ) parser.add_argument( "--no-vac", @@ -318,15 +318,12 @@ def parse_args(): ) args = parser.parse_args() - args.transcription = not args.no_transcription - args.vad = not args.no_vad + args.vad = not args.no_vad + args.vac = not args.no_vac delattr(args, 'no_transcription') delattr(args, 'no_vad') + delattr(args, 'no_vac') - if args.backend_policy == "1": - args.backend_policy = "simulstreaming" - elif args.backend_policy == "2": - args.backend_policy = "localagreement" - - return args + from whisperlivekit.config import WhisperLiveKitConfig + return WhisperLiveKitConfig.from_namespace(args) diff --git a/whisperlivekit/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py index 5b69f5f..05d9acd 100644 --- a/whisperlivekit/silero_vad_iterator.py +++ b/whisperlivekit/silero_vad_iterator.py @@ -115,7 +115,7 @@ class OnnxWrapper(): out, state = ort_outs self._state = torch.from_numpy(state) else: - raise ValueError() + raise ValueError(f"Unsupported sampling rate {sr}. Supported: {self.sample_rates} (with sample sizes 256 for 8000, 512 for 16000)") self._context = x[..., -context_size:] self._last_sr = sr @@ -129,7 +129,7 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat """Get the path to the ONNX model file.""" available_ops = [15, 16] if opset_version not in available_ops: - raise Exception(f'Available ONNX opset_version: {available_ops}') + raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}') if model_path is None: current_dir = Path(__file__).parent @@ -255,8 +255,8 @@ class VADIterator: if not torch.is_tensor(x): try: x = torch.Tensor(x) - except: - raise TypeError("Audio cannot be casted to tensor. Cast it manually") + except (ValueError, TypeError, RuntimeError) as exc: + raise TypeError("Audio cannot be cast to tensor. Cast it manually") from exc window_size_samples = len(x[0]) if x.dim() == 2 else len(x) self.current_sample += window_size_samples diff --git a/whisperlivekit/simul_whisper/align_att_base.py b/whisperlivekit/simul_whisper/align_att_base.py new file mode 100644 index 0000000..fd326c7 --- /dev/null +++ b/whisperlivekit/simul_whisper/align_att_base.py @@ -0,0 +1,483 @@ +"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX).""" +import logging +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Tuple + +from whisperlivekit.timed_objects import ASRToken +from whisperlivekit.whisper import DecodingOptions, tokenizer + +from .config import AlignAttConfig + +DEC_PAD = 50257 +logger = logging.getLogger(__name__) + + +class AlignAttBase(ABC): + """ + Abstract base class for AlignAtt streaming decoders. + + Provides shared logic for both PyTorch and MLX implementations: + - Properties (speaker, global_time_offset) + - Pure-Python methods (warmup, trim_context, refresh_segment, etc.) + - Template infer() with abstract hooks for tensor-specific operations + - Post-decode logic (token splitting, timestamped word building) + + Subclasses must implement ~20 abstract methods for tensor-specific ops. + """ + + # === Properties === + + @property + def speaker(self): + return self.state.speaker + + @speaker.setter + def speaker(self, value): + self.state.speaker = value + + @property + def global_time_offset(self): + return self.state.global_time_offset + + @global_time_offset.setter + def global_time_offset(self, value): + self.state.global_time_offset = value + + # === Constructor helpers === + + def _base_init(self, cfg: AlignAttConfig, model): + """Common initialization — call from subclass __init__.""" + self.model = model + self.cfg = cfg + self.decode_options = DecodingOptions( + language=cfg.language, + without_timestamps=True, + task=cfg.task, + ) + self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual + self.max_text_len = model.dims.n_text_ctx + self.num_decoder_layers = len(model.decoder.blocks) + if cfg.max_context_tokens is None: + self.max_context_tokens = self.max_text_len + else: + self.max_context_tokens = cfg.max_context_tokens + + def _init_state_common(self, cfg: AlignAttConfig): + """Common state initialization — call from subclass _init_state.""" + self.create_tokenizer(cfg.language if cfg.language != "auto" else None) + self.state.tokenizer = self.tokenizer + self.state.detected_language = cfg.language if cfg.language != "auto" else None + self.state.global_time_offset = 0.0 + self.state.last_attend_frame = -cfg.rewind_threshold + self.state.speaker = -1 + + # === Shared concrete methods === + + def warmup(self, audio): + try: + self.insert_audio(audio) + self.infer(is_last=True) + self.refresh_segment(complete=True) + logger.info("Model warmed up successfully") + except Exception as e: + logger.exception(f"Model warmup failed: {e}") + + def create_tokenizer(self, language=None): + self.tokenizer = tokenizer.get_tokenizer( + multilingual=self.tokenizer_is_multilingual, + language=language, + num_languages=self.model.num_languages, + task=self.decode_options.task, + ) + self.state.tokenizer = self.tokenizer + + def trim_context(self): + logger.info("Trimming context") + c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids) + logger.info(f"Context text: {self.state.context.as_text()}") + l = sum(t.shape[1] for t in self.state.tokens) + c + after = 0 if self.cfg.static_init_prompt is None else len(self.cfg.static_init_prompt) + while c > self.max_context_tokens or l > self.max_text_len - 20: + t = self.state.context.trim_words(after=after) + l -= t + c -= t + logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") + if t == 0: + break + logger.info(f"Context after trim: {self.state.context.text} (len: {l})") + + def refresh_segment(self, complete=False): + logger.debug("Refreshing segment:") + self.init_tokens() + self.state.last_attend_frame = -self.cfg.rewind_threshold + self.state.cumulative_time_offset = 0.0 + self.init_context() + logger.debug(f"Context: {self.state.context}") + if not complete and len(self.state.segments) > 2: + self.state.segments = self.state.segments[-2:] + else: + logger.debug("removing all segments.") + self.state.segments = [] + self.state.log_segments += 1 + self.state.pending_incomplete_tokens = [] + + def segments_len(self): + return sum(s.shape[0] for s in self.state.segments) / 16000 + + def _apply_minseglen(self): + segments_len = self.segments_len() + if segments_len < self.cfg.audio_min_len: + logger.debug("waiting for next segment") + return False + return True + + def _clean_cache(self): + self.state.clean_cache() + + def debug_print_tokens(self, tokens): + for i in range(min(self.cfg.beam_size, tokens.shape[0])): + logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist())) + + # === Language detection === + + def _detect_language_if_needed(self, encoder_feature): + if ( + self.cfg.language == "auto" + and self.state.detected_language is None + and self.state.first_timestamp + ): + seconds_since_start = self.segments_len() - self.state.first_timestamp + if seconds_since_start >= 2.0: + language_tokens, language_probs = self.lang_id(encoder_feature) + top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) + print(f"Detected language: {top_lan} with p={p:.4f}") + self.create_tokenizer(top_lan) + self.state.last_attend_frame = -self.cfg.rewind_threshold + self.state.cumulative_time_offset = 0.0 + self.init_tokens() + self.init_context() + self.state.detected_language = top_lan + logger.info(f"Tokenizer language: {self.tokenizer.language}") + + # === Template infer() === + + def infer(self, is_last=False): + """Main inference — template method calling abstract hooks for tensor ops.""" + new_segment = True + + if len(self.state.segments) == 0: + logger.debug("No segments, nothing to do") + return [] + if not self._apply_minseglen(): + logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") + return [] + + input_segments = self._concat_segments() + encoder_feature, content_mel_len = self._encode(input_segments) + self._evaluate(encoder_feature) + + self._detect_language_if_needed(encoder_feature) + self.trim_context() + current_tokens = self._current_tokens() + + fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) + + sum_logprobs = self._init_sum_logprobs() + completed = False + token_len_before = current_tokens.shape[1] + l_absolute_timestamps = [] + accumulated_cross_attns = [] + + audio_duration_s = self.segments_len() + max_tokens = max(50, int(audio_duration_s * 15 * 1.5)) + tokens_produced = 0 + most_attended_frame = None + + while not completed and current_tokens.shape[1] < self.max_text_len: + tokens_produced += 1 + if tokens_produced > max_tokens: + logger.warning( + f"[Loop Detection] Too many tokens ({tokens_produced}) " + f"for {audio_duration_s:.2f}s audio. Breaking." + ) + current_tokens = current_tokens[:, :token_len_before] + break + + tokens_for_logits = current_tokens if new_segment else current_tokens[:, -1:] + logits, cross_attns = self._get_logits_and_cross_attn( + tokens_for_logits, encoder_feature + ) + self._evaluate(logits) + + accumulated_cross_attns.append(cross_attns) + if len(accumulated_cross_attns) > 16: + accumulated_cross_attns = accumulated_cross_attns[-16:] + + if new_segment and self._check_no_speech(logits): + break + + logits = logits[:, -1, :] + + if new_segment: + logits = self._suppress_blank_tokens(logits) + new_segment = False + + logits = self._apply_token_suppression(logits) + current_tokens, completed = self._update_tokens( + current_tokens, logits, sum_logprobs + ) + self._evaluate(current_tokens) + + logger.debug(f"Decoding completed: {completed}") + self.debug_print_tokens(current_tokens) + + attn = self._process_cross_attention(accumulated_cross_attns, content_mel_len) + frames_list, most_attended_frame = self._get_attended_frames(attn) + + absolute_timestamps = [ + (frame * 0.02 + self.state.cumulative_time_offset) + for frame in frames_list + ] + l_absolute_timestamps.append(absolute_timestamps[0]) + logger.debug(f"Absolute timestamps: {absolute_timestamps}") + + if completed: + current_tokens = current_tokens[:, :-1] + break + + # Rewind check + if ( + not is_last + and self.state.last_attend_frame - most_attended_frame + > self.cfg.rewind_threshold + ): + if current_tokens.shape[1] > 1 and self._is_special_token(current_tokens): + logger.debug("omit rewinding from special tokens") + self.state.last_attend_frame = most_attended_frame + else: + logger.debug( + f"[rewind detected] current: {most_attended_frame}, " + f"last: {self.state.last_attend_frame}" + ) + self.state.last_attend_frame = -self.cfg.rewind_threshold + current_tokens = self._rewind_tokens() + break + else: + self.state.last_attend_frame = most_attended_frame + + if content_mel_len - most_attended_frame <= ( + 4 if is_last else self.cfg.frame_threshold + ): + logger.debug( + f"attention reaches the end: {most_attended_frame}/{content_mel_len}" + ) + current_tokens = current_tokens[:, :-1] + break + + # Post-decode: split tokens and build timestamped words + tokens_to_split = self._tokens_to_list(current_tokens, token_len_before) + if self.state.pending_incomplete_tokens: + logger.debug( + f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} " + f"pending tokens: {self.state.pending_incomplete_tokens}" + ) + tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split + + new_hypothesis, split_words, split_tokens = self._split_tokens( + tokens_to_split, fire_detected, is_last + ) + + new_tokens_tensor = self._make_new_tokens_tensor(new_hypothesis) + self.state.tokens.append(new_tokens_tensor) + logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") + + self._clean_cache() + + if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None: + self.state.first_timestamp = l_absolute_timestamps[0] + + timestamped_words = self._build_timestamped_words( + split_words, split_tokens, l_absolute_timestamps + ) + self._handle_pending_tokens(split_words, split_tokens) + + return timestamped_words + + # === Post-decode shared helpers === + + def _split_tokens(self, tokens_list, fire_detected, is_last): + """Split token list into words. Returns (hypothesis, split_words, split_tokens).""" + if fire_detected or is_last: + new_hypothesis = tokens_list + split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) + else: + split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_list) + if len(split_words) > 1: + new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist] + else: + new_hypothesis = [] + return new_hypothesis, split_words, split_tokens + + def _build_timestamped_words(self, split_words, split_tokens, l_absolute_timestamps): + """Build list of timestamped ASRToken from split words.""" + timestamped_words = [] + timestamp_idx = 0 + replacement_char = "\ufffd" + + for word, word_tokens in zip(split_words, split_tokens): + if replacement_char in word: + logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}") + timestamp_idx += len(word_tokens) + continue + + try: + current_timestamp = l_absolute_timestamps[timestamp_idx] + except IndexError: + logger.warning( + f"Timestamp index {timestamp_idx} out of range, using last timestamp" + ) + current_timestamp = ( + l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0 + ) + timestamp_idx += len(word_tokens) + + timestamp_entry = ASRToken( + start=round(current_timestamp, 2), + end=round(current_timestamp + 0.1, 2), + text=word, + speaker=self.state.speaker, + detected_language=self.state.detected_language, + ).with_offset(self.state.global_time_offset) + timestamped_words.append(timestamp_entry) + + return timestamped_words + + def _handle_pending_tokens(self, split_words, split_tokens): + """Handle incomplete UTF-8 tokens for next chunk.""" + self.state.pending_incomplete_tokens = [] + MAX_PENDING_TOKENS = 10 + replacement_char = "\ufffd" + if split_words and replacement_char in split_words[-1]: + if len(split_tokens[-1]) <= MAX_PENDING_TOKENS: + self.state.pending_incomplete_tokens = split_tokens[-1] + logger.debug( + f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} " + f"incomplete tokens for next chunk" + ) + else: + logger.warning( + f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens " + f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)" + ) + + # === Abstract methods — subclass must implement === + + @abstractmethod + def _init_state(self, cfg: AlignAttConfig): + """Initialize per-session decoder state.""" + ... + + @abstractmethod + def init_tokens(self): + """Initialize token sequence with framework-specific tensors.""" + ... + + @abstractmethod + def init_context(self): + """Initialize context buffer with framework-specific TokenBuffer.""" + ... + + @abstractmethod + def insert_audio(self, segment=None): + """Insert audio segment into buffer.""" + ... + + @abstractmethod + def _current_tokens(self): + """Build current token tensor for decoding.""" + ... + + @abstractmethod + def fire_at_boundary(self, feature): + """Check if we should fire at word boundary.""" + ... + + @abstractmethod + def lang_id(self, encoder_features): + """Language detection from encoder features. Returns (tokens, probs).""" + ... + + @abstractmethod + def _concat_segments(self): + """Concatenate audio segments into single array/tensor.""" + ... + + @abstractmethod + def _encode(self, input_segments): + """Encode audio. Returns (encoder_feature, content_mel_len).""" + ... + + @abstractmethod + def _init_sum_logprobs(self): + """Create zero sum_logprobs tensor for beam search.""" + ... + + @abstractmethod + def _get_logits_and_cross_attn(self, tokens, encoder_feature): + """Get logits and cross-attention from decoder. Returns (logits, cross_attns).""" + ... + + @abstractmethod + def _check_no_speech(self, logits): + """Check no_speech probability at start of segment. Returns True to break.""" + ... + + @abstractmethod + def _suppress_blank_tokens(self, logits): + """Suppress blank/EOT tokens at segment start. Returns modified logits.""" + ... + + @abstractmethod + def _apply_token_suppression(self, logits): + """Apply general token suppression. Returns modified logits.""" + ... + + @abstractmethod + def _update_tokens(self, current_tokens, logits, sum_logprobs): + """Update tokens via decoder. Returns (current_tokens, completed).""" + ... + + @abstractmethod + def _process_cross_attention(self, accumulated_cross_attns, content_mel_len): + """Process cross-attention for alignment. Returns attention tensor.""" + ... + + @abstractmethod + def _get_attended_frames(self, attn): + """Get most attended frames. Returns (frames_as_python_list, first_frame_int).""" + ... + + @abstractmethod + def _is_special_token(self, current_tokens): + """Check if second-to-last token is a special token (>= DEC_PAD).""" + ... + + @abstractmethod + def _rewind_tokens(self): + """Concatenate state tokens for rewind. Returns token tensor.""" + ... + + @abstractmethod + def _tokens_to_list(self, current_tokens, start_col): + """Extract tokens as Python list from start_col onwards.""" + ... + + @abstractmethod + def _make_new_tokens_tensor(self, hypothesis): + """Create tensor from hypothesis token list, repeated for beam search.""" + ... + + @abstractmethod + def _evaluate(self, tensor): + """Evaluate lazy tensor (mx.eval for MLX, no-op for PyTorch).""" + ... diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 33b1b81..8f84285 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -350,7 +350,7 @@ class SimulStreamingASR: def set_translate_task(self): """Set up translation task.""" if self.cfg.language == 'auto': - raise Exception('Translation cannot be done with language = auto') + raise ValueError('Translation cannot be done with language = auto') return tokenizer.get_tokenizer( multilingual=True, language=self.cfg.language, diff --git a/whisperlivekit/simul_whisper/mlx/simul_whisper.py b/whisperlivekit/simul_whisper/mlx/simul_whisper.py index 50b327c..3211320 100644 --- a/whisperlivekit/simul_whisper/mlx/simul_whisper.py +++ b/whisperlivekit/simul_whisper/mlx/simul_whisper.py @@ -1,9 +1,6 @@ -""" -MLX whisper AlignAtt streaming decoder -""" +"""MLX whisper AlignAtt streaming decoder.""" import logging -from time import time -from typing import Any, List, Optional, Tuple +from typing import Any, List, Tuple import mlx.core as mx import numpy as np @@ -11,19 +8,18 @@ import numpy as np from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim -from whisperlivekit.timed_objects import ASRToken -from whisperlivekit.whisper import DecodingOptions, tokenizer from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND +from ..align_att_base import DEC_PAD, AlignAttBase from ..config import AlignAttConfig from .decoder_state import MLXDecoderState from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference -DEC_PAD = 50257 + logger = logging.getLogger(__name__) -class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class +class MLXTokenBuffer: """Token buffer for MLX-based decoding.""" def __init__(self, text="", tokenizer=None, prefix_token_ids=None): @@ -40,12 +36,10 @@ class MLXTokenBuffer: #should try to make it heritate from classic simul whisper return self.prefix_token_ids + tokenizer.encode(self.text) def as_mlx_array(self) -> mx.array: - """Return tokens as MLX array.""" tok_ids = self.as_token_ids() return mx.array([tok_ids], dtype=mx.int32) def as_mlx_array_beam(self, beam: int) -> mx.array: - """Return tokens as MLX array repeated for beam search.""" t = self.as_mlx_array() return mx.repeat(t, beam, axis=0) @@ -64,10 +58,8 @@ class MLXTokenBuffer: #should try to make it heritate from classic simul whisper return self.text is None or self.text == "" def trim_words(self, num=1, after=0): - """Trim words from the beginning of the context.""" tokenizer = self.tokenizer assert tokenizer is not None, "Tokenizer is not set." - ids = tokenizer.encode(self.text[after:]) words, wids = self.tokenizer.split_to_word_tokens(ids) if not words: @@ -76,14 +68,11 @@ class MLXTokenBuffer: #should try to make it heritate from classic simul whisper return sum(len(wi) for wi in wids[:num]) def append_token_ids(self, token_ids): - """Append token IDs to the buffer, handling incomplete UTF-8.""" tokenizer = self.tokenizer assert tokenizer is not None, "Tokenizer is not set." - all_tokens = self.pending_token_ids + token_ids decoded = tokenizer.decode(all_tokens) replacement_char = "\ufffd" - if replacement_char in decoded: if len(all_tokens) > 1: decoded_partial = tokenizer.decode(all_tokens[:-1]) @@ -100,106 +89,47 @@ class MLXTokenBuffer: #should try to make it heritate from classic simul whisper def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array: - """ - Apply median filter along the last axis. - - Args: - x: Input array of shape (..., T) - filter_width: Width of the median filter (should be odd) - - Returns: - Filtered array of same shape - """ + """Apply median filter along the last axis.""" if filter_width <= 1: return x - pad_width = filter_width // 2 shape = x.shape - left_pad = mx.repeat(x[..., :1], pad_width, axis=-1) right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1) x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1) - - result_shape = list(shape) result = [] - for i in range(shape[-1]): window = x_padded[..., i:i + filter_width] sorted_window = mx.sort(window, axis=-1) median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1] result.append(median_val) - return mx.concatenate(result, axis=-1) -class MLXAlignAtt: +class MLXAlignAtt(AlignAttBase): """ MLX-native Alignment-based Attention decoder for SimulStreaming. - - This class runs entirely on MLX, with no PyTorch dependencies for inference. + + Runs entirely on MLX, with no PyTorch dependencies for inference. """ - @property - def speaker(self): - return self.state.speaker - - @speaker.setter - def speaker(self, value): - self.state.speaker = value - - @property - def global_time_offset(self): - return self.state.global_time_offset - - @global_time_offset.setter - def global_time_offset(self, value): - self.state.global_time_offset = value - def __init__( self, cfg: AlignAttConfig, mlx_model: Any, ) -> None: - """ - Initialize MLX AlignAtt decoder. - - Args: - cfg: AlignAtt configuration - mlx_model: MLX Whisper model (full model, not just encoder) - """ - self.model = mlx_model - self.cfg = cfg - + # Common init (sets self.model, self.cfg, decode_options, etc.) + self._base_init(cfg, mlx_model) logger.info(f"MLX Model dimensions: {self.model.dims}") - - self.decode_options = DecodingOptions( - language=cfg.language, - without_timestamps=True, - task=cfg.task - ) - self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual - - self.max_text_len = self.model.dims.n_text_ctx - self.num_decoder_layers = len(self.model.decoder.blocks) - if self.cfg.max_context_tokens is None: - self.max_context_tokens = self.max_text_len - else: - self.max_context_tokens = self.cfg.max_context_tokens - - # Initialize per-session state + # Per-session state self.state = MLXDecoderState() self._init_state(cfg) def _init_state(self, cfg: AlignAttConfig): - """Initialize the per-session decoder state.""" - self.create_tokenizer(cfg.language if cfg.language != "auto" else None) - self.state.tokenizer = self.tokenizer - self.state.detected_language = cfg.language if cfg.language != "auto" else None - self.state.global_time_offset = 0.0 - self.state.last_attend_frame = -cfg.rewind_threshold - self.state.speaker = -1 + self._init_state_common(cfg) + # CIF: MLX doesn't support CIF checkpoint loading if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path: if cfg.never_fire: self.state.never_fire = True @@ -208,19 +138,20 @@ class MLXAlignAtt: self.state.always_fire = True self.state.never_fire = False else: - logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True") + logger.warning( + "CIF checkpoint provided but MLX CIF not implemented. " + "Using always_fire=True" + ) self.state.always_fire = True self.state.never_fire = cfg.never_fire self._build_alignment_source() + # Suppress tokens suppress_tokens = [ - self.tokenizer.transcribe, - self.tokenizer.translate, - self.tokenizer.sot, - self.tokenizer.sot_prev, - self.tokenizer.sot_lm, - self.tokenizer.no_timestamps, + self.tokenizer.transcribe, self.tokenizer.translate, + self.tokenizer.sot, self.tokenizer.sot_prev, + self.tokenizer.sot_lm, self.tokenizer.no_timestamps, ] + list(self.tokenizer.all_language_tokens) if self.tokenizer.no_speech is not None: suppress_tokens.append(self.tokenizer.no_speech) @@ -230,35 +161,34 @@ class MLXAlignAtt: self.init_tokens() self.init_context() + # Decoder type self.state.decoder_type = cfg.decoder_type if cfg.decoder_type == "greedy": logger.info("Using MLX greedy decoder") self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot) elif cfg.decoder_type == "beam": logger.info("Using MLX beam decoder") - self.state.inference = MLXInference(self.model, self.state.initial_token_length) + self.state.inference = MLXInference( + self.model, self.state.initial_token_length, + ) self.state.token_decoder = MLXBeamSearchDecoder( inference=self.state.inference, eot=self.tokenizer.eot, - beam_size=cfg.beam_size + beam_size=cfg.beam_size, ) def _build_alignment_source(self): """Build alignment source mapping from model's alignment_heads.""" self.state.align_source = {} self.state.num_align_heads = 0 - alignment_heads = self.model.alignment_heads - if alignment_heads is None: logger.warning("No alignment heads found in model") return - if hasattr(alignment_heads, 'tolist'): heads_list = alignment_heads.tolist() else: heads_list = np.array(alignment_heads).tolist() - for layer_rank, head_id in heads_list: layer_rank = int(layer_rank) head_id = int(head_id) @@ -267,31 +197,23 @@ class MLXAlignAtt: self.state.align_source[layer_rank] = heads self.state.num_align_heads += 1 - def warmup(self, audio: np.ndarray): - """Warmup the model with sample audio.""" - try: - self.insert_audio(audio) - self.infer(is_last=True) - self.refresh_segment(complete=True) - logger.info("MLX model warmed up successfully") - except Exception as e: - logger.exception(f"MLX model warmup failed: {e}") + # === Abstract method implementations === - def create_tokenizer(self, language=None): - """Create tokenizer for the given language.""" - self.tokenizer = tokenizer.get_tokenizer( - multilingual=self.tokenizer_is_multilingual, - language=language, - num_languages=self.model.num_languages, - task=self.decode_options.task + def init_tokens(self): + logger.debug(f"init tokens, {len(self.state.segments)}") + self.state.initial_tokens = mx.array( + [self.tokenizer.sot_sequence_including_notimestamps], + dtype=mx.int32, ) - self.state.tokenizer = self.tokenizer + self.state.initial_token_length = self.state.initial_tokens.shape[1] + self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) + logger.debug(f"init tokens after, {len(self.state.segments)}") + self.state.tokens = [self.state.initial_tokens] def init_context(self): - """Initialize context buffer.""" kw = { 'tokenizer': self.tokenizer, - 'prefix_token_ids': [self.tokenizer.sot_prev] + 'prefix_token_ids': [self.tokenizer.sot_prev], } self.state.context = MLXTokenBuffer.empty(**kw) if self.cfg.static_init_prompt is not None: @@ -299,409 +221,138 @@ class MLXAlignAtt: if self.cfg.init_prompt is not None: self.state.context.text += self.cfg.init_prompt - def init_tokens(self): - """Initialize token sequence.""" - logger.debug(f"init tokens, {len(self.state.segments)}") - self.state.initial_tokens = mx.array( - [self.tokenizer.sot_sequence_including_notimestamps], - dtype=mx.int32 - ) - self.state.initial_token_length = self.state.initial_tokens.shape[1] - self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) - logger.debug(f"init tokens after, {len(self.state.segments)}") - self.state.tokens = [self.state.initial_tokens] - - def trim_context(self): - """Trim context if too long.""" - logger.info("Trimming context") - c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids) - logger.info(f"Context text: {self.state.context.as_text()}") - l = sum(t.shape[1] for t in self.state.tokens) + c - if self.cfg.static_init_prompt is None: - after = 0 - else: - after = len(self.cfg.static_init_prompt) - while c > self.max_context_tokens or l > self.max_text_len - 20: - t = self.state.context.trim_words(after=after) - l -= t - c -= t - logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") - if t == 0: - break - logger.info(f"Context after trim: {self.state.context.text} (len: {l})") - - def refresh_segment(self, complete=False): - """Refresh segment state.""" - logger.debug("Refreshing segment:") - self.init_tokens() - self.state.last_attend_frame = -self.cfg.rewind_threshold - self.state.cumulative_time_offset = 0.0 - self.init_context() - logger.debug(f"Context: {self.state.context}") - if not complete and len(self.state.segments) > 2: - self.state.segments = self.state.segments[-2:] - else: - logger.debug("removing all segments.") - self.state.segments = [] - self.state.log_segments += 1 - self.state.pending_incomplete_tokens = [] - - def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool: - """Check if we should fire at word boundary (CIF-based).""" - if self.state.always_fire: - return True - if self.state.never_fire: - return False - return True - - def _current_tokens(self) -> mx.array: - """Get current token sequence for decoding.""" - toks = self.state.tokens - - if toks[0].shape[0] == 1: - toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0) - - if not self.state.context.is_empty(): - context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size) - toks = [context_toks] + toks - - # Concatenate all tokens - if len(toks) > 1: - current_tokens = mx.concatenate(toks, axis=1) - else: - current_tokens = toks[0] - - logger.debug("debug print current_tokens:") - self.debug_print_tokens(current_tokens) - return current_tokens - - def debug_print_tokens(self, tokens: mx.array): - """Debug print token sequences.""" - tokens_np = np.array(tokens) - for i in range(min(self.cfg.beam_size, tokens_np.shape[0])): - logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist())) - - def segments_len(self) -> float: - """Get total length of audio segments in seconds.""" - return sum(s.shape[0] for s in self.state.segments) / 16000 - - def _apply_minseglen(self) -> bool: - """Check if we have enough audio to process.""" - segments_len = self.segments_len() - if segments_len < self.cfg.audio_min_len: - logger.debug("waiting for next segment") - return False - return True - - def insert_audio(self, segment: np.ndarray = None): - """Insert audio segment into buffer.""" + def insert_audio(self, segment=None): if segment is not None: if hasattr(segment, 'numpy'): segment = segment.numpy() self.state.segments.append(segment) - removed_len = 0 segments_len = self.segments_len() - while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len: removed_len = self.state.segments[0].shape[0] / 16000 segments_len -= removed_len self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len) self.state.cumulative_time_offset += removed_len self.state.segments = self.state.segments[1:] - logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s") - + logger.debug( + f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, " + f"cumulative offset: {self.state.cumulative_time_offset:.2f}s" + ) if len(self.state.tokens) > 1: - # Convert MLX array to list for context token_list = np.array(self.state.tokens[1][0, :]).tolist() self.state.context.append_token_ids(token_list) self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:] - return removed_len - def _clean_cache(self): - """Clean the kv_cache after each inference step.""" - self.state.clean_cache() + def _current_tokens(self) -> mx.array: + toks = self.state.tokens + if toks[0].shape[0] == 1: + toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0) + if not self.state.context.is_empty(): + context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size) + toks = [context_toks] + toks + if len(toks) > 1: + current_tokens = mx.concatenate(toks, axis=1) + else: + current_tokens = toks[0] + logger.debug("debug print current_tokens:") + self.debug_print_tokens(current_tokens) + return current_tokens - def _suppress_tokens(self, logits: mx.array) -> mx.array: - """Apply token suppression to logits.""" - if self.state.suppress_tokens: - suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32) - logits = logits.at[:, suppress_indices].add(-float('inf')) - return logits + def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool: + if self.state.always_fire: + return True + if self.state.never_fire: + return False + return True # MLX CIF not implemented def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]: - """Language detection from encoder features.""" n_audio = encoder_features.shape[0] x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32) - logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None) logits = logits[:, 0] - + mask = mx.ones(logits.shape[-1], dtype=mx.bool_) - language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32) + language_token_indices = mx.array( + list(self.tokenizer.all_language_tokens), dtype=mx.int32, + ) mask = mask.at[language_token_indices].add(False) - logits = mx.where(mask, mx.array(-float('inf')), logits) - + language_tokens = mx.argmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1) - probs_np = np.array(language_token_probs) - language_probs = [ { c: float(probs_np[i, j]) - for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes) + for j, c in zip( + self.tokenizer.all_language_tokens, + self.tokenizer.all_language_codes, + ) } for i in range(n_audio) ] - self._clean_cache() return language_tokens, language_probs - def infer(self, is_last: bool = False) -> List[ASRToken]: - """ - Main inference method. - - Args: - is_last: Whether this is the final chunk - - Returns: - List of timestamped ASR tokens - """ - new_segment = True - - if len(self.state.segments) == 0: - logger.debug("No segments, nothing to do") - return [] - - if not self._apply_minseglen(): - logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") - return [] - + def _concat_segments(self): if len(self.state.segments) > 1: - input_segments = np.concatenate(self.state.segments, axis=0) - else: - input_segments = self.state.segments[0] + return np.concatenate(self.state.segments, axis=0) + return self.state.segments[0] - beg_encode = time() - + def _encode(self, input_segments): mlx_mel_padded = mlx_log_mel_spectrogram( audio=input_segments, n_mels=self.model.dims.n_mels, - padding=N_SAMPLES + padding=N_SAMPLES, ) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) encoder_feature = self.model.encoder(mlx_mel[None]) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2) - - mx.eval(encoder_feature) - - end_encode = time() - logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s') + return encoder_feature, content_mel_len - if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp: - seconds_since_start = self.segments_len() - self.state.first_timestamp - if seconds_since_start >= 2.0: - language_tokens, language_probs = self.lang_id(encoder_feature) - top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) - print(f"Detected language: {top_lan} with p={p:.4f}") - self.create_tokenizer(top_lan) - self.state.last_attend_frame = -self.cfg.rewind_threshold - self.state.cumulative_time_offset = 0.0 - self.init_tokens() - self.init_context() - self.state.detected_language = top_lan - logger.info(f"Tokenizer language: {self.tokenizer.language}") + def _init_sum_logprobs(self): + return mx.zeros((self.cfg.beam_size,), dtype=mx.float32) - self.trim_context() - current_tokens = self._current_tokens() - - fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) - - sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32) - completed = False - - attn_of_alignment_heads = None - most_attended_frame = None - - token_len_before_decoding = current_tokens.shape[1] - - l_absolute_timestamps = [] - accumulated_cross_attns = [] - - audio_duration_s = self.segments_len() - # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50) - # is the mel-frame rate and was causing 10-40x over-allocation on repetition loops. - max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5)) - tokens_produced_this_chunk = 0 - - while not completed and current_tokens.shape[1] < self.max_text_len: - tokens_produced_this_chunk += 1 - - if tokens_produced_this_chunk > max_tokens_per_chunk: - logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.") - current_tokens = current_tokens[:, :token_len_before_decoding] - break - - if new_segment: - tokens_for_logits = current_tokens - else: - tokens_for_logits = current_tokens[:, -1:] - - if self.state.decoder_type == "greedy": - logits, self.state.kv_cache, cross_qk = self.model.decoder( - tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache - ) - else: - logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature) - - mx.eval(logits) - - accumulated_cross_attns.append(cross_qk) - if len(accumulated_cross_attns) > 16: - accumulated_cross_attns = accumulated_cross_attns[-16:] - - if new_segment and self.tokenizer.no_speech is not None: - probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1) - no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist() - if no_speech_probs[0] > self.cfg.nonspeech_prob: - logger.info("no speech, stop") - break - - logits = logits[:, -1, :] # Last token logits - - # Suppress tokens at segment start - if new_segment: - blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot] - logits = logits.at[:, blank_tokens].add(-float('inf')) - new_segment = False - - logits = self._suppress_tokens(logits) - - current_tokens, completed = self.state.token_decoder.update( - current_tokens, logits, sum_logprobs + def _get_logits_and_cross_attn(self, tokens, encoder_feature): + if self.state.decoder_type == "greedy": + logits, self.state.kv_cache, cross_qk = self.model.decoder( + tokens, encoder_feature, kv_cache=self.state.kv_cache, ) - mx.eval(current_tokens) - - logger.debug(f"Decoding completed: {completed}") - self.debug_print_tokens(current_tokens) - - attn_of_alignment_heads = self._process_cross_attention( - accumulated_cross_attns, content_mel_len - ) - - most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1) - most_attended_frames_np = np.array(most_attended_frames) - - absolute_timestamps = [ - (frame * 0.02 + self.state.cumulative_time_offset) - for frame in most_attended_frames_np.tolist() - ] - - logger.debug(str(most_attended_frames_np.tolist()) + " most att frames") - logger.debug(f"Absolute timestamps: {absolute_timestamps}") - - most_attended_frame = int(most_attended_frames_np[0]) - l_absolute_timestamps.append(absolute_timestamps[0]) - - if completed: - current_tokens = current_tokens[:, :-1] - break - if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold: - current_tokens_np = np.array(current_tokens) - if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD: - logger.debug("omit rewinding from special tokens") - self.state.last_attend_frame = most_attended_frame - else: - logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}") - self.state.last_attend_frame = -self.cfg.rewind_threshold - current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0] - break - else: - self.state.last_attend_frame = most_attended_frame - if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold): - logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}") - current_tokens = current_tokens[:, :-1] - break - tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist() - if self.state.pending_incomplete_tokens: - logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}") - tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split - - if fire_detected or is_last: - new_hypothesis = tokens_to_split - split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) + return logits, cross_qk else: - split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split) - if len(split_words) > 1: - new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist] - else: - new_hypothesis = [] + return self.state.inference.logits(tokens, encoder_feature) - logger.debug(f"new_hypothesis: {new_hypothesis}") - new_tokens = mx.array([new_hypothesis], dtype=mx.int32) - new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0) - self.state.tokens.append(new_tokens) + def _check_no_speech(self, logits): + if self.tokenizer.no_speech is not None: + probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1) + no_speech_probs = np.array( + probs_at_sot[:, self.tokenizer.no_speech], + ).tolist() + if no_speech_probs[0] > self.cfg.nonspeech_prob: + logger.info("no speech, stop") + return True + return False - logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") + def _suppress_blank_tokens(self, logits): + blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot] + logits = logits.at[:, blank_tokens].add(-float('inf')) + return logits - self._clean_cache() + def _apply_token_suppression(self, logits): + if self.state.suppress_tokens: + suppress_indices = mx.array( + list(self.state.suppress_tokens), dtype=mx.int32, + ) + logits = logits.at[:, suppress_indices].add(-float('inf')) + return logits - if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None: - self.state.first_timestamp = l_absolute_timestamps[0] - timestamped_words = [] - timestamp_idx = 0 - replacement_char = "\ufffd" - - for word, word_tokens in zip(split_words, split_tokens): - if replacement_char in word: - logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}") - timestamp_idx += len(word_tokens) - continue - - try: - current_timestamp = l_absolute_timestamps[timestamp_idx] - except IndexError: - pass - timestamp_idx += len(word_tokens) - - timestamp_entry = ASRToken( - start=round(current_timestamp, 2), - end=round(current_timestamp + 0.1, 2), - text=word, - speaker=self.state.speaker, - detected_language=self.state.detected_language - ).with_offset(self.state.global_time_offset) - timestamped_words.append(timestamp_entry) - self.state.pending_incomplete_tokens = [] - MAX_PENDING_TOKENS = 10 - if split_words and replacement_char in split_words[-1]: - if len(split_tokens[-1]) <= MAX_PENDING_TOKENS: - self.state.pending_incomplete_tokens = split_tokens[-1] - logger.debug(f"[UTF-8 Fix] Holding incomplete tokens") - else: - logger.warning(f"[UTF-8 Fix] Skipping too many tokens") - - return timestamped_words + def _update_tokens(self, current_tokens, logits, sum_logprobs): + return self.state.token_decoder.update(current_tokens, logits, sum_logprobs) def _process_cross_attention( - self, - cross_attns: List[List[mx.array]], - content_mel_len: int + self, cross_attns: List, content_mel_len: int, ) -> mx.array: - """ - Process cross-attention weights for alignment. - - Args: - cross_attns: List of cross-attention from each forward pass - Each element is a list of mx.arrays per layer - content_mel_len: Length of actual audio content - - Returns: - Processed attention tensor, shape (batch, seq_len, content_mel_len) - """ attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)] num_decoder_layers = self.num_decoder_layers @@ -713,14 +364,11 @@ class MLXAlignAtt: for idx, attn_mat in enumerate(flattened_attns): if attn_mat is None: continue - layer_rank = idx % num_decoder_layers align_heads_in_layer = self.state.align_source.get(layer_rank, []) - - if len(align_heads_in_layer) == 0: + if not align_heads_in_layer: continue attn_mat = mx.softmax(attn_mat, axis=-1) - for align_head_rank, head_id in align_heads_in_layer: if self.cfg.beam_size == 1: if attn_mat.ndim == 4: @@ -731,26 +379,43 @@ class MLXAlignAtt: else: a = attn_mat[:, head_id, :, :] attn_of_alignment_heads[align_head_rank].append(a) + tmp = [] for mat in attn_of_alignment_heads: if mat: - t = mx.concatenate(mat, axis=1) - tmp.append(t) - + tmp.append(mx.concatenate(mat, axis=1)) if not tmp: return mx.zeros((self.cfg.beam_size, 1, content_mel_len)) - attn_of_alignment_heads = mx.stack(tmp, axis=1) + attn_of_alignment_heads = mx.stack(tmp, axis=1) std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True) mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True) attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8) - attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7) - attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1) - attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len] - mx.eval(attn_of_alignment_heads) return attn_of_alignment_heads + def _get_attended_frames(self, attn): + most_attended_frames = mx.argmax(attn[:, -1, :], axis=-1) + frames_np = np.array(most_attended_frames) + return frames_np.tolist(), int(frames_np[0]) + + def _is_special_token(self, current_tokens): + return int(np.array(current_tokens[0, -2])) >= DEC_PAD + + def _rewind_tokens(self): + if len(self.state.tokens) > 0: + return mx.concatenate(self.state.tokens, axis=1) + return self.state.tokens[0] + + def _tokens_to_list(self, current_tokens, start_col): + return np.array(current_tokens[0, start_col:]).tolist() + + def _make_new_tokens_tensor(self, hypothesis): + new_tokens = mx.array([hypothesis], dtype=mx.int32) + return mx.repeat(new_tokens, self.cfg.beam_size, axis=0) + + def _evaluate(self, tensor): + mx.eval(tensor) diff --git a/whisperlivekit/simul_whisper/mlx_encoder.py b/whisperlivekit/simul_whisper/mlx_encoder.py index 7c64079..642ed59 100644 --- a/whisperlivekit/simul_whisper/mlx_encoder.py +++ b/whisperlivekit/simul_whisper/mlx_encoder.py @@ -7,21 +7,9 @@ from huggingface_hub import snapshot_download from mlx.utils import tree_unflatten from mlx_whisper import whisper -mlx_model_mapping = { - "tiny.en": "mlx-community/whisper-tiny.en-mlx", - "tiny": "mlx-community/whisper-tiny-mlx", - "base.en": "mlx-community/whisper-base.en-mlx", - "base": "mlx-community/whisper-base-mlx", - "small.en": "mlx-community/whisper-small.en-mlx", - "small": "mlx-community/whisper-small-mlx", - "medium.en": "mlx-community/whisper-medium.en-mlx", - "medium": "mlx-community/whisper-medium-mlx", - "large-v1": "mlx-community/whisper-large-v1-mlx", - "large-v2": "mlx-community/whisper-large-v2-mlx", - "large-v3": "mlx-community/whisper-large-v3-mlx", - "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", - "large": "mlx-community/whisper-large-mlx", -} +from whisperlivekit.model_mapping import MLX_MODEL_MAPPING + +mlx_model_mapping = MLX_MODEL_MAPPING def load_mlx_encoder( path_or_hf_repo: str, diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 0ea15a7..af2c768 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -1,7 +1,6 @@ import logging import os -from time import time -from typing import List, Optional, Tuple +from typing import List import numpy as np import torch @@ -9,8 +8,6 @@ import torch.nn.functional as F from whisperlivekit.backend_support import (faster_backend_available, mlx_backend_available) -from whisperlivekit.timed_objects import ASRToken -from whisperlivekit.whisper import DecodingOptions, tokenizer from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim) @@ -18,14 +15,13 @@ from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder, SuppressTokens) from whisperlivekit.whisper.timing import median_filter -from ..timed_objects import PUNCTUATION_MARKS +from .align_att_base import DEC_PAD, AlignAttBase from .beam import BeamPyTorchInference from .config import AlignAttConfig from .decoder_state import DecoderState from .eow_detection import fire_at_boundary, load_cif from .token_buffer import TokenBuffer -DEC_PAD = 50257 logger = logging.getLogger(__name__) if mlx_backend_available(): @@ -46,7 +42,10 @@ def load_coreml_encoder(): except ImportError: logger.warning("coremltools is not installed") return None - COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage") + COREML_ENCODER_PATH = os.environ.get( + "MLCORE_ENCODER_PATH", + "whisperlivekit/whisper/whisper_encoder.mlpackage", + ) _coreml_encoder = MLModel(COREML_ENCODER_PATH) spec = _coreml_encoder.get_spec() _coreml_input_name = spec.description.input[0].name if spec.description.input else "mel" @@ -54,92 +53,50 @@ def load_coreml_encoder(): return _coreml_encoder, _coreml_input_name, _coreml_output_name -class AlignAtt: +class AlignAtt(AlignAttBase): """ - Alignment-based Attention decoder for SimulStreaming. - - This class is now hookless - the model can be shared across multiple - sessions, with each session maintaining its own DecoderState. + PyTorch Alignment-based Attention decoder for SimulStreaming. + + Hookless — the model can be shared across multiple sessions, + with each session maintaining its own DecoderState. """ - - # Property accessors for backward compatibility - @property - def speaker(self): - return self.state.speaker - - @speaker.setter - def speaker(self, value): - self.state.speaker = value - - @property - def global_time_offset(self): - return self.state.global_time_offset - - @global_time_offset.setter - def global_time_offset(self, value): - self.state.global_time_offset = value - + def __init__( - self, - cfg: AlignAttConfig, - loaded_model=None, - mlx_encoder=None, - fw_encoder=None, - ) -> None: - # Shared model reference (can be shared across sessions) - self.model = loaded_model + self, + cfg: AlignAttConfig, + loaded_model=None, + mlx_encoder=None, + fw_encoder=None, + ) -> None: self.mlx_encoder = mlx_encoder - self.fw_encoder = fw_encoder + self.fw_encoder = fw_encoder if fw_encoder: - self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) + self.fw_feature_extractor = FeatureExtractor( + feature_size=loaded_model.dims.n_mels, + ) self.coreml_encoder_tuple = None if USE_MLCORE: self.coreml_encoder_tuple = load_coreml_encoder() self.use_mlcore = self.coreml_encoder_tuple is not None - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - + + # Common init (sets self.model, self.cfg, decode_options, etc.) + self._base_init(cfg, loaded_model) logger.info(f"Model dimensions: {self.model.dims}") - self.decode_options = DecodingOptions( - language=cfg.language, - without_timestamps=True, - task=cfg.task - ) - self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual - - self.max_text_len = self.model.dims.n_text_ctx - self.num_decoder_layers = len(self.model.decoder.blocks) - self.cfg = cfg - if self.cfg.max_context_tokens is None: - self.max_context_tokens = self.max_text_len - else: - self.max_context_tokens = self.cfg.max_context_tokens - - # Initialize per-session state + # Per-session state self.state = DecoderState() self._init_state(cfg) - + def _init_state(self, cfg: AlignAttConfig): - """Initialize the per-session decoder state.""" - # Create tokenizer - self.create_tokenizer(cfg.language if cfg.language != "auto" else None) - self.state.tokenizer = self.tokenizer - self.state.detected_language = cfg.language if cfg.language != "auto" else None - - # Timing state - self.state.global_time_offset = 0.0 - self.state.last_attend_frame = -cfg.rewind_threshold - self.state.speaker = -1 - + self._init_state_common(cfg) + # CIF helpers for end-of-word boundary detection self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif( - cfg, - n_audio_state=self.model.dims.n_audio_state, - device=self.model.device + cfg, n_audio_state=self.model.dims.n_audio_state, device=self.model.device, ) - # Build alignment source mapping from model's alignment_heads + # Build alignment source mapping self.state.align_source = {} self.state.num_align_heads = 0 for layer_rank, head_id in self.model.alignment_heads.indices().T: @@ -151,12 +108,9 @@ class AlignAtt: # Build suppress tokens function suppress_tokens = [ - self.tokenizer.transcribe, - self.tokenizer.translate, - self.tokenizer.sot, - self.tokenizer.sot_prev, - self.tokenizer.sot_lm, - self.tokenizer.no_timestamps, + self.tokenizer.transcribe, self.tokenizer.translate, + self.tokenizer.sot, self.tokenizer.sot_prev, + self.tokenizer.sot_lm, self.tokenizer.no_timestamps, ] + list(self.tokenizer.all_language_tokens) if self.tokenizer.no_speech is not None: suppress_tokens.append(self.tokenizer.no_speech) @@ -165,138 +119,80 @@ class AlignAtt: sup_tokens = SuppressTokens(suppress_tokens) self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None) - # Initialize tokens self.init_tokens() self.init_context() - # Set up decoder type + # Decoder type self.state.decoder_type = cfg.decoder_type if cfg.decoder_type == "greedy": logger.info("Using greedy decoder") self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot) elif cfg.decoder_type == "beam": logger.info("Using beam decoder") - self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length) + self.state.inference = BeamPyTorchInference( + self.model, self.state.initial_token_length, + ) self.state.inference.kv_cache = self.state.kv_cache self.state.token_decoder = BeamSearchDecoder( - inference=self.state.inference, - eot=self.tokenizer.eot, - beam_size=cfg.beam_size + inference=self.state.inference, + eot=self.tokenizer.eot, + beam_size=cfg.beam_size, ) - def warmup(self, audio): - try: - self.insert_audio(audio) - self.infer(is_last=True) - self.refresh_segment(complete=True) - logger.info("Model warmed up successfully") - except Exception as e: - logger.exception(f"Model warmup failed: {e}") + # === Abstract method implementations === - def create_tokenizer(self, language=None): - self.tokenizer = tokenizer.get_tokenizer( - multilingual=self.tokenizer_is_multilingual, - language=language, - num_languages=self.model.num_languages, - task=self.decode_options.task - ) - self.state.tokenizer = self.tokenizer + def init_tokens(self): + logger.debug(f"init tokens, {len(self.state.segments)}") + self.state.initial_tokens = torch.tensor( + self.tokenizer.sot_sequence_including_notimestamps, + dtype=torch.long, device=self.model.device, + ).unsqueeze(0) + self.state.initial_token_length = self.state.initial_tokens.shape[1] + self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) + logger.debug(f"init tokens after, {len(self.state.segments)}") + self.state.tokens = [self.state.initial_tokens] def init_context(self): - kw = {'tokenizer': self.tokenizer, - 'device': self.model.device, - 'prefix_token_ids': [self.tokenizer.sot_prev]} + kw = { + 'tokenizer': self.tokenizer, + 'device': self.model.device, + 'prefix_token_ids': [self.tokenizer.sot_prev], + } self.state.context = TokenBuffer.empty(**kw) if self.cfg.static_init_prompt is not None: self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw) if self.cfg.init_prompt is not None: self.state.context.text += self.cfg.init_prompt - def init_tokens(self): - logger.debug(f"init tokens, {len(self.state.segments)}") - # init tokens (mandatory prompt) - self.state.initial_tokens = torch.tensor( - self.tokenizer.sot_sequence_including_notimestamps, - dtype=torch.long, - device=self.model.device).unsqueeze(0) - self.state.initial_token_length = self.state.initial_tokens.shape[1] - self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) - logger.debug(f"init tokens after, {len(self.state.segments)}") - self.state.tokens = [self.state.initial_tokens] - - def trim_context(self): - logger.info("Trimming context") - c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids) - logger.info(f"Context text: {self.state.context.as_text()}") - l = sum(t.shape[1] for t in self.state.tokens) + c - if self.cfg.static_init_prompt is None: - after = 0 - else: - after = len(self.cfg.static_init_prompt) - while c > self.max_context_tokens or l > self.max_text_len - 20: - t = self.state.context.trim_words(after=after) - l -= t - c -= t - logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") - if t == 0: - break - logger.info(f"Context after trim: {self.state.context.text} (len: {l})") - - - def logits( - self, - tokens: torch.Tensor, - audio_features: torch.Tensor, - return_cross_attn: bool = False - ): - """Get logits from decoder, optionally returning cross-attention weights.""" - if self.state.decoder_type == "greedy": - return self.model.decoder( - tokens, audio_features, - kv_cache=self.state.kv_cache, - return_cross_attn=return_cross_attn + def insert_audio(self, segment=None): + if segment is not None: + self.state.segments.append(segment) + removed_len = 0 + segments_len = self.segments_len() + while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len: + removed_len = self.state.segments[0].shape[0] / 16000 + segments_len -= removed_len + self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len) + self.state.cumulative_time_offset += removed_len + self.state.segments = self.state.segments[1:] + logger.debug( + f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, " + f"cumulative offset: {self.state.cumulative_time_offset:.2f}s" ) - else: - logger.debug(f"Logits shape: {tokens.shape}") - return self.state.inference.logits( - tokens, audio_features, - return_cross_attn=return_cross_attn - ) - - - def refresh_segment(self, complete=False): - logger.debug("Refreshing segment:") - self.init_tokens() - self.state.last_attend_frame = -self.cfg.rewind_threshold - self.state.cumulative_time_offset = 0.0 - self.init_context() - logger.debug(f"Context: {self.state.context}") - if not complete and len(self.state.segments) > 2: - self.state.segments = self.state.segments[-2:] - else: - logger.debug("removing all segments.") - self.state.segments = [] - self.state.log_segments += 1 - self.state.pending_incomplete_tokens = [] - - def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): - if self.state.always_fire: - return True - if self.state.never_fire: - return False - return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear) + if len(self.state.tokens) > 1: + self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist()) + self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:] + return removed_len def _current_tokens(self): toks = self.state.tokens - # very first infer: duplicate start of seq to beam_size if toks[0].shape[0] == 1: toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0) - if not self.state.context.is_empty(): - context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device) + context_toks = self.state.context.as_tensor_beam( + self.cfg.beam_size, device=self.model.device, + ) toks = [context_toks] + toks - - # make it one tensor if len(toks) > 1: current_tokens = torch.cat(toks, dim=1) else: @@ -305,60 +201,19 @@ class AlignAtt: self.debug_print_tokens(current_tokens) return current_tokens - - def debug_print_tokens(self, tokens): - for i in range(self.cfg.beam_size): - logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist())) - - ### audio buffer - - def segments_len(self): - segments_len = sum(s.shape[0] for s in self.state.segments) / 16000 - return segments_len - - def _apply_minseglen(self): - segments_len = self.segments_len() - # wait for long enough audio to start - if segments_len < self.cfg.audio_min_len: - logger.debug("waiting for next segment") + def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): + if self.state.always_fire: + return True + if self.state.never_fire: return False - return True - - def insert_audio(self, segment=None): - if segment is not None: - self.state.segments.append(segment) - - removed_len = 0 - # len of audio is bigger than buffer_len. Going to remove the first segment - segments_len = self.segments_len() - while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len: - removed_len = self.state.segments[0].shape[0] / 16000 - segments_len -= removed_len - self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len) - self.state.cumulative_time_offset += removed_len # Track cumulative time removed - self.state.segments = self.state.segments[1:] - logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s") - if len(self.state.tokens) > 1: - self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist()) - self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:] - return removed_len - - def _clean_cache(self): - """Clean the kv_cache after each inference step.""" - self.state.clean_cache() + return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear) @torch.no_grad() def lang_id(self, encoder_features): - """Language detection from encoder features. - This code is trimmed and copy-pasted from whisper.decoding.detect_language. - """ - # forward pass using a single token, startoftranscript n_audio = encoder_features.shape[0] - x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1] - # Note: don't use kv_cache for language detection + x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) logits = self.model.logits(x, encoder_features)[:, 0] - # collect detected languages; suppress all non-language tokens mask = torch.ones(logits.shape[-1], dtype=torch.bool) mask[list(self.tokenizer.all_language_tokens)] = False logits[:, mask] = -np.inf @@ -367,45 +222,31 @@ class AlignAtt: language_probs = [ { c: language_token_probs[i, j].item() - for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes) + for j, c in zip( + self.tokenizer.all_language_tokens, + self.tokenizer.all_language_codes, + ) } for i in range(n_audio) ] - single = encoder_features.ndim == 2 if single: language_tokens = language_tokens[0] language_probs = language_probs[0] - self._clean_cache() return language_tokens, language_probs - ### transcription / translation - - @torch.no_grad() - def infer(self, is_last=False): - new_segment = True - if len(self.state.segments) == 0: - logger.debug("No segments, nothing to do") - return [] - if not self._apply_minseglen(): - logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") - return [] - - # input_segments is concatenation of audio, it's one array + def _concat_segments(self): if len(self.state.segments) > 1: - input_segments = torch.cat(self.state.segments, dim=0) - else: - input_segments = self.state.segments[0] + return torch.cat(self.state.segments, dim=0) + return self.state.segments[0] - beg_encode = time() + def _encode(self, input_segments): if self.use_mlcore: coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple mel_padded = log_mel_spectrogram( - input_segments, - n_mels=self.model.dims.n_mels, - padding=N_SAMPLES, - device="cpu", + input_segments, n_mels=self.model.dims.n_mels, + padding=N_SAMPLES, device="cpu", ).unsqueeze(0) mel = pad_or_trim(mel_padded, N_FRAMES) content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2) @@ -417,310 +258,151 @@ class AlignAtt: else: encoder_feature_np = next(iter(coreml_outputs.values())) encoder_feature = torch.as_tensor( - np.array(encoder_feature_np), - device=self.device, + np.array(encoder_feature_np), device=self.device, ) if self.mlx_encoder: - mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) + mlx_mel_padded = mlx_log_mel_spectrogram( + audio=input_segments.detach(), + n_mels=self.model.dims.n_mels, padding=N_SAMPLES, + ) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) encoder_feature = torch.as_tensor(mlx_encoder_feature) - content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) + content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2) elif self.fw_encoder: - audio_length_seconds = len(input_segments) / 16000 - content_mel_len = int(audio_length_seconds * 100)//2 - mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] + audio_length_seconds = len(input_segments) / 16000 + content_mel_len = int(audio_length_seconds * 100) // 2 + mel_padded_2 = self.fw_feature_extractor( + waveform=input_segments.numpy(), padding=N_SAMPLES, + )[None, :] mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1) encoder_feature_ctranslate = self.fw_encoder.encode(mel) - if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works + if self.device == 'cpu': encoder_feature_ctranslate = np.array(encoder_feature_ctranslate) try: - encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device) - except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case: - encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device) + encoder_feature = torch.as_tensor( + encoder_feature_ctranslate, device=self.device, + ) + except TypeError: + encoder_feature = torch.as_tensor( + np.array(encoder_feature_ctranslate), device=self.device, + ) else: - # mel + padding to 30s - mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, - device=self.device).unsqueeze(0) - # trim to 3000 + mel_padded = log_mel_spectrogram( + input_segments, n_mels=self.model.dims.n_mels, + padding=N_SAMPLES, device=self.device, + ).unsqueeze(0) mel = pad_or_trim(mel_padded, N_FRAMES) - # the len of actual audio - content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) + content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2) encoder_feature = self.model.encoder(mel) - end_encode = time() - # print('Encoder duration:', end_encode-beg_encode) - - if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp: - seconds_since_start = self.segments_len() - self.state.first_timestamp - if seconds_since_start >= 2.0: - language_tokens, language_probs = self.lang_id(encoder_feature) - top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) - print(f"Detected language: {top_lan} with p={p:.4f}") - self.create_tokenizer(top_lan) - self.state.last_attend_frame = -self.cfg.rewind_threshold - self.state.cumulative_time_offset = 0.0 - self.init_tokens() - self.init_context() - self.state.detected_language = top_lan - logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") + return encoder_feature, content_mel_len - self.trim_context() - current_tokens = self._current_tokens() - - fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) + def _init_sum_logprobs(self): + return torch.zeros(self.cfg.beam_size, device=self.device) - - sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device) - completed = False - # punctuation_stop = False - - attn_of_alignment_heads = None - most_attended_frame = None - - token_len_before_decoding = current_tokens.shape[1] - - l_absolute_timestamps = [] - - accumulated_cross_attns = [] - - audio_duration_s = self.segments_len() - # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50) - # is the mel-frame rate and was causing 10-40x over-allocation on repetition loops. - max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5)) - tokens_produced_this_chunk = 0 - - while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens - tokens_produced_this_chunk += 1 - - if tokens_produced_this_chunk > max_tokens_per_chunk: - logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.") - current_tokens = current_tokens[:, :token_len_before_decoding] # Discard all new tokens - break - - if new_segment: - tokens_for_logits = current_tokens - else: - # only need to use the last token except in the first forward pass - tokens_for_logits = current_tokens[:, -1:] - - # Get logits and cross-attention weights from decoder - result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True) - logits, cross_attns = result - - # Accumulate cross-attention from this forward pass (rolling window to - # bound VRAM — only the last entry matters for alignment, and the - # median_filter kernel is 7, so 16 entries is more than enough). - accumulated_cross_attns.append(cross_attns) - if len(accumulated_cross_attns) > 16: - accumulated_cross_attns = accumulated_cross_attns[-16:] - - if new_segment and self.tokenizer.no_speech is not None: - probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() - if no_speech_probs[0] > self.cfg.nonspeech_prob: - logger.info("no speech, stop") - break - - logits = logits[:, -1, :] # logits for the last token - - # suppress blank tokens only at the beginning of the segment - if new_segment: - logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf - new_segment = False - self.state.suppress_tokens_fn(logits) - current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs) - - logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") - self.debug_print_tokens(current_tokens) - - # Process accumulated cross-attention weights for alignment - attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len) - - # for each beam, the most attended frame is: - most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1) - - # Calculate absolute timestamps accounting for cumulative offset - absolute_timestamps = [ - (frame * 0.02 + self.state.cumulative_time_offset) - for frame in most_attended_frames.tolist() - ] - - logger.debug(str(most_attended_frames.tolist()) + " most att frames") - logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)") - - most_attended_frame = most_attended_frames[0].item() - l_absolute_timestamps.append(absolute_timestamps[0]) - - logger.debug("current tokens" + str(current_tokens.shape)) - if completed: - # stripping the last token, the eot - current_tokens = current_tokens[:, :-1] - break - - # for some rare cases where the attention fails - if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold: - if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD: - logger.debug("omit rewinding from special tokens") - self.state.last_attend_frame = most_attended_frame - else: - logger.debug( - f"[rewind detected] current attention pos: {most_attended_frame}, " - f"last attention pos: {self.state.last_attend_frame}; omit this segment") - self.state.last_attend_frame = -self.cfg.rewind_threshold - current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0] - break - else: - self.state.last_attend_frame = most_attended_frame - - if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold): - logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}") - # stripping the last token, the one that is attended too close to the end - current_tokens = current_tokens[:, :-1] - break - - # debug print - for i in range(self.cfg.beam_size): - logger.debug("attn: {}, current pos: {}, current token: {}({})".format( - attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None, - most_attended_frames[i], - current_tokens[i, -1].item(), - self.tokenizer.decode([current_tokens[i, -1].item()]) - )) - - tokens_to_split = current_tokens[0, token_len_before_decoding:] - - # Prepend pending tokens from previous chunk if any - if self.state.pending_incomplete_tokens: - logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}") - pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device) - tokens_to_split = torch.cat([pending_tensor, tokens_to_split]) - - if fire_detected or is_last: - new_hypothesis = tokens_to_split.flatten().tolist() - split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) - else: - # going to truncate the tokens after the last space - split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist()) - if len(split_words) > 1: - new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist] - else: - new_hypothesis = [] - - logger.debug(f"new_hypothesis: {new_hypothesis}") - new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to( - device=self.device, - ) - self.state.tokens.append(new_tokens) - - logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") - - self._clean_cache() - - if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None: - self.state.first_timestamp = l_absolute_timestamps[0] - - timestamped_words = [] - timestamp_idx = 0 - replacement_char = "\ufffd" - for word, word_tokens in zip(split_words, split_tokens): - # Skip words containing incomplete UTF-8 from client output - if replacement_char in word: - logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}") - timestamp_idx += len(word_tokens) - continue - - try: - current_timestamp = l_absolute_timestamps[timestamp_idx] - except IndexError: - # Use last timestamp if index out of range - logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp") - current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0 - timestamp_idx += len(word_tokens) - - timestamp_entry = ASRToken( - start=round(current_timestamp, 2), - end=round(current_timestamp + 0.1, 2), - text=word, - speaker=self.state.speaker, - detected_language=self.state.detected_language - ).with_offset( - self.state.global_time_offset + def _get_logits_and_cross_attn(self, tokens, encoder_feature): + if self.state.decoder_type == "greedy": + return self.model.decoder( + tokens, encoder_feature, + kv_cache=self.state.kv_cache, + return_cross_attn=True, + ) + else: + logger.debug(f"Logits shape: {tokens.shape}") + return self.state.inference.logits( + tokens, encoder_feature, return_cross_attn=True, ) - timestamped_words.append(timestamp_entry) - # Hold incomplete tokens for next chunk (with limit to prevent hallucination accumulation) - self.state.pending_incomplete_tokens = [] - MAX_PENDING_TOKENS = 10 # Real incomplete UTF-8 chars are at most a few tokens - if split_words and replacement_char in split_words[-1]: - if len(split_tokens[-1]) <= MAX_PENDING_TOKENS: - self.state.pending_incomplete_tokens = split_tokens[-1] - logger.debug(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk") - else: - logger.warning(f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens (exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)") + def _check_no_speech(self, logits): + if self.tokenizer.no_speech is not None: + probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + if no_speech_probs[0] > self.cfg.nonspeech_prob: + logger.info("no speech, stop") + return True + return False - return timestamped_words + def _suppress_blank_tokens(self, logits): + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf + return logits + + def _apply_token_suppression(self, logits): + self.state.suppress_tokens_fn(logits) + return logits + + def _update_tokens(self, current_tokens, logits, sum_logprobs): + return self.state.token_decoder.update(current_tokens, logits, sum_logprobs) def _process_cross_attention( - self, - cross_attns: List[torch.Tensor], - content_mel_len: int + self, cross_attns: List, content_mel_len: int, ) -> torch.Tensor: - """ - Process cross-attention weights from decoder layers for alignment. - - Args: - cross_attns: List of cross-attention tensors from each decoder layer. - Each tensor has shape (batch, n_head, seq_len, audio_len) - content_mel_len: Length of actual audio content in mel frames - - Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len) - """ attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)] num_decoder_layers = len(self.model.decoder.blocks) if cross_attns and isinstance(cross_attns[0], list): - flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list] + flattened_attns = [attn for layer_list in cross_attns for attn in layer_list] else: flattened_attns = cross_attns - + for idx, attn_mat in enumerate(flattened_attns): layer_rank = idx % num_decoder_layers - # attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1 align_heads_in_layer = self.state.align_source.get(layer_rank, []) - if len(align_heads_in_layer) == 0: + if not align_heads_in_layer: continue - attn_mat = F.softmax(attn_mat, dim=-1) - for align_head_rank, head_id in align_heads_in_layer: if self.cfg.beam_size == 1: - # (n_head, seq_len, audio_len) when squeezed if attn_mat.dim() == 4: - a = attn_mat[0, head_id, :, :] # (seq_len, audio_len) + a = attn_mat[0, head_id, :, :] else: a = attn_mat[head_id, :, :] - a = a.unsqueeze(0) # (1, seq_len, audio_len) + a = a.unsqueeze(0) else: - # attn_mat: (batch, n_head, seq_len, audio_len) - a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len) + a = attn_mat[:, head_id, :, :] attn_of_alignment_heads[align_head_rank].append(a) - + tmp = [] for mat in attn_of_alignment_heads: if mat: - t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len) - tmp.append(t) - + tmp.append(torch.cat(mat, dim=1)) if not tmp: return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device) - - # stck al heads: (batch, num_align_heads, seq_len, audio_len) + attn_of_alignment_heads = torch.stack(tmp, dim=1) - - std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False) + std, mean = torch.std_mean( + attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False, + ) attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8) - attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1) attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len] - return attn_of_alignment_heads \ No newline at end of file + return attn_of_alignment_heads + + def _get_attended_frames(self, attn): + most_attended_frames = torch.argmax(attn[:, -1, :], dim=-1) + return most_attended_frames.tolist(), most_attended_frames[0].item() + + def _is_special_token(self, current_tokens): + return current_tokens[0, -2].item() >= DEC_PAD + + def _rewind_tokens(self): + if len(self.state.tokens) > 0: + return torch.cat(self.state.tokens, dim=1) + return self.state.tokens[0] + + def _tokens_to_list(self, current_tokens, start_col): + return current_tokens[0, start_col:].flatten().tolist() + + def _make_new_tokens_tensor(self, hypothesis): + return ( + torch.tensor([hypothesis], dtype=torch.long) + .repeat_interleave(self.cfg.beam_size, dim=0) + .to(device=self.device) + ) + + def _evaluate(self, tensor): + pass # No-op for PyTorch + + @torch.no_grad() + def infer(self, is_last=False): + return super().infer(is_last) diff --git a/whisperlivekit/whisper/val.py b/whisperlivekit/whisper/val.py new file mode 100644 index 0000000..4223fc8 --- /dev/null +++ b/whisperlivekit/whisper/val.py @@ -0,0 +1,200 @@ +""" +The most atomic way to train and inference a GPT in pure, dependency-free Python. +This file is the complete algorithm. +Everything else is just efficiency. + +@karpathy +""" + +import os # os.path.exists +import math # math.log, math.exp +import random # random.seed, random.choices, random.gauss, random.shuffle +random.seed(42) # Let there be order among chaos + +# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names) +if not os.path.exists('input.txt'): + import urllib.request + names_url = 'https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt' + urllib.request.urlretrieve(names_url, 'input.txt') +docs = [l.strip() for l in open('input.txt').read().strip().split('\n') if l.strip()] # list[str] of documents +random.shuffle(docs) +print(f"num docs: {len(docs)}") + +# Let there be a Tokenizer to translate strings to discrete symbols and back +uchars = sorted(set(''.join(docs))) # unique characters in the dataset become token ids 0..n-1 +BOS = len(uchars) # token id for the special Beginning of Sequence (BOS) token +vocab_size = len(uchars) + 1 # total number of unique tokens, +1 is for BOS +print(f"vocab size: {vocab_size}") + +# Let there be Autograd, to recursively apply the chain rule through a computation graph +class Value: + __slots__ = ('data', 'grad', '_children', '_local_grads') # Python optimization for memory usage + + def __init__(self, data, children=(), local_grads=()): + self.data = data # scalar value of this node calculated during forward pass + self.grad = 0 # derivative of the loss w.r.t. this node, calculated in backward pass + self._children = children # children of this node in the computation graph + self._local_grads = local_grads # local derivative of this node w.r.t. its children + + def __add__(self, other): + other = other if isinstance(other, Value) else Value(other) + return Value(self.data + other.data, (self, other), (1, 1)) + + def __mul__(self, other): + other = other if isinstance(other, Value) else Value(other) + return Value(self.data * other.data, (self, other), (other.data, self.data)) + + def __pow__(self, other): return Value(self.data**other, (self,), (other * self.data**(other-1),)) + def log(self): return Value(math.log(self.data), (self,), (1/self.data,)) + def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),)) + def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),)) + def __neg__(self): return self * -1 + def __radd__(self, other): return self + other + def __sub__(self, other): return self + (-other) + def __rsub__(self, other): return other + (-self) + def __rmul__(self, other): return self * other + def __truediv__(self, other): return self * other**-1 + def __rtruediv__(self, other): return other * self**-1 + + def backward(self): + topo = [] + visited = set() + def build_topo(v): + if v not in visited: + visited.add(v) + for child in v._children: + build_topo(child) + topo.append(v) + build_topo(self) + self.grad = 1 + for v in reversed(topo): + for child, local_grad in zip(v._children, v._local_grads): + child.grad += local_grad * v.grad + +# Initialize the parameters, to store the knowledge of the model. +n_embd = 16 # embedding dimension +n_head = 4 # number of attention heads +n_layer = 1 # number of layers +block_size = 16 # maximum sequence length +head_dim = n_embd // n_head # dimension of each head +matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)] +state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)} +for i in range(n_layer): + state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd) + state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd) + state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd) + state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd) + state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd) + state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd) +params = [p for mat in state_dict.values() for row in mat for p in row] # flatten params into a single list[Value] +print(f"num params: {len(params)}") +# Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next. +# Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU + +def linear(x, w): + return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w] + + +def softmax(logits): + max_val = max(val.data for val in logits) + exps = [(val - max_val).exp() for val in logits] + total = sum(exps) + return [e / total for e in exps] + +def rmsnorm(x): + ms = sum(xi * xi for xi in x) / len(x) + scale = (ms + 1e-5) ** -0.5 + return [xi * scale for xi in x] + +def gpt(token_id, pos_id, keys, values): + tok_emb = state_dict['wte'][token_id] # token embedding + pos_emb = state_dict['wpe'][pos_id] # position embedding + x = [t + p for t, p in zip(tok_emb, pos_emb)] # joint token and position embedding + x = rmsnorm(x) + + for li in range(n_layer): + # 1) Multi-head attention block + x_residual = x + x = rmsnorm(x) + q = linear(x, state_dict[f'layer{li}.attn_wq']) + k = linear(x, state_dict[f'layer{li}.attn_wk']) + v = linear(x, state_dict[f'layer{li}.attn_wv']) + keys[li].append(k) + values[li].append(v) + x_attn = [] + for h in range(n_head): + hs = h * head_dim + q_h = q[hs:hs+head_dim] + k_h = [ki[hs:hs+head_dim] for ki in keys[li]] + v_h = [vi[hs:hs+head_dim] for vi in values[li]] + attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))] + attn_weights = softmax(attn_logits) + head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)] + x_attn.extend(head_out) + x = linear(x_attn, state_dict[f'layer{li}.attn_wo']) + x = [a + b for a, b in zip(x, x_residual)] + # 2) MLP block + x_residual = x + x = rmsnorm(x) + x = linear(x, state_dict[f'layer{li}.mlp_fc1']) + x = [xi.relu() for xi in x] + x = linear(x, state_dict[f'layer{li}.mlp_fc2']) + x = [a + b for a, b in zip(x, x_residual)] + + logits = linear(x, state_dict['lm_head']) + return logits + +# Let there be Adam, the blessed optimizer and its buffers +learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8 +m = [0.0] * len(params) # first moment buffer +v = [0.0] * len(params) # second moment buffer +# Repeat in sequence +num_steps = 1000 # number of training steps +for step in range(num_steps): + + # Take single document, tokenize it, surround it with BOS special token on both sides + doc = docs[step % len(docs)] + tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS] + n = min(block_size, len(tokens) - 1) + + # Forward the token sequence through the model, building up the computation graph all the way to the loss. + keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)] + losses = [] + for pos_id in range(n): + token_id, target_id = tokens[pos_id], tokens[pos_id + 1] + logits = gpt(token_id, pos_id, keys, values) + probs = softmax(logits) + loss_t = -probs[target_id].log() + losses.append(loss_t) + loss = (1 / n) * sum(losses) # final average loss over the document sequence. May yours be low. + + # Backward the loss, calculating the gradients with respect to all model parameters. + loss.backward() + + # Adam optimizer update: update the model parameters based on the corresponding gradients. + lr_t = learning_rate * (1 - step / num_steps) # linear learning rate decay + for i, p in enumerate(params): + m[i] = beta1 * m[i] + (1 - beta1) * p.grad + v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2 + m_hat = m[i] / (1 - beta1 ** (step + 1)) + v_hat = v[i] / (1 - beta2 ** (step + 1)) + p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam) + p.grad = 0 + + print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}") + +# Inference: may the model babble back to us +temperature = 0.5 # in (0, 1], control the "creativity" of generated text, low to high +print("\n--- inference (new, hallucinated names) ---") +for sample_idx in range(20): + keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)] + token_id = BOS + sample = [] + for pos_id in range(block_size): + logits = gpt(token_id, pos_id, keys, values) + probs = softmax([l / temperature for l in logits]) + token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0] + if token_id == BOS: + break + sample.append(uchars[token_id]) + print(f"sample {sample_idx+1:2d}: {''.join(sample)}") \ No newline at end of file