latest version of simulstreaming

This commit is contained in:
Quentin Fuxa 2025-07-31 16:44:23 +02:00
parent 4b738d6f63
commit 00424d7ca3
2 changed files with 125 additions and 48 deletions

View file

@ -25,6 +25,9 @@ class BeamTokens(Tokens):
def __repr__(self):
return self.__str__()
def as_text(self, tokenizer):
return tokenizer.decode(self.tokens)
class Logits(Tokens):
def __init__(self, logits):
super().__init__(logits)

View file

@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
from token_buffer import TokenBuffer
import numpy as np
from .generation_progress import *
@ -24,6 +24,7 @@ DEC_PAD = 50257
logger = logging.getLogger(__name__)
import sys
import wave
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
@ -33,28 +34,26 @@ import sys
# - context
class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
self.model = load_model(name=model_name, download_root=model_path)
logger.info(f"Model dimensions: {self.model.dims}")
decode_options = DecodingOptions(
self.decode_options = DecodingOptions(
language = cfg.language,
without_timestamps = True,
task=cfg.task
)
self.tokenizer = tokenizer.get_tokenizer(
multilingual=not model_name.endswith(".en"),
language=cfg.language,
num_languages=self.model.num_languages,
task=decode_options.task
)
self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.detected_language = cfg.language if cfg.language != "auto" else None
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
# model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
n_audio_state=self.model.dims.n_audio_state,
@ -95,14 +94,6 @@ class PaddedAlignAttWhisper:
self.num_align_heads += 1
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [
self.tokenizer.transcribe,
@ -121,6 +112,17 @@ class PaddedAlignAttWhisper:
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
# decoder type: greedy or beam
if cfg.decoder_type == "greedy":
@ -135,16 +137,13 @@ class PaddedAlignAttWhisper:
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# init state
self.segments = []
self.tokens = [self.initial_tokens]
self.last_attend_frame = -self.cfg.rewind_threshold
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
def init_context(self):
kw = {'tokenizer': self.tokenizer,
@ -156,6 +155,19 @@ class PaddedAlignAttWhisper:
if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt
def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}")
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# self.segments = []
logger.debug(f"init tokens after, {len(self.segments)}")
self.tokens = [self.initial_tokens]
def trim_context(self):
logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
@ -191,15 +203,19 @@ class PaddedAlignAttWhisper:
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment")
self.tokens = [self.initial_tokens]
logger.debug("Refreshing segment:")
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:]
else:
logger.debug("removing all segments.")
self.segments = []
self.log_segments += 1
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
@ -208,8 +224,6 @@ class PaddedAlignAttWhisper:
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
def _current_tokens(self):
toks = self.tokens
@ -256,16 +270,59 @@ class PaddedAlignAttWhisper:
removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len()
while segments_len > self.cfg.audio_max_len:
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.segments = self.segments[1:]
if len(self.tokens) > 1: # When warming up, we can have a too long segments_len while not having any tokens yet
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len
def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache.
It must be called every time after generation with the model.'''
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
@torch.no_grad()
def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(self.tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
single = encoder_features.ndim == 2
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
self._clean_cache()
return language_tokens, language_probs
### transcription / translation
@ -273,9 +330,12 @@ class PaddedAlignAttWhisper:
def infer(self, is_last=False):
new_segment = True
if len(self.segments) == 0:
return []
logger.debug("No segments, nothing to do")
return [], {}
if not self._apply_minseglen():
return []
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return [], {}
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
@ -283,8 +343,7 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
self.trim_context()
current_tokens = self._current_tokens()
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
@ -295,18 +354,38 @@ class PaddedAlignAttWhisper:
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
# encode
encoder_feature = self.model.encoder(mel)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context()
current_tokens = self._current_tokens()
#
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
attn_of_alignment_heads = None
miost_attended_frame = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
@ -515,11 +594,6 @@ class PaddedAlignAttWhisper:
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
self._clean_cache()
return new_hypothesis, generation
return new_hypothesis, generation