diff --git a/whisperlivekit/simul_whisper/generation_progress.py b/whisperlivekit/simul_whisper/generation_progress.py index e17e3ea..86c6263 100644 --- a/whisperlivekit/simul_whisper/generation_progress.py +++ b/whisperlivekit/simul_whisper/generation_progress.py @@ -25,6 +25,9 @@ class BeamTokens(Tokens): def __repr__(self): return self.__str__() + def as_text(self, tokenizer): + return tokenizer.decode(self.tokens) + class Logits(Tokens): def __init__(self, logits): super().__init__(logits) diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index ce8123a..59c4089 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer from .config import AlignAttConfig from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES from .whisper.timing import median_filter -from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens +from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language from .beam import BeamPyTorchInference from .eow_detection import fire_at_boundary, load_cif import os -from whisperlivekit.simul_whisper.token_buffer import TokenBuffer +from token_buffer import TokenBuffer import numpy as np from .generation_progress import * @@ -24,6 +24,7 @@ DEC_PAD = 50257 logger = logging.getLogger(__name__) import sys +import wave # New features added to the original version of Simul-Whisper: # - large-v3 model support @@ -33,28 +34,26 @@ import sys # - context class PaddedAlignAttWhisper: def __init__(self, cfg: AlignAttConfig) -> None: + self.log_segments = 0 model_name = os.path.basename(cfg.model_path).replace(".pt", "") model_path = os.path.dirname(os.path.abspath(cfg.model_path)) self.model = load_model(name=model_name, download_root=model_path) logger.info(f"Model dimensions: {self.model.dims}") - decode_options = DecodingOptions( + self.decode_options = DecodingOptions( language = cfg.language, without_timestamps = True, task=cfg.task ) - self.tokenizer = tokenizer.get_tokenizer( - multilingual=not model_name.endswith(".en"), - language=cfg.language, - num_languages=self.model.num_languages, - task=decode_options.task - ) + self.tokenizer_is_multilingual = not model_name.endswith(".en") + self.create_tokenizer(cfg.language if cfg.language != "auto" else None) + self.detected_language = cfg.language if cfg.language != "auto" else None + self.max_text_len = self.model.dims.n_text_ctx self.num_decoder_layers = len(self.model.decoder.blocks) self.cfg = cfg - # model to detect end-of-word boundary at the end of the segment self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg, n_audio_state=self.model.dims.n_audio_state, @@ -95,14 +94,6 @@ class PaddedAlignAttWhisper: self.num_align_heads += 1 - # init tokens (mandatory prompt) - self.initial_tokens = torch.tensor( - self.tokenizer.sot_sequence_including_notimestamps, - dtype=torch.long, - device=self.model.device).unsqueeze(0) - self.initial_token_length = self.initial_tokens.shape[1] - self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) - # tokens to be suppressed from decoding, to prevent hallucinations suppress_tokens = [ self.tokenizer.transcribe, @@ -121,6 +112,17 @@ class PaddedAlignAttWhisper: self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None) # blank tokens are suppresed for new segments near the line 334 + # it's going to be regenerated after lang id + self.segments = [] + self.init_tokens() + + self.last_attend_frame = -self.cfg.rewind_threshold + + if self.cfg.max_context_tokens is None: + self.max_context_tokens = self.max_text_len + else: + self.max_context_tokens = self.cfg.max_context_tokens + self.init_context() # decoder type: greedy or beam if cfg.decoder_type == "greedy": @@ -135,16 +137,13 @@ class PaddedAlignAttWhisper: self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) - # init state - self.segments = [] - self.tokens = [self.initial_tokens] - self.last_attend_frame = -self.cfg.rewind_threshold - - if self.cfg.max_context_tokens is None: - self.max_context_tokens = self.max_text_len - else: - self.max_context_tokens = self.cfg.max_context_tokens - self.init_context() + def create_tokenizer(self, language=None): + self.tokenizer = tokenizer.get_tokenizer( + multilingual=self.tokenizer_is_multilingual, + language=language, + num_languages=self.model.num_languages, + task=self.decode_options.task + ) def init_context(self): kw = {'tokenizer': self.tokenizer, @@ -156,6 +155,19 @@ class PaddedAlignAttWhisper: if self.cfg.init_prompt is not None: self.context.text += self.cfg.init_prompt + def init_tokens(self): + logger.debug(f"init tokens, {len(self.segments)}") + # init tokens (mandatory prompt) + self.initial_tokens = torch.tensor( + self.tokenizer.sot_sequence_including_notimestamps, + dtype=torch.long, + device=self.model.device).unsqueeze(0) + self.initial_token_length = self.initial_tokens.shape[1] + self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) +# self.segments = [] + logger.debug(f"init tokens after, {len(self.segments)}") + self.tokens = [self.initial_tokens] + def trim_context(self): logger.info("Trimming context") c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids) @@ -191,15 +203,19 @@ class PaddedAlignAttWhisper: def refresh_segment(self, complete=False): - logger.debug("Refreshing segment") - self.tokens = [self.initial_tokens] + logger.debug("Refreshing segment:") + self.init_tokens() self.last_attend_frame = -self.cfg.rewind_threshold + self.detected_language = None self.init_context() logger.debug(f"Context: {self.context}") if not complete and len(self.segments) > 2: + logger.debug("keeping last two segments because they are and it is not complete.") self.segments = self.segments[-2:] else: + logger.debug("removing all segments.") self.segments = [] + self.log_segments += 1 def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): @@ -208,8 +224,6 @@ class PaddedAlignAttWhisper: return fire_at_boundary(chunked_encoder_feature, self.CIFLinear) - - def _current_tokens(self): toks = self.tokens @@ -256,16 +270,59 @@ class PaddedAlignAttWhisper: removed_len = 0 # len of audio is bigger than buffer_len. Going to remove the first segment segments_len = self.segments_len() - while segments_len > self.cfg.audio_max_len: + while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len: removed_len = self.segments[0].shape[0] / 16000 segments_len -= removed_len self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) self.segments = self.segments[1:] - if len(self.tokens) > 1: # When warming up, we can have a too long segments_len while not having any tokens yet + logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}") + if len(self.tokens) > 1: self.context.append_token_ids(self.tokens[1][0,:]) self.tokens = [self.initial_tokens] + self.tokens[2:] return removed_len + def _clean_cache(self): + '''clean the cache that stores the attention matrices and kv_cache. + It must be called every time after generation with the model.''' + # cleaning cache + self.dec_attns = [] + self.kv_cache = {} + if self.decoder_type == "beam": + self.inference.kv_cache = self.kv_cache + self.token_decoder.reset() + + @torch.no_grad() + def lang_id(self, encoder_features): + """Language detection from encoder features. + This code is trimmed and copy-pasted from whisper.decoding.detect_language . + """ + + # forward pass using a single token, startoftranscript + n_audio = encoder_features.shape[0] + x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1] + logits = self.model.logits(x, encoder_features)[:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) + mask[list(self.tokenizer.all_language_tokens)] = False + logits[:, mask] = -np.inf + language_tokens = logits.argmax(dim=-1) + language_token_probs = logits.softmax(dim=-1).cpu() + language_probs = [ + { + c: language_token_probs[i, j].item() + for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes) + } + for i in range(n_audio) + ] + + single = encoder_features.ndim == 2 + if single: + language_tokens = language_tokens[0] + language_probs = language_probs[0] + + self._clean_cache() + return language_tokens, language_probs ### transcription / translation @@ -273,9 +330,12 @@ class PaddedAlignAttWhisper: def infer(self, is_last=False): new_segment = True if len(self.segments) == 0: - return [] + logger.debug("No segments, nothing to do") + return [], {} if not self._apply_minseglen(): - return [] + logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") + input_segments = torch.cat(self.segments, dim=0) + return [], {} # input_segments is concatenation of audio, it's one array if len(self.segments) > 1: @@ -283,8 +343,7 @@ class PaddedAlignAttWhisper: else: input_segments = self.segments[0] - self.trim_context() - current_tokens = self._current_tokens() + # mel + padding to 30s mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, @@ -295,18 +354,38 @@ class PaddedAlignAttWhisper: # the len of actual audio content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) + # encode encoder_feature = self.model.encoder(mel) - sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device) - completed = False +# logger.debug(f"Encoder feature shape: {encoder_feature.shape}") +# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): +# logger.debug("mel ") + if self.cfg.language == "auto" and self.detected_language is None: + language_tokens, language_probs = self.lang_id(encoder_feature) + logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}") + top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) + logger.info(f"Detected language: {top_lan} with p={p:.4f}") + #self.tokenizer.language = top_lan + #self.tokenizer.__post_init__() + self.create_tokenizer(top_lan) + self.detected_language = top_lan + self.init_tokens() + logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") + + self.trim_context() + current_tokens = self._current_tokens() +# fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) ####################### Decoding loop logger.info("Decoding loop starts\n") + sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device) + completed = False + attn_of_alignment_heads = None - miost_attended_frame = None + most_attended_frame = None token_len_before_decoding = current_tokens.shape[1] @@ -515,11 +594,6 @@ class PaddedAlignAttWhisper: logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") - # cleaning cache - self.dec_attns = [] - self.kv_cache = {} - if self.decoder_type == "beam": - self.inference.kv_cache = self.kv_cache - self.token_decoder.reset() + self._clean_cache() - return new_hypothesis, generation + return new_hypothesis, generation \ No newline at end of file