Ruff lint cleanup

This commit is contained in:
Quentin Fuxa 2026-01-03 10:23:00 +01:00
parent 451535d48f
commit cf6c49f502
18 changed files with 150 additions and 167 deletions

View file

@ -1,7 +1,6 @@
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX).""" """Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer from whisperlivekit.whisper import DecodingOptions, tokenizer
@ -151,7 +150,7 @@ class AlignAttBase(ABC):
if seconds_since_start >= 2.0: if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature) language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}") logger.info(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan) self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0 self.state.cumulative_time_offset = 0.0

View file

@ -1,31 +1,27 @@
import gc import gc
import logging import logging
import os
import platform import platform
import sys import sys
from pathlib import Path from typing import List, Tuple
from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from whisperlivekit.backend_support import (faster_backend_available, from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
from whisperlivekit.warmup import load_file from whisperlivekit.warmup import load_file
from whisperlivekit.whisper import load_model, tokenizer from whisperlivekit.whisper import load_model, tokenizer
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt from .mlx import MLXAlignAtt
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
else: else:
mlx_model_mapping = {} mlx_model_mapping = {}
MLXAlignAtt = None MLXAlignAtt = None
@ -259,7 +255,7 @@ class SimulStreamingASR:
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path) self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model() self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper": elif self.encoder_backend == "faster-whisper":
print('SimulStreaming will use Faster Whisper for the encoder.') logger.info('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path) fw_model = str(self._resolved_model_path)
else: else:

View file

@ -21,4 +21,3 @@ class AlignAttConfig():
init_prompt: str = field(default=None) init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None) static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None) max_context_tokens: int = field(default=None)

View file

@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch

View file

@ -4,7 +4,6 @@ from typing import Any, List, Tuple
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
@ -15,7 +14,6 @@ from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -6,13 +6,9 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from whisperlivekit.backend_support import (faster_backend_available, from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
mlx_backend_available) from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES, from whisperlivekit.whisper.decoding import BeamSearchDecoder, GreedyDecoder, SuppressTokens
TOKENS_PER_SECOND,
log_mel_spectrogram, pad_or_trim)
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
SuppressTokens)
from whisperlivekit.whisper.timing import median_filter from whisperlivekit.whisper.timing import median_filter
from .align_att_base import DEC_PAD, AlignAttBase from .align_att_base import DEC_PAD, AlignAttBase
@ -25,8 +21,7 @@ from .token_buffer import TokenBuffer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if mlx_backend_available(): if mlx_backend_available():
from mlx_whisper.audio import \ from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
if faster_backend_available(): if faster_backend_available():

View file

@ -1,4 +1,3 @@
import sys
import torch import torch

View file

@ -11,10 +11,8 @@ import torch
from torch import Tensor from torch import Tensor
from tqdm import tqdm from tqdm import tqdm
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram, from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
pad_or_trim) from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
decode, detect_language)
from whisperlivekit.whisper.model import ModelDimensions, Whisper from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__ from whisperlivekit.whisper.version import __version__

View file

@ -1,6 +1,5 @@
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
Tuple, Union)
import numpy as np import numpy as np
import torch import torch

View file

@ -8,13 +8,11 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, from .audio import FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import (exact_div, format_timestamp, get_end, get_writer, from .utils import exact_div, format_timestamp, get_end, get_writer, make_safe, optional_float, optional_int, str2bool
make_safe, optional_float, optional_int, str2bool)
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper

View file

@ -6,9 +6,10 @@ Everything else is just efficiency.
@karpathy @karpathy
""" """
import os # os.path.exists import math # math.log, math.exp
import math # math.log, math.exp import os # os.path.exists
import random # random.seed, random.choices, random.gauss, random.shuffle import random # random.seed, random.choices, random.gauss, random.shuffle
random.seed(42) # Let there be order among chaos random.seed(42) # Let there be order among chaos
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names) # Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)