improve diarization with lag diarization substraction
This commit is contained in:
parent
0f79d442ee
commit
b01b81bad0
3 changed files with 11 additions and 12 deletions
|
|
@ -24,7 +24,7 @@ class TranscriptionEngine:
|
|||
"warmup_file": None,
|
||||
"confidence_validation": False,
|
||||
"diarization": False,
|
||||
"punctuation_split": True,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
"model": "tiny",
|
||||
"model_cache_dir": None,
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ class DiartDiarization:
|
|||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5):
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
|
|
@ -222,10 +223,11 @@ class DiartDiarization:
|
|||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
# First pass: assign speakers based on timing overlap
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
|
||||
|
|
@ -239,13 +241,12 @@ class DiartDiarization:
|
|||
for segment in segments:
|
||||
speaker_num = extract_number(segment.speaker) + 1
|
||||
segment_map.append((segment.start, segment.end, speaker_num))
|
||||
segment_map.sort(key=lambda x: x[0]) # Sort by start time
|
||||
segment_map.sort(key=lambda x: x[0])
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
current_token = tokens[i]
|
||||
|
||||
# Check if current token ends with sentence-ending punctuation
|
||||
is_sentence_end = False
|
||||
if current_token.text and current_token.text.strip():
|
||||
text = current_token.text.strip()
|
||||
|
|
@ -254,15 +255,11 @@ class DiartDiarization:
|
|||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||
|
||||
if is_sentence_end and current_token.speaker != -1:
|
||||
# Find the dominant speaker for tokens after this punctuation
|
||||
punctuation_time = current_token.end
|
||||
current_speaker = current_token.speaker
|
||||
|
||||
# Look ahead to find where the next sentence starts and ends
|
||||
j = i + 1
|
||||
next_sentence_tokens = []
|
||||
|
||||
# Collect tokens until we hit another sentence-ending punctuation or run out
|
||||
next_sentence_tokens = []
|
||||
while j < len(tokens):
|
||||
next_token = tokens[j]
|
||||
next_sentence_tokens.append(j)
|
||||
|
|
@ -307,7 +304,6 @@ class DiartDiarization:
|
|||
tokens[idx].speaker = current_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
|
||||
i += 1
|
||||
|
||||
i += 1
|
||||
|
||||
return end_attributed_speaker
|
||||
|
|
|
|||
|
|
@ -26,4 +26,7 @@ class Transcript(TimedText):
|
|||
|
||||
@dataclass
|
||||
class SpeakerSegment(TimedText):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
pass
|
||||
Loading…
Reference in a new issue