add nllb-backend and translation perf test in dev_notes
This commit is contained in:
parent
99dc96c644
commit
bbba1d9bb7
5 changed files with 66 additions and 17 deletions
23
DEV_NOTES.md
23
DEV_NOTES.md
|
|
@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes
|
||||||
Encoder weights: 15268874 bytes
|
Encoder weights: 15268874 bytes
|
||||||
|
|
||||||
|
|
||||||
|
# 2. Translation: Faster model for each system
|
||||||
|
|
||||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
## Benchmark Results
|
||||||
|
|
||||||
|
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||||
|
|
||||||
|
### Standard Transformers vs CTranslate2
|
||||||
|
|
||||||
|
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||||
|
|-----------|-------------------------|---------------------------|---------|
|
||||||
|
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||||
|
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||||
|
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||||
|
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||||
|
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||||
|
|
||||||
|
**Results:**
|
||||||
|
- Total Standard time: 4.1068s
|
||||||
|
- Total CTranslate2 time: 8.5476s
|
||||||
|
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||||
|
|
||||||
|
|
||||||
|
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||||
|
|
||||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,10 @@ An important list of parameters can be changed. But what *should* you change?
|
||||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
|
|
||||||
|
| Translation options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--nllb-backend` | [NOT FUNCTIONNAL YET] transformer or ctranslate2 | `ctranslate2` |
|
||||||
|
|
||||||
> For diarization using Diart, you need access to pyannote.audio models:
|
> For diarization using Diart, you need access to pyannote.audio models:
|
||||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||||
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||||
|
|
|
||||||
|
|
@ -43,10 +43,12 @@ class TranscriptionEngine:
|
||||||
"transcription": True,
|
"transcription": True,
|
||||||
"vad": True,
|
"vad": True,
|
||||||
"pcm_input": False,
|
"pcm_input": False,
|
||||||
|
|
||||||
# whisperstreaming params:
|
# whisperstreaming params:
|
||||||
"buffer_trimming": "segment",
|
"buffer_trimming": "segment",
|
||||||
"confidence_validation": False,
|
"confidence_validation": False,
|
||||||
"buffer_trimming_sec": 15,
|
"buffer_trimming_sec": 15,
|
||||||
|
|
||||||
# simulstreaming params:
|
# simulstreaming params:
|
||||||
"disable_fast_encoder": False,
|
"disable_fast_encoder": False,
|
||||||
"frame_threshold": 25,
|
"frame_threshold": 25,
|
||||||
|
|
@ -61,10 +63,14 @@ class TranscriptionEngine:
|
||||||
"max_context_tokens": None,
|
"max_context_tokens": None,
|
||||||
"model_path": './base.pt',
|
"model_path": './base.pt',
|
||||||
"diarization_backend": "sortformer",
|
"diarization_backend": "sortformer",
|
||||||
|
|
||||||
# diarization params:
|
# diarization params:
|
||||||
"disable_punctuation_split" : False,
|
"disable_punctuation_split" : False,
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
"embedding_model": "pyannote/embedding",
|
"embedding_model": "pyannote/embedding",
|
||||||
|
|
||||||
|
# translation params:
|
||||||
|
"nllb_backend": "ctranslate2"
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
config_dict = {**defaults, **kwargs}
|
||||||
|
|
@ -142,7 +148,7 @@ class TranscriptionEngine:
|
||||||
raise Exception('Translation cannot be set with language auto')
|
raise Exception('Translation cannot be set with language auto')
|
||||||
else:
|
else:
|
||||||
from whisperlivekit.translation.translation import load_model
|
from whisperlivekit.translation.translation import load_model
|
||||||
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
|
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend) #in the future we want to handle different languages for different speakers
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -287,6 +287,13 @@ def parse_args():
|
||||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--nllb-backend",
|
||||||
|
type=str,
|
||||||
|
default="ctranslate2",
|
||||||
|
help="transformer or ctranslate2",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.transcription = not args.no_transcription
|
args.transcription = not args.no_transcription
|
||||||
|
|
|
||||||
|
|
@ -21,26 +21,37 @@ class TranslationModel():
|
||||||
translator: ctranslate2.Translator
|
translator: ctranslate2.Translator
|
||||||
tokenizer: dict
|
tokenizer: dict
|
||||||
|
|
||||||
def load_model(src_langs):
|
def load_model(src_langs, backend='ctranslate2'):
|
||||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
if backend=='ctranslate2':
|
||||||
MODEL_GUY = 'entai2965'
|
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
MODEL_GUY = 'entai2965'
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||||
translator = ctranslate2.Translator(MODEL,device=device)
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
tokenizer = dict()
|
translator = ctranslate2.Translator(MODEL,device=device)
|
||||||
for src_lang in src_langs:
|
tokenizer = dict()
|
||||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
for src_lang in src_langs:
|
||||||
|
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||||
|
elif backend=='transformers':
|
||||||
|
raise Exception('not implemented yet')
|
||||||
return TranslationModel(
|
return TranslationModel(
|
||||||
translator=translator,
|
translator=translator,
|
||||||
tokenizer=tokenizer
|
tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
def translate(input, translation_model, tgt_lang):
|
def translate(input, translation_model, tgt_lang, src_lang="en"):
|
||||||
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
|
# Get the specific tokenizer for the source language
|
||||||
|
tokenizer = translation_model.tokenizer[src_lang]
|
||||||
|
|
||||||
|
# Convert input to tokens
|
||||||
|
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(input))
|
||||||
|
|
||||||
|
# Translate with target language prefix
|
||||||
target_prefix = [tgt_lang]
|
target_prefix = [tgt_lang]
|
||||||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
||||||
|
|
||||||
|
# Get translated tokens and decode
|
||||||
target = results[0].hypotheses[0][1:]
|
target = results[0].hypotheses[0][1:]
|
||||||
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
|
return tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||||
|
|
||||||
class OnlineTranslation:
|
class OnlineTranslation:
|
||||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue