From 197293e25e91b2b9aea514d5f2dae7c99670e8ee Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 8 Aug 2025 18:07:51 +0200 Subject: [PATCH] refactor(simulstreaming): extract backend + online module into separate files from whisper streaming --- whisperlivekit/audio_processor.py | 3 +- whisperlivekit/core.py | 65 +++- whisperlivekit/simul_whisper/__init__.py | 6 + whisperlivekit/simul_whisper/backend.py | 331 ++++++++++++++++++ .../whisper_streaming_custom/backends.py | 201 +---------- .../whisper_streaming_custom/online_asr.py | 201 ----------- .../whisper_online.py | 73 +--- 7 files changed, 403 insertions(+), 477 deletions(-) create mode 100644 whisperlivekit/simul_whisper/backend.py diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index a5be846..beee001 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -6,8 +6,7 @@ import logging import traceback from datetime import timedelta from whisperlivekit.timed_objects import ASRToken -from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory -from whisperlivekit.core import TranscriptionEngine +from whisperlivekit.core import TranscriptionEngine, online_factory from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState # Set up logging once diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 7ec9887..fb1bbd9 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -1,9 +1,12 @@ try: from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr + from whisperlivekit.whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor except ImportError: from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr -from argparse import Namespace + from .whisper_streaming_custom.online_asr import VACOnlineASRProcessor, OnlineASRProcessor +from argparse import Namespace +import sys class TranscriptionEngine: _instance = None @@ -78,8 +81,32 @@ class TranscriptionEngine: self.diarization = None if self.args.transcription: - self.asr, self.tokenizer = backend_factory(self.args) - warmup_asr(self.asr, self.args.warmup_file) + if self.args.backend == "simulstreaming": + from simul_whisper import SimulStreamingASR + self.tokenizer = None + simulstreaming_kwargs = {} + for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', + 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', + 'max_context_tokens', 'model_path']: + if hasattr(self.args, attr): + simulstreaming_kwargs[attr] = getattr(self.args, attr) + + # Add segment_length from min_chunk_size + simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5) + simulstreaming_kwargs['task'] = self.args.task + + size = self.args.model + self.asr = SimulStreamingASR( + modelsize=size, + lan=self.args.lan, + cache_dir=getattr(self.args, 'model_cache_dir', None), + model_dir=getattr(self.args, 'model_dir', None), + **simulstreaming_kwargs + ) + + else: + self.asr, self.tokenizer = backend_factory(self.args) + warmup_asr(self.asr, self.args.warmup_file) if self.args.diarization: from whisperlivekit.diarization.diarization_online import DiartDiarization @@ -90,3 +117,35 @@ class TranscriptionEngine: ) TranscriptionEngine._initialized = True + + + +def online_factory(args, asr, tokenizer, logfile=sys.stderr): + if args.backend == "simulstreaming": + from simul_whisper import SimulStreamingOnlineProcessor + online = SimulStreamingOnlineProcessor( + asr, + tokenizer, + logfile=logfile, + buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), + confidence_validation=args.confidence_validation + ) + elif args.vac: + online = VACOnlineASRProcessor( + args.min_chunk_size, + asr, + tokenizer, + logfile=logfile, + buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), + confidence_validation = args.confidence_validation + ) + else: + online = OnlineASRProcessor( + asr, + tokenizer, + logfile=logfile, + buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), + confidence_validation = args.confidence_validation + ) + return online + \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/__init__.py b/whisperlivekit/simul_whisper/__init__.py index e69de29..3dd7624 100644 --- a/whisperlivekit/simul_whisper/__init__.py +++ b/whisperlivekit/simul_whisper/__init__.py @@ -0,0 +1,6 @@ +from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor + +__all__ = [ + "SimulStreamingASR", + "SimulStreamingOnlineProcessor", +] diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py new file mode 100644 index 0000000..6ab2e23 --- /dev/null +++ b/whisperlivekit/simul_whisper/backend.py @@ -0,0 +1,331 @@ +import sys +import numpy as np +import logging +from typing import List, Tuple, Optional +import logging +from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript +from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE +logger = logging.getLogger(__name__) + +try: + import torch + from whisperlivekit.simul_whisper.config import AlignAttConfig + from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD + from whisperlivekit.simul_whisper.whisper import tokenizer + SIMULSTREAMING_AVAILABLE = True +except ImportError as e: + raise ImportError( + """SimulStreaming dependencies are not available. + Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""") + +class SimulStreamingOnlineProcessor: + SAMPLING_RATE = 16000 + + def __init__( + self, + asr, + tokenize_method: Optional[callable] = None, + buffer_trimming: Tuple[str, float] = ("segment", 15), + confidence_validation = False, + logfile=sys.stderr, + ): + if not SIMULSTREAMING_AVAILABLE: + raise ImportError("SimulStreaming dependencies are not available.") + + self.asr = asr + self.tokenize = tokenize_method + self.logfile = logfile + self.confidence_validation = confidence_validation + self.init() + + # buffer does not work yet + self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming + + def init(self, offset: Optional[float] = None): + """Initialize or reset the processing state.""" + self.audio_chunks = [] + self.offset = offset if offset is not None else 0.0 + self.is_last = False + self.beg = self.offset + self.end = self.offset + self.cumulative_audio_duration = 0.0 + self.last_audio_stream_end_time = self.offset + + self.committed: List[ASRToken] = [] + self.last_result_tokens: List[ASRToken] = [] + self.buffer_content = "" + self.processed_audio_duration = 0.0 + + def get_audio_buffer_end_time(self) -> float: + """Returns the absolute end time of the current audio buffer.""" + return self.end + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None): + """Append an audio chunk to be processed by SimulStreaming.""" + if torch is None: + raise ImportError("PyTorch is required for SimulStreaming but not available") + + # Convert numpy array to torch tensor + audio_tensor = torch.from_numpy(audio).float() + self.audio_chunks.append(audio_tensor) + + # Update timing + chunk_duration = len(audio) / self.SAMPLING_RATE + self.cumulative_audio_duration += chunk_duration + + if audio_stream_end_time is not None: + self.last_audio_stream_end_time = audio_stream_end_time + self.end = audio_stream_end_time + else: + self.end = self.offset + self.cumulative_audio_duration + + def prompt(self) -> Tuple[str, str]: + """ + Returns a tuple: (prompt, context). + SimulStreaming handles prompting internally, so we return empty strings. + """ + return "", "" + + def get_buffer(self): + """ + Get the unvalidated buffer content. + """ + buffer_end = self.end if hasattr(self, 'end') else None + return Transcript( + start=None, + end=buffer_end, + text=self.buffer_content, + probability=None + ) + + def timestamped_text(self, tokens, generation): + # From the simulstreaming repo. self.model to self.asr.model + pr = generation["progress"] + if "result" not in generation: + split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens) + else: + split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"] + + frames = [p["most_attended_frames"][0] for p in pr] + tokens = tokens.copy() + ret = [] + for sw,st in zip(split_words,split_tokens): + b = None + for stt in st: + t,f = tokens.pop(0), frames.pop(0) + if t != stt: + raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.") + if b is None: + b = f + e = f + out = (b*0.02, e*0.02, sw) + ret.append(out) + logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}") + return ret + + def process_iter(self) -> Tuple[List[ASRToken], float]: + """ + Process accumulated audio chunks using SimulStreaming. + + Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). + """ + if not self.audio_chunks: + return [], self.end + + try: + # concatenate all audio chunks + if len(self.audio_chunks) == 1: + audio = self.audio_chunks[0] + else: + audio = torch.cat(self.audio_chunks, dim=0) + + audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0 + self.processed_audio_duration += audio_duration + + self.audio_chunks = [] + + logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s") + logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s") + + self.asr.model.insert_audio(audio) + tokens, generation_progress = self.asr.model.infer(is_last=self.is_last) + ts_words = self.timestamped_text(tokens, generation_progress) + text = self.asr.model.tokenizer.decode(tokens) + + new_tokens = [] + for ts_word in ts_words: + + start, end, word = ts_word + token = ASRToken( + start=start, + end=end, + text=word, + probability=0.95 # fake prob. Maybe we can extract it from the model? + ) + new_tokens.append(token) + self.committed.extend(new_tokens) + + return new_tokens, self.end + + + except Exception as e: + logger.exception(f"SimulStreaming processing error: {e}") + return [], self.end + + def finish(self) -> Tuple[List[ASRToken], float]: + logger.debug("SimulStreaming finish() called") + self.is_last = True + final_tokens, final_time = self.process_iter() + self.is_last = False + return final_tokens, final_time + + def concatenate_tokens( + self, + tokens: List[ASRToken], + sep: Optional[str] = None, + offset: float = 0 + ) -> Transcript: + """Concatenate tokens into a Transcript object.""" + sep = sep if sep is not None else self.asr.sep + text = sep.join(token.text for token in tokens) + probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None + if tokens: + start = offset + tokens[0].start + end = offset + tokens[-1].end + else: + start = None + end = None + return Transcript(start, end, text, probability=probability) + + def chunk_at(self, time: float): + """ + useless but kept for compatibility + """ + logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally") + pass + + def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]: + """ + Create simple sentences. + """ + if not tokens: + return [] + + full_text = " ".join(token.text for token in tokens) + sentence = Sentence( + start=tokens[0].start, + end=tokens[-1].end, + text=full_text + ) + return [sentence] + + + +class SimulStreamingASR(): + """SimulStreaming backend with AlignAtt policy.""" + sep = "" + + def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs): + logger.warning(SIMULSTREAMING_LICENSE) + self.logfile = logfile + self.transcribe_kargs = {} + self.original_language = None if lan == "auto" else lan + + self.model_path = kwargs.get('model_path', './large-v3.pt') + self.frame_threshold = kwargs.get('frame_threshold', 25) + self.audio_max_len = kwargs.get('audio_max_len', 30.0) + self.audio_min_len = kwargs.get('audio_min_len', 0.0) + self.segment_length = kwargs.get('segment_length', 0.5) + self.beams = kwargs.get('beams', 1) + self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam') + self.task = kwargs.get('task', 'transcribe') + self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None) + self.never_fire = kwargs.get('never_fire', False) + self.init_prompt = kwargs.get('init_prompt', None) + self.static_init_prompt = kwargs.get('static_init_prompt', None) + self.max_context_tokens = kwargs.get('max_context_tokens', None) + + if model_dir is not None: + self.model_path = model_dir + elif modelsize is not None: #For the moment the .en.pt models do not work! + model_mapping = { + 'tiny': './tiny.pt', + 'base': './base.pt', + 'small': './small.pt', + 'medium': './medium.pt', + 'medium.en': './medium.en.pt', + 'large-v1': './large-v1.pt', + 'base.en': './base.en.pt', + 'small.en': './small.en.pt', + 'tiny.en': './tiny.en.pt', + 'large-v2': './large-v2.pt', + 'large-v3': './large-v3.pt', + 'large': './large-v3.pt' + } + self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt') + + self.model = self.load_model(modelsize, cache_dir, model_dir) + + # Set up tokenizer for translation if needed + if self.task == "translate": + self.set_translate_task() + + def load_model(self, modelsize, cache_dir, model_dir): + try: + cfg = AlignAttConfig( + model_path=self.model_path, + segment_length=self.segment_length, + frame_threshold=self.frame_threshold, + language=self.original_language, + audio_max_len=self.audio_max_len, + audio_min_len=self.audio_min_len, + cif_ckpt_path=self.cif_ckpt_path, + decoder_type="beam", + beam_size=self.beams, + task=self.task, + never_fire=self.never_fire, + init_prompt=self.init_prompt, + max_context_tokens=self.max_context_tokens, + static_init_prompt=self.static_init_prompt, + ) + + logger.info(f"Loading SimulStreaming model with language: {self.original_language}") + model = PaddedAlignAttWhisper(cfg) + return model + + except Exception as e: + logger.error(f"Failed to load SimulStreaming model: {e}") + raise + + def segments_end_ts(self, result) -> List[float]: + """Get segment end timestamps.""" + if torch.is_tensor(result): + num_tokens = len(result) + return [num_tokens * 0.1] # rough estimate + return [1.0] + + def set_translate_task(self): + """Set up translation task.""" + try: + self.model.tokenizer = tokenizer.get_tokenizer( + multilingual=True, + language=self.model.cfg.language, + num_languages=self.model.model.num_languages, + task="translate" + ) + logger.info("SimulStreaming configured for translation task") + except Exception as e: + logger.error(f"Failed to configure SimulStreaming for translation: {e}") + raise + + def warmup(self, audio, init_prompt=""): + """Warmup the SimulStreaming model.""" + try: + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio).float() + self.model.insert_audio(audio) + self.model.infer(True) + self.model.refresh_segment(complete=True) + logger.info("SimulStreaming model warmed up successfully") + except Exception as e: + logger.exception(f"SimulStreaming warmup failed: {e}") diff --git a/whisperlivekit/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py index d6ad639..8f7d643 100644 --- a/whisperlivekit/whisper_streaming_custom/backends.py +++ b/whisperlivekit/whisper_streaming_custom/backends.py @@ -3,32 +3,10 @@ import logging import io import soundfile as sf import math -try: - import torch -except ImportError: - torch = None from typing import List import numpy as np from whisperlivekit.timed_objects import ASRToken -from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE logger = logging.getLogger(__name__) -SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError( -"""SimulStreaming dependencies are not available. -Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]" -""") - -try: - from whisperlivekit.simul_whisper.config import AlignAttConfig - from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD - from whisperlivekit.simul_whisper.whisper import tokenizer - SIMULSTREAMING_AVAILABLE = True -except ImportError: - SIMULSTREAMING_AVAILABLE = False - AlignAttConfig = None - PaddedAlignAttWhisper = None - DEC_PAD = None - tokenizer = None - class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, # "" for faster-whisper because it emits the spaces when needed) @@ -309,181 +287,4 @@ class OpenaiApiASR(ASRBase): self.use_vad_opt = True def set_translate_task(self): - self.task = "translate" - - -class SimulStreamingASR(ASRBase): - """SimulStreaming backend with AlignAtt policy.""" - sep = "" - - def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs): - if not SIMULSTREAMING_AVAILABLE: - raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS - logger.warning(SIMULSTREAMING_LICENSE) - self.logfile = logfile - self.transcribe_kargs = {} - self.original_language = None if lan == "auto" else lan - - self.model_path = kwargs.get('model_path', './large-v3.pt') - self.frame_threshold = kwargs.get('frame_threshold', 25) - self.audio_max_len = kwargs.get('audio_max_len', 30.0) - self.audio_min_len = kwargs.get('audio_min_len', 0.0) - self.segment_length = kwargs.get('segment_length', 0.5) - self.beams = kwargs.get('beams', 1) - self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam') - self.task = kwargs.get('task', 'transcribe') - self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None) - self.never_fire = kwargs.get('never_fire', False) - self.init_prompt = kwargs.get('init_prompt', None) - self.static_init_prompt = kwargs.get('static_init_prompt', None) - self.max_context_tokens = kwargs.get('max_context_tokens', None) - - if model_dir is not None: - self.model_path = model_dir - elif modelsize is not None: #For the moment the .en.pt models do not work! - model_mapping = { - 'tiny': './tiny.pt', - 'base': './base.pt', - 'small': './small.pt', - 'medium': './medium.pt', - 'medium.en': './medium.en.pt', - 'large-v1': './large-v1.pt', - 'base.en': './base.en.pt', - 'small.en': './small.en.pt', - 'tiny.en': './tiny.en.pt', - 'large-v2': './large-v2.pt', - 'large-v3': './large-v3.pt', - 'large': './large-v3.pt' - } - self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt') - - self.model = self.load_model(modelsize, cache_dir, model_dir) - - # Set up tokenizer for translation if needed - if self.task == "translate": - self.set_translate_task() - - def load_model(self, modelsize, cache_dir, model_dir): - try: - cfg = AlignAttConfig( - model_path=self.model_path, - segment_length=self.segment_length, - frame_threshold=self.frame_threshold, - language=self.original_language, - audio_max_len=self.audio_max_len, - audio_min_len=self.audio_min_len, - cif_ckpt_path=self.cif_ckpt_path, - decoder_type="beam", - beam_size=self.beams, - task=self.task, - never_fire=self.never_fire, - init_prompt=self.init_prompt, - max_context_tokens=self.max_context_tokens, - static_init_prompt=self.static_init_prompt, - ) - - logger.info(f"Loading SimulStreaming model with language: {self.original_language}") - model = PaddedAlignAttWhisper(cfg) - return model - - except Exception as e: - logger.error(f"Failed to load SimulStreaming model: {e}") - raise - - def transcribe(self, audio, init_prompt=""): - """Transcribe audio using SimulStreaming.""" - try: - if isinstance(audio, np.ndarray): - audio_tensor = torch.from_numpy(audio).float() - else: - audio_tensor = audio - - prompt = init_prompt if init_prompt else (self.init_prompt or "") - - result = self.model.infer(audio_tensor, init_prompt=prompt) - - if torch.is_tensor(result): - result = result[result < DEC_PAD] - - logger.debug(f"SimulStreaming transcription result: {result}") - return result - - except Exception as e: - logger.error(f"SimulStreaming transcription failed: {e}") - raise - - def ts_words(self, result) -> List[ASRToken]: - """Convert SimulStreaming result to ASRToken list.""" - tokens = [] - - try: - if torch.is_tensor(result): - text = self.model.tokenizer.decode(result.cpu().numpy()) - else: - text = str(result) - - if not text or len(text.strip()) == 0: - return tokens - - # We dont have word-level timestamps here. 1rst approach, should be improved later. - words = text.strip().split() - if not words: - return tokens - - duration_per_word = 0.1 # this will be modified based on actual audio duration - #with the SimulStreamingOnlineProcessor - - for i, word in enumerate(words): - start_time = i * duration_per_word - end_time = (i + 1) * duration_per_word - - token = ASRToken( - start=start_time, - end=end_time, - text=word, - probability=1.0 - ) - tokens.append(token) - - except Exception as e: - logger.error(f"Error converting SimulStreaming result to tokens: {e}") - - return tokens - - def segments_end_ts(self, result) -> List[float]: - """Get segment end timestamps.""" - if torch.is_tensor(result): - num_tokens = len(result) - return [num_tokens * 0.1] # rough estimate - return [1.0] - - def use_vad(self): - """Enable VAD - SimulStreaming has different VAD handling.""" - logger.info("VAD requested for SimulStreaming - handled internally by the model") - pass - - def set_translate_task(self): - """Set up translation task.""" - try: - self.model.tokenizer = tokenizer.get_tokenizer( - multilingual=True, - language=self.model.cfg.language, - num_languages=self.model.model.num_languages, - task="translate" - ) - logger.info("SimulStreaming configured for translation task") - except Exception as e: - logger.error(f"Failed to configure SimulStreaming for translation: {e}") - raise - - def warmup(self, audio, init_prompt=""): - """Warmup the SimulStreaming model.""" - try: - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio).float() - self.model.insert_audio(audio) - self.model.infer(True) - self.model.refresh_segment(complete=True) - logger.info("SimulStreaming model warmed up successfully") - except Exception as e: - logger.exception(f"SimulStreaming warmup failed: {e}") + self.task = "translate" \ No newline at end of file diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index 7f2c65c..04a57a6 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -528,204 +528,3 @@ class VACOnlineASRProcessor: """ return self.online.concatenate_tokens(self.online.transcript_buffer.buffer) - -class SimulStreamingOnlineProcessor: - SAMPLING_RATE = 16000 - - def __init__( - self, - asr, - tokenize_method: Optional[callable] = None, - buffer_trimming: Tuple[str, float] = ("segment", 15), - confidence_validation = False, - logfile=sys.stderr, - ): - if not SIMULSTREAMING_AVAILABLE: - raise ImportError("SimulStreaming dependencies are not available.") - - self.asr = asr - self.tokenize = tokenize_method - self.logfile = logfile - self.confidence_validation = confidence_validation - self.init() - - # buffer does not work yet - self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming - - def init(self, offset: Optional[float] = None): - """Initialize or reset the processing state.""" - self.audio_chunks = [] - self.offset = offset if offset is not None else 0.0 - self.is_last = False - self.beg = self.offset - self.end = self.offset - self.cumulative_audio_duration = 0.0 - self.last_audio_stream_end_time = self.offset - - self.committed: List[ASRToken] = [] - self.last_result_tokens: List[ASRToken] = [] - self.buffer_content = "" - self.processed_audio_duration = 0.0 - - def get_audio_buffer_end_time(self) -> float: - """Returns the absolute end time of the current audio buffer.""" - return self.end - - def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None): - """Append an audio chunk to be processed by SimulStreaming.""" - if torch is None: - raise ImportError("PyTorch is required for SimulStreaming but not available") - - # Convert numpy array to torch tensor - audio_tensor = torch.from_numpy(audio).float() - self.audio_chunks.append(audio_tensor) - - # Update timing - chunk_duration = len(audio) / self.SAMPLING_RATE - self.cumulative_audio_duration += chunk_duration - - if audio_stream_end_time is not None: - self.last_audio_stream_end_time = audio_stream_end_time - self.end = audio_stream_end_time - else: - self.end = self.offset + self.cumulative_audio_duration - - def prompt(self) -> Tuple[str, str]: - """ - Returns a tuple: (prompt, context). - SimulStreaming handles prompting internally, so we return empty strings. - """ - return "", "" - - def get_buffer(self): - """ - Get the unvalidated buffer content. - """ - buffer_end = self.end if hasattr(self, 'end') else None - return Transcript( - start=None, - end=buffer_end, - text=self.buffer_content, - probability=None - ) - - def timestamped_text(self, tokens, generation): - # From the simulstreaming repo. self.model to self.asr.model - pr = generation["progress"] - if "result" not in generation: - split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens) - else: - split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"] - - frames = [p["most_attended_frames"][0] for p in pr] - tokens = tokens.copy() - ret = [] - for sw,st in zip(split_words,split_tokens): - b = None - for stt in st: - t,f = tokens.pop(0), frames.pop(0) - if t != stt: - raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.") - if b is None: - b = f - e = f - out = (b*0.02, e*0.02, sw) - ret.append(out) - logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}") - return ret - - def process_iter(self) -> Tuple[List[ASRToken], float]: - """ - Process accumulated audio chunks using SimulStreaming. - - Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). - """ - if not self.audio_chunks: - return [], self.end - - try: - # concatenate all audio chunks - if len(self.audio_chunks) == 1: - audio = self.audio_chunks[0] - else: - audio = torch.cat(self.audio_chunks, dim=0) - - audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0 - self.processed_audio_duration += audio_duration - - self.audio_chunks = [] - - logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s") - logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s") - - self.asr.model.insert_audio(audio) - tokens, generation_progress = self.asr.model.infer(is_last=self.is_last) - ts_words = self.timestamped_text(tokens, generation_progress) - text = self.asr.model.tokenizer.decode(tokens) - - new_tokens = [] - for ts_word in ts_words: - - start, end, word = ts_word - token = ASRToken( - start=start, - end=end, - text=word, - probability=0.95 # fake prob. Maybe we can extract it from the model? - ) - new_tokens.append(token) - self.committed.extend(new_tokens) - - return new_tokens, self.end - - - except Exception as e: - logger.exception(f"SimulStreaming processing error: {e}") - return [], self.end - - def finish(self) -> Tuple[List[ASRToken], float]: - logger.debug("SimulStreaming finish() called") - self.is_last = True - final_tokens, final_time = self.process_iter() - self.is_last = False - return final_tokens, final_time - - def concatenate_tokens( - self, - tokens: List[ASRToken], - sep: Optional[str] = None, - offset: float = 0 - ) -> Transcript: - """Concatenate tokens into a Transcript object.""" - sep = sep if sep is not None else self.asr.sep - text = sep.join(token.text for token in tokens) - probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None - if tokens: - start = offset + tokens[0].start - end = offset + tokens[-1].end - else: - start = None - end = None - return Transcript(start, end, text, probability=probability) - - def chunk_at(self, time: float): - """ - useless but kept for compatibility - """ - logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally") - pass - - def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]: - """ - Create simple sentences. - """ - if not tokens: - return [] - - full_text = " ".join(token.text for token in tokens) - sentence = Sentence( - start=tokens[0].start, - end=tokens[-1].end, - text=full_text - ) - return [sentence] diff --git a/whisperlivekit/whisper_streaming_custom/whisper_online.py b/whisperlivekit/whisper_streaming_custom/whisper_online.py index 352ae1d..25349ac 100644 --- a/whisperlivekit/whisper_streaming_custom/whisper_online.py +++ b/whisperlivekit/whisper_streaming_custom/whisper_online.py @@ -5,8 +5,7 @@ import librosa from functools import lru_cache import time import logging -from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS -from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE +from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR logger = logging.getLogger(__name__) @@ -68,35 +67,7 @@ def backend_factory(args): backend = args.backend if backend == "openai-api": logger.debug("Using OpenAI API.") - asr = OpenaiApiASR(lan=args.lan) - elif backend == "simulstreaming": - logger.debug("Using SimulStreaming backend.") - if not SIMULSTREAMING_AVAILABLE: - raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS - - simulstreaming_kwargs = {} - for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', - 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', - 'max_context_tokens', 'model_path']: - if hasattr(args, attr): - simulstreaming_kwargs[attr] = getattr(args, attr) - - # Add segment_length from min_chunk_size - simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5) - simulstreaming_kwargs['task'] = args.task - - size = args.model - t = time.time() - logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...") - asr = SimulStreamingASR( - modelsize=size, - lan=args.lan, - cache_dir=getattr(args, 'model_cache_dir', None), - model_dir=getattr(args, 'model_dir', None), - **simulstreaming_kwargs - ) - e = time.time() - logger.info(f"done. It took {round(e-t,2)} seconds.") + asr = OpenaiApiASR(lan=args.lan) else: if backend == "faster-whisper": asr_cls = FasterWhisperASR @@ -138,46 +109,6 @@ def backend_factory(args): tokenizer = None return asr, tokenizer -def online_factory(args, asr, tokenizer, logfile=sys.stderr): - if args.backend == "simulstreaming": - if not SIMULSTREAMING_ONLINE_AVAILABLE: - raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS - - logger.debug("Creating SimulStreaming online processor") - online = SimulStreamingOnlineProcessor( - asr, - tokenizer, - logfile=logfile, - buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), - confidence_validation=args.confidence_validation - ) - elif args.vac: - online = VACOnlineASRProcessor( - args.min_chunk_size, - asr, - tokenizer, - logfile=logfile, - buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), - confidence_validation = args.confidence_validation - ) - else: - online = OnlineASRProcessor( - asr, - tokenizer, - logfile=logfile, - buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), - confidence_validation = args.confidence_validation - ) - return online - -def asr_factory(args, logfile=sys.stderr): - """ - Creates and configures an ASR and ASR Online instance based on the specified backend and arguments. - """ - asr, tokenizer = backend_factory(args) - online = online_factory(args, asr, tokenizer, logfile=logfile) - return asr, online - def warmup_asr(asr, warmup_file=None, timeout=5): """ Warmup the ASR model by transcribing a short audio file.