Ruff lint cleanup
This commit is contained in:
parent
451535d48f
commit
cf6c49f502
18 changed files with 150 additions and 167 deletions
|
|
@ -1,7 +1,6 @@
|
||||||
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
|
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Tuple
|
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||||
|
|
@ -151,7 +150,7 @@ class AlignAttBase(ABC):
|
||||||
if seconds_since_start >= 2.0:
|
if seconds_since_start >= 2.0:
|
||||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
self.create_tokenizer(top_lan)
|
self.create_tokenizer(top_lan)
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
self.state.cumulative_time_offset = 0.0
|
self.state.cumulative_time_offset = 0.0
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,27 @@
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from typing import List, Tuple
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from whisperlivekit.backend_support import (faster_backend_available,
|
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||||
mlx_backend_available)
|
|
||||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||||
from whisperlivekit.warmup import load_file
|
from whisperlivekit.warmup import load_file
|
||||||
from whisperlivekit.whisper import load_model, tokenizer
|
from whisperlivekit.whisper import load_model, tokenizer
|
||||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
|
||||||
from .mlx import MLXAlignAtt
|
from .mlx import MLXAlignAtt
|
||||||
|
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||||
else:
|
else:
|
||||||
mlx_model_mapping = {}
|
mlx_model_mapping = {}
|
||||||
MLXAlignAtt = None
|
MLXAlignAtt = None
|
||||||
|
|
@ -259,7 +255,7 @@ class SimulStreamingASR:
|
||||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||||
self.shared_model = self.load_model()
|
self.shared_model = self.load_model()
|
||||||
elif self.encoder_backend == "faster-whisper":
|
elif self.encoder_backend == "faster-whisper":
|
||||||
print('SimulStreaming will use Faster Whisper for the encoder.')
|
logger.info('SimulStreaming will use Faster Whisper for the encoder.')
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
fw_model = str(self._resolved_model_path)
|
fw_model = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -21,4 +21,3 @@ class AlignAttConfig():
|
||||||
init_prompt: str = field(default=None)
|
init_prompt: str = field(default=None)
|
||||||
static_init_prompt: str = field(default=None)
|
static_init_prompt: str = field(default=None)
|
||||||
max_context_tokens: int = field(default=None)
|
max_context_tokens: int = field(default=None)
|
||||||
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from typing import Any, List, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
|
|
||||||
|
|
@ -15,7 +14,6 @@ from ..config import AlignAttConfig
|
||||||
from .decoder_state import MLXDecoderState
|
from .decoder_state import MLXDecoderState
|
||||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,9 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from whisperlivekit.backend_support import (faster_backend_available,
|
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||||
mlx_backend_available)
|
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
|
||||||
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
from whisperlivekit.whisper.decoding import BeamSearchDecoder, GreedyDecoder, SuppressTokens
|
||||||
TOKENS_PER_SECOND,
|
|
||||||
log_mel_spectrogram, pad_or_trim)
|
|
||||||
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
|
||||||
SuppressTokens)
|
|
||||||
from whisperlivekit.whisper.timing import median_filter
|
from whisperlivekit.whisper.timing import median_filter
|
||||||
|
|
||||||
from .align_att_base import DEC_PAD, AlignAttBase
|
from .align_att_base import DEC_PAD, AlignAttBase
|
||||||
|
|
@ -25,8 +21,7 @@ from .token_buffer import TokenBuffer
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if mlx_backend_available():
|
if mlx_backend_available():
|
||||||
from mlx_whisper.audio import \
|
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
log_mel_spectrogram as mlx_log_mel_spectrogram
|
|
||||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
|
|
||||||
if faster_backend_available():
|
if faster_backend_available():
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,8 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
|
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
pad_or_trim)
|
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
|
||||||
decode, detect_language)
|
|
||||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||||
from whisperlivekit.whisper.transcribe import transcribe
|
from whisperlivekit.whisper.transcribe import transcribe
|
||||||
from whisperlivekit.whisper.version import __version__
|
from whisperlivekit.whisper.version import __version__
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
Tuple, Union)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,11 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
|
from .audio import FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim
|
||||||
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
|
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .timing import add_word_timestamps
|
from .timing import add_word_timestamps
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (exact_div, format_timestamp, get_end, get_writer,
|
from .utils import exact_div, format_timestamp, get_end, get_writer, make_safe, optional_float, optional_int, str2bool
|
||||||
make_safe, optional_float, optional_int, str2bool)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,10 @@ Everything else is just efficiency.
|
||||||
@karpathy
|
@karpathy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os # os.path.exists
|
import math # math.log, math.exp
|
||||||
import math # math.log, math.exp
|
import os # os.path.exists
|
||||||
import random # random.seed, random.choices, random.gauss, random.shuffle
|
import random # random.seed, random.choices, random.gauss, random.shuffle
|
||||||
|
|
||||||
random.seed(42) # Let there be order among chaos
|
random.seed(42) # Let there be order among chaos
|
||||||
|
|
||||||
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
|
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue