From 6e85c16614ca18e3d5d2babf08b12ffb522ee552 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 18 Jan 2026 15:27:00 +0100 Subject: [PATCH] Refactor TranscriptionEngine singleton --- whisperlivekit/core.py | 57 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index c306f52..30a8da7 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -1,5 +1,4 @@ import logging -import sys import threading from argparse import Namespace from dataclasses import asdict @@ -15,7 +14,7 @@ class TranscriptionEngine: _instance = None _initialized = False _lock = threading.Lock() # Thread-safe singleton lock - + def __new__(cls, *args, **kwargs): # Double-checked locking pattern for thread-safe singleton if cls._instance is None: @@ -24,7 +23,18 @@ class TranscriptionEngine: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance - + + @classmethod + def reset(cls): + """Reset the singleton so a new instance can be created. + + For testing only — allows switching backends between test runs. + In production, the singleton should never be reset. + """ + with cls._lock: + cls._instance = None + cls._initialized = False + def __init__(self, config=None, **kwargs): # Thread-safe initialization check with TranscriptionEngine._lock: @@ -102,6 +112,17 @@ class TranscriptionEngine: self.tokenizer = None self.asr = VoxtralHFStreamingASR(**transcription_common_params) logger.info("Using Voxtral HF Transformers streaming backend") + elif config.backend == "qwen3": + from whisperlivekit.qwen3_asr import Qwen3ASR + self.asr = Qwen3ASR(**transcription_common_params) + self.asr.confidence_validation = config.confidence_validation + self.asr.tokenizer = None + self.asr.buffer_trimming = config.buffer_trimming + self.asr.buffer_trimming_sec = config.buffer_trimming_sec + self.asr.backend_choice = "qwen3" + from whisperlivekit.warmup import warmup_asr + warmup_asr(self.asr, config.warmup_file) + logger.info("Using Qwen3-ASR backend with LocalAgreement policy") elif config.backend_policy == "simulstreaming": simulstreaming_params = { "disable_fast_encoder": config.disable_fast_encoder, @@ -173,26 +194,42 @@ class TranscriptionEngine: ) -def online_factory(args, asr): - if getattr(args, 'backend', None) == "voxtral-mlx": +def online_factory(args, asr, language=None): + """Create an online ASR processor for a session. + + Args: + args: Configuration namespace. + asr: Shared ASR backend instance. + language: Optional per-session language override (e.g. "en", "fr", "auto"). + If provided and the backend supports it, transcription will use + this language instead of the server-wide default. + """ + # Wrap the shared ASR with a per-session language if requested + if language is not None: + from whisperlivekit.session_asr_proxy import SessionASRProxy + asr = SessionASRProxy(asr, language) + + backend = getattr(args, 'backend', None) + if backend == "voxtral-mlx": from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor return VoxtralMLXOnlineProcessor(asr) - if getattr(args, 'backend', None) == "voxtral": + if backend == "voxtral": from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor return VoxtralHFStreamingOnlineProcessor(asr) + if backend == "qwen3": + return OnlineASRProcessor(asr) if args.backend_policy == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor return SimulStreamingOnlineProcessor(asr) return OnlineASRProcessor(asr) - - + + def online_diarization_factory(args, diarization_backend): if args.diarization_backend == "diart": online = diarization_backend # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended elif args.diarization_backend == "sortformer": - from whisperlivekit.diarization.sortformer_backend import \ - SortformerDiarizationOnline + from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline online = SortformerDiarizationOnline(shared_model=diarization_backend) else: raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")