improve diarization with lag diarization substraction

This commit is contained in:
Quentin Fuxa 2025-06-19 16:18:49 +02:00
parent 0f79d442ee
commit b01b81bad0
3 changed files with 11 additions and 12 deletions

View file

@ -24,7 +24,7 @@ class TranscriptionEngine:
"warmup_file": None,
"confidence_validation": False,
"diarization": False,
"punctuation_split": True,
"punctuation_split": False,
"min_chunk_size": 0.5,
"model": "tiny",
"model_cache_dir": None,

View file

@ -171,6 +171,7 @@ class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5):
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration)
@ -222,10 +223,11 @@ class DiartDiarization:
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
# First pass: assign speakers based on timing overlap
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start or segment.start >= token.end):
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
@ -239,13 +241,12 @@ class DiartDiarization:
for segment in segments:
speaker_num = extract_number(segment.speaker) + 1
segment_map.append((segment.start, segment.end, speaker_num))
segment_map.sort(key=lambda x: x[0]) # Sort by start time
segment_map.sort(key=lambda x: x[0])
i = 0
while i < len(tokens):
current_token = tokens[i]
# Check if current token ends with sentence-ending punctuation
is_sentence_end = False
if current_token.text and current_token.text.strip():
text = current_token.text.strip()
@ -254,15 +255,11 @@ class DiartDiarization:
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
if is_sentence_end and current_token.speaker != -1:
# Find the dominant speaker for tokens after this punctuation
punctuation_time = current_token.end
current_speaker = current_token.speaker
# Look ahead to find where the next sentence starts and ends
j = i + 1
next_sentence_tokens = []
# Collect tokens until we hit another sentence-ending punctuation or run out
next_sentence_tokens = []
while j < len(tokens):
next_token = tokens[j]
next_sentence_tokens.append(j)
@ -307,7 +304,6 @@ class DiartDiarization:
tokens[idx].speaker = current_speaker
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
i += 1
i += 1
return end_attributed_speaker

View file

@ -26,4 +26,7 @@ class Transcript(TimedText):
@dataclass
class SpeakerSegment(TimedText):
"""Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment.
"""
pass