O(n) to O(1) for simulstreaming timestamp determination

This commit is contained in:
Quentin Fuxa 2025-09-21 11:04:00 +02:00
parent e61afdefa3
commit a5503308c5
5 changed files with 98 additions and 71 deletions

View file

@ -429,7 +429,7 @@ class AudioProcessor:
state = await self.get_current_state() state = await self.get_current_state()
# Format output # Format output
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output( lines, undiarized_text, end_w_silence = format_output(
state, state,
self.silence, self.silence,
current_time = time() - self.beg_loop if self.beg_loop else None, current_time = time() - self.beg_loop if self.beg_loop else None,
@ -437,6 +437,13 @@ class AudioProcessor:
debug = self.debug, debug = self.debug,
sep=self.sep sep=self.sep
) )
if end_w_silence:
buffer_transcription = ''
buffer_diarization = ''
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
# Handle undiarized text # Handle undiarized text
if undiarized_text: if undiarized_text:
combined = self.sep.join(undiarized_text) combined = self.sep.join(undiarized_text)

View file

@ -77,15 +77,17 @@ def no_token_to_silence(tokens):
new_tokens.append(token) new_tokens.append(token)
return new_tokens return new_tokens
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence): def ends_with_silence(tokens, current_time, vac_detected_silence):
end_w_silence = False
if not tokens: if not tokens:
return [], buffer_transcription, buffer_diarization return [], end_w_silence
last_token = tokens[-1] last_token = tokens[-1]
if tokens and current_time and ( if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION current_time - last_token.end >= END_SILENCE_DURATION
or or
(current_time - last_token.end >= 3 and vac_detected_silence) (current_time - last_token.end >= 3 and vac_detected_silence)
): ):
end_w_silence = True
if last_token.speaker == -2: if last_token.speaker == -2:
last_token.end = current_time last_token.end = current_time
else: else:
@ -97,14 +99,12 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
probability=0.95 probability=0.95
) )
) )
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence return tokens, end_w_silence
buffer_diarization = ""
return tokens, buffer_transcription, buffer_diarization
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence): def handle_silences(tokens, current_time, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens) tokens = no_token_to_silence(tokens)
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence) tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence)
return tokens, buffer_transcription, buffer_diarization return tokens, end_w_silence

View file

@ -50,14 +50,12 @@ def format_output(state, silence, current_time, args, debug, sep):
disable_punctuation_split = args.disable_punctuation_split disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens tokens = state.tokens
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
end_attributed_speaker = state.end_attributed_speaker end_attributed_speaker = state.end_attributed_speaker
previous_speaker = -1 previous_speaker = -1
lines = [] lines = []
undiarized_text = [] undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence) tokens, end_w_silence = handle_silences(tokens, current_time, silence)
last_punctuation = None last_punctuation = None
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
speaker = token.speaker speaker = token.speaker
@ -121,6 +119,7 @@ def format_output(state, silence, current_time, args, debug, sep):
pass pass
append_token_to_last_line(lines, sep, token, debug_info) append_token_to_last_line(lines, sep, token, debug_info)
if lines and translated_segments: if lines and translated_segments:
unassigned_translated_segments = [] unassigned_translated_segments = []
for ts in translated_segments: for ts in translated_segments:
@ -151,4 +150,4 @@ def format_output(state, silence, current_time, args, debug, sep):
else: else:
remaining_segments.append(ts) remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
return lines, undiarized_text, buffer_transcription, '' return lines, undiarized_text, end_w_silence

View file

@ -6,7 +6,6 @@ import logging
import platform import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment
from whisperlivekit.warmup import load_file from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND from .whisper.audio import TOKENS_PER_SECOND
import os import os
@ -23,7 +22,11 @@ try:
HAS_MLX_WHISPER = True HAS_MLX_WHISPER = True
except ImportError: except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64": if platform.system() == "Darwin" and platform.machine() == "arm64":
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper') print(f"""
{"="*50}
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
{"="*50}
""")
HAS_MLX_WHISPER = False HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False HAS_FASTER_WHISPER = False
@ -49,8 +52,12 @@ class SimulStreamingOnlineProcessor:
self.asr = asr self.asr = asr
self.logfile = logfile self.logfile = logfile
self.end = 0.0 self.end = 0.0
self.global_time_offset = 0.0 self.buffer = Transcript(
start=None,
end=None,
text='',
probability=None
)
self.committed: List[ASRToken] = [] self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = []
self.load_new_backend() self.load_new_backend()
@ -79,7 +86,7 @@ class SimulStreamingOnlineProcessor:
else: else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer. self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.global_time_offset = silence_duration + offset self.model.global_time_offset = silence_duration + offset
@ -96,31 +103,7 @@ class SimulStreamingOnlineProcessor:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
def get_buffer(self): def get_buffer(self):
return Transcript( return self.buffer
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, split_words, split_tokens, l_absolute_timestamps):
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
for i in word_tokens:
current_timestamp = l_absolute_timestamps.pop(0)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text=word,
probability=0.95
).with_offset(
self.global_time_offset
)
timestamped_words.append(timestamp_entry)
return timestamped_words
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
""" """
@ -129,9 +112,7 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
""" """
try: try:
split_words, split_tokens, l_absolute_timestamps = self.model.infer(is_last=is_last) new_tokens = self.model.infer(is_last=is_last)
new_tokens = self.timestamped_text(split_words, split_tokens, l_absolute_timestamps)
self.committed.extend(new_tokens) self.committed.extend(new_tokens)
return new_tokens, self.end return new_tokens, self.end
@ -163,7 +144,6 @@ class SimulStreamingASR():
sep = "" sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs): def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile self.logfile = logfile
self.transcribe_kargs = {} self.transcribe_kargs = {}
self.original_language = lan self.original_language = lan

View file

@ -8,6 +8,7 @@ import torch.nn.functional as F
from .whisper import load_model, DecodingOptions, tokenizer from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter from .whisper.timing import median_filter
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
@ -18,6 +19,7 @@ from time import time
from .token_buffer import TokenBuffer from .token_buffer import TokenBuffer
import numpy as np import numpy as np
from ..timed_objects import PUNCTUATION_MARKS
from .generation_progress import * from .generation_progress import *
DEC_PAD = 50257 DEC_PAD = 50257
@ -40,12 +42,6 @@ else:
except ImportError: except ImportError:
HAS_FASTER_WHISPER = False HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
# - translation support
# - beam search
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper: class PaddedAlignAttWhisper:
def __init__( def __init__(
self, self,
@ -79,6 +75,9 @@ class PaddedAlignAttWhisper:
self.tokenizer_is_multilingual = not model_name.endswith(".en") self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.create_tokenizer(cfg.language if cfg.language != "auto" else None) self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.detected_language = cfg.language if cfg.language != "auto" else None self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.sentence_start_time = 0.0
self.max_text_len = self.model.dims.n_text_ctx self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks) self.num_decoder_layers = len(self.model.decoder.blocks)
@ -153,6 +152,7 @@ class PaddedAlignAttWhisper:
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
if self.cfg.max_context_tokens is None: if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len self.max_context_tokens = self.max_text_len
@ -382,11 +382,11 @@ class PaddedAlignAttWhisper:
new_segment = True new_segment = True
if len(self.segments) == 0: if len(self.segments) == 0:
logger.debug("No segments, nothing to do") logger.debug("No segments, nothing to do")
return [], [], [] return []
if not self._apply_minseglen(): if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0) input_segments = torch.cat(self.segments, dim=0)
return [], [], [] return []
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
if len(self.segments) > 1: if len(self.segments) > 1:
@ -394,6 +394,13 @@ class PaddedAlignAttWhisper:
else: else:
input_segments = self.segments[0] input_segments = self.segments[0]
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
# logger.debug("Resetting tokenizer to auto for new sentence.")
# self.create_tokenizer(None)
# self.detected_language = None
# self.init_tokens()
# self.reset_tokenizer_to_auto_next_call = False
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time() beg_encode = time()
if self.mlx_encoder: if self.mlx_encoder:
@ -426,17 +433,21 @@ class PaddedAlignAttWhisper:
end_encode = time() end_encode = time()
# print('Encoder duration:', end_encode-beg_encode) # print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None: # if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature) # seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}") # if seconds_since_start >= 3.0:
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) # language_tokens, language_probs = self.lang_id(encoder_feature)
logger.info(f"Detected language: {top_lan} with p={p:.4f}") # logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
#self.tokenizer.language = top_lan # top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
#self.tokenizer.__post_init__() # logger.info(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan) # #self.tokenizer.language = top_lan
self.detected_language = top_lan # #self.tokenizer.__post_init__()
self.init_tokens() # self.create_tokenizer(top_lan)
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") # self.detected_language = top_lan
# self.init_tokens()
# logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
# else:
# logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
self.trim_context() self.trim_context()
current_tokens = self._current_tokens() current_tokens = self._current_tokens()
@ -446,6 +457,7 @@ class PaddedAlignAttWhisper:
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device) sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
completed = False completed = False
# punctuation_stop = False
attn_of_alignment_heads = None attn_of_alignment_heads = None
most_attended_frame = None most_attended_frame = None
@ -467,9 +479,7 @@ class PaddedAlignAttWhisper:
if new_segment and self.tokenizer.no_speech is not None: if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1) probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# generation["no_speech_prob"] = no_speech_probs[0]
if no_speech_probs[0] > self.cfg.nonspeech_prob: if no_speech_probs[0] > self.cfg.nonspeech_prob:
# generation["no_speech"] = True
logger.info("no speech, stop") logger.info("no speech, stop")
break break
@ -485,6 +495,19 @@ class PaddedAlignAttWhisper:
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens) self.debug_print_tokens(current_tokens)
# # Early stop on sentence-ending punctuation when language is auto
# if not completed and self.cfg.language == "auto":
# last_token_id = current_tokens[0, -1].item()
# last_token_text = self.tokenizer.decode([last_token_id]).strip()
# if last_token_text in PUNCTUATION_MARKS:
# logger.debug(f"Punctuation boundary '{last_token_text}' hit; stopping early to allow language re-check.")
# punctuation_stop = True
# # Ensure next call starts with auto language (re-detect for new sentence)
# self.reset_tokenizer_to_auto_next_call = True
# self.detected_language = None
# self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
# break
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)] attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns): for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks)) layer_rank = int(i % len(self.model.decoder.blocks))
@ -560,7 +583,7 @@ class PaddedAlignAttWhisper:
tokens_to_split = current_tokens[0, token_len_before_decoding:] tokens_to_split = current_tokens[0, token_len_before_decoding:]
if fire_detected or is_last: if fire_detected or is_last: #or punctuation_stop:
new_hypothesis = tokens_to_split.flatten().tolist() new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else: else:
@ -582,4 +605,22 @@ class PaddedAlignAttWhisper:
self._clean_cache() self._clean_cache()
return split_words, split_tokens, l_absolute_timestamps timestamped_words = []
timestamp_idx = 0
for word, word_tokens in zip(split_words, split_tokens):
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text=word,
probability=0.95
).with_offset(
self.global_time_offset
)
timestamped_words.append(timestamp_entry)
return timestamped_words