diff --git a/DEV_NOTES.md b/DEV_NOTES.md index c41016f..f9c3c4a 100644 --- a/DEV_NOTES.md +++ b/DEV_NOTES.md @@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes Encoder weights: 15268874 bytes +# 2. Translation: Faster model for each system -# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm +## Benchmark Results + +Testing on MacBook M3 with NLLB-200-distilled-600M model: + +### Standard Transformers vs CTranslate2 + +| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup | +|-----------|-------------------------|---------------------------|---------| +| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x | +| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x | +| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x | +| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x | +| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x | + +**Results:** +- Total Standard time: 4.1068s +- Total CTranslate2 time: 8.5476s +- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation. + + +# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions. @@ -67,4 +88,4 @@ ELSE: AS_2 ← B to finish -``` \ No newline at end of file +``` diff --git a/README.md b/README.md index 2d0cb83..656b5fb 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,11 @@ An important list of parameters can be changed. But what *should* you change? | `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` | +| Translation options | Description | Default | +|-----------|-------------|---------| +| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` | +| `--nllb-size` | `600M` or `1.3B` | `600M` | + > For diarization using Diart, you need access to pyannote.audio models: > 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model > 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index df02a52..fd9307b 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -179,12 +179,11 @@ class AudioProcessor: asr_processing_logs += f" + Silence of = {item.duration:.2f}s" if self.tokens: asr_processing_logs += f" | last_end = {self.tokens[-1].end} |" - logger.info(asr_processing_logs) - - if type(item) is Silence: + logger.info(asr_processing_logs) cumulative_pcm_duration_stream_time += item.duration self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0) continue + logger.info(asr_processing_logs) if isinstance(item, np.ndarray): pcm_array = item @@ -223,7 +222,7 @@ class AudioProcessor: new_tokens, buffer_text, new_end_buffer ) - if new_tokens and self.args.target_language and self.translation_queue: + if self.translation_queue: for token in new_tokens: await self.translation_queue.put(token) @@ -256,13 +255,11 @@ class AudioProcessor: logger.debug("Diarization processor received sentinel. Finishing.") self.diarization_queue.task_done() break - - if type(item) is Silence: + elif type(item) is Silence: cumulative_pcm_duration_stream_time += item.duration diarization_obj.insert_silence(item.duration) continue - - if isinstance(item, np.ndarray): + elif isinstance(item, np.ndarray): pcm_array = item else: raise Exception('item should be pcm_array') @@ -295,14 +292,17 @@ class AudioProcessor: # in the future we want to have different languages for each speaker etc, so it will be more complex. while True: try: - token = await self.translation_queue.get() #block until at least 1 token - if token is SENTINEL: + item = await self.translation_queue.get() #block until at least 1 token + if item is SENTINEL: logger.debug("Translation processor received sentinel. Finishing.") self.translation_queue.task_done() break + elif type(item) is Silence: + online_translation.insert_silence(item.duration) + continue # get all the available tokens for translation. The more words, the more precise - tokens_to_process = [token] + tokens_to_process = [item] additional_tokens = await get_all_from_queue(self.translation_queue) sentinel_found = False @@ -326,7 +326,7 @@ class AudioProcessor: except Exception as e: logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") - if 'token' in locals() and token is not SENTINEL: + if 'token' in locals() and item is not SENTINEL: self.translation_queue.task_done() if 'additional_tokens' in locals(): for _ in additional_tokens: @@ -367,7 +367,7 @@ class AudioProcessor: if not state.tokens and not buffer_transcription and not buffer_diarization: response_status = "no_audio_detected" lines = [] - elif response_status == "active_transcription" and not lines: + elif not lines: lines = [Line( speaker=1, start=state.get("end_buffer", 0), @@ -528,6 +528,8 @@ class AudioProcessor: await self.transcription_queue.put(silence_buffer) if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(silence_buffer) + if self.translation_queue: + await self.translation_queue.put(silence_buffer) if not self.silence: if self.args.transcription and self.transcription_queue: diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index fd290d5..578e624 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -43,10 +43,12 @@ class TranscriptionEngine: "transcription": True, "vad": True, "pcm_input": False, + # whisperstreaming params: "buffer_trimming": "segment", "confidence_validation": False, "buffer_trimming_sec": 15, + # simulstreaming params: "disable_fast_encoder": False, "frame_threshold": 25, @@ -61,10 +63,15 @@ class TranscriptionEngine: "max_context_tokens": None, "model_path": './base.pt', "diarization_backend": "sortformer", + # diarization params: "disable_punctuation_split" : False, "segmentation_model": "pyannote/segmentation-3.0", - "embedding_model": "pyannote/embedding", + "embedding_model": "pyannote/embedding", + + # translation params: + "nllb_backend": "ctranslate2", + "nllb_size": "600M" } config_dict = {**defaults, **kwargs} @@ -142,8 +149,7 @@ class TranscriptionEngine: raise Exception('Translation cannot be set with language auto') else: from whisperlivekit.translation.translation import load_model - self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers - + self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers TranscriptionEngine._initialized = True diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 3ef74bf..55d4173 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -287,6 +287,20 @@ def parse_args(): help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).", ) + simulstreaming_group.add_argument( + "--nllb-backend", + type=str, + default="ctranslate2", + help="transformers or ctranslate2", + ) + + simulstreaming_group.add_argument( + "--nllb-size", + type=str, + default="600M", + help="600M or 1.3B", + ) + args = parser.parse_args() args.transcription = not args.no_transcription diff --git a/whisperlivekit/remove_silences.py b/whisperlivekit/remove_silences.py index dc207fc..3e4edb1 100644 --- a/whisperlivekit/remove_silences.py +++ b/whisperlivekit/remove_silences.py @@ -39,7 +39,7 @@ def blank_to_silence(tokens): ) else: if silence_token: #there was silence but no more - if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION: + if silence_token.duration() >= MIN_SILENCE_DURATION: cleaned_tokens.append( silence_token ) diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 1526ef1..1556ac9 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -123,14 +123,33 @@ def format_output(state, silence, current_time, args, debug, sep): append_token_to_last_line(lines, sep, token, debug_info) if lines and translated_segments: - cts_idx = 0 # current_translated_segment_idx - for line in lines: - while cts_idx < len(translated_segments): - ts = translated_segments[cts_idx] - if ts and ts.start and ts.start >= line.start and ts.end <= line.end: - line.translation += ts.text + ' ' - cts_idx += 1 - else: - break - return lines, undiarized_text, buffer_transcription, '' - + unassigned_translated_segments = [] + for ts in translated_segments: + assigned = False + for line in lines: + if ts and ts.overlaps_with(line): + if ts.is_within(line): + line.translation += ts.text + ' ' + assigned = True + break + else: + ts0, ts1 = ts.approximate_cut_at(line.end) + if ts0 and line.overlaps_with(ts0): + line.translation += ts0.text + ' ' + if ts1: + unassigned_translated_segments.append(ts1) + assigned = True + break + if not assigned: + unassigned_translated_segments.append(ts) + + if unassigned_translated_segments: + for line in lines: + remaining_segments = [] + for ts in unassigned_translated_segments: + if ts and ts.overlaps_with(line): + line.translation += ts.text + ' ' + else: + remaining_segments.append(ts) + unassigned_translated_segments = remaining_segments #maybe do smth in the future about that + return lines, undiarized_text, buffer_transcription, '' diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 3acf7c8..a9df490 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Any from datetime import timedelta def format_time(seconds: float) -> str: @@ -15,6 +15,21 @@ class TimedText: speaker: Optional[int] = -1 probability: Optional[float] = None is_dummy: Optional[bool] = False + + def overlaps_with(self, other: 'TimedText') -> bool: + return not (self.end <= other.start or other.end <= self.start) + + def is_within(self, other: 'TimedText') -> bool: + return other.contains_timespan(self) + + def duration(self) -> float: + return self.end - self.start + + def contains_time(self, time: float) -> bool: + return self.start <= time <= self.end + + def contains_timespan(self, other: 'TimedText') -> bool: + return self.start <= other.start and self.end >= other.end @dataclass class ASRToken(TimedText): @@ -41,6 +56,34 @@ class SpeakerSegment(TimedText): class Translation(TimedText): pass + def approximate_cut_at(self, cut_time): + """ + Each word in text is considered to be of duration (end-start)/len(words in text) + """ + if not self.text or not self.contains_time(cut_time): + return self, None + + words = self.text.split() + num_words = len(words) + if num_words == 0: + return self, None + + duration_per_word = self.duration() / num_words + + cut_word_index = int((cut_time - self.start) / duration_per_word) + + if cut_word_index >= num_words: + cut_word_index = num_words -1 + + text0 = " ".join(words[:cut_word_index]) + text1 = " ".join(words[cut_word_index:]) + + segment0 = Translation(start=self.start, end=cut_time, text=text0) + segment1 = Translation(start=cut_time, end=self.end, text=text1) + + return segment0, segment1 + + @dataclass class Silence(): duration: float @@ -91,4 +134,4 @@ class State(): end_buffer: float end_attributed_speaker: float remaining_time_transcription: float - remaining_time_diarization: float \ No newline at end of file + remaining_time_diarization: float diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index a28f2fa..c08f190 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -1,3 +1,5 @@ +import logging +import time import ctranslate2 import torch import transformers @@ -6,38 +8,42 @@ import huggingface_hub from whisperlivekit.translation.mapping_languages import get_nllb_code from whisperlivekit.timed_objects import Translation +logger = logging.getLogger(__name__) #In diarization case, we may want to translate just one speaker, or at least start the sentences there PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} +MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous +# sentence is not finished. @dataclass class TranslationModel(): translator: ctranslate2.Translator tokenizer: dict + device: str + backend_type: str = 'ctranslate2' -def load_model(src_langs): - MODEL = 'nllb-200-distilled-600M-ctranslate2' - MODEL_GUY = 'entai2965' - huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) +def load_model(src_langs, backend='ctranslate2', model_size='600M'): device = "cuda" if torch.cuda.is_available() else "cpu" - translator = ctranslate2.Translator(MODEL,device=device) + MODEL = f'nllb-200-distilled-{model_size}-ctranslate2' + if backend=='ctranslate2': + MODEL_GUY = 'entai2965' + huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) + translator = ctranslate2.Translator(MODEL,device=device) + elif backend=='transformers': + translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}") tokenizer = dict() for src_lang in src_langs: tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) + return TranslationModel( translator=translator, - tokenizer=tokenizer + tokenizer=tokenizer, + backend_type=backend, + device = device ) -def translate(input, translation_model, tgt_lang): - source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input)) - target_prefix = [tgt_lang] - results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix]) - target = results[0].hypotheses[0][1:] - return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target)) - class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): self.buffer = [] @@ -68,12 +74,19 @@ class OnlineTranslation: output_lang = self.output_languages[0] nllb_output_lang = get_nllb_code(output_lang) - source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input)) - results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff. - target = results[0].hypotheses[0][1:] - results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target)) - return results - + tokenizer = self.translation_model.tokenizer[input_lang] + tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device) + + if self.translation_model.backend_type == 'ctranslate2': + source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0]) + results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) + target = results[0].hypotheses[0][1:] + result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) + else: + translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang)) + result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] + return result + def translate_tokens(self, tokens): if tokens: text = ' '.join([token.text for token in tokens]) @@ -88,7 +101,6 @@ class OnlineTranslation: return translation return None - def insert_tokens(self, tokens): self.buffer.extend(tokens) @@ -109,7 +121,11 @@ class OnlineTranslation: self.translation_remaining = self.translate_tokens(self.buffer) self.len_processed_buffer = len(self.buffer) return self.validated + [self.translation_remaining] - + + def insert_silence(self, silence_duration: float): + if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER: + self.buffer = [] + self.validated += [self.translation_remaining] if __name__ == '__main__': output_lang = 'fr' @@ -122,16 +138,13 @@ if __name__ == '__main__': test = test_string.split(' ') step = len(test) // 3 - shared_model = load_model([input_lang]) + shared_model = load_model([input_lang], backend='ctranslate2') online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang]) - + + beg_inference = time.time() for id in range(5): val = test[id*step : (id+1)*step] val_str = ' '.join(val) result = online_translation.translate(val_str) print(result) - - - - - # print(result) \ No newline at end of file + print('inference time:', time.time() - beg_inference) \ No newline at end of file diff --git a/whisperlivekit/web/live_transcription.css b/whisperlivekit/web/live_transcription.css index 422d156..3cf5007 100644 --- a/whisperlivekit/web/live_transcription.css +++ b/whisperlivekit/web/live_transcription.css @@ -438,7 +438,6 @@ label { font-size: 13px; border-radius: 30px; padding: 2px 10px; - display: none; } .loading {