Merge pull request #85 from SilasK/warm-up

add warmup ASR, with default file being https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav
This commit is contained in:
Quentin Fuxa 2025-03-14 11:43:24 +01:00 committed by GitHub
commit f4a57cd810
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 60 additions and 8 deletions

View file

@ -10,7 +10,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
from timed_objects import ASRToken
import math
@ -42,8 +42,13 @@ parser.add_argument(
parser.add_argument(
"--warmup-file",
type=str,
default=None,
dest="warmup_file",
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
help="""
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
If False, no warmup is performed.
""",
)
parser.add_argument(
@ -160,6 +165,7 @@ async def lifespan(app: FastAPI):
global asr, tokenizer, diarization
if args.transcription:
asr, tokenizer = backend_factory(args)
warmup_asr(asr, args.warmup_file)
else:
asr, tokenizer = None, None

View file

@ -227,11 +227,57 @@ def asr_factory(args, logfile=sys.stderr):
online = online_factory(args, asr, tokenizer, logfile=logfile)
return asr, online
def set_logging(args, logger, others=[]):
logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
logger.setLevel(args.log_level)
for other in others:
logging.getLogger(other).setLevel(args.log_level)
def warmup_asr(asr, warmup_file=None, timeout=5):
"""
Warmup the ASR model by transcribing a short audio file.
"""
import os
import tempfile
if warmup_file is None:
# Download JFK sample if not already present
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
temp_dir = tempfile.gettempdir()
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
if not os.path.exists(warmup_file):
logger.debug(f"Downloading warmup file from {jfk_url}")
print(f"Downloading warmup file from {jfk_url}")
import time
import urllib.request
import urllib.error
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(timeout)
start_time = time.time()
try:
urllib.request.urlretrieve(jfk_url, warmup_file)
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
print(f"Warmping up Whisper with {warmup_file}")
try:
import librosa
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
# Process the audio
asr.transcribe(audio)
logger.info("Whisper is warmed up")