translation device determined with torch.device
This commit is contained in:
parent
4209d7f7c0
commit
b6164aa59b
1 changed files with 8 additions and 4 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
|
@ -10,13 +11,16 @@ class TranslationModel():
|
||||||
tokenizer: transformers.AutoTokenizer
|
tokenizer: transformers.AutoTokenizer
|
||||||
|
|
||||||
def load_model(src_lang):
|
def load_model(src_lang):
|
||||||
huggingface_hub.snapshot_download('entai2965/nllb-200-distilled-600M-ctranslate2',local_dir='nllb-200-distilled-600M-ctranslate2')
|
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||||
translator = ctranslate2.Translator("nllb-200-distilled-600M-ctranslate2",device="cpu")
|
MODEL_GUY = 'entai2965'
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained("nllb-200-distilled-600M-ctranslate2", src_lang=src_lang, clean_up_tokenization_spaces=True)
|
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
translator = ctranslate2.Translator(MODEL,device=device)
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||||
return TranslationModel(
|
return TranslationModel(
|
||||||
translator=translator,
|
translator=translator,
|
||||||
tokenizer=tokenizer
|
tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
def translate(input, translation_model, tgt_lang):
|
def translate(input, translation_model, tgt_lang):
|
||||||
if not input:
|
if not input:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue