Refactor TranscriptionEngine singleton

This commit is contained in:
Quentin Fuxa 2026-01-18 15:27:00 +01:00
parent e1823dd99c
commit 6e85c16614

View file

@ -1,5 +1,4 @@
import logging import logging
import sys
import threading import threading
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
@ -15,7 +14,7 @@ class TranscriptionEngine:
_instance = None _instance = None
_initialized = False _initialized = False
_lock = threading.Lock() # Thread-safe singleton lock _lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton # Double-checked locking pattern for thread-safe singleton
if cls._instance is None: if cls._instance is None:
@ -24,7 +23,18 @@ class TranscriptionEngine:
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
@classmethod
def reset(cls):
"""Reset the singleton so a new instance can be created.
For testing only allows switching backends between test runs.
In production, the singleton should never be reset.
"""
with cls._lock:
cls._instance = None
cls._initialized = False
def __init__(self, config=None, **kwargs): def __init__(self, config=None, **kwargs):
# Thread-safe initialization check # Thread-safe initialization check
with TranscriptionEngine._lock: with TranscriptionEngine._lock:
@ -102,6 +112,17 @@ class TranscriptionEngine:
self.tokenizer = None self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params) self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend") logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend == "qwen3":
from whisperlivekit.qwen3_asr import Qwen3ASR
self.asr = Qwen3ASR(**transcription_common_params)
self.asr.confidence_validation = config.confidence_validation
self.asr.tokenizer = None
self.asr.buffer_trimming = config.buffer_trimming
self.asr.buffer_trimming_sec = config.buffer_trimming_sec
self.asr.backend_choice = "qwen3"
from whisperlivekit.warmup import warmup_asr
warmup_asr(self.asr, config.warmup_file)
logger.info("Using Qwen3-ASR backend with LocalAgreement policy")
elif config.backend_policy == "simulstreaming": elif config.backend_policy == "simulstreaming":
simulstreaming_params = { simulstreaming_params = {
"disable_fast_encoder": config.disable_fast_encoder, "disable_fast_encoder": config.disable_fast_encoder,
@ -173,26 +194,42 @@ class TranscriptionEngine:
) )
def online_factory(args, asr): def online_factory(args, asr, language=None):
if getattr(args, 'backend', None) == "voxtral-mlx": """Create an online ASR processor for a session.
Args:
args: Configuration namespace.
asr: Shared ASR backend instance.
language: Optional per-session language override (e.g. "en", "fr", "auto").
If provided and the backend supports it, transcription will use
this language instead of the server-wide default.
"""
# Wrap the shared ASR with a per-session language if requested
if language is not None:
from whisperlivekit.session_asr_proxy import SessionASRProxy
asr = SessionASRProxy(asr, language)
backend = getattr(args, 'backend', None)
if backend == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
return VoxtralMLXOnlineProcessor(asr) return VoxtralMLXOnlineProcessor(asr)
if getattr(args, 'backend', None) == "voxtral": if backend == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
return VoxtralHFStreamingOnlineProcessor(asr) return VoxtralHFStreamingOnlineProcessor(asr)
if backend == "qwen3":
return OnlineASRProcessor(asr)
if args.backend_policy == "simulstreaming": if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
return SimulStreamingOnlineProcessor(asr) return SimulStreamingOnlineProcessor(asr)
return OnlineASRProcessor(asr) return OnlineASRProcessor(asr)
def online_diarization_factory(args, diarization_backend): def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart": if args.diarization_backend == "diart":
online = diarization_backend online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
elif args.diarization_backend == "sortformer": elif args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import \ from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend) online = SortformerDiarizationOnline(shared_model=diarization_backend)
else: else:
raise ValueError(f"Unknown diarization backend: {args.diarization_backend}") raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")