Ruff lint cleanup

This commit is contained in:
Quentin Fuxa 2026-01-03 10:23:00 +01:00
parent 451535d48f
commit cf6c49f502
18 changed files with 150 additions and 167 deletions

View file

@ -130,18 +130,18 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat
available_ops = [15, 16] available_ops = [15, 16]
if opset_version not in available_ops: if opset_version not in available_ops:
raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}') raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}')
if model_path is None: if model_path is None:
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models' data_dir = current_dir / 'silero_vad_models'
if opset_version == 16: if opset_version == 16:
model_name = 'silero_vad.onnx' model_name = 'silero_vad.onnx'
else: else:
model_name = f'silero_vad_16k_op{opset_version}.onnx' model_name = f'silero_vad_16k_op{opset_version}.onnx'
model_path = data_dir / model_name model_path = data_dir / model_name
if not model_path.exists(): if not model_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
f"Model file not found: {model_path}\n" f"Model file not found: {model_path}\n"
@ -149,7 +149,7 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat
) )
else: else:
model_path = Path(model_path) model_path = Path(model_path)
return model_path return model_path
@ -169,9 +169,9 @@ def load_jit_vad(model_path: str = None):
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models' data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit' model_name = 'silero_vad.jit'
model_path = data_dir / model_name model_path = data_dir / model_name
if not model_path.exists(): if not model_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
f"Model file not found: {model_path}\n" f"Model file not found: {model_path}\n"
@ -181,17 +181,17 @@ def load_jit_vad(model_path: str = None):
model_path = Path(model_path) model_path = Path(model_path)
model = init_jit_model(str(model_path)) model = init_jit_model(str(model_path))
return model return model
class VADIterator: class VADIterator:
""" """
Voice Activity Detection iterator for streaming audio. Voice Activity Detection iterator for streaming audio.
This is the Silero VAD v6 implementation. This is the Silero VAD v6 implementation.
""" """
def __init__(self, def __init__(self,
model, model,
threshold: float = 0.5, threshold: float = 0.5,
@ -319,8 +319,8 @@ if __name__ == "__main__":
audio_buffer = np.array([0] * 512, dtype=np.float32) audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 512 samples: {result}") print(f" 512 samples: {result}")
# test with 511 samples # test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32) audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 511 samples: {result}") print(f" 511 samples: {result}")

View file

@ -1,7 +1,6 @@
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX).""" """Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer from whisperlivekit.whisper import DecodingOptions, tokenizer
@ -151,7 +150,7 @@ class AlignAttBase(ABC):
if seconds_since_start >= 2.0: if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature) language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}") logger.info(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan) self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0 self.state.cumulative_time_offset = 0.0

View file

@ -1,31 +1,27 @@
import gc import gc
import logging import logging
import os
import platform import platform
import sys import sys
from pathlib import Path from typing import List, Tuple
from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from whisperlivekit.backend_support import (faster_backend_available, from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
from whisperlivekit.warmup import load_file from whisperlivekit.warmup import load_file
from whisperlivekit.whisper import load_model, tokenizer from whisperlivekit.whisper import load_model, tokenizer
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt from .mlx import MLXAlignAtt
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
else: else:
mlx_model_mapping = {} mlx_model_mapping = {}
MLXAlignAtt = None MLXAlignAtt = None
@ -47,7 +43,7 @@ class SimulStreamingOnlineProcessor:
self.end = 0.0 self.end = 0.0
self.buffer = [] self.buffer = []
self.model = self._create_alignatt() self.model = self._create_alignatt()
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer self.model.state.tokenizer = asr.tokenizer
@ -99,7 +95,7 @@ class SimulStreamingOnlineProcessor:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker self.model.speaker = change_speaker.speaker
self.model.global_time_offset = change_speaker.start self.model.global_time_offset = change_speaker.start
def get_buffer(self): def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
return concat_buffer return concat_buffer
@ -107,19 +103,19 @@ class SimulStreamingOnlineProcessor:
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
""" """
Process accumulated audio chunks using SimulStreaming. Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
""" """
try: try:
timestamped_words = self.model.infer(is_last=is_last) timestamped_words = self.model.infer(is_last=is_last)
if not timestamped_words: if not timestamped_words:
return [], self.end return [], self.end
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None: if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
self.buffer.extend(timestamped_words) self.buffer.extend(timestamped_words)
return [], self.end return [], self.end
self.buffer = [] self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
except Exception as e: except Exception as e:
@ -156,7 +152,7 @@ class SimulStreamingASR:
def __init__(self, logfile=sys.stderr, **kwargs): def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile self.logfile = logfile
self.transcribe_kargs = {} self.transcribe_kargs = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self, key, value) setattr(self, key, value)
@ -169,20 +165,20 @@ class SimulStreamingASR:
self.use_full_mlx = getattr(self, "use_full_mlx", False) self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto") preferred_backend = getattr(self, "backend", "auto")
compatible_whisper_mlx, compatible_faster_whisper = True, True compatible_whisper_mlx, compatible_faster_whisper = True, True
if self.model_path: if self.model_path:
resolved_model_path = resolve_model_path(self.model_path) resolved_model_path = resolve_model_path(self.model_path)
self._resolved_model_path = resolved_model_path self._resolved_model_path = resolved_model_path
self.model_path = str(resolved_model_path) self.model_path = str(resolved_model_path)
model_info = detect_model_format(resolved_model_path) model_info = detect_model_format(resolved_model_path)
compatible_whisper_mlx = model_info.compatible_whisper_mlx compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper compatible_faster_whisper = model_info.compatible_faster_whisper
if not self.use_full_mlx and not model_info.has_pytorch: if not self.use_full_mlx and not model_info.has_pytorch:
raise FileNotFoundError( raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
) )
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
elif self.model_size is not None: elif self.model_size is not None:
self.model_name = self.model_size self.model_name = self.model_size
@ -199,14 +195,14 @@ class SimulStreamingASR:
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper") self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
if self.encoder_backend == "whisper": if self.encoder_backend == "whisper":
self.disable_fast_encoder = True self.disable_fast_encoder = True
# MLX full decoder disabled by default — MLXAlignAtt has known issues # MLX full decoder disabled by default — MLXAlignAtt has known issues
# with token generation after punctuation. Users can opt-in with # with token generation after punctuation. Users can opt-in with
# --use-full-mlx if they want to test it. # --use-full-mlx if they want to test it.
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin": # if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
# if not hasattr(self, '_full_mlx_disabled'): # if not hasattr(self, '_full_mlx_disabled'):
# self.use_full_mlx = True # self.use_full_mlx = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual, tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size, segment_length=self.min_chunk_size,
@ -222,8 +218,8 @@ class SimulStreamingASR:
init_prompt=self.init_prompt, init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens, max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt, static_init_prompt=self.static_init_prompt,
) )
# Set up tokenizer for translation if needed # Set up tokenizer for translation if needed
if self.direct_english_translation: if self.direct_english_translation:
self.tokenizer = self.set_translate_task() self.tokenizer = self.set_translate_task()
@ -232,7 +228,7 @@ class SimulStreamingASR:
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
self.shared_model = None self.shared_model = None
if self.use_full_mlx and HAS_MLX_WHISPER: if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.') logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
@ -259,7 +255,7 @@ class SimulStreamingASR:
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path) self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model() self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper": elif self.encoder_backend == "faster-whisper":
print('SimulStreaming will use Faster Whisper for the encoder.') logger.info('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path) fw_model = str(self._resolved_model_path)
else: else:
@ -272,7 +268,7 @@ class SimulStreamingASR:
self.shared_model = self.load_model() self.shared_model = self.load_model()
else: else:
self.shared_model = self.load_model() self.shared_model = self.load_model()
def _warmup_mlx_model(self): def _warmup_mlx_model(self):
"""Warmup the full MLX model.""" """Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)

View file

@ -19,14 +19,14 @@ class BeamPyTorchInference(PyTorchInference):
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach() self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
def logits( def logits(
self, self,
tokens: Tensor, tokens: Tensor,
audio_features: Tensor, audio_features: Tensor,
return_cross_attn: bool = False, return_cross_attn: bool = False,
): ):
"""Get logits, optionally returning cross-attention weights.""" """Get logits, optionally returning cross-attention weights."""
return self.model.decoder( return self.model.decoder(
tokens, audio_features, tokens, audio_features,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
return_cross_attn=return_cross_attn, return_cross_attn=return_cross_attn,
) )

View file

@ -21,4 +21,3 @@ class AlignAttConfig():
init_prompt: str = field(default=None) init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None) static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None) max_context_tokens: int = field(default=None)

View file

@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
@ -7,23 +8,23 @@ import torch
class DecoderState: class DecoderState:
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict) kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
tokenizer: Any = None tokenizer: Any = None
detected_language: Optional[str] = None detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False reset_tokenizer_to_auto_next_call: bool = False
tokens: List[torch.Tensor] = field(default_factory=list) tokens: List[torch.Tensor] = field(default_factory=list)
initial_tokens: Optional[torch.Tensor] = None initial_tokens: Optional[torch.Tensor] = None
initial_token_length: int = 0 initial_token_length: int = 0
sot_index: int = 0 sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict) align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0 num_align_heads: int = 0
segments: List[torch.Tensor] = field(default_factory=list) segments: List[torch.Tensor] = field(default_factory=list)
context: Any = None context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list) pending_incomplete_tokens: List[int] = field(default_factory=list)
pending_retries: int = 0 pending_retries: int = 0
@ -31,21 +32,21 @@ class DecoderState:
cumulative_time_offset: float = 0.0 cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None first_timestamp: Optional[float] = None
last_attend_frame: int = 0 last_attend_frame: int = 0
speaker: int = -1 speaker: int = -1
log_segments: int = 0 log_segments: int = 0
CIFLinear: Optional[torch.nn.Module] = None CIFLinear: Optional[torch.nn.Module] = None
always_fire: bool = False always_fire: bool = False
never_fire: bool = False never_fire: bool = False
suppress_tokens_fn: Any = None suppress_tokens_fn: Any = None
token_decoder: Any = None token_decoder: Any = None
decoder_type: str = "greedy" decoder_type: str = "greedy"
inference: Any = None inference: Any = None
def clean_cache(self): def clean_cache(self):
"""Clean the kv_cache after each inference step.""" """Clean the kv_cache after each inference step."""
# Explicitly delete tensor references to free GPU memory # Explicitly delete tensor references to free GPU memory
@ -68,11 +69,11 @@ class DecoderState:
self.inference.kv_cache = {} self.inference.kv_cache = {}
if self.token_decoder is not None: if self.token_decoder is not None:
self.token_decoder.reset() self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200): def reset(self, rewind_threshold: int = 200):
""" """
Reset transient state for a new segment. Reset transient state for a new segment.
Args: Args:
rewind_threshold: Value for resetting last_attend_frame rewind_threshold: Value for resetting last_attend_frame
""" """
@ -85,7 +86,7 @@ class DecoderState:
def full_reset(self, rewind_threshold: int = 200): def full_reset(self, rewind_threshold: int = 200):
""" """
Full reset including audio segments and tokens. Full reset including audio segments and tokens.
Args: Args:
rewind_threshold: Value for resetting last_attend_frame rewind_threshold: Value for resetting last_attend_frame
""" """

View file

@ -46,7 +46,7 @@ def resize(alphas, target_lengths, threshold=0.999):
_alphas[x] = _alphas[x] * 0.5 + mean * mask _alphas[x] = _alphas[x] * 0.5 + mean * mask
return _alphas, _num return _alphas, _num
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear): def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
@ -62,4 +62,4 @@ def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
if important_positions.numel() == 0: if important_positions.numel() == 0:
return False return False
else: else:
return important_positions[0] >= content_mel_len-2 return important_positions[0] >= content_mel_len-2

View file

@ -13,21 +13,21 @@ class MLXDecoderState:
""" """
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None tokenizer: Any = None
detected_language: Optional[str] = None detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list) tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0 initial_token_length: int = 0
sot_index: int = 0 sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict) align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0 num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list) segments: List[np.ndarray] = field(default_factory=list)
context: Any = None context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list) pending_incomplete_tokens: List[int] = field(default_factory=list)
pending_retries: int = 0 pending_retries: int = 0
@ -35,27 +35,27 @@ class MLXDecoderState:
cumulative_time_offset: float = 0.0 cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None first_timestamp: Optional[float] = None
last_attend_frame: int = 0 last_attend_frame: int = 0
speaker: int = -1 speaker: int = -1
log_segments: int = 0 log_segments: int = 0
cif_weights: Optional[mx.array] = None cif_weights: Optional[mx.array] = None
always_fire: bool = False always_fire: bool = False
never_fire: bool = False never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None token_decoder: Any = None
decoder_type: str = "greedy" decoder_type: str = "greedy"
inference: Any = None inference: Any = None
def clean_cache(self): def clean_cache(self):
self.kv_cache = None self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None: if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None self.inference.kv_cache = None
if self.token_decoder is not None: if self.token_decoder is not None:
self.token_decoder.reset() self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200): def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0

View file

@ -9,7 +9,7 @@ import numpy as np
class MLXGreedyDecoder: class MLXGreedyDecoder:
"""Greedy decoder using MLX operations.""" """Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int): def __init__(self, temperature: float, eot: int):
self.temperature = temperature self.temperature = temperature
self.eot = eot self.eot = eot
@ -33,18 +33,18 @@ class MLXGreedyDecoder:
else: else:
probs = mx.softmax(logits / self.temperature, axis=-1) probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10)) next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1) logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10) logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0] batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens] current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32) mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot) eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens) next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1) tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot)) completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array): def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
@ -56,7 +56,7 @@ class MLXGreedyDecoder:
class MLXBeamSearchDecoder: class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations.""" """Beam search decoder using MLX operations."""
def __init__( def __init__(
self, self,
beam_size: int, beam_size: int,
@ -100,21 +100,21 @@ class MLXBeamSearchDecoder:
if self.finished_sequences is None: if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)] self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1) logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10) logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs) logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens) tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs) sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], [] next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = [] new_sum_logprobs = []
for i in range(n_audio): for i in range(n_audio):
scores, sources, finished = {}, {}, {} scores, sources, finished = {}, {}, {}
for j in range(self.beam_size): for j in range(self.beam_size):
idx = i * self.beam_size + j idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist() prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1] top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices: for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx] logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob new_logprob = sum_logprobs_np[idx] + logprob
@ -136,7 +136,7 @@ class MLXBeamSearchDecoder:
finished_sequences.append(finished) finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32)) tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32)) sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices) self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences) assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip( for previously_finished, newly_finished in zip(
@ -150,14 +150,14 @@ class MLXBeamSearchDecoder:
len(sequences) >= self.max_candidates len(sequences) >= self.max_candidates
for sequences in self.finished_sequences for sequences in self.finished_sequences
) )
return tokens, completed return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array): def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences.""" """Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens) preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs) sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)] tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio sum_logprobs_list: List[float] = [0.0] * n_audio
@ -181,34 +181,34 @@ class MLXBeamSearchDecoder:
class MLXInference: class MLXInference:
"""MLX inference wrapper for beam search KV cache management.""" """MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int): def __init__(self, model, initial_token_length: int):
self.model = model self.model = model
self.initial_token_length = initial_token_length self.initial_token_length = initial_token_length
self.kv_cache = None self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]): def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices.""" """Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None: if self.kv_cache is None:
return return
if source_indices == list(range(len(source_indices))): if source_indices == list(range(len(source_indices))):
return return
source_indices_mx = mx.array(source_indices, dtype=mx.int32) source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = [] new_cache = []
for layer_cache in self.kv_cache: for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache (k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx] new_k = k[source_indices_mx]
new_v = v[source_indices_mx] new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v))) new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache self.kv_cache = new_cache
def logits( def logits(
self, self,
tokens: mx.array, tokens: mx.array,
audio_features: mx.array, audio_features: mx.array,
) -> Tuple[mx.array, List]: ) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache.""" """Get logits from decoder with KV cache."""

View file

@ -4,7 +4,6 @@ from typing import Any, List, Tuple
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
@ -15,7 +14,6 @@ from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -41,17 +41,17 @@ def load_mlx_encoder(
nn.quantize(model, **quantization, class_predicate=class_predicate) nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here. # we only want to load the encoder weights here.
# Size examples: for tiny.en, # Size examples: for tiny.en,
# Decoder weights: 59110771 bytes # Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes # Encoder weights: 15268874 bytes
encoder_weights = {} encoder_weights = {}
encoder_weights['encoder'] = weights['encoder'] encoder_weights['encoder'] = weights['encoder']
del(weights) del(weights)
model.update(encoder_weights) model.update(encoder_weights)
@ -89,7 +89,7 @@ def load_mlx_model(
nn.quantize(model, **quantization, class_predicate=class_predicate) nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
model.update(weights) model.update(weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model return model

View file

@ -6,13 +6,9 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from whisperlivekit.backend_support import (faster_backend_available, from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
mlx_backend_available) from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES, from whisperlivekit.whisper.decoding import BeamSearchDecoder, GreedyDecoder, SuppressTokens
TOKENS_PER_SECOND,
log_mel_spectrogram, pad_or_trim)
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
SuppressTokens)
from whisperlivekit.whisper.timing import median_filter from whisperlivekit.whisper.timing import median_filter
from .align_att_base import DEC_PAD, AlignAttBase from .align_att_base import DEC_PAD, AlignAttBase
@ -25,8 +21,7 @@ from .token_buffer import TokenBuffer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if mlx_backend_available(): if mlx_backend_available():
from mlx_whisper.audio import \ from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
if faster_backend_available(): if faster_backend_available():

View file

@ -1,4 +1,3 @@
import sys
import torch import torch
@ -17,7 +16,7 @@ class TokenBuffer:
if tokenizer is None: if tokenizer is None:
tokenizer = self.tokenizer tokenizer = self.tokenizer
if tokenizer is None: if tokenizer is None:
raise ValueError("Tokenizer is not set.") raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text) return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None): def as_tensor(self, device=None):
@ -26,7 +25,7 @@ class TokenBuffer:
if device is None: if device is None:
raise ValueError("Device is not set.") raise ValueError("Device is not set.")
tok_ids = self.as_token_ids() tok_ids = self.as_token_ids()
return torch.tensor(tok_ids, return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0) dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None): def as_tensor_beam(self, beam, device=None):
@ -44,7 +43,7 @@ class TokenBuffer:
@staticmethod @staticmethod
def from_text(text, *a, **kw): def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw) return TokenBuffer(*a, text=text, **kw)
def is_empty(self): def is_empty(self):
return self.text is None or self.text == "" return self.text is None or self.text == ""

View file

@ -11,10 +11,8 @@ import torch
from torch import Tensor from torch import Tensor
from tqdm import tqdm from tqdm import tqdm
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram, from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
pad_or_trim) from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
decode, detect_language)
from whisperlivekit.whisper.model import ModelDimensions, Whisper from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__ from whisperlivekit.whisper.version import __version__
@ -266,7 +264,7 @@ def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, to
for key, value in state_dict.items(): for key, value in state_dict.items():
if key == "alignment_heads": if key == "alignment_heads":
continue continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.") new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value converted[new_key] = value
@ -310,13 +308,13 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
""" """
if not lora_path: if not lora_path:
return None return None
# Check if it's already a valid local path # Check if it's already a valid local path
if os.path.isdir(lora_path): if os.path.isdir(lora_path):
config_path = os.path.join(lora_path, "adapter_config.json") config_path = os.path.join(lora_path, "adapter_config.json")
if os.path.isfile(config_path): if os.path.isfile(config_path):
return lora_path return lora_path
# Try to download from HuggingFace Hub # Try to download from HuggingFace Hub
if "/" in lora_path: if "/" in lora_path:
try: try:
@ -330,7 +328,7 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
raise FileNotFoundError( raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}" f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
) )
raise FileNotFoundError( raise FileNotFoundError(
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID." f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
) )
@ -339,7 +337,7 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]): def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
if not lora_path: if not lora_path:
return return
# Resolve path (handles HuggingFace Hub download) # Resolve path (handles HuggingFace Hub download)
lora_path = _resolve_lora_path(lora_path) lora_path = _resolve_lora_path(lora_path)
if not lora_path: if not lora_path:
@ -410,10 +408,10 @@ def _load_checkpoint(
if checkpoint_bytes is not None: if checkpoint_bytes is not None:
with io.BytesIO(checkpoint_bytes) as fp: with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device) return torch.load(fp, map_location=device)
file_path = Path(file_path) file_path = Path(file_path)
suffix = file_path.suffix.lower() suffix = file_path.suffix.lower()
if suffix == '.safetensors': if suffix == '.safetensors':
try: try:
from safetensors.torch import load_file from safetensors.torch import load_file
@ -444,7 +442,7 @@ def _load_sharded_checkpoint(
""" """
merged_state_dict = {} merged_state_dict = {}
first_suffix = shard_files[0].suffix.lower() first_suffix = shard_files[0].suffix.lower()
if first_suffix == '.safetensors': if first_suffix == '.safetensors':
try: try:
from safetensors.torch import load_file from safetensors.torch import load_file
@ -461,7 +459,7 @@ def _load_sharded_checkpoint(
shard_dict = torch.load(fp, map_location=device) shard_dict = torch.load(fp, map_location=device)
if isinstance(shard_dict, dict): if isinstance(shard_dict, dict):
merged_state_dict.update(shard_dict) merged_state_dict.update(shard_dict)
return merged_state_dict return merged_state_dict
@ -505,10 +503,10 @@ def load_model(
if download_root is None: if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache") default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
checkpoint = None checkpoint = None
model_path_for_config = name # Used to find config.json for dims inference model_path_for_config = name # Used to find config.json for dims inference
if name in _MODELS: if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory) checkpoint_file = _download(_MODELS[name], download_root, in_memory)
if in_memory: if in_memory:
@ -525,13 +523,13 @@ def load_model(
model_path_for_config = name model_path_for_config = name
elif os.path.isdir(name): elif os.path.isdir(name):
model_info = detect_model_format(name) model_info = detect_model_format(name)
if not model_info.has_pytorch: if not model_info.has_pytorch:
raise RuntimeError( raise RuntimeError(
f"No PyTorch checkpoint found in directory {name}. " f"No PyTorch checkpoint found in directory {name}. "
f"Expected .pt, .bin, or .safetensors file(s)." f"Expected .pt, .bin, or .safetensors file(s)."
) )
if model_info.is_sharded: if model_info.is_sharded:
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device) checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
else: else:
@ -547,7 +545,7 @@ def load_model(
raise RuntimeError( raise RuntimeError(
f"Model {name} not found; available models = {available_models()}" f"Model {name} not found; available models = {available_models()}"
) )
alignment_heads = _ALIGNMENT_HEADS.get(name, None) alignment_heads = _ALIGNMENT_HEADS.get(name, None)
if custom_alignment_heads: if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode() alignment_heads = custom_alignment_heads.encode()
@ -557,10 +555,10 @@ def load_model(
state_dict = checkpoint["model_state_dict"] state_dict = checkpoint["model_state_dict"]
else: else:
state_dict = checkpoint state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict: if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"] alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict) state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict) state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path) _apply_lora_adapter(state_dict, lora_path)
@ -578,10 +576,10 @@ def load_model(
state_dict = checkpoint state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only) model = Whisper(dims, decoder_only=decoder_only)
if decoder_only: if decoder_only:
state_dict = { state_dict = {
k: v for k, v in state_dict.items() k: v for k, v in state_dict.items()
if 'encoder' not in k if 'encoder' not in k
} }
@ -604,7 +602,7 @@ def convert_encoder_to_coreml(
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
precision = "float16", precision = "float16",
): ):
import coremltools as ct import coremltools as ct
model = load_model(model_name, device="cpu", decoder_only=False) model = load_model(model_name, device="cpu", decoder_only=False)
encoder = model.encoder.eval().cpu() encoder = model.encoder.eval().cpu()
@ -639,4 +637,4 @@ def convert_encoder_to_coreml(
return output_path return output_path
# if __name__ == "__main__": # if __name__ == "__main__":
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram") # convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")

View file

@ -1,6 +1,5 @@
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
Tuple, Union)
import numpy as np import numpy as np
import torch import torch

View file

@ -175,7 +175,7 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
def __init__( def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False, self, n_state: int, n_head: int, cross_attention: bool = False,
cache_id: str = "", n_text_ctx: int = 448 cache_id: str = "", n_text_ctx: int = 448
): ):
super().__init__() super().__init__()
@ -267,7 +267,7 @@ class TextDecoder(nn.Module):
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ [
ResidualAttentionBlock( ResidualAttentionBlock(
n_state, n_head, cross_attention=True, n_state, n_head, cross_attention=True,
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
) )
for i in range(n_layer) for i in range(n_layer)
@ -279,9 +279,9 @@ class TextDecoder(nn.Module):
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
xa: Tensor, xa: Tensor,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
return_cross_attn: bool = False, return_cross_attn: bool = False,
): ):
@ -309,7 +309,7 @@ class TextDecoder(nn.Module):
first_self_attn_key = self.blocks[0].attn.key_cache_id first_self_attn_key = self.blocks[0].attn.key_cache_id
if first_self_attn_key in kv_cache: if first_self_attn_key in kv_cache:
offset = kv_cache[first_self_attn_key].shape[1] offset = kv_cache[first_self_attn_key].shape[1]
x = ( x = (
self.token_embedding(x) self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]] + self.positional_embedding[offset : offset + x.shape[-1]]
@ -336,7 +336,7 @@ class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, decoder_only: bool = False): def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
if not decoder_only: if not decoder_only:
self.encoder = AudioEncoder( self.encoder = AudioEncoder(
self.dims.n_mels, self.dims.n_mels,
@ -373,15 +373,15 @@ class Whisper(nn.Module):
return self.encoder(mel) return self.encoder(mel)
def logits( def logits(
self, self,
tokens: torch.Tensor, tokens: torch.Tensor,
audio_features: torch.Tensor, audio_features: torch.Tensor,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
return_cross_attn: bool = False, return_cross_attn: bool = False,
): ):
return self.decoder( return self.decoder(
tokens, audio_features, tokens, audio_features,
kv_cache=kv_cache, kv_cache=kv_cache,
return_cross_attn=return_cross_attn return_cross_attn=return_cross_attn
) )

View file

@ -8,13 +8,11 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, from .audio import FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import (exact_div, format_timestamp, get_end, get_writer, from .utils import exact_div, format_timestamp, get_end, get_writer, make_safe, optional_float, optional_int, str2bool
make_safe, optional_float, optional_int, str2bool)
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper

View file

@ -6,9 +6,10 @@ Everything else is just efficiency.
@karpathy @karpathy
""" """
import os # os.path.exists import math # math.log, math.exp
import math # math.log, math.exp import os # os.path.exists
import random # random.seed, random.choices, random.gauss, random.shuffle import random # random.seed, random.choices, random.gauss, random.shuffle
random.seed(42) # Let there be order among chaos random.seed(42) # Let there be order among chaos
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names) # Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
@ -197,4 +198,4 @@ for sample_idx in range(20):
if token_id == BOS: if token_id == BOS:
break break
sample.append(uchars[token_id]) sample.append(uchars[token_id])
print(f"sample {sample_idx+1:2d}: {''.join(sample)}") print(f"sample {sample_idx+1:2d}: {''.join(sample)}")