Refactor TranscriptionEngine singleton
This commit is contained in:
parent
e1823dd99c
commit
6e85c16614
1 changed files with 47 additions and 10 deletions
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
@ -15,7 +14,7 @@ class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
_lock = threading.Lock() # Thread-safe singleton lock
|
_lock = threading.Lock() # Thread-safe singleton lock
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
# Double-checked locking pattern for thread-safe singleton
|
# Double-checked locking pattern for thread-safe singleton
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
|
|
@ -24,7 +23,18 @@ class TranscriptionEngine:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset(cls):
|
||||||
|
"""Reset the singleton so a new instance can be created.
|
||||||
|
|
||||||
|
For testing only — allows switching backends between test runs.
|
||||||
|
In production, the singleton should never be reset.
|
||||||
|
"""
|
||||||
|
with cls._lock:
|
||||||
|
cls._instance = None
|
||||||
|
cls._initialized = False
|
||||||
|
|
||||||
def __init__(self, config=None, **kwargs):
|
def __init__(self, config=None, **kwargs):
|
||||||
# Thread-safe initialization check
|
# Thread-safe initialization check
|
||||||
with TranscriptionEngine._lock:
|
with TranscriptionEngine._lock:
|
||||||
|
|
@ -102,6 +112,17 @@ class TranscriptionEngine:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
||||||
logger.info("Using Voxtral HF Transformers streaming backend")
|
logger.info("Using Voxtral HF Transformers streaming backend")
|
||||||
|
elif config.backend == "qwen3":
|
||||||
|
from whisperlivekit.qwen3_asr import Qwen3ASR
|
||||||
|
self.asr = Qwen3ASR(**transcription_common_params)
|
||||||
|
self.asr.confidence_validation = config.confidence_validation
|
||||||
|
self.asr.tokenizer = None
|
||||||
|
self.asr.buffer_trimming = config.buffer_trimming
|
||||||
|
self.asr.buffer_trimming_sec = config.buffer_trimming_sec
|
||||||
|
self.asr.backend_choice = "qwen3"
|
||||||
|
from whisperlivekit.warmup import warmup_asr
|
||||||
|
warmup_asr(self.asr, config.warmup_file)
|
||||||
|
logger.info("Using Qwen3-ASR backend with LocalAgreement policy")
|
||||||
elif config.backend_policy == "simulstreaming":
|
elif config.backend_policy == "simulstreaming":
|
||||||
simulstreaming_params = {
|
simulstreaming_params = {
|
||||||
"disable_fast_encoder": config.disable_fast_encoder,
|
"disable_fast_encoder": config.disable_fast_encoder,
|
||||||
|
|
@ -173,26 +194,42 @@ class TranscriptionEngine:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def online_factory(args, asr):
|
def online_factory(args, asr, language=None):
|
||||||
if getattr(args, 'backend', None) == "voxtral-mlx":
|
"""Create an online ASR processor for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Configuration namespace.
|
||||||
|
asr: Shared ASR backend instance.
|
||||||
|
language: Optional per-session language override (e.g. "en", "fr", "auto").
|
||||||
|
If provided and the backend supports it, transcription will use
|
||||||
|
this language instead of the server-wide default.
|
||||||
|
"""
|
||||||
|
# Wrap the shared ASR with a per-session language if requested
|
||||||
|
if language is not None:
|
||||||
|
from whisperlivekit.session_asr_proxy import SessionASRProxy
|
||||||
|
asr = SessionASRProxy(asr, language)
|
||||||
|
|
||||||
|
backend = getattr(args, 'backend', None)
|
||||||
|
if backend == "voxtral-mlx":
|
||||||
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
|
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
|
||||||
return VoxtralMLXOnlineProcessor(asr)
|
return VoxtralMLXOnlineProcessor(asr)
|
||||||
if getattr(args, 'backend', None) == "voxtral":
|
if backend == "voxtral":
|
||||||
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
|
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
|
||||||
return VoxtralHFStreamingOnlineProcessor(asr)
|
return VoxtralHFStreamingOnlineProcessor(asr)
|
||||||
|
if backend == "qwen3":
|
||||||
|
return OnlineASRProcessor(asr)
|
||||||
if args.backend_policy == "simulstreaming":
|
if args.backend_policy == "simulstreaming":
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
return SimulStreamingOnlineProcessor(asr)
|
return SimulStreamingOnlineProcessor(asr)
|
||||||
return OnlineASRProcessor(asr)
|
return OnlineASRProcessor(asr)
|
||||||
|
|
||||||
|
|
||||||
def online_diarization_factory(args, diarization_backend):
|
def online_diarization_factory(args, diarization_backend):
|
||||||
if args.diarization_backend == "diart":
|
if args.diarization_backend == "diart":
|
||||||
online = diarization_backend
|
online = diarization_backend
|
||||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||||
elif args.diarization_backend == "sortformer":
|
elif args.diarization_backend == "sortformer":
|
||||||
from whisperlivekit.diarization.sortformer_backend import \
|
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||||
SortformerDiarizationOnline
|
|
||||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")
|
raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue