translation: use of get_nllb_code
This commit is contained in:
parent
84890b8e61
commit
72f33be6f2
3 changed files with 11 additions and 8 deletions
|
|
@ -133,12 +133,14 @@ class TranscriptionEngine:
|
||||||
self.diarization_model = SortformerDiarization()
|
self.diarization_model = SortformerDiarization()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||||
|
|
||||||
|
self.translation_model = None
|
||||||
if self.args.target_language:
|
if self.args.target_language:
|
||||||
if self.args.language == 'auto':
|
if self.args.language == 'auto':
|
||||||
raise Exception('Translation cannot be set with language auto')
|
raise Exception('Translation cannot be set with language auto')
|
||||||
else:
|
else:
|
||||||
from whisperlivekit.translation.translation import load_model
|
from whisperlivekit.translation.translation import load_model
|
||||||
|
self.translation_model = load_model()
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
||||||
|
|
||||||
|
|
||||||
def get_nllb_code(crowdin_code):
|
def get_nllb_code(crowdin_code):
|
||||||
return CROWDIN_TO_NLLB.get(crowdin_code, crowdin_code)
|
return CROWDIN_TO_NLLB.get(crowdin_code, None)
|
||||||
|
|
||||||
|
|
||||||
def get_crowdin_code(nllb_code):
|
def get_crowdin_code(nllb_code):
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@ import ctranslate2
|
||||||
import transformers
|
import transformers
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
from .mapping_languages import get_nllb_code
|
||||||
src_lang = "eng_Latn"
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TranslationModel():
|
class TranslationModel():
|
||||||
|
|
@ -30,8 +29,10 @@ def translate(input, translation_model, tgt_lang):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tgt_lang = "fra_Latn"
|
tgt_lang = 'fr'
|
||||||
src_lang = "eng_Latn"
|
src_lang = "en"
|
||||||
translation_model = load_model(src_lang)
|
nllb_tgt_lang = get_nllb_code(tgt_lang)
|
||||||
result = translate('Hello world', translation_model=translation_model, tgt_lang=tgt_lang)
|
nllb_src_lang = get_nllb_code(src_lang)
|
||||||
|
translation_model = load_model(nllb_src_lang)
|
||||||
|
result = translate('Hello world', translation_model=translation_model, tgt_lang=nllb_tgt_lang)
|
||||||
print(result)
|
print(result)
|
||||||
Loading…
Reference in a new issue