diff --git a/README.md b/README.md index e0e5969..e508609 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,8 @@ WhisperLiveKit offers extensive configuration options: | `--warmup-file` | Audio file path for model warmup | `jfk.wav` | | `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | +| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model | `pyannote/segmentation-3.0` | +| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model | `speechbrain/spkrec-ecapa-voxceleb` | ## 🔧 How It Works diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index f101099..a236407 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -39,8 +39,10 @@ class TranscriptionEngine: "log_level": "DEBUG", "ssl_certfile": None, "ssl_keyfile": None, - "transcription": True, + "transcription": True, "vad": True, + "segmentation_model": "pyannote/segmentation-3.0", + "embedding_model": "pyannote/embedding", } config_dict = {**defaults, **kwargs} @@ -69,6 +71,10 @@ class TranscriptionEngine: if self.args.diarization: from whisperlivekit.diarization.diarization_online import DiartDiarization - self.diarization = DiartDiarization(block_duration=self.args.min_chunk_size) + self.diarization = DiartDiarization( + block_duration=self.args.min_chunk_size, + segmentation_model_name=self.args.segmentation_model, + embedding_model_name=self.args.embedding_model + ) TranscriptionEngine._initialized = True diff --git a/whisperlivekit/diarization/diarization_online.py b/whisperlivekit/diarization/diarization_online.py index f3ced1b..04ce0ac 100644 --- a/whisperlivekit/diarization/diarization_online.py +++ b/whisperlivekit/diarization/diarization_online.py @@ -16,9 +16,6 @@ from typing import Tuple, Any, List from pyannote.core import Annotation import diart.models as m -segmentation = m.SegmentationModel.from_pretrained("pyannote/segmentation-3.0") -embedding = m.EmbeddingModel.from_pretrained("speechbrain/spkrec-ecapa-voxceleb") - logger = logging.getLogger(__name__) def extract_number(s: str) -> int: @@ -168,7 +165,16 @@ class WebSocketAudioSource(AudioSource): class DiartDiarization: - def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5): + def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"): + segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name) + embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name) + + if config is None: + config = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + ) + self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() self.lag_diart = None diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index da28532..18db453 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -44,6 +44,20 @@ def parse_args(): help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.", ) + parser.add_argument( + "--segmentation-model", + type=str, + default="pyannote/segmentation-3.0", + help="Hugging Face model ID for pyannote.audio segmentation model.", + ) + + parser.add_argument( + "--embedding-model", + type=str, + default="pyannote/embedding", + help="Hugging Face model ID for pyannote.audio embedding model.", + ) + parser.add_argument( "--no-transcription", action="store_true", @@ -145,4 +159,4 @@ def parse_args(): delattr(args, 'no_transcription') delattr(args, 'no_vad') - return args \ No newline at end of file + return args