diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index ae6fa71..f101099 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -24,7 +24,7 @@ class TranscriptionEngine: "warmup_file": None, "confidence_validation": False, "diarization": False, - "punctuation_split": True, + "punctuation_split": False, "min_chunk_size": 0.5, "model": "tiny", "model_cache_dir": None, diff --git a/whisperlivekit/diarization/diarization_online.py b/whisperlivekit/diarization/diarization_online.py index 978c47a..f3ced1b 100644 --- a/whisperlivekit/diarization/diarization_online.py +++ b/whisperlivekit/diarization/diarization_online.py @@ -171,6 +171,7 @@ class DiartDiarization: def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5): self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() + self.lag_diart = None if use_microphone: self.source = MicrophoneAudioSource(block_duration=block_duration) @@ -222,10 +223,11 @@ class DiartDiarization: for i, seg in enumerate(segments[:5]): # Show first 5 segments logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]") - # First pass: assign speakers based on timing overlap + if not self.lag_diart and segments and tokens: + self.lag_diart = segments[0].start - tokens[0].start for token in tokens: for segment in segments: - if not (segment.end <= token.start or segment.start >= token.end): + if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart): token.speaker = extract_number(segment.speaker) + 1 end_attributed_speaker = max(token.end, end_attributed_speaker) @@ -239,13 +241,12 @@ class DiartDiarization: for segment in segments: speaker_num = extract_number(segment.speaker) + 1 segment_map.append((segment.start, segment.end, speaker_num)) - segment_map.sort(key=lambda x: x[0]) # Sort by start time + segment_map.sort(key=lambda x: x[0]) i = 0 while i < len(tokens): current_token = tokens[i] - # Check if current token ends with sentence-ending punctuation is_sentence_end = False if current_token.text and current_token.text.strip(): text = current_token.text.strip() @@ -254,15 +255,11 @@ class DiartDiarization: logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s") if is_sentence_end and current_token.speaker != -1: - # Find the dominant speaker for tokens after this punctuation punctuation_time = current_token.end current_speaker = current_token.speaker - # Look ahead to find where the next sentence starts and ends j = i + 1 - next_sentence_tokens = [] - - # Collect tokens until we hit another sentence-ending punctuation or run out + next_sentence_tokens = [] while j < len(tokens): next_token = tokens[j] next_sentence_tokens.append(j) @@ -307,7 +304,6 @@ class DiartDiarization: tokens[idx].speaker = current_speaker end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker) - i += 1 - + i += 1 return end_attributed_speaker diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index c347b9a..82fd1f6 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -26,4 +26,7 @@ class Transcript(TimedText): @dataclass class SpeakerSegment(TimedText): + """Represents a segment of audio attributed to a specific speaker. + No text nor probability is associated with this segment. + """ pass \ No newline at end of file