diff --git a/whisperlivekit/diarization/diart_backend.py b/whisperlivekit/diarization/diart_backend.py index ccfdbfc..5d7c249 100644 --- a/whisperlivekit/diarization/diart_backend.py +++ b/whisperlivekit/diarization/diart_backend.py @@ -20,25 +20,25 @@ logger = logging.getLogger(__name__) class DiarizationObserver(Observer): """Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" - + def __init__(self): self.diarization_segments = [] self.processed_time = 0 self.segment_lock = threading.Lock() self.global_time_offset = 0.0 - + def on_next(self, value: Tuple[Annotation, Any]): annotation, audio = value - + logger.debug("\n--- New Diarization Result ---") - + duration = audio.extent.end - audio.extent.start logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)") logger.debug(f"Audio shape: {audio.data.shape}") - + with self.segment_lock: if audio.extent.end > self.processed_time: - self.processed_time = audio.extent.end + self.processed_time = audio.extent.end if annotation and len(annotation._labels) > 0: logger.debug("\nSpeaker segments:") for speaker, label in annotation._labels.items(): @@ -51,25 +51,25 @@ class DiarizationObserver(Observer): )) else: logger.debug("\nNo speakers detected in this segment") - + def get_segments(self) -> List[SpeakerSegment]: """Get a copy of the current speaker segments.""" with self.segment_lock: return self.diarization_segments.copy() - + def clear_old_segments(self, older_than: float = 30.0): """Clear segments older than the specified time.""" with self.segment_lock: current_time = self.processed_time self.diarization_segments = [ - segment for segment in self.diarization_segments + segment for segment in self.diarization_segments if current_time - segment.end < older_than ] - + def on_error(self, error): """Handle an error in the stream.""" logger.debug(f"Error in diarization stream: {error}") - + def on_completed(self): """Handle the completion of the stream.""" logger.debug("Diarization stream completed") @@ -96,7 +96,7 @@ class WebSocketAudioSource(AudioSource): self._processing_thread = threading.Thread(target=self._process_chunks) self._processing_thread.daemon = True self._processing_thread.start() - + self._close_event.wait() if self._processing_thread: self._processing_thread.join(timeout=2.0) @@ -106,30 +106,30 @@ class WebSocketAudioSource(AudioSource): while not self._closed: try: audio_chunk = self._queue.get(timeout=0.1) - + with self._buffer_lock: self._buffer = np.concatenate([self._buffer, audio_chunk]) - + while len(self._buffer) >= self.block_size: chunk = self._buffer[:self.block_size] self._buffer = self._buffer[self.block_size:] - + current_time = time.time() time_since_last = current_time - self._last_chunk_time if time_since_last < self.block_duration: time.sleep(self.block_duration - time_since_last) - + chunk_reshaped = chunk.reshape(1, -1) self.stream.on_next(chunk_reshaped) self._last_chunk_time = time.time() - + except Empty: with self._buffer_lock: if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration: padded_chunk = np.zeros(self.block_size, dtype=np.float32) padded_chunk[:len(self._buffer)] = self._buffer self._buffer = np.array([], dtype=np.float32) - + chunk_reshaped = padded_chunk.reshape(1, -1) self.stream.on_next(chunk_reshaped) self._last_chunk_time = time.time() @@ -137,14 +137,14 @@ class WebSocketAudioSource(AudioSource): logger.error(f"Error in audio processing thread: {e}") self.stream.on_error(e) break - + with self._buffer_lock: if len(self._buffer) > 0: padded_chunk = np.zeros(self.block_size, dtype=np.float32) padded_chunk[:len(self._buffer)] = self._buffer chunk_reshaped = padded_chunk.reshape(1, -1) self.stream.on_next(chunk_reshaped) - + self.stream.on_completed() def close(self): @@ -165,27 +165,27 @@ class DiartDiarization: def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"): segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name) embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name) - + if config is None: config = SpeakerDiarizationConfig( segmentation=segmentation_model, embedding=embedding_model, ) - - self.pipeline = SpeakerDiarization(config=config) + + self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() - + if use_microphone: self.source = MicrophoneAudioSource(block_duration=block_duration) self.custom_source = None else: self.custom_source = WebSocketAudioSource( - uri="websocket_source", + uri="websocket_source", sample_rate=sample_rate, block_duration=block_duration ) self.source = self.custom_source - + self.inference = StreamingInference( pipeline=self.pipeline, source=self.source, @@ -205,14 +205,14 @@ class DiartDiarization: async def diarize(self): """Return the current speaker segments from the diarization pipeline.""" - return self.observer.get_segments() + return self.observer.get_segments() def close(self): """Close the audio source.""" if self.custom_source: self.custom_source.close() - + def concatenate_speakers(segments): segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] for segment in segments: @@ -223,7 +223,7 @@ def concatenate_speakers(segments): segments_concatenated[-1]['end'] = segment.end # print("Segments concatenated:") # for entry in segments_concatenated: - # print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s") + # print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s") return segments_concatenated @@ -281,4 +281,4 @@ def visualize_tokens(tokens): conversation[-1]['text'] += token.text print("Conversation:") for entry in conversation: - print(f"Speaker {entry['speaker']}: {entry['text']}") \ No newline at end of file + print(f"Speaker {entry['speaker']}: {entry['text']}") diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index 2f60a46..b06525e 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -1,8 +1,6 @@ import logging import threading -import time import wave -from queue import Empty, SimpleQueue from typing import List, Optional import numpy as np @@ -54,7 +52,7 @@ class SortformerDiarization: Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized. """ self._load_model(model_name) - + def _load_model(self, model_name: str): """Load and configure the Sortformer model for streaming.""" try: @@ -63,12 +61,12 @@ class SortformerDiarization: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.diar_model.to(device) - + ## to test # for name, param in self.diar_model.named_parameters(): # if param.device != device: # raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}") - + logger.info(f"Using {device.type.upper()} for Sortformer model") self.diar_model.sortformer_modules.chunk_len = 10 @@ -80,16 +78,16 @@ class SortformerDiarization: self.diar_model.sortformer_modules.spkcache_update_period = 144 self.diar_model.sortformer_modules.log = False self.diar_model.sortformer_modules._check_streaming_parameters() - + except Exception as e: logger.error(f"Failed to load Sortformer model: {e}") raise - + class SortformerDiarizationOnline: def __init__(self, shared_model, sample_rate: int = 16000): """ Initialize the streaming Sortformer diarization system. - + Args: sample_rate: Audio sample rate (default: 16000) model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") @@ -101,9 +99,9 @@ class SortformerDiarizationOnline: self.segment_lock = threading.Lock() self.global_time_offset = 0.0 self.debug = False - + self.diar_model = shared_model.diar_model - + self.audio2mel = AudioToMelSpectrogramPreprocessor( window_size=0.025, normalize="NA", @@ -112,26 +110,26 @@ class SortformerDiarizationOnline: pad_to=0 ) self.audio2mel.to(self.diar_model.device) - + self.chunk_duration_seconds = ( - self.diar_model.sortformer_modules.chunk_len * - self.diar_model.sortformer_modules.subsampling_factor * + self.diar_model.sortformer_modules.chunk_len * + self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride ) - + self._init_streaming_state() - + self._previous_chunk_features = None self._chunk_index = 0 self._len_prediction = None - + # Audio buffer to store PCM chunks for debugging self.audio_buffer = [] - + # Buffer for accumulating audio chunks until reaching chunk_duration_seconds self.audio_chunk_buffer = [] self.accumulated_duration = 0.0 - + logger.info("SortformerDiarization initialized successfully") @@ -139,30 +137,30 @@ class SortformerDiarizationOnline: """Initialize the streaming state for the model.""" batch_size = 1 device = self.diar_model.device - + self.streaming_state = StreamingSortformerState() self.streaming_state.spkcache = torch.zeros( - (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model), + (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model), device=device ) self.streaming_state.spkcache_preds = torch.zeros( - (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk), + (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk), device=device ) self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.fifo = torch.zeros( - (batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model), + (batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model), device=device ) self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device) - self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) + self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device) def insert_silence(self, silence_duration: Optional[float]): """ Insert silence period by adjusting the global time offset. - + Args: silence_duration: Duration of silence in seconds """ @@ -174,48 +172,48 @@ class SortformerDiarizationOnline: if self.debug: self.audio_buffer.append(pcm_array.copy()) self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) - + async def diarize(self): """ Process audio data for diarization in streaming fashion. - + Args: pcm_array: Audio data as numpy array """ threshold = int(self.chunk_duration_seconds * self.sample_rate) - + if not len(self.buffer_audio) >= threshold: return [] - + audio = self.buffer_audio[:threshold] self.buffer_audio = self.buffer_audio[threshold:] - + device = self.diar_model.device audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device) - + processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( audio_signal_chunk, audio_signal_length_chunk ) processed_signal_chunk = processed_signal_chunk.to(device) processed_signal_length_chunk = processed_signal_length_chunk.to(device) - + if self._previous_chunk_features is not None: to_add = self._previous_chunk_features[:, :, -99:].to(device) total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device) else: total_features = processed_signal_chunk.to(device) - + self._previous_chunk_features = processed_signal_chunk.to(device) - + chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device) - + with torch.inference_mode(): left_offset = 8 if self._chunk_index > 0 else 0 right_offset = 8 - + self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( processed_signal=chunk_feat_seq_t, processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device), @@ -223,9 +221,9 @@ class SortformerDiarizationOnline: total_preds=self.total_preds, left_offset=left_offset, right_offset=right_offset, - ) + ) new_segments = self._process_predictions() - + self._chunk_index += 1 return new_segments @@ -233,13 +231,13 @@ class SortformerDiarizationOnline: """Process model predictions and convert to speaker segments.""" preds_np = self.total_preds[0].cpu().numpy() active_speakers = np.argmax(preds_np, axis=1) - + if self._len_prediction is None: self._len_prediction = len(active_speakers) #12 - + frame_duration = self.chunk_duration_seconds / self._len_prediction current_chunk_preds = active_speakers[-self._len_prediction:] - + new_segments = [] with self.segment_lock: @@ -264,7 +262,7 @@ class SortformerDiarizationOnline: ) ) return new_segments - + def get_segments(self) -> List[SpeakerSegment]: """Get a copy of the current speaker segments.""" with self.segment_lock: @@ -275,10 +273,10 @@ class SortformerDiarizationOnline: logger.info("Closing SortformerDiarization") with self.segment_lock: self.diarization_segments.clear() - + if self.debug: concatenated_audio = np.concatenate(self.audio_buffer) - audio_data_int16 = (concatenated_audio * 32767).astype(np.int16) + audio_data_int16 = (concatenated_audio * 32767).astype(np.int16) with wave.open("diarization_audio.wav", "wb") as wav_file: wav_file.setnchannels(1) # mono audio wav_file.setsampwidth(2) # 2 bytes per sample (int16) @@ -287,14 +285,13 @@ class SortformerDiarizationOnline: logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav") -from whisperlivekit.diarization.utils import extract_number if __name__ == '__main__': import asyncio import librosa - + async def main(): """TEST ONLY.""" an4_audio = 'diarization_audio.wav' @@ -304,24 +301,24 @@ if __name__ == '__main__': print("\n" + "=" * 50) print("ground truth:") print("Speaker 0: 0:00 - 0:09") - print("Speaker 1: 0:09 - 0:19") + print("Speaker 1: 0:09 - 0:19") print("Speaker 2: 0:19 - 0:25") print("Speaker 0: 0:25 - 0:30") print("=" * 50) - + diarization_backend = SortformerDiarization() - diarization = SortformerDiarizationOnline(shared_model = diarization_backend) + diarization = SortformerDiarizationOnline(shared_model = diarization_backend) chunk_size = 1600 - + for i in range(0, len(signal), chunk_size): chunk = signal[i:i+chunk_size] new_segments = await diarization.diarize(chunk) print(f"Processed chunk {i // chunk_size + 1}") print(new_segments) - + segments = diarization.get_segments() print("\nDiarization results:") for segment in segments: print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s") - + asyncio.run(main())