cuda or cpu auto detection
This commit is contained in:
parent
4d1aa4421a
commit
0d874fb515
1 changed files with 6 additions and 4 deletions
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import io
|
||||
import soundfile as sf
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -102,11 +102,13 @@ class FasterWhisperASR(ASRBase):
|
|||
else:
|
||||
raise ValueError("modelsize or model_dir parameter must be set")
|
||||
|
||||
# this worked fast and reliably on NVIDIA L40
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "float32"
|
||||
|
||||
model = WhisperModel(
|
||||
model_size_or_path,
|
||||
device="cuda",
|
||||
compute_type="float16",
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=cache_dir,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue