Ruff lint cleanup
This commit is contained in:
parent
451535d48f
commit
cf6c49f502
18 changed files with 150 additions and 167 deletions
|
|
@ -1,7 +1,6 @@
|
|||
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||
|
|
@ -151,7 +150,7 @@ class AlignAttBase(ABC):
|
|||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.state.cumulative_time_offset = 0.0
|
||||
|
|
|
|||
|
|
@ -1,31 +1,27 @@
|
|||
import gc
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||
from .mlx import MLXAlignAtt
|
||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||
else:
|
||||
mlx_model_mapping = {}
|
||||
MLXAlignAtt = None
|
||||
|
|
@ -259,7 +255,7 @@ class SimulStreamingASR:
|
|||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||
self.shared_model = self.load_model()
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('SimulStreaming will use Faster Whisper for the encoder.')
|
||||
logger.info('SimulStreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -21,4 +21,3 @@ class AlignAttConfig():
|
|||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from typing import Any, List, Tuple
|
|||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
|
|
@ -15,7 +14,6 @@ from ..config import AlignAttConfig
|
|||
from .decoder_state import MLXDecoderState
|
||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,13 +6,9 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
||||
TOKENS_PER_SECOND,
|
||||
log_mel_spectrogram, pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
||||
SuppressTokens)
|
||||
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
|
||||
from whisperlivekit.whisper.decoding import BeamSearchDecoder, GreedyDecoder, SuppressTokens
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
|
||||
from .align_att_base import DEC_PAD, AlignAttBase
|
||||
|
|
@ -25,8 +21,7 @@ from .token_buffer import TokenBuffer
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import \
|
||||
log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
if faster_backend_available():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,8 @@ import torch
|
|||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
|
||||
pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
||||
decode, detect_language)
|
||||
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||
Tuple, Union)
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -8,13 +8,11 @@ import numpy as np
|
|||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
|
||||
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
|
||||
from .audio import FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (exact_div, format_timestamp, get_end, get_writer,
|
||||
make_safe, optional_float, optional_int, str2bool)
|
||||
from .utils import exact_div, format_timestamp, get_end, get_writer, make_safe, optional_float, optional_int, str2bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ Everything else is just efficiency.
|
|||
@karpathy
|
||||
"""
|
||||
|
||||
import os # os.path.exists
|
||||
import math # math.log, math.exp
|
||||
import random # random.seed, random.choices, random.gauss, random.shuffle
|
||||
import math # math.log, math.exp
|
||||
import os # os.path.exists
|
||||
import random # random.seed, random.choices, random.gauss, random.shuffle
|
||||
|
||||
random.seed(42) # Let there be order among chaos
|
||||
|
||||
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
|
||||
|
|
|
|||
Loading…
Reference in a new issue