Clean up config and model paths
This commit is contained in:
parent
74c4dc791d
commit
83362c89c4
3 changed files with 57 additions and 57 deletions
|
|
@ -1,6 +1,6 @@
|
|||
"""Typed configuration for the WhisperLiveKit pipeline."""
|
||||
import logging
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -56,7 +56,7 @@ class WhisperLiveKitConfig:
|
|||
frame_threshold: int = 25
|
||||
beams: int = 1
|
||||
decoder_type: Optional[str] = None
|
||||
audio_max_len: float = 20.0
|
||||
audio_max_len: float = 30.0
|
||||
audio_min_len: float = 0.0
|
||||
cif_ckpt_path: Optional[str] = None
|
||||
never_fire: bool = False
|
||||
|
|
|
|||
|
|
@ -7,20 +7,20 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about detected model format and files in a directory."""
|
||||
path: Optional[Path] = None
|
||||
"""Information about detected model format and files in a directory."""
|
||||
path: Optional[Path] = None
|
||||
pytorch_files: List[Path] = field(default_factory=list)
|
||||
compatible_whisper_mlx: bool = False
|
||||
compatible_faster_whisper: bool = False
|
||||
|
||||
|
||||
@property
|
||||
def has_pytorch(self) -> bool:
|
||||
return len(self.pytorch_files) > 0
|
||||
|
||||
|
||||
@property
|
||||
def is_sharded(self) -> bool:
|
||||
return len(self.pytorch_files) > 1
|
||||
|
||||
|
||||
@property
|
||||
def primary_pytorch_file(self) -> Optional[Path]:
|
||||
"""Return the primary PyTorch file (or first shard for sharded models)."""
|
||||
|
|
@ -40,15 +40,15 @@ CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.j
|
|||
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
||||
"""
|
||||
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
|
||||
|
||||
|
||||
CTranslate2 models have specific companion files that distinguish them
|
||||
from PyTorch .bin files.
|
||||
"""
|
||||
n_indicators = 0
|
||||
for indicator in CT2_INDICATOR_FILES: #test 1
|
||||
if (directory / indicator).exists():
|
||||
if (directory / indicator).exists():
|
||||
n_indicators += 1
|
||||
|
||||
|
||||
if n_indicators == 0:
|
||||
return False
|
||||
|
||||
|
|
@ -61,19 +61,19 @@ def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
|||
return False
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||
"""
|
||||
Collect all PyTorch checkpoint files from a directory.
|
||||
|
||||
|
||||
Handles:
|
||||
- Single files: model.safetensors, pytorch_model.bin, *.pt
|
||||
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
|
||||
- Index-based sharded models (reads index file to find shards)
|
||||
|
||||
|
||||
Returns files sorted appropriately (shards in order, or single file).
|
||||
"""
|
||||
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
||||
|
|
@ -90,20 +90,20 @@ def _collect_pytorch_files(directory: Path) -> List[Path]:
|
|||
return shards
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
|
||||
sharded_groups = {}
|
||||
single_files = {}
|
||||
|
||||
|
||||
for file in directory.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
|
||||
filename = file.name
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
|
||||
if filename.startswith("adapter_"):
|
||||
continue
|
||||
|
||||
|
||||
match = SHARDED_PATTERN.match(filename)
|
||||
if match:
|
||||
base_name, shard_idx, total_shards, ext = match.groups()
|
||||
|
|
@ -112,7 +112,7 @@ def _collect_pytorch_files(directory: Path) -> List[Path]:
|
|||
sharded_groups[key] = []
|
||||
sharded_groups[key].append((int(shard_idx), file))
|
||||
continue
|
||||
|
||||
|
||||
if filename == "model.safetensors":
|
||||
single_files[0] = file # Highest priority
|
||||
elif filename == "pytorch_model.bin":
|
||||
|
|
@ -121,68 +121,68 @@ def _collect_pytorch_files(directory: Path) -> List[Path]:
|
|||
single_files[2] = file
|
||||
elif suffix == ".safetensors" and not filename.startswith("adapter"):
|
||||
single_files[3] = file
|
||||
|
||||
|
||||
for (base_name, ext, total_shards), shards in sharded_groups.items():
|
||||
if len(shards) == total_shards:
|
||||
return [path for _, path in sorted(shards)]
|
||||
|
||||
|
||||
for priority in sorted(single_files.keys()):
|
||||
return [single_files[priority]]
|
||||
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
|
||||
"""
|
||||
Detect the model format in a given path.
|
||||
|
||||
|
||||
This function analyzes a file or directory to determine:
|
||||
- What PyTorch checkpoint files are available (including sharded models)
|
||||
- Whether the directory contains MLX Whisper weights
|
||||
- Whether the directory contains Faster-Whisper (CTranslate2) weights
|
||||
|
||||
|
||||
Args:
|
||||
model_path: Path to a model file or directory
|
||||
|
||||
|
||||
Returns:
|
||||
ModelInfo with detected format information
|
||||
"""
|
||||
path = Path(model_path)
|
||||
info = ModelInfo(path=path)
|
||||
|
||||
|
||||
if path.is_file():
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".pt", ".safetensors", ".bin"}:
|
||||
info.pytorch_files = [path]
|
||||
return info
|
||||
|
||||
|
||||
if not path.is_dir():
|
||||
return info
|
||||
|
||||
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
|
||||
filename = file.name.lower()
|
||||
|
||||
|
||||
if filename in MLX_WHISPER_MARKERS:
|
||||
info.compatible_whisper_mlx = True
|
||||
|
||||
|
||||
if filename in FASTER_WHISPER_MARKERS:
|
||||
if _is_ct2_model_bin(path, filename):
|
||||
info.compatible_faster_whisper = True
|
||||
|
||||
|
||||
info.pytorch_files = _collect_pytorch_files(path)
|
||||
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
|
||||
This is a compatibility wrapper around detect_model_format().
|
||||
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
|
|
|
|||
|
|
@ -72,20 +72,20 @@ def parse_args():
|
|||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-punctuation-split",
|
||||
action="store_true",
|
||||
help="Disable the split parameter.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
|
|
@ -93,7 +93,7 @@ def parse_args():
|
|||
dest='model_size',
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--model_cache_dir",
|
||||
type=str,
|
||||
|
|
@ -127,14 +127,14 @@ def parse_args():
|
|||
default=False,
|
||||
help="Use Whisper to directly translate to english.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--target-language",
|
||||
type=str,
|
||||
default="",
|
||||
dest="target_language",
|
||||
help="Target language for translation. Not functional yet.",
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend-policy",
|
||||
|
|
@ -147,8 +147,8 @@ def parse_args():
|
|||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
|
||||
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3"],
|
||||
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon. Use 'qwen3' for Qwen3-ASR.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
|
|
@ -165,7 +165,7 @@ def parse_args():
|
|||
action="store_true",
|
||||
help="Disable VAD (voice activity detection).",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--buffer_trimming",
|
||||
type=str,
|
||||
|
|
@ -213,7 +213,7 @@ def parse_args():
|
|||
default=None,
|
||||
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
type=int,
|
||||
|
|
@ -221,7 +221,7 @@ def parse_args():
|
|||
dest="frame_threshold",
|
||||
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--beams",
|
||||
"-b",
|
||||
|
|
@ -229,7 +229,7 @@ def parse_args():
|
|||
default=1,
|
||||
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
|
|
@ -238,7 +238,7 @@ def parse_args():
|
|||
choices=["beam", "greedy"],
|
||||
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--audio-max-len",
|
||||
type=float,
|
||||
|
|
@ -246,7 +246,7 @@ def parse_args():
|
|||
dest="audio_max_len",
|
||||
help="Max length of the audio buffer, in seconds.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--audio-min-len",
|
||||
type=float,
|
||||
|
|
@ -254,7 +254,7 @@ def parse_args():
|
|||
dest="audio_min_len",
|
||||
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--cif-ckpt-path",
|
||||
type=str,
|
||||
|
|
@ -262,7 +262,7 @@ def parse_args():
|
|||
dest="cif_ckpt_path",
|
||||
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--never-fire",
|
||||
action="store_true",
|
||||
|
|
@ -270,7 +270,7 @@ def parse_args():
|
|||
dest="never_fire",
|
||||
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--init-prompt",
|
||||
type=str,
|
||||
|
|
@ -278,7 +278,7 @@ def parse_args():
|
|||
dest="init_prompt",
|
||||
help="Init prompt for the model. It should be in the target language.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--static-init-prompt",
|
||||
type=str,
|
||||
|
|
@ -286,7 +286,7 @@ def parse_args():
|
|||
dest="static_init_prompt",
|
||||
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--max-context-tokens",
|
||||
type=int,
|
||||
|
|
@ -294,7 +294,7 @@ def parse_args():
|
|||
dest="max_context_tokens",
|
||||
help="Max context tokens for the model. Default is 0.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
|
|
@ -302,14 +302,14 @@ def parse_args():
|
|||
dest="model_path",
|
||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="transformers",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
|
|
|
|||
Loading…
Reference in a new issue