diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index bcdef9b..e1e12e2 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -4,7 +4,7 @@ import logging from typing import List, Tuple, Optional import logging import platform -from whisperlivekit.timed_objects import ASRToken, Transcript +from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment from whisperlivekit.warmup import load_file from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE from .whisper import load_model, tokenizer @@ -91,6 +91,10 @@ class SimulStreamingOnlineProcessor: self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. self.model.insert_audio(audio_tensor) + def on_new_speaker(self, last_segment: SpeakerSegment): + self.model.on_new_speaker(last_segment) + self.model.refresh_segment(complete=True) + def get_buffer(self): return Transcript( start=None, @@ -99,54 +103,23 @@ class SimulStreamingOnlineProcessor: probability=None ) - def timestamped_text(self, tokens, generation): - """ - generate timestamped text from tokens and generation data. - - args: - tokens: List of tokens to process - generation: Dictionary containing generation progress and optionally results - - returns: - List of tuples containing (start_time, end_time, word) for each word - """ - FRAME_DURATION = 0.02 - if "result" in generation: - split_words = generation["result"]["split_words"] - split_tokens = generation["result"]["split_tokens"] - else: - split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens) - progress = generation["progress"] - frames = [p["most_attended_frames"][0] for p in progress] - absolute_timestamps = [p["absolute_timestamps"][0] for p in progress] - tokens_queue = tokens.copy() + def timestamped_text(self, split_words, split_tokens, l_absolute_timestamps): timestamped_words = [] - + for word, word_tokens in zip(split_words, split_tokens): - # start_frame = None - # end_frame = None - for expected_token in word_tokens: - if not tokens_queue or not frames: - raise ValueError(f"Insufficient tokens or frames for word '{word}'") - - actual_token = tokens_queue.pop(0) - current_frame = frames.pop(0) - current_timestamp = absolute_timestamps.pop(0) - if actual_token != expected_token: - raise ValueError( - f"Token mismatch: expected '{expected_token}', " - f"got '{actual_token}' at frame {current_frame}" - ) - # if start_frame is None: - # start_frame = current_frame - # end_frame = current_frame - # start_time = start_frame * FRAME_DURATION - # end_time = end_frame * FRAME_DURATION - start_time = current_timestamp - end_time = current_timestamp + 0.1 - timestamp_entry = (start_time, end_time, word) + + for i in word_tokens: + current_timestamp = l_absolute_timestamps.pop(0) + + timestamp_entry = ASRToken( + start=current_timestamp, + end=current_timestamp + 0.1, + text=word, + probability=0.95 + ).with_offset( + self.global_time_offset + ) timestamped_words.append(timestamp_entry) - logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}") return timestamped_words def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: @@ -156,46 +129,10 @@ class SimulStreamingOnlineProcessor: Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). """ try: - tokens, generation_progress = self.model.infer(is_last=is_last) - ts_words = self.timestamped_text(tokens, generation_progress) - - new_tokens = [] - for ts_word in ts_words: - - start, end, word = ts_word - token = ASRToken( - start=start, - end=end, - text=word, - probability=0.95 # fake prob. Maybe we can extract it from the model? - ).with_offset( - self.global_time_offset - ) - new_tokens.append(token) - - # identical_tokens = 0 - # n_new_tokens = len(new_tokens) - # if n_new_tokens: + split_words, split_tokens, l_absolute_timestamps = self.model.infer(is_last=is_last) + new_tokens = self.timestamped_text(split_words, split_tokens, l_absolute_timestamps) self.committed.extend(new_tokens) - - # if token in self.committed: - # pos = len(self.committed) - 1 - self.committed[::-1].index(token) - # if pos: - # for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens): - # commited_segment = self.committed[i:i+n_new_tokens] - # if commited_segment == new_tokens: - # identical_segments +=1 - # if identical_tokens >= TOO_MANY_REPETITIONS: - # logger.warning('Too many repetition, model is stuck. Load a new one') - # self.committed = self.committed[:i] - # self.load_new_backend() - # return [], self.end - - # pos = self.committed.rindex(token) - - - return new_tokens, self.end @@ -362,4 +299,4 @@ class SimulStreamingASR(): """ Warmup is done directly in load_model """ - pass \ No newline at end of file + pass diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index cc183c4..49468e6 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -382,11 +382,11 @@ class PaddedAlignAttWhisper: new_segment = True if len(self.segments) == 0: logger.debug("No segments, nothing to do") - return [], {} + return [], [], [] if not self._apply_minseglen(): logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") input_segments = torch.cat(self.segments, dim=0) - return [], {} + return [], [], [] # input_segments is concatenation of audio, it's one array if len(self.segments) > 1: @@ -426,9 +426,6 @@ class PaddedAlignAttWhisper: end_encode = time() # print('Encoder duration:', end_encode-beg_encode) -# logger.debug(f"Encoder feature shape: {encoder_feature.shape}") -# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): -# logger.debug("mel ") if self.cfg.language == "auto" and self.detected_language is None: language_tokens, language_probs = self.lang_id(encoder_feature) logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}") @@ -443,13 +440,10 @@ class PaddedAlignAttWhisper: self.trim_context() current_tokens = self._current_tokens() -# + fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) - ####################### Decoding loop - logger.info("Decoding loop starts\n") - sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device) completed = False @@ -458,26 +452,9 @@ class PaddedAlignAttWhisper: token_len_before_decoding = current_tokens.shape[1] - generation_progress = [] - generation = { - "starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size), - "token_len_before_decoding": token_len_before_decoding, - #"fire_detected": fire_detected, - "frames_len": content_mel_len, - "frames_threshold": 4 if is_last else self.cfg.frame_threshold, - - # to be filled later - "logits_starting": None, - - # to be filled later - "no_speech_prob": None, - "no_speech": False, - - # to be filled in the loop - "progress": generation_progress, - } + l_absolute_timestamps = [] + while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens - generation_progress_loop = [] if new_segment: tokens_for_logits = current_tokens @@ -486,50 +463,28 @@ class PaddedAlignAttWhisper: tokens_for_logits = current_tokens[:,-1:] logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size - if new_segment: - generation["logits_starting"] = Logits(logits[:,:,:]) if new_segment and self.tokenizer.no_speech is not None: probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1) no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() - generation["no_speech_prob"] = no_speech_probs[0] + # generation["no_speech_prob"] = no_speech_probs[0] if no_speech_probs[0] > self.cfg.nonspeech_prob: - generation["no_speech"] = True + # generation["no_speech"] = True logger.info("no speech, stop") break logits = logits[:, -1, :] # logits for the last token - generation_progress_loop.append(("logits_before_suppress",Logits(logits))) # supress blank tokens only at the beginning of the segment if new_segment: logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf new_segment = False self.suppress_tokens(logits) - #generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size))) - generation_progress_loop.append(("logits_after_suppress",Logits(logits))) - current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs) - generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone()))) - generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist())) - generation_progress_loop.append(("completed",completed)) logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") self.debug_print_tokens(current_tokens) - - # if self.decoder_type == "beam": - # logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}") - - # logprobs = F.log_softmax(logits.float(), dim=-1) - # idx = 0 - # logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}") - # logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}") - # if completed: - # self.debug_print_tokens(current_tokens) - - # logger.debug("decode stopped because decoder completed") - attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)] for i, attn_mat in enumerate(self.dec_attns): layer_rank = int(i % len(self.model.decoder.blocks)) @@ -548,30 +503,24 @@ class PaddedAlignAttWhisper: t = torch.cat(mat, dim=1) tmp.append(t) attn_of_alignment_heads = torch.stack(tmp, dim=1) -# logger.debug(str(attn_of_alignment_heads.shape) + " tttady") std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False) attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1) -# logger.debug(str(attn_of_alignment_heads.shape) + " po mean") attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len] -# logger.debug(str(attn_of_alignment_heads.shape) + " pak ") # for each beam, the most attended frame is: most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1) - generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist())) # Calculate absolute timestamps accounting for cumulative offset absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()] - generation_progress_loop.append(("absolute_timestamps", absolute_timestamps)) logger.debug(str(most_attended_frames.tolist()) + " most att frames") logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)") most_attended_frame = most_attended_frames[0].item() + l_absolute_timestamps.append(absolute_timestamps[0]) - - generation_progress.append(dict(generation_progress_loop)) logger.debug("current tokens" + str(current_tokens.shape)) if completed: # # stripping the last token, the eot @@ -609,66 +558,28 @@ class PaddedAlignAttWhisper: self.tokenizer.decode([current_tokens[i, -1].item()]) )) -# for k,v in generation.items(): -# print(k,v,file=sys.stderr) -# for x in generation_progress: -# for y in x.items(): -# print("\t\t",*y,file=sys.stderr) -# print("\t","----", file=sys.stderr) -# print("\t", "end of generation_progress_loop", file=sys.stderr) - # sys.exit(1) - ####################### End of decoding loop - - logger.info("End of decoding loop") - - # if attn_of_alignment_heads is not None: - # seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND) - - # # Lets' now consider only the top hypothesis in the beam search - # top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0] - - # # debug print: how is the new token attended? - # new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:] - # logger.debug(f"New token attention shape: {new_token_attn.shape}") - # if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment - # logger.debug("no token generated") - # else: # it is, and the max attention is: - # new_token_max_attn, _ = new_token_attn.max(dim=-1) - # logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}") - - - # let's now operate only with the top beam hypothesis tokens_to_split = current_tokens[0, token_len_before_decoding:] + if fire_detected or is_last: new_hypothesis = tokens_to_split.flatten().tolist() + split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) else: # going to truncate the tokens after the last space split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist()) - generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]} - generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]} - -# text_to_split = self.tokenizer.decode(tokens_to_split) -# logger.debug(f"text_to_split: {text_to_split}") -# logger.debug("text at current step: {}".format(text_to_split.replace(" ", ""))) -# text_before_space = " ".join(text_to_split.split(" ")[:-1]) -# logger.debug("before the last space: {}".format(text_before_space.replace(" ", ""))) if len(split_words) > 1: new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist] else: new_hypothesis = [] - ### new hypothesis logger.debug(f"new_hypothesis: {new_hypothesis}") new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to( device=self.device, ) self.tokens.append(new_tokens) - # TODO: test if this is redundant or not -# ret = ret[ret