add segmentation and embedding model options to configuration
This commit is contained in:
parent
b01b81bad0
commit
8532a91c7a
4 changed files with 35 additions and 7 deletions
|
|
@ -240,6 +240,8 @@ WhisperLiveKit offers extensive configuration options:
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
|
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
## 🔧 How It Works
|
## 🔧 How It Works
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,8 +39,10 @@ class TranscriptionEngine:
|
||||||
"log_level": "DEBUG",
|
"log_level": "DEBUG",
|
||||||
"ssl_certfile": None,
|
"ssl_certfile": None,
|
||||||
"ssl_keyfile": None,
|
"ssl_keyfile": None,
|
||||||
"transcription": True,
|
"transcription": True,
|
||||||
"vad": True,
|
"vad": True,
|
||||||
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
|
"embedding_model": "pyannote/embedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
config_dict = {**defaults, **kwargs}
|
||||||
|
|
@ -69,6 +71,10 @@ class TranscriptionEngine:
|
||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||||
self.diarization = DiartDiarization(block_duration=self.args.min_chunk_size)
|
self.diarization = DiartDiarization(
|
||||||
|
block_duration=self.args.min_chunk_size,
|
||||||
|
segmentation_model_name=self.args.segmentation_model,
|
||||||
|
embedding_model_name=self.args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,6 @@ from typing import Tuple, Any, List
|
||||||
from pyannote.core import Annotation
|
from pyannote.core import Annotation
|
||||||
import diart.models as m
|
import diart.models as m
|
||||||
|
|
||||||
segmentation = m.SegmentationModel.from_pretrained("pyannote/segmentation-3.0")
|
|
||||||
embedding = m.EmbeddingModel.from_pretrained("speechbrain/spkrec-ecapa-voxceleb")
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def extract_number(s: str) -> int:
|
def extract_number(s: str) -> int:
|
||||||
|
|
@ -168,7 +165,16 @@ class WebSocketAudioSource(AudioSource):
|
||||||
|
|
||||||
|
|
||||||
class DiartDiarization:
|
class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||||
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = SpeakerDiarizationConfig(
|
||||||
|
segmentation=segmentation_model,
|
||||||
|
embedding=embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
self.pipeline = SpeakerDiarization(config=config)
|
self.pipeline = SpeakerDiarization(config=config)
|
||||||
self.observer = DiarizationObserver()
|
self.observer = DiarizationObserver()
|
||||||
self.lag_diart = None
|
self.lag_diart = None
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,20 @@ def parse_args():
|
||||||
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--segmentation-model",
|
||||||
|
type=str,
|
||||||
|
default="pyannote/segmentation-3.0",
|
||||||
|
help="Hugging Face model ID for pyannote.audio segmentation model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="pyannote/embedding",
|
||||||
|
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-transcription",
|
"--no-transcription",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -145,4 +159,4 @@ def parse_args():
|
||||||
delattr(args, 'no_transcription')
|
delattr(args, 'no_transcription')
|
||||||
delattr(args, 'no_vad')
|
delattr(args, 'no_vad')
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue