use platform to determine system and recommand mlx whisper
This commit is contained in:
parent
72f33be6f2
commit
334b338ab0
3 changed files with 11 additions and 10 deletions
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.8"
|
version = "0.2.8.post1"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
from whisperlivekit.warmup import load_file
|
from whisperlivekit.warmup import load_file
|
||||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||||
from .whisper import load_model, tokenizer
|
from .whisper import load_model, tokenizer
|
||||||
from .whisper.audio import TOKENS_PER_SECOND
|
from .whisper.audio import TOKENS_PER_SECOND
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import gc
|
import gc
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -22,6 +22,8 @@ try:
|
||||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||||
HAS_MLX_WHISPER = True
|
HAS_MLX_WHISPER = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||||
|
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
|
||||||
HAS_MLX_WHISPER = False
|
HAS_MLX_WHISPER = False
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
HAS_FASTER_WHISPER = False
|
HAS_FASTER_WHISPER = False
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,8 @@ class PaddedAlignAttWhisper:
|
||||||
self.model = loaded_model
|
self.model = loaded_model
|
||||||
else:
|
else:
|
||||||
self.model = load_model(name=model_name, download_root=model_path)
|
self.model = load_model(name=model_name, download_root=model_path)
|
||||||
|
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
self.mlx_encoder = mlx_encoder
|
self.mlx_encoder = mlx_encoder
|
||||||
self.fw_encoder = fw_encoder
|
self.fw_encoder = fw_encoder
|
||||||
|
|
@ -401,25 +403,22 @@ class PaddedAlignAttWhisper:
|
||||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||||
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
||||||
device = encoder_feature.device #'cpu' is apple silicon
|
|
||||||
elif self.fw_encoder:
|
elif self.fw_encoder:
|
||||||
audio_length_seconds = len(input_segments) / 16000
|
audio_length_seconds = len(input_segments) / 16000
|
||||||
content_mel_len = int(audio_length_seconds * 100)//2
|
content_mel_len = int(audio_length_seconds * 100)//2
|
||||||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
encoder_feature_ctranslate = np.array(self.fw_encoder.encode(mel))
|
||||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate)
|
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||||
device = encoder_feature.device
|
|
||||||
else:
|
else:
|
||||||
# mel + padding to 30s
|
# mel + padding to 30s
|
||||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||||
device=self.model.device).unsqueeze(0)
|
device=self.device).unsqueeze(0)
|
||||||
# trim to 3000
|
# trim to 3000
|
||||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||||
# the len of actual audio
|
# the len of actual audio
|
||||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||||
encoder_feature = self.model.encoder(mel)
|
encoder_feature = self.model.encoder(mel)
|
||||||
device = mel.device
|
|
||||||
end_encode = time()
|
end_encode = time()
|
||||||
# print('Encoder duration:', end_encode-beg_encode)
|
# print('Encoder duration:', end_encode-beg_encode)
|
||||||
|
|
||||||
|
|
@ -447,7 +446,7 @@ class PaddedAlignAttWhisper:
|
||||||
####################### Decoding loop
|
####################### Decoding loop
|
||||||
logger.info("Decoding loop starts\n")
|
logger.info("Decoding loop starts\n")
|
||||||
|
|
||||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
|
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||||
completed = False
|
completed = False
|
||||||
|
|
||||||
attn_of_alignment_heads = None
|
attn_of_alignment_heads = None
|
||||||
|
|
@ -658,7 +657,7 @@ class PaddedAlignAttWhisper:
|
||||||
### new hypothesis
|
### new hypothesis
|
||||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||||
device=self.model.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.tokens.append(new_tokens)
|
self.tokens.append(new_tokens)
|
||||||
# TODO: test if this is redundant or not
|
# TODO: test if this is redundant or not
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue