From cf6c49f502b860fffb03ffd18ffbef3c9bdabd7e Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 3 Jan 2026 10:23:00 +0100 Subject: [PATCH] Ruff lint cleanup --- whisperlivekit/silero_vad_iterator.py | 24 ++++----- .../simul_whisper/align_att_base.py | 3 +- whisperlivekit/simul_whisper/backend.py | 46 ++++++++-------- whisperlivekit/simul_whisper/beam.py | 8 +-- whisperlivekit/simul_whisper/config.py | 1 - whisperlivekit/simul_whisper/decoder_state.py | 31 +++++------ whisperlivekit/simul_whisper/eow_detection.py | 4 +- .../simul_whisper/mlx/decoder_state.py | 26 ++++----- whisperlivekit/simul_whisper/mlx/decoders.py | 54 +++++++++---------- .../simul_whisper/mlx/simul_whisper.py | 2 - whisperlivekit/simul_whisper/mlx_encoder.py | 12 ++--- whisperlivekit/simul_whisper/simul_whisper.py | 13 ++--- whisperlivekit/simul_whisper/token_buffer.py | 7 ++- whisperlivekit/whisper/__init__.py | 46 ++++++++-------- whisperlivekit/whisper/decoding.py | 3 +- whisperlivekit/whisper/model.py | 22 ++++---- whisperlivekit/whisper/transcribe.py | 6 +-- whisperlivekit/whisper/val.py | 9 ++-- 18 files changed, 150 insertions(+), 167 deletions(-) diff --git a/whisperlivekit/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py index 05d9acd..7d3d950 100644 --- a/whisperlivekit/silero_vad_iterator.py +++ b/whisperlivekit/silero_vad_iterator.py @@ -130,18 +130,18 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat available_ops = [15, 16] if opset_version not in available_ops: raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}') - + if model_path is None: current_dir = Path(__file__).parent data_dir = current_dir / 'silero_vad_models' - + if opset_version == 16: model_name = 'silero_vad.onnx' else: model_name = f'silero_vad_16k_op{opset_version}.onnx' - + model_path = data_dir / model_name - + if not model_path.exists(): raise FileNotFoundError( f"Model file not found: {model_path}\n" @@ -149,7 +149,7 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat ) else: model_path = Path(model_path) - + return model_path @@ -169,9 +169,9 @@ def load_jit_vad(model_path: str = None): current_dir = Path(__file__).parent data_dir = current_dir / 'silero_vad_models' model_name = 'silero_vad.jit' - + model_path = data_dir / model_name - + if not model_path.exists(): raise FileNotFoundError( f"Model file not found: {model_path}\n" @@ -181,17 +181,17 @@ def load_jit_vad(model_path: str = None): model_path = Path(model_path) model = init_jit_model(str(model_path)) - + return model class VADIterator: """ Voice Activity Detection iterator for streaming audio. - + This is the Silero VAD v6 implementation. """ - + def __init__(self, model, threshold: float = 0.5, @@ -319,8 +319,8 @@ if __name__ == "__main__": audio_buffer = np.array([0] * 512, dtype=np.float32) result = vad(audio_buffer) print(f" 512 samples: {result}") - + # test with 511 samples audio_buffer = np.array([0] * 511, dtype=np.float32) result = vad(audio_buffer) - print(f" 511 samples: {result}") \ No newline at end of file + print(f" 511 samples: {result}") diff --git a/whisperlivekit/simul_whisper/align_att_base.py b/whisperlivekit/simul_whisper/align_att_base.py index 910be1d..9bbb2cd 100644 --- a/whisperlivekit/simul_whisper/align_att_base.py +++ b/whisperlivekit/simul_whisper/align_att_base.py @@ -1,7 +1,6 @@ """Abstract base class for AlignAtt streaming decoders (PyTorch & MLX).""" import logging from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple from whisperlivekit.timed_objects import ASRToken from whisperlivekit.whisper import DecodingOptions, tokenizer @@ -151,7 +150,7 @@ class AlignAttBase(ABC): if seconds_since_start >= 2.0: language_tokens, language_probs = self.lang_id(encoder_feature) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) - print(f"Detected language: {top_lan} with p={p:.4f}") + logger.info(f"Detected language: {top_lan} with p={p:.4f}") self.create_tokenizer(top_lan) self.state.last_attend_frame = -self.cfg.rewind_threshold self.state.cumulative_time_offset = 0.0 diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 4d010db..4a3c702 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -1,31 +1,27 @@ import gc import logging -import os import platform import sys -from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import torch -from whisperlivekit.backend_support import (faster_backend_available, - mlx_backend_available) +from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.simul_whisper import AlignAtt from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript from whisperlivekit.warmup import load_file from whisperlivekit.whisper import load_model, tokenizer -from whisperlivekit.whisper.audio import TOKENS_PER_SECOND logger = logging.getLogger(__name__) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) if HAS_MLX_WHISPER: - from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping from .mlx import MLXAlignAtt + from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping else: mlx_model_mapping = {} MLXAlignAtt = None @@ -47,7 +43,7 @@ class SimulStreamingOnlineProcessor: self.end = 0.0 self.buffer = [] self.model = self._create_alignatt() - + if asr.tokenizer: self.model.tokenizer = asr.tokenizer self.model.state.tokenizer = asr.tokenizer @@ -99,7 +95,7 @@ class SimulStreamingOnlineProcessor: self.model.refresh_segment(complete=True) self.model.speaker = change_speaker.speaker self.model.global_time_offset = change_speaker.start - + def get_buffer(self): concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') return concat_buffer @@ -107,19 +103,19 @@ class SimulStreamingOnlineProcessor: def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: """ Process accumulated audio chunks using SimulStreaming. - + Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). """ try: timestamped_words = self.model.infer(is_last=is_last) - + if not timestamped_words: return [], self.end - + if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None: self.buffer.extend(timestamped_words) return [], self.end - + self.buffer = [] return timestamped_words, self.end except Exception as e: @@ -156,7 +152,7 @@ class SimulStreamingASR: def __init__(self, logfile=sys.stderr, **kwargs): self.logfile = logfile self.transcribe_kargs = {} - + for key, value in kwargs.items(): setattr(self, key, value) @@ -169,20 +165,20 @@ class SimulStreamingASR: self.use_full_mlx = getattr(self, "use_full_mlx", False) preferred_backend = getattr(self, "backend", "auto") compatible_whisper_mlx, compatible_faster_whisper = True, True - + if self.model_path: resolved_model_path = resolve_model_path(self.model_path) self._resolved_model_path = resolved_model_path self.model_path = str(resolved_model_path) - + model_info = detect_model_format(resolved_model_path) compatible_whisper_mlx = model_info.compatible_whisper_mlx compatible_faster_whisper = model_info.compatible_faster_whisper - + if not self.use_full_mlx and not model_info.has_pytorch: raise FileNotFoundError( f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" - ) + ) self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem elif self.model_size is not None: self.model_name = self.model_size @@ -199,14 +195,14 @@ class SimulStreamingASR: self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper") if self.encoder_backend == "whisper": self.disable_fast_encoder = True - + # MLX full decoder disabled by default — MLXAlignAtt has known issues # with token generation after punctuation. Users can opt-in with # --use-full-mlx if they want to test it. # if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin": # if not hasattr(self, '_full_mlx_disabled'): # self.use_full_mlx = True - + self.cfg = AlignAttConfig( tokenizer_is_multilingual= is_multilingual, segment_length=self.min_chunk_size, @@ -222,8 +218,8 @@ class SimulStreamingASR: init_prompt=self.init_prompt, max_context_tokens=self.max_context_tokens, static_init_prompt=self.static_init_prompt, - ) - + ) + # Set up tokenizer for translation if needed if self.direct_english_translation: self.tokenizer = self.set_translate_task() @@ -232,7 +228,7 @@ class SimulStreamingASR: self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None self.shared_model = None - + if self.use_full_mlx and HAS_MLX_WHISPER: logger.info('MLX Whisper backend used.') if self._resolved_model_path is not None: @@ -259,7 +255,7 @@ class SimulStreamingASR: self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path) self.shared_model = self.load_model() elif self.encoder_backend == "faster-whisper": - print('SimulStreaming will use Faster Whisper for the encoder.') + logger.info('SimulStreaming will use Faster Whisper for the encoder.') if self._resolved_model_path is not None: fw_model = str(self._resolved_model_path) else: @@ -272,7 +268,7 @@ class SimulStreamingASR: self.shared_model = self.load_model() else: self.shared_model = self.load_model() - + def _warmup_mlx_model(self): """Warmup the full MLX model.""" warmup_audio = load_file(self.warmup_file) diff --git a/whisperlivekit/simul_whisper/beam.py b/whisperlivekit/simul_whisper/beam.py index 27cec0b..06e9845 100644 --- a/whisperlivekit/simul_whisper/beam.py +++ b/whisperlivekit/simul_whisper/beam.py @@ -19,14 +19,14 @@ class BeamPyTorchInference(PyTorchInference): self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach() def logits( - self, - tokens: Tensor, + self, + tokens: Tensor, audio_features: Tensor, return_cross_attn: bool = False, ): """Get logits, optionally returning cross-attention weights.""" return self.model.decoder( - tokens, audio_features, + tokens, audio_features, kv_cache=self.kv_cache, return_cross_attn=return_cross_attn, - ) \ No newline at end of file + ) diff --git a/whisperlivekit/simul_whisper/config.py b/whisperlivekit/simul_whisper/config.py index 1897aac..9a28255 100644 --- a/whisperlivekit/simul_whisper/config.py +++ b/whisperlivekit/simul_whisper/config.py @@ -21,4 +21,3 @@ class AlignAttConfig(): init_prompt: str = field(default=None) static_init_prompt: str = field(default=None) max_context_tokens: int = field(default=None) - \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/decoder_state.py b/whisperlivekit/simul_whisper/decoder_state.py index 44076a3..8722851 100644 --- a/whisperlivekit/simul_whisper/decoder_state.py +++ b/whisperlivekit/simul_whisper/decoder_state.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple + import torch @@ -7,23 +8,23 @@ import torch class DecoderState: kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict) - + tokenizer: Any = None detected_language: Optional[str] = None reset_tokenizer_to_auto_next_call: bool = False - + tokens: List[torch.Tensor] = field(default_factory=list) initial_tokens: Optional[torch.Tensor] = None initial_token_length: int = 0 sot_index: int = 0 - + align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict) num_align_heads: int = 0 - + segments: List[torch.Tensor] = field(default_factory=list) - + context: Any = None - + pending_incomplete_tokens: List[int] = field(default_factory=list) pending_retries: int = 0 @@ -31,21 +32,21 @@ class DecoderState: cumulative_time_offset: float = 0.0 first_timestamp: Optional[float] = None last_attend_frame: int = 0 - + speaker: int = -1 log_segments: int = 0 - + CIFLinear: Optional[torch.nn.Module] = None always_fire: bool = False never_fire: bool = False - + suppress_tokens_fn: Any = None - + token_decoder: Any = None decoder_type: str = "greedy" - + inference: Any = None - + def clean_cache(self): """Clean the kv_cache after each inference step.""" # Explicitly delete tensor references to free GPU memory @@ -68,11 +69,11 @@ class DecoderState: self.inference.kv_cache = {} if self.token_decoder is not None: self.token_decoder.reset() - + def reset(self, rewind_threshold: int = 200): """ Reset transient state for a new segment. - + Args: rewind_threshold: Value for resetting last_attend_frame """ @@ -85,7 +86,7 @@ class DecoderState: def full_reset(self, rewind_threshold: int = 200): """ Full reset including audio segments and tokens. - + Args: rewind_threshold: Value for resetting last_attend_frame """ diff --git a/whisperlivekit/simul_whisper/eow_detection.py b/whisperlivekit/simul_whisper/eow_detection.py index 252a856..f99e543 100644 --- a/whisperlivekit/simul_whisper/eow_detection.py +++ b/whisperlivekit/simul_whisper/eow_detection.py @@ -46,7 +46,7 @@ def resize(alphas, target_lengths, threshold=0.999): _alphas[x] = _alphas[x] * 0.5 + mean * mask return _alphas, _num - + def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear): content_mel_len = chunked_encoder_feature.shape[1] # B, T, D alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T @@ -62,4 +62,4 @@ def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear): if important_positions.numel() == 0: return False else: - return important_positions[0] >= content_mel_len-2 \ No newline at end of file + return important_positions[0] >= content_mel_len-2 diff --git a/whisperlivekit/simul_whisper/mlx/decoder_state.py b/whisperlivekit/simul_whisper/mlx/decoder_state.py index 5065b4c..79fa714 100644 --- a/whisperlivekit/simul_whisper/mlx/decoder_state.py +++ b/whisperlivekit/simul_whisper/mlx/decoder_state.py @@ -13,21 +13,21 @@ class MLXDecoderState: """ kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None - + tokenizer: Any = None detected_language: Optional[str] = None reset_tokenizer_to_auto_next_call: bool = False - + tokens: List[mx.array] = field(default_factory=list) initial_tokens: Optional[mx.array] = None initial_token_length: int = 0 - sot_index: int = 0 + sot_index: int = 0 align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict) - num_align_heads: int = 0 + num_align_heads: int = 0 segments: List[np.ndarray] = field(default_factory=list) - + context: Any = None - + pending_incomplete_tokens: List[int] = field(default_factory=list) pending_retries: int = 0 @@ -35,27 +35,27 @@ class MLXDecoderState: cumulative_time_offset: float = 0.0 first_timestamp: Optional[float] = None last_attend_frame: int = 0 - + speaker: int = -1 - log_segments: int = 0 + log_segments: int = 0 cif_weights: Optional[mx.array] = None always_fire: bool = False never_fire: bool = False - + suppress_tokens: Optional[Tuple[int, ...]] = None - + token_decoder: Any = None decoder_type: str = "greedy" - + inference: Any = None - + def clean_cache(self): self.kv_cache = None if self.decoder_type == "beam" and self.inference is not None: self.inference.kv_cache = None if self.token_decoder is not None: self.token_decoder.reset() - + def reset(self, rewind_threshold: int = 200): self.last_attend_frame = -rewind_threshold self.cumulative_time_offset = 0.0 diff --git a/whisperlivekit/simul_whisper/mlx/decoders.py b/whisperlivekit/simul_whisper/mlx/decoders.py index 58c3b8c..d0d245f 100644 --- a/whisperlivekit/simul_whisper/mlx/decoders.py +++ b/whisperlivekit/simul_whisper/mlx/decoders.py @@ -9,7 +9,7 @@ import numpy as np class MLXGreedyDecoder: """Greedy decoder using MLX operations.""" - + def __init__(self, temperature: float, eot: int): self.temperature = temperature self.eot = eot @@ -33,18 +33,18 @@ class MLXGreedyDecoder: else: probs = mx.softmax(logits / self.temperature, axis=-1) next_tokens = mx.random.categorical(mx.log(probs + 1e-10)) - + logprobs = mx.softmax(logits, axis=-1) - logprobs = mx.log(logprobs + 1e-10) + logprobs = mx.log(logprobs + 1e-10) batch_size = logprobs.shape[0] - current_logprobs = logprobs[mx.arange(batch_size), next_tokens] + current_logprobs = logprobs[mx.arange(batch_size), next_tokens] mask = (tokens[:, -1] != self.eot).astype(mx.float32) - sum_logprobs = sum_logprobs + current_logprobs * mask + sum_logprobs = sum_logprobs + current_logprobs * mask eot_mask = (tokens[:, -1] == self.eot) - next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens) - tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1) + next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens) + tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1) completed = bool(mx.all(tokens[:, -1] == self.eot)) - + return tokens, completed def finalize(self, tokens: mx.array, sum_logprobs: mx.array): @@ -56,7 +56,7 @@ class MLXGreedyDecoder: class MLXBeamSearchDecoder: """Beam search decoder using MLX operations.""" - + def __init__( self, beam_size: int, @@ -100,21 +100,21 @@ class MLXBeamSearchDecoder: if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)] logprobs = mx.softmax(logits, axis=-1) - logprobs = mx.log(logprobs + 1e-10) + logprobs = mx.log(logprobs + 1e-10) logprobs_np = np.array(logprobs) tokens_np = np.array(tokens) sum_logprobs_np = np.array(sum_logprobs) - + next_tokens, source_indices, finished_sequences = [], [], [] new_sum_logprobs = [] - + for i in range(n_audio): scores, sources, finished = {}, {}, {} for j in range(self.beam_size): idx = i * self.beam_size + j - prefix = tokens_np[idx].tolist() + prefix = tokens_np[idx].tolist() top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1] - + for token_idx in top_k_indices: logprob = logprobs_np[idx, token_idx] new_logprob = sum_logprobs_np[idx] + logprob @@ -136,7 +136,7 @@ class MLXBeamSearchDecoder: finished_sequences.append(finished) tokens = mx.array(np.array(next_tokens, dtype=np.int32)) - sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32)) + sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32)) self.inference.rearrange_kv_cache(source_indices) assert len(self.finished_sequences) == len(finished_sequences) for previously_finished, newly_finished in zip( @@ -150,14 +150,14 @@ class MLXBeamSearchDecoder: len(sequences) >= self.max_candidates for sequences in self.finished_sequences ) - + return tokens, completed def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array): """Finalize beam search by selecting best sequences.""" preceding_tokens_np = np.array(preceding_tokens) sum_logprobs_np = np.array(sum_logprobs) - + n_audio = preceding_tokens_np.shape[0] // self.beam_size tokens_list: List[List[int]] = [[] for _ in range(n_audio)] sum_logprobs_list: List[float] = [0.0] * n_audio @@ -181,34 +181,34 @@ class MLXBeamSearchDecoder: class MLXInference: """MLX inference wrapper for beam search KV cache management.""" - + def __init__(self, model, initial_token_length: int): self.model = model self.initial_token_length = initial_token_length self.kv_cache = None - + def rearrange_kv_cache(self, source_indices: List[int]): """Rearrange KV cache based on beam search source indices.""" if self.kv_cache is None: return - + if source_indices == list(range(len(source_indices))): return - + source_indices_mx = mx.array(source_indices, dtype=mx.int32) - + new_cache = [] for layer_cache in self.kv_cache: - (k, v), (cross_k, cross_v) = layer_cache + (k, v), (cross_k, cross_v) = layer_cache new_k = k[source_indices_mx] new_v = v[source_indices_mx] new_cache.append(((new_k, new_v), (cross_k, cross_v))) - + self.kv_cache = new_cache - + def logits( - self, - tokens: mx.array, + self, + tokens: mx.array, audio_features: mx.array, ) -> Tuple[mx.array, List]: """Get logits from decoder with KV cache.""" diff --git a/whisperlivekit/simul_whisper/mlx/simul_whisper.py b/whisperlivekit/simul_whisper/mlx/simul_whisper.py index 3211320..dbe4ed3 100644 --- a/whisperlivekit/simul_whisper/mlx/simul_whisper.py +++ b/whisperlivekit/simul_whisper/mlx/simul_whisper.py @@ -4,7 +4,6 @@ from typing import Any, List, Tuple import mlx.core as mx import numpy as np - from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim @@ -15,7 +14,6 @@ from ..config import AlignAttConfig from .decoder_state import MLXDecoderState from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference - logger = logging.getLogger(__name__) diff --git a/whisperlivekit/simul_whisper/mlx_encoder.py b/whisperlivekit/simul_whisper/mlx_encoder.py index 642ed59..9190ed7 100644 --- a/whisperlivekit/simul_whisper/mlx_encoder.py +++ b/whisperlivekit/simul_whisper/mlx_encoder.py @@ -41,17 +41,17 @@ def load_mlx_encoder( nn.quantize(model, **quantization, class_predicate=class_predicate) weights = tree_unflatten(list(weights.items())) - + # we only want to load the encoder weights here. - # Size examples: for tiny.en, + # Size examples: for tiny.en, # Decoder weights: 59110771 bytes # Encoder weights: 15268874 bytes - + encoder_weights = {} encoder_weights['encoder'] = weights['encoder'] del(weights) - + model.update(encoder_weights) @@ -89,7 +89,7 @@ def load_mlx_model( nn.quantize(model, **quantization, class_predicate=class_predicate) weights = tree_unflatten(list(weights.items())) - + model.update(weights) mx.eval(model.parameters()) - return model \ No newline at end of file + return model diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index c5d7923..f6be082 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -6,13 +6,9 @@ import numpy as np import torch import torch.nn.functional as F -from whisperlivekit.backend_support import (faster_backend_available, - mlx_backend_available) -from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES, - TOKENS_PER_SECOND, - log_mel_spectrogram, pad_or_trim) -from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder, - SuppressTokens) +from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available +from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim +from whisperlivekit.whisper.decoding import BeamSearchDecoder, GreedyDecoder, SuppressTokens from whisperlivekit.whisper.timing import median_filter from .align_att_base import DEC_PAD, AlignAttBase @@ -25,8 +21,7 @@ from .token_buffer import TokenBuffer logger = logging.getLogger(__name__) if mlx_backend_available(): - from mlx_whisper.audio import \ - log_mel_spectrogram as mlx_log_mel_spectrogram + from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim if faster_backend_available(): diff --git a/whisperlivekit/simul_whisper/token_buffer.py b/whisperlivekit/simul_whisper/token_buffer.py index c7ace39..5938b27 100644 --- a/whisperlivekit/simul_whisper/token_buffer.py +++ b/whisperlivekit/simul_whisper/token_buffer.py @@ -1,4 +1,3 @@ -import sys import torch @@ -17,7 +16,7 @@ class TokenBuffer: if tokenizer is None: tokenizer = self.tokenizer if tokenizer is None: - raise ValueError("Tokenizer is not set.") + raise ValueError("Tokenizer is not set.") return self.prefix_token_ids + tokenizer.encode(self.text) def as_tensor(self, device=None): @@ -26,7 +25,7 @@ class TokenBuffer: if device is None: raise ValueError("Device is not set.") tok_ids = self.as_token_ids() - return torch.tensor(tok_ids, + return torch.tensor(tok_ids, dtype=torch.long, device=device).unsqueeze(0) def as_tensor_beam(self, beam, device=None): @@ -44,7 +43,7 @@ class TokenBuffer: @staticmethod def from_text(text, *a, **kw): return TokenBuffer(*a, text=text, **kw) - + def is_empty(self): return self.text is None or self.text == "" diff --git a/whisperlivekit/whisper/__init__.py b/whisperlivekit/whisper/__init__.py index 00cf761..c56556d 100644 --- a/whisperlivekit/whisper/__init__.py +++ b/whisperlivekit/whisper/__init__.py @@ -11,10 +11,8 @@ import torch from torch import Tensor from tqdm import tqdm -from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram, - pad_or_trim) -from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult, - decode, detect_language) +from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim +from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language from whisperlivekit.whisper.model import ModelDimensions, Whisper from whisperlivekit.whisper.transcribe import transcribe from whisperlivekit.whisper.version import __version__ @@ -266,7 +264,7 @@ def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, to for key, value in state_dict.items(): if key == "alignment_heads": continue - + new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.") converted[new_key] = value @@ -310,13 +308,13 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]: """ if not lora_path: return None - + # Check if it's already a valid local path if os.path.isdir(lora_path): config_path = os.path.join(lora_path, "adapter_config.json") if os.path.isfile(config_path): return lora_path - + # Try to download from HuggingFace Hub if "/" in lora_path: try: @@ -330,7 +328,7 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]: raise FileNotFoundError( f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}" ) - + raise FileNotFoundError( f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID." ) @@ -339,7 +337,7 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]: def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]): if not lora_path: return - + # Resolve path (handles HuggingFace Hub download) lora_path = _resolve_lora_path(lora_path) if not lora_path: @@ -410,10 +408,10 @@ def _load_checkpoint( if checkpoint_bytes is not None: with io.BytesIO(checkpoint_bytes) as fp: return torch.load(fp, map_location=device) - + file_path = Path(file_path) suffix = file_path.suffix.lower() - + if suffix == '.safetensors': try: from safetensors.torch import load_file @@ -444,7 +442,7 @@ def _load_sharded_checkpoint( """ merged_state_dict = {} first_suffix = shard_files[0].suffix.lower() - + if first_suffix == '.safetensors': try: from safetensors.torch import load_file @@ -461,7 +459,7 @@ def _load_sharded_checkpoint( shard_dict = torch.load(fp, map_location=device) if isinstance(shard_dict, dict): merged_state_dict.update(shard_dict) - + return merged_state_dict @@ -505,10 +503,10 @@ def load_model( if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") - + checkpoint = None model_path_for_config = name # Used to find config.json for dims inference - + if name in _MODELS: checkpoint_file = _download(_MODELS[name], download_root, in_memory) if in_memory: @@ -525,13 +523,13 @@ def load_model( model_path_for_config = name elif os.path.isdir(name): model_info = detect_model_format(name) - + if not model_info.has_pytorch: raise RuntimeError( f"No PyTorch checkpoint found in directory {name}. " f"Expected .pt, .bin, or .safetensors file(s)." ) - + if model_info.is_sharded: checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device) else: @@ -547,7 +545,7 @@ def load_model( raise RuntimeError( f"Model {name} not found; available models = {available_models()}" ) - + alignment_heads = _ALIGNMENT_HEADS.get(name, None) if custom_alignment_heads: alignment_heads = custom_alignment_heads.encode() @@ -557,10 +555,10 @@ def load_model( state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint - + if alignment_heads is None and "alignment_heads" in state_dict: alignment_heads = state_dict["alignment_heads"] - + state_dict = _convert_hf_state_dict(state_dict) state_dict = _convert_mlx_state_dict(state_dict) _apply_lora_adapter(state_dict, lora_path) @@ -578,10 +576,10 @@ def load_model( state_dict = checkpoint model = Whisper(dims, decoder_only=decoder_only) - + if decoder_only: state_dict = { - k: v for k, v in state_dict.items() + k: v for k, v in state_dict.items() if 'encoder' not in k } @@ -604,7 +602,7 @@ def convert_encoder_to_coreml( dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing precision = "float16", ): - + import coremltools as ct model = load_model(model_name, device="cpu", decoder_only=False) encoder = model.encoder.eval().cpu() @@ -639,4 +637,4 @@ def convert_encoder_to_coreml( return output_path # if __name__ == "__main__": -# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram") \ No newline at end of file +# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram") diff --git a/whisperlivekit/whisper/decoding.py b/whisperlivekit/whisper/decoding.py index 1ef7bf7..83f967e 100644 --- a/whisperlivekit/whisper/decoding.py +++ b/whisperlivekit/whisper/decoding.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field, replace -from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, - Tuple, Union) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch diff --git a/whisperlivekit/whisper/model.py b/whisperlivekit/whisper/model.py index 2d0a298..7d60234 100644 --- a/whisperlivekit/whisper/model.py +++ b/whisperlivekit/whisper/model.py @@ -175,7 +175,7 @@ class MultiHeadAttention(nn.Module): class ResidualAttentionBlock(nn.Module): def __init__( - self, n_state: int, n_head: int, cross_attention: bool = False, + self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = "", n_text_ctx: int = 448 ): super().__init__() @@ -267,7 +267,7 @@ class TextDecoder(nn.Module): self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ ResidualAttentionBlock( - n_state, n_head, cross_attention=True, + n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}", n_text_ctx=n_ctx ) for i in range(n_layer) @@ -279,9 +279,9 @@ class TextDecoder(nn.Module): self.register_buffer("mask", mask, persistent=False) def forward( - self, - x: Tensor, - xa: Tensor, + self, + x: Tensor, + xa: Tensor, kv_cache: Optional[dict] = None, return_cross_attn: bool = False, ): @@ -309,7 +309,7 @@ class TextDecoder(nn.Module): first_self_attn_key = self.blocks[0].attn.key_cache_id if first_self_attn_key in kv_cache: offset = kv_cache[first_self_attn_key].shape[1] - + x = ( self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] @@ -336,7 +336,7 @@ class Whisper(nn.Module): def __init__(self, dims: ModelDimensions, decoder_only: bool = False): super().__init__() self.dims = dims - + if not decoder_only: self.encoder = AudioEncoder( self.dims.n_mels, @@ -373,15 +373,15 @@ class Whisper(nn.Module): return self.encoder(mel) def logits( - self, - tokens: torch.Tensor, + self, + tokens: torch.Tensor, audio_features: torch.Tensor, kv_cache: Optional[dict] = None, return_cross_attn: bool = False, ): return self.decoder( - tokens, audio_features, - kv_cache=kv_cache, + tokens, audio_features, + kv_cache=kv_cache, return_cross_attn=return_cross_attn ) diff --git a/whisperlivekit/whisper/transcribe.py b/whisperlivekit/whisper/transcribe.py index 7c192b6..96f69a5 100644 --- a/whisperlivekit/whisper/transcribe.py +++ b/whisperlivekit/whisper/transcribe.py @@ -8,13 +8,11 @@ import numpy as np import torch import tqdm -from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, - SAMPLE_RATE, log_mel_spectrogram, pad_or_trim) +from .audio import FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import (exact_div, format_timestamp, get_end, get_writer, - make_safe, optional_float, optional_int, str2bool) +from .utils import exact_div, format_timestamp, get_end, get_writer, make_safe, optional_float, optional_int, str2bool if TYPE_CHECKING: from .model import Whisper diff --git a/whisperlivekit/whisper/val.py b/whisperlivekit/whisper/val.py index 4223fc8..aedaa5e 100644 --- a/whisperlivekit/whisper/val.py +++ b/whisperlivekit/whisper/val.py @@ -6,9 +6,10 @@ Everything else is just efficiency. @karpathy """ -import os # os.path.exists -import math # math.log, math.exp -import random # random.seed, random.choices, random.gauss, random.shuffle +import math # math.log, math.exp +import os # os.path.exists +import random # random.seed, random.choices, random.gauss, random.shuffle + random.seed(42) # Let there be order among chaos # Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names) @@ -197,4 +198,4 @@ for sample_idx in range(20): if token_id == BOS: break sample.append(uchars[token_id]) - print(f"sample {sample_idx+1:2d}: {''.join(sample)}") \ No newline at end of file + print(f"sample {sample_idx+1:2d}: {''.join(sample)}")