add segmentation and embedding model options to configuration

This commit is contained in:
Quentin Fuxa 2025-06-19 16:29:25 +02:00
parent b01b81bad0
commit 8532a91c7a
4 changed files with 35 additions and 7 deletions

View file

@ -240,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model | `speechbrain/spkrec-ecapa-voxceleb` |
## 🔧 How It Works

View file

@ -39,8 +39,10 @@ class TranscriptionEngine:
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"transcription": True,
"transcription": True,
"vad": True,
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
@ -69,6 +71,10 @@ class TranscriptionEngine:
if self.args.diarization:
from whisperlivekit.diarization.diarization_online import DiartDiarization
self.diarization = DiartDiarization(block_duration=self.args.min_chunk_size)
self.diarization = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
TranscriptionEngine._initialized = True

View file

@ -16,9 +16,6 @@ from typing import Tuple, Any, List
from pyannote.core import Annotation
import diart.models as m
segmentation = m.SegmentationModel.from_pretrained("pyannote/segmentation-3.0")
embedding = m.EmbeddingModel.from_pretrained("speechbrain/spkrec-ecapa-voxceleb")
logger = logging.getLogger(__name__)
def extract_number(s: str) -> int:
@ -168,7 +165,16 @@ class WebSocketAudioSource(AudioSource):
class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5):
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
self.lag_diart = None

View file

@ -44,6 +44,20 @@ def parse_args():
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
)
parser.add_argument(
"--segmentation-model",
type=str,
default="pyannote/segmentation-3.0",
help="Hugging Face model ID for pyannote.audio segmentation model.",
)
parser.add_argument(
"--embedding-model",
type=str,
default="pyannote/embedding",
help="Hugging Face model ID for pyannote.audio embedding model.",
)
parser.add_argument(
"--no-transcription",
action="store_true",
@ -145,4 +159,4 @@ def parse_args():
delattr(args, 'no_transcription')
delattr(args, 'no_vad')
return args
return args