Clean up config and model paths
This commit is contained in:
parent
74c4dc791d
commit
83362c89c4
3 changed files with 57 additions and 57 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
"""Typed configuration for the WhisperLiveKit pipeline."""
|
"""Typed configuration for the WhisperLiveKit pipeline."""
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -56,7 +56,7 @@ class WhisperLiveKitConfig:
|
||||||
frame_threshold: int = 25
|
frame_threshold: int = 25
|
||||||
beams: int = 1
|
beams: int = 1
|
||||||
decoder_type: Optional[str] = None
|
decoder_type: Optional[str] = None
|
||||||
audio_max_len: float = 20.0
|
audio_max_len: float = 30.0
|
||||||
audio_min_len: float = 0.0
|
audio_min_len: float = 0.0
|
||||||
cif_ckpt_path: Optional[str] = None
|
cif_ckpt_path: Optional[str] = None
|
||||||
never_fire: bool = False
|
never_fire: bool = False
|
||||||
|
|
|
||||||
|
|
@ -7,20 +7,20 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
"""Information about detected model format and files in a directory."""
|
"""Information about detected model format and files in a directory."""
|
||||||
path: Optional[Path] = None
|
path: Optional[Path] = None
|
||||||
pytorch_files: List[Path] = field(default_factory=list)
|
pytorch_files: List[Path] = field(default_factory=list)
|
||||||
compatible_whisper_mlx: bool = False
|
compatible_whisper_mlx: bool = False
|
||||||
compatible_faster_whisper: bool = False
|
compatible_faster_whisper: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_pytorch(self) -> bool:
|
def has_pytorch(self) -> bool:
|
||||||
return len(self.pytorch_files) > 0
|
return len(self.pytorch_files) > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_sharded(self) -> bool:
|
def is_sharded(self) -> bool:
|
||||||
return len(self.pytorch_files) > 1
|
return len(self.pytorch_files) > 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def primary_pytorch_file(self) -> Optional[Path]:
|
def primary_pytorch_file(self) -> Optional[Path]:
|
||||||
"""Return the primary PyTorch file (or first shard for sharded models)."""
|
"""Return the primary PyTorch file (or first shard for sharded models)."""
|
||||||
|
|
@ -40,15 +40,15 @@ CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.j
|
||||||
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
|
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
|
||||||
|
|
||||||
CTranslate2 models have specific companion files that distinguish them
|
CTranslate2 models have specific companion files that distinguish them
|
||||||
from PyTorch .bin files.
|
from PyTorch .bin files.
|
||||||
"""
|
"""
|
||||||
n_indicators = 0
|
n_indicators = 0
|
||||||
for indicator in CT2_INDICATOR_FILES: #test 1
|
for indicator in CT2_INDICATOR_FILES: #test 1
|
||||||
if (directory / indicator).exists():
|
if (directory / indicator).exists():
|
||||||
n_indicators += 1
|
n_indicators += 1
|
||||||
|
|
||||||
if n_indicators == 0:
|
if n_indicators == 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -61,19 +61,19 @@ def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
||||||
return False
|
return False
|
||||||
except (json.JSONDecodeError, IOError):
|
except (json.JSONDecodeError, IOError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _collect_pytorch_files(directory: Path) -> List[Path]:
|
def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||||
"""
|
"""
|
||||||
Collect all PyTorch checkpoint files from a directory.
|
Collect all PyTorch checkpoint files from a directory.
|
||||||
|
|
||||||
Handles:
|
Handles:
|
||||||
- Single files: model.safetensors, pytorch_model.bin, *.pt
|
- Single files: model.safetensors, pytorch_model.bin, *.pt
|
||||||
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
|
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
|
||||||
- Index-based sharded models (reads index file to find shards)
|
- Index-based sharded models (reads index file to find shards)
|
||||||
|
|
||||||
Returns files sorted appropriately (shards in order, or single file).
|
Returns files sorted appropriately (shards in order, or single file).
|
||||||
"""
|
"""
|
||||||
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
||||||
|
|
@ -90,20 +90,20 @@ def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||||
return shards
|
return shards
|
||||||
except (json.JSONDecodeError, IOError):
|
except (json.JSONDecodeError, IOError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sharded_groups = {}
|
sharded_groups = {}
|
||||||
single_files = {}
|
single_files = {}
|
||||||
|
|
||||||
for file in directory.iterdir():
|
for file in directory.iterdir():
|
||||||
if not file.is_file():
|
if not file.is_file():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filename = file.name
|
filename = file.name
|
||||||
suffix = file.suffix.lower()
|
suffix = file.suffix.lower()
|
||||||
|
|
||||||
if filename.startswith("adapter_"):
|
if filename.startswith("adapter_"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
match = SHARDED_PATTERN.match(filename)
|
match = SHARDED_PATTERN.match(filename)
|
||||||
if match:
|
if match:
|
||||||
base_name, shard_idx, total_shards, ext = match.groups()
|
base_name, shard_idx, total_shards, ext = match.groups()
|
||||||
|
|
@ -112,7 +112,7 @@ def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||||
sharded_groups[key] = []
|
sharded_groups[key] = []
|
||||||
sharded_groups[key].append((int(shard_idx), file))
|
sharded_groups[key].append((int(shard_idx), file))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if filename == "model.safetensors":
|
if filename == "model.safetensors":
|
||||||
single_files[0] = file # Highest priority
|
single_files[0] = file # Highest priority
|
||||||
elif filename == "pytorch_model.bin":
|
elif filename == "pytorch_model.bin":
|
||||||
|
|
@ -121,68 +121,68 @@ def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||||
single_files[2] = file
|
single_files[2] = file
|
||||||
elif suffix == ".safetensors" and not filename.startswith("adapter"):
|
elif suffix == ".safetensors" and not filename.startswith("adapter"):
|
||||||
single_files[3] = file
|
single_files[3] = file
|
||||||
|
|
||||||
for (base_name, ext, total_shards), shards in sharded_groups.items():
|
for (base_name, ext, total_shards), shards in sharded_groups.items():
|
||||||
if len(shards) == total_shards:
|
if len(shards) == total_shards:
|
||||||
return [path for _, path in sorted(shards)]
|
return [path for _, path in sorted(shards)]
|
||||||
|
|
||||||
for priority in sorted(single_files.keys()):
|
for priority in sorted(single_files.keys()):
|
||||||
return [single_files[priority]]
|
return [single_files[priority]]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
|
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Detect the model format in a given path.
|
Detect the model format in a given path.
|
||||||
|
|
||||||
This function analyzes a file or directory to determine:
|
This function analyzes a file or directory to determine:
|
||||||
- What PyTorch checkpoint files are available (including sharded models)
|
- What PyTorch checkpoint files are available (including sharded models)
|
||||||
- Whether the directory contains MLX Whisper weights
|
- Whether the directory contains MLX Whisper weights
|
||||||
- Whether the directory contains Faster-Whisper (CTranslate2) weights
|
- Whether the directory contains Faster-Whisper (CTranslate2) weights
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to a model file or directory
|
model_path: Path to a model file or directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelInfo with detected format information
|
ModelInfo with detected format information
|
||||||
"""
|
"""
|
||||||
path = Path(model_path)
|
path = Path(model_path)
|
||||||
info = ModelInfo(path=path)
|
info = ModelInfo(path=path)
|
||||||
|
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
suffix = path.suffix.lower()
|
suffix = path.suffix.lower()
|
||||||
if suffix in {".pt", ".safetensors", ".bin"}:
|
if suffix in {".pt", ".safetensors", ".bin"}:
|
||||||
info.pytorch_files = [path]
|
info.pytorch_files = [path]
|
||||||
return info
|
return info
|
||||||
|
|
||||||
if not path.is_dir():
|
if not path.is_dir():
|
||||||
return info
|
return info
|
||||||
|
|
||||||
for file in path.iterdir():
|
for file in path.iterdir():
|
||||||
if not file.is_file():
|
if not file.is_file():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filename = file.name.lower()
|
filename = file.name.lower()
|
||||||
|
|
||||||
if filename in MLX_WHISPER_MARKERS:
|
if filename in MLX_WHISPER_MARKERS:
|
||||||
info.compatible_whisper_mlx = True
|
info.compatible_whisper_mlx = True
|
||||||
|
|
||||||
if filename in FASTER_WHISPER_MARKERS:
|
if filename in FASTER_WHISPER_MARKERS:
|
||||||
if _is_ct2_model_bin(path, filename):
|
if _is_ct2_model_bin(path, filename):
|
||||||
info.compatible_faster_whisper = True
|
info.compatible_faster_whisper = True
|
||||||
|
|
||||||
info.pytorch_files = _collect_pytorch_files(path)
|
info.pytorch_files = _collect_pytorch_files(path)
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||||
"""
|
"""
|
||||||
Inspect the provided path and determine which model formats are available.
|
Inspect the provided path and determine which model formats are available.
|
||||||
|
|
||||||
This is a compatibility wrapper around detect_model_format().
|
This is a compatibility wrapper around detect_model_format().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
|
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
|
||||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||||
|
|
|
||||||
|
|
@ -72,20 +72,20 @@ def parse_args():
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable transcription to only see live diarization results.",
|
help="Disable transcription to only see live diarization results.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-punctuation-split",
|
"--disable-punctuation-split",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable the split parameter.",
|
help="Disable the split parameter.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-chunk-size",
|
"--min-chunk-size",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.1,
|
default=0.1,
|
||||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -93,7 +93,7 @@ def parse_args():
|
||||||
dest='model_size',
|
dest='model_size',
|
||||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_cache_dir",
|
"--model_cache_dir",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -127,14 +127,14 @@ def parse_args():
|
||||||
default=False,
|
default=False,
|
||||||
help="Use Whisper to directly translate to english.",
|
help="Use Whisper to directly translate to english.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--target-language",
|
"--target-language",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
dest="target_language",
|
dest="target_language",
|
||||||
help="Target language for translation. Not functional yet.",
|
help="Target language for translation. Not functional yet.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend-policy",
|
"--backend-policy",
|
||||||
|
|
@ -147,8 +147,8 @@ def parse_args():
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
|
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3"],
|
||||||
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
|
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon. Use 'qwen3' for Qwen3-ASR.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-vac",
|
"--no-vac",
|
||||||
|
|
@ -165,7 +165,7 @@ def parse_args():
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable VAD (voice activity detection).",
|
help="Disable VAD (voice activity detection).",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--buffer_trimming",
|
"--buffer_trimming",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -213,7 +213,7 @@ def parse_args():
|
||||||
default=None,
|
default=None,
|
||||||
help="Use your own alignment heads, useful when `--model-dir` is used",
|
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--frame-threshold",
|
"--frame-threshold",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -221,7 +221,7 @@ def parse_args():
|
||||||
dest="frame_threshold",
|
dest="frame_threshold",
|
||||||
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
|
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--beams",
|
"--beams",
|
||||||
"-b",
|
"-b",
|
||||||
|
|
@ -229,7 +229,7 @@ def parse_args():
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
|
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--decoder",
|
"--decoder",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -238,7 +238,7 @@ def parse_args():
|
||||||
choices=["beam", "greedy"],
|
choices=["beam", "greedy"],
|
||||||
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
|
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--audio-max-len",
|
"--audio-max-len",
|
||||||
type=float,
|
type=float,
|
||||||
|
|
@ -246,7 +246,7 @@ def parse_args():
|
||||||
dest="audio_max_len",
|
dest="audio_max_len",
|
||||||
help="Max length of the audio buffer, in seconds.",
|
help="Max length of the audio buffer, in seconds.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--audio-min-len",
|
"--audio-min-len",
|
||||||
type=float,
|
type=float,
|
||||||
|
|
@ -254,7 +254,7 @@ def parse_args():
|
||||||
dest="audio_min_len",
|
dest="audio_min_len",
|
||||||
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
|
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--cif-ckpt-path",
|
"--cif-ckpt-path",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -262,7 +262,7 @@ def parse_args():
|
||||||
dest="cif_ckpt_path",
|
dest="cif_ckpt_path",
|
||||||
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
|
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--never-fire",
|
"--never-fire",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -270,7 +270,7 @@ def parse_args():
|
||||||
dest="never_fire",
|
dest="never_fire",
|
||||||
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
|
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--init-prompt",
|
"--init-prompt",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -278,7 +278,7 @@ def parse_args():
|
||||||
dest="init_prompt",
|
dest="init_prompt",
|
||||||
help="Init prompt for the model. It should be in the target language.",
|
help="Init prompt for the model. It should be in the target language.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--static-init-prompt",
|
"--static-init-prompt",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -286,7 +286,7 @@ def parse_args():
|
||||||
dest="static_init_prompt",
|
dest="static_init_prompt",
|
||||||
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
|
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--max-context-tokens",
|
"--max-context-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -294,7 +294,7 @@ def parse_args():
|
||||||
dest="max_context_tokens",
|
dest="max_context_tokens",
|
||||||
help="Max context tokens for the model. Default is 0.",
|
help="Max context tokens for the model. Default is 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--model-path",
|
"--model-path",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -302,14 +302,14 @@ def parse_args():
|
||||||
dest="model_path",
|
dest="model_path",
|
||||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--nllb-backend",
|
"--nllb-backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="transformers",
|
default="transformers",
|
||||||
help="transformers or ctranslate2",
|
help="transformers or ctranslate2",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--nllb-size",
|
"--nllb-size",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue