clean simulwhisper backend and online

This commit is contained in:
Quentin Fuxa 2025-08-09 18:02:15 +02:00
parent 197293e25e
commit b05297a96d
2 changed files with 5 additions and 97 deletions

View file

@ -3,16 +3,15 @@ import numpy as np
import logging import logging
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import logging import logging
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
import torch import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer from whisperlivekit.simul_whisper.whisper import tokenizer
SIMULSTREAMING_AVAILABLE = True
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"""SimulStreaming dependencies are not available. """SimulStreaming dependencies are not available.
@ -28,23 +27,14 @@ class SimulStreamingOnlineProcessor:
buffer_trimming: Tuple[str, float] = ("segment", 15), buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False, confidence_validation = False,
logfile=sys.stderr, logfile=sys.stderr,
): ):
if not SIMULSTREAMING_AVAILABLE:
raise ImportError("SimulStreaming dependencies are not available.")
self.asr = asr self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile self.logfile = logfile
self.confidence_validation = confidence_validation self.confidence_validation = confidence_validation
self.init()
# buffer does not work yet # buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing state."""
self.audio_chunks = [] self.audio_chunks = []
self.offset = offset if offset is not None else 0.0 self.offset = 0.0
self.is_last = False self.is_last = False
self.beg = self.offset self.beg = self.offset
self.end = self.offset self.end = self.offset
@ -56,14 +46,8 @@ class SimulStreamingOnlineProcessor:
self.buffer_content = "" self.buffer_content = ""
self.processed_audio_duration = 0.0 self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
return self.end
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming.""" """Append an audio chunk to be processed by SimulStreaming."""
if torch is None:
raise ImportError("PyTorch is required for SimulStreaming but not available")
# Convert numpy array to torch tensor # Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float() audio_tensor = torch.from_numpy(audio).float()
@ -79,13 +63,6 @@ class SimulStreamingOnlineProcessor:
else: else:
self.end = self.offset + self.cumulative_audio_duration self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
Returns a tuple: (prompt, context).
SimulStreaming handles prompting internally, so we return empty strings.
"""
return "", ""
def get_buffer(self): def get_buffer(self):
""" """
Get the unvalidated buffer content. Get the unvalidated buffer content.
@ -150,7 +127,6 @@ class SimulStreamingOnlineProcessor:
self.asr.model.insert_audio(audio) self.asr.model.insert_audio(audio)
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last) tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress) ts_words = self.timestamped_text(tokens, generation_progress)
text = self.asr.model.tokenizer.decode(tokens)
new_tokens = [] new_tokens = []
for ts_word in ts_words: for ts_word in ts_words:
@ -172,55 +148,6 @@ class SimulStreamingOnlineProcessor:
logger.exception(f"SimulStreaming processing error: {e}") logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end return [], self.end
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug("SimulStreaming finish() called")
self.is_last = True
final_tokens, final_time = self.process_iter()
self.is_last = False
return final_tokens, final_time
def concatenate_tokens(
self,
tokens: List[ASRToken],
sep: Optional[str] = None,
offset: float = 0
) -> Transcript:
"""Concatenate tokens into a Transcript object."""
sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens:
start = offset + tokens[0].start
end = offset + tokens[-1].end
else:
start = None
end = None
return Transcript(start, end, text, probability=probability)
def chunk_at(self, time: float):
"""
useless but kept for compatibility
"""
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
pass
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
"""
Create simple sentences.
"""
if not tokens:
return []
full_text = " ".join(token.text for token in tokens)
sentence = Sentence(
start=tokens[0].start,
end=tokens[-1].end,
text=full_text
)
return [sentence]
class SimulStreamingASR(): class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
sep = "" sep = ""
@ -247,7 +174,7 @@ class SimulStreamingASR():
if model_dir is not None: if model_dir is not None:
self.model_path = model_dir self.model_path = model_dir
elif modelsize is not None: #For the moment the .en.pt models do not work! elif modelsize is not None:
model_mapping = { model_mapping = {
'tiny': './tiny.pt', 'tiny': './tiny.pt',
'base': './base.pt', 'base': './base.pt',
@ -297,13 +224,6 @@ class SimulStreamingASR():
logger.error(f"Failed to load SimulStreaming model: {e}") logger.error(f"Failed to load SimulStreaming model: {e}")
raise raise
def segments_end_ts(self, result) -> List[float]:
"""Get segment end timestamps."""
if torch.is_tensor(result):
num_tokens = len(result)
return [num_tokens * 0.1] # rough estimate
return [1.0]
def set_translate_task(self): def set_translate_task(self):
"""Set up translation task.""" """Set up translation task."""
try: try:

View file

@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")
SIMULSTREAMING_AVAILABLE = False
OnlineProcessorInterface = None
torch = None
class HypothesisBuffer: class HypothesisBuffer:
""" """
Buffer to store and process ASR hypothesis tokens. Buffer to store and process ASR hypothesis tokens.