Improve diarization backends

This commit is contained in:
Quentin Fuxa 2026-02-15 14:55:00 +01:00
parent 32de7b1276
commit e30f9a2573
2 changed files with 78 additions and 81 deletions

View file

@ -20,25 +20,25 @@ logger = logging.getLogger(__name__)
class DiarizationObserver(Observer): class DiarizationObserver(Observer):
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" """Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
def __init__(self): def __init__(self):
self.diarization_segments = [] self.diarization_segments = []
self.processed_time = 0 self.processed_time = 0
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0 self.global_time_offset = 0.0
def on_next(self, value: Tuple[Annotation, Any]): def on_next(self, value: Tuple[Annotation, Any]):
annotation, audio = value annotation, audio = value
logger.debug("\n--- New Diarization Result ---") logger.debug("\n--- New Diarization Result ---")
duration = audio.extent.end - audio.extent.start duration = audio.extent.end - audio.extent.start
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)") logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
logger.debug(f"Audio shape: {audio.data.shape}") logger.debug(f"Audio shape: {audio.data.shape}")
with self.segment_lock: with self.segment_lock:
if audio.extent.end > self.processed_time: if audio.extent.end > self.processed_time:
self.processed_time = audio.extent.end self.processed_time = audio.extent.end
if annotation and len(annotation._labels) > 0: if annotation and len(annotation._labels) > 0:
logger.debug("\nSpeaker segments:") logger.debug("\nSpeaker segments:")
for speaker, label in annotation._labels.items(): for speaker, label in annotation._labels.items():
@ -51,25 +51,25 @@ class DiarizationObserver(Observer):
)) ))
else: else:
logger.debug("\nNo speakers detected in this segment") logger.debug("\nNo speakers detected in this segment")
def get_segments(self) -> List[SpeakerSegment]: def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments.""" """Get a copy of the current speaker segments."""
with self.segment_lock: with self.segment_lock:
return self.diarization_segments.copy() return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0): def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time.""" """Clear segments older than the specified time."""
with self.segment_lock: with self.segment_lock:
current_time = self.processed_time current_time = self.processed_time
self.diarization_segments = [ self.diarization_segments = [
segment for segment in self.diarization_segments segment for segment in self.diarization_segments
if current_time - segment.end < older_than if current_time - segment.end < older_than
] ]
def on_error(self, error): def on_error(self, error):
"""Handle an error in the stream.""" """Handle an error in the stream."""
logger.debug(f"Error in diarization stream: {error}") logger.debug(f"Error in diarization stream: {error}")
def on_completed(self): def on_completed(self):
"""Handle the completion of the stream.""" """Handle the completion of the stream."""
logger.debug("Diarization stream completed") logger.debug("Diarization stream completed")
@ -96,7 +96,7 @@ class WebSocketAudioSource(AudioSource):
self._processing_thread = threading.Thread(target=self._process_chunks) self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True self._processing_thread.daemon = True
self._processing_thread.start() self._processing_thread.start()
self._close_event.wait() self._close_event.wait()
if self._processing_thread: if self._processing_thread:
self._processing_thread.join(timeout=2.0) self._processing_thread.join(timeout=2.0)
@ -106,30 +106,30 @@ class WebSocketAudioSource(AudioSource):
while not self._closed: while not self._closed:
try: try:
audio_chunk = self._queue.get(timeout=0.1) audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock: with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk]) self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size: while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size] chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:] self._buffer = self._buffer[self.block_size:]
current_time = time.time() current_time = time.time()
time_since_last = current_time - self._last_chunk_time time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration: if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last) time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1) chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped) self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time() self._last_chunk_time = time.time()
except Empty: except Empty:
with self._buffer_lock: with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration: if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32) padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32) self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1) chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped) self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time() self._last_chunk_time = time.time()
@ -137,14 +137,14 @@ class WebSocketAudioSource(AudioSource):
logger.error(f"Error in audio processing thread: {e}") logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e) self.stream.on_error(e)
break break
with self._buffer_lock: with self._buffer_lock:
if len(self._buffer) > 0: if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32) padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1) chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped) self.stream.on_next(chunk_reshaped)
self.stream.on_completed() self.stream.on_completed()
def close(self): def close(self):
@ -165,27 +165,27 @@ class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"): def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name) segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name) embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None: if config is None:
config = SpeakerDiarizationConfig( config = SpeakerDiarizationConfig(
segmentation=segmentation_model, segmentation=segmentation_model,
embedding=embedding_model, embedding=embedding_model,
) )
self.pipeline = SpeakerDiarization(config=config) self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver() self.observer = DiarizationObserver()
if use_microphone: if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration) self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None self.custom_source = None
else: else:
self.custom_source = WebSocketAudioSource( self.custom_source = WebSocketAudioSource(
uri="websocket_source", uri="websocket_source",
sample_rate=sample_rate, sample_rate=sample_rate,
block_duration=block_duration block_duration=block_duration
) )
self.source = self.custom_source self.source = self.custom_source
self.inference = StreamingInference( self.inference = StreamingInference(
pipeline=self.pipeline, pipeline=self.pipeline,
source=self.source, source=self.source,
@ -205,14 +205,14 @@ class DiartDiarization:
async def diarize(self): async def diarize(self):
"""Return the current speaker segments from the diarization pipeline.""" """Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments() return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def concatenate_speakers(segments): def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
for segment in segments: for segment in segments:
@ -223,7 +223,7 @@ def concatenate_speakers(segments):
segments_concatenated[-1]['end'] = segment.end segments_concatenated[-1]['end'] = segment.end
# print("Segments concatenated:") # print("Segments concatenated:")
# for entry in segments_concatenated: # for entry in segments_concatenated:
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s") # print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
return segments_concatenated return segments_concatenated
@ -281,4 +281,4 @@ def visualize_tokens(tokens):
conversation[-1]['text'] += token.text conversation[-1]['text'] += token.text
print("Conversation:") print("Conversation:")
for entry in conversation: for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}") print(f"Speaker {entry['speaker']}: {entry['text']}")

View file

@ -1,8 +1,6 @@
import logging import logging
import threading import threading
import time
import wave import wave
from queue import Empty, SimpleQueue
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@ -54,7 +52,7 @@ class SortformerDiarization:
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized. Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
""" """
self._load_model(model_name) self._load_model(model_name)
def _load_model(self, model_name: str): def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming.""" """Load and configure the Sortformer model for streaming."""
try: try:
@ -63,12 +61,12 @@ class SortformerDiarization:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.diar_model.to(device) self.diar_model.to(device)
## to test ## to test
# for name, param in self.diar_model.named_parameters(): # for name, param in self.diar_model.named_parameters():
# if param.device != device: # if param.device != device:
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}") # raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
logger.info(f"Using {device.type.upper()} for Sortformer model") logger.info(f"Using {device.type.upper()} for Sortformer model")
self.diar_model.sortformer_modules.chunk_len = 10 self.diar_model.sortformer_modules.chunk_len = 10
@ -80,16 +78,16 @@ class SortformerDiarization:
self.diar_model.sortformer_modules.spkcache_update_period = 144 self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters() self.diar_model.sortformer_modules._check_streaming_parameters()
except Exception as e: except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}") logger.error(f"Failed to load Sortformer model: {e}")
raise raise
class SortformerDiarizationOnline: class SortformerDiarizationOnline:
def __init__(self, shared_model, sample_rate: int = 16000): def __init__(self, shared_model, sample_rate: int = 16000):
""" """
Initialize the streaming Sortformer diarization system. Initialize the streaming Sortformer diarization system.
Args: Args:
sample_rate: Audio sample rate (default: 16000) sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
@ -101,9 +99,9 @@ class SortformerDiarizationOnline:
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0 self.global_time_offset = 0.0
self.debug = False self.debug = False
self.diar_model = shared_model.diar_model self.diar_model = shared_model.diar_model
self.audio2mel = AudioToMelSpectrogramPreprocessor( self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025, window_size=0.025,
normalize="NA", normalize="NA",
@ -112,26 +110,26 @@ class SortformerDiarizationOnline:
pad_to=0 pad_to=0
) )
self.audio2mel.to(self.diar_model.device) self.audio2mel.to(self.diar_model.device)
self.chunk_duration_seconds = ( self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len * self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.preprocessor._cfg.window_stride self.diar_model.preprocessor._cfg.window_stride
) )
self._init_streaming_state() self._init_streaming_state()
self._previous_chunk_features = None self._previous_chunk_features = None
self._chunk_index = 0 self._chunk_index = 0
self._len_prediction = None self._len_prediction = None
# Audio buffer to store PCM chunks for debugging # Audio buffer to store PCM chunks for debugging
self.audio_buffer = [] self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds # Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = [] self.audio_chunk_buffer = []
self.accumulated_duration = 0.0 self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully") logger.info("SortformerDiarization initialized successfully")
@ -139,30 +137,30 @@ class SortformerDiarizationOnline:
"""Initialize the streaming state for the model.""" """Initialize the streaming state for the model."""
batch_size = 1 batch_size = 1
device = self.diar_model.device device = self.diar_model.device
self.streaming_state = StreamingSortformerState() self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros( self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model), (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device device=device
) )
self.streaming_state.spkcache_preds = torch.zeros( self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk), (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device device=device
) )
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros( self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model), (batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device device=device
) )
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device) self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device) self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: Optional[float]): def insert_silence(self, silence_duration: Optional[float]):
""" """
Insert silence period by adjusting the global time offset. Insert silence period by adjusting the global time offset.
Args: Args:
silence_duration: Duration of silence in seconds silence_duration: Duration of silence in seconds
""" """
@ -174,48 +172,48 @@ class SortformerDiarizationOnline:
if self.debug: if self.debug:
self.audio_buffer.append(pcm_array.copy()) self.audio_buffer.append(pcm_array.copy())
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
async def diarize(self): async def diarize(self):
""" """
Process audio data for diarization in streaming fashion. Process audio data for diarization in streaming fashion.
Args: Args:
pcm_array: Audio data as numpy array pcm_array: Audio data as numpy array
""" """
threshold = int(self.chunk_duration_seconds * self.sample_rate) threshold = int(self.chunk_duration_seconds * self.sample_rate)
if not len(self.buffer_audio) >= threshold: if not len(self.buffer_audio) >= threshold:
return [] return []
audio = self.buffer_audio[:threshold] audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:] self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device) audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk audio_signal_chunk, audio_signal_length_chunk
) )
processed_signal_chunk = processed_signal_chunk.to(device) processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device) processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None: if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device) to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device) total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else: else:
total_features = processed_signal_chunk.to(device) total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device) self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device) chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode(): with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0 left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8 right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t, processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device), processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
@ -223,9 +221,9 @@ class SortformerDiarizationOnline:
total_preds=self.total_preds, total_preds=self.total_preds,
left_offset=left_offset, left_offset=left_offset,
right_offset=right_offset, right_offset=right_offset,
) )
new_segments = self._process_predictions() new_segments = self._process_predictions()
self._chunk_index += 1 self._chunk_index += 1
return new_segments return new_segments
@ -233,13 +231,13 @@ class SortformerDiarizationOnline:
"""Process model predictions and convert to speaker segments.""" """Process model predictions and convert to speaker segments."""
preds_np = self.total_preds[0].cpu().numpy() preds_np = self.total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1) active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None: if self._len_prediction is None:
self._len_prediction = len(active_speakers) #12 self._len_prediction = len(active_speakers) #12
frame_duration = self.chunk_duration_seconds / self._len_prediction frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:] current_chunk_preds = active_speakers[-self._len_prediction:]
new_segments = [] new_segments = []
with self.segment_lock: with self.segment_lock:
@ -264,7 +262,7 @@ class SortformerDiarizationOnline:
) )
) )
return new_segments return new_segments
def get_segments(self) -> List[SpeakerSegment]: def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments.""" """Get a copy of the current speaker segments."""
with self.segment_lock: with self.segment_lock:
@ -275,10 +273,10 @@ class SortformerDiarizationOnline:
logger.info("Closing SortformerDiarization") logger.info("Closing SortformerDiarization")
with self.segment_lock: with self.segment_lock:
self.diarization_segments.clear() self.diarization_segments.clear()
if self.debug: if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer) concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16) audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file: with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16) wav_file.setsampwidth(2) # 2 bytes per sample (int16)
@ -287,14 +285,13 @@ class SortformerDiarizationOnline:
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav") logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
from whisperlivekit.diarization.utils import extract_number
if __name__ == '__main__': if __name__ == '__main__':
import asyncio import asyncio
import librosa import librosa
async def main(): async def main():
"""TEST ONLY.""" """TEST ONLY."""
an4_audio = 'diarization_audio.wav' an4_audio = 'diarization_audio.wav'
@ -304,24 +301,24 @@ if __name__ == '__main__':
print("\n" + "=" * 50) print("\n" + "=" * 50)
print("ground truth:") print("ground truth:")
print("Speaker 0: 0:00 - 0:09") print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19") print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25") print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30") print("Speaker 0: 0:25 - 0:30")
print("=" * 50) print("=" * 50)
diarization_backend = SortformerDiarization() diarization_backend = SortformerDiarization()
diarization = SortformerDiarizationOnline(shared_model = diarization_backend) diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
chunk_size = 1600 chunk_size = 1600
for i in range(0, len(signal), chunk_size): for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size] chunk = signal[i:i+chunk_size]
new_segments = await diarization.diarize(chunk) new_segments = await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}") print(f"Processed chunk {i // chunk_size + 1}")
print(new_segments) print(new_segments)
segments = diarization.get_segments() segments = diarization.get_segments()
print("\nDiarization results:") print("\nDiarization results:")
for segment in segments: for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s") print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main()) asyncio.run(main())