commit
d27b5eb23e
3 changed files with 29 additions and 41 deletions
|
|
@ -4,7 +4,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .whisper_streaming_custom.whisper_online import backend_factory
|
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||||
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||||
from whisperlivekit.warmup import warmup_asr, warmup_online
|
from whisperlivekit.warmup import warmup_asr
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ class TranscriptionEngine:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.asr, self.tokenizer = backend_factory(self.args)
|
self.asr, self.tokenizer = backend_factory(self.args)
|
||||||
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
if self.args.diarization_backend == "diart":
|
if self.args.diarization_backend == "diart":
|
||||||
|
|
@ -155,7 +155,6 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||||
asr,
|
asr,
|
||||||
logfile=logfile,
|
logfile=logfile,
|
||||||
)
|
)
|
||||||
# warmup_online(online, args.warmup_file)
|
|
||||||
else:
|
else:
|
||||||
online = OnlineASRProcessor(
|
online = OnlineASRProcessor(
|
||||||
asr,
|
asr,
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def parse_args():
|
||||||
help="""
|
help="""
|
||||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||||
If False, no warmup is performed.
|
If empty, no warmup is performed.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,57 +6,46 @@ logger = logging.getLogger(__name__)
|
||||||
def load_file(warmup_file=None, timeout=5):
|
def load_file(warmup_file=None, timeout=5):
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import urllib.request
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
|
if warmup_file == "":
|
||||||
|
logger.info(f"Skipping warmup.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Download JFK sample if not already present
|
||||||
if warmup_file is None:
|
if warmup_file is None:
|
||||||
# Download JFK sample if not already present
|
|
||||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||||
|
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
if not os.path.exists(warmup_file):
|
|
||||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
|
||||||
print(f"Downloading warmup file from {jfk_url}")
|
|
||||||
import time
|
|
||||||
import urllib.request
|
|
||||||
import urllib.error
|
|
||||||
import socket
|
|
||||||
|
|
||||||
original_timeout = socket.getdefaulttimeout()
|
|
||||||
socket.setdefaulttimeout(timeout)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
try:
|
||||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
with urllib.request.urlopen(jfk_url, timeout=timeout) as r, open(warmup_file, "wb") as f:
|
||||||
except (urllib.error.URLError, socket.timeout) as e:
|
f.write(r.read())
|
||||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
except Exception as e:
|
||||||
|
logger.warning(f"Warmup file download failed: {e}.")
|
||||||
return None
|
return None
|
||||||
finally:
|
|
||||||
socket.setdefaulttimeout(original_timeout)
|
# Validate file and load
|
||||||
elif not warmup_file:
|
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
return None
|
logger.warning(f"Warmup file {warmup_file} is invalid or missing.")
|
||||||
|
|
||||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
|
||||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
audio, _ = librosa.load(warmup_file, sr=16000)
|
||||||
|
return audio
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load audio file: {e}")
|
logger.warning(f"Failed to load warmup file: {e}")
|
||||||
return None
|
return None
|
||||||
return audio
|
|
||||||
|
|
||||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||||
"""
|
"""
|
||||||
Warmup the ASR model by transcribing a short audio file.
|
Warmup the ASR model by transcribing a short audio file.
|
||||||
"""
|
"""
|
||||||
audio = load_file(warmup_file=None, timeout=5)
|
audio = load_file(warmup_file=warmup_file, timeout=timeout)
|
||||||
|
if audio is None:
|
||||||
|
logger.warning("Warmup file unavailable. Skipping ASR warmup.")
|
||||||
|
return
|
||||||
asr.transcribe(audio)
|
asr.transcribe(audio)
|
||||||
logger.info("ASR model is warmed up")
|
logger.info("ASR model is warmed up.")
|
||||||
|
|
||||||
def warmup_online(online, warmup_file=None, timeout=5):
|
|
||||||
audio = load_file(warmup_file=None, timeout=5)
|
|
||||||
online.warmup(audio)
|
|
||||||
logger.warning("ASR is warmed up")
|
|
||||||
Loading…
Reference in a new issue