Improve diarization backends
This commit is contained in:
parent
32de7b1276
commit
e30f9a2573
2 changed files with 78 additions and 81 deletions
|
|
@ -20,25 +20,25 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DiarizationObserver(Observer):
|
class DiarizationObserver(Observer):
|
||||||
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.diarization_segments = []
|
self.diarization_segments = []
|
||||||
self.processed_time = 0
|
self.processed_time = 0
|
||||||
self.segment_lock = threading.Lock()
|
self.segment_lock = threading.Lock()
|
||||||
self.global_time_offset = 0.0
|
self.global_time_offset = 0.0
|
||||||
|
|
||||||
def on_next(self, value: Tuple[Annotation, Any]):
|
def on_next(self, value: Tuple[Annotation, Any]):
|
||||||
annotation, audio = value
|
annotation, audio = value
|
||||||
|
|
||||||
logger.debug("\n--- New Diarization Result ---")
|
logger.debug("\n--- New Diarization Result ---")
|
||||||
|
|
||||||
duration = audio.extent.end - audio.extent.start
|
duration = audio.extent.end - audio.extent.start
|
||||||
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
||||||
logger.debug(f"Audio shape: {audio.data.shape}")
|
logger.debug(f"Audio shape: {audio.data.shape}")
|
||||||
|
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
if audio.extent.end > self.processed_time:
|
if audio.extent.end > self.processed_time:
|
||||||
self.processed_time = audio.extent.end
|
self.processed_time = audio.extent.end
|
||||||
if annotation and len(annotation._labels) > 0:
|
if annotation and len(annotation._labels) > 0:
|
||||||
logger.debug("\nSpeaker segments:")
|
logger.debug("\nSpeaker segments:")
|
||||||
for speaker, label in annotation._labels.items():
|
for speaker, label in annotation._labels.items():
|
||||||
|
|
@ -51,25 +51,25 @@ class DiarizationObserver(Observer):
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
logger.debug("\nNo speakers detected in this segment")
|
logger.debug("\nNo speakers detected in this segment")
|
||||||
|
|
||||||
def get_segments(self) -> List[SpeakerSegment]:
|
def get_segments(self) -> List[SpeakerSegment]:
|
||||||
"""Get a copy of the current speaker segments."""
|
"""Get a copy of the current speaker segments."""
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
return self.diarization_segments.copy()
|
return self.diarization_segments.copy()
|
||||||
|
|
||||||
def clear_old_segments(self, older_than: float = 30.0):
|
def clear_old_segments(self, older_than: float = 30.0):
|
||||||
"""Clear segments older than the specified time."""
|
"""Clear segments older than the specified time."""
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
current_time = self.processed_time
|
current_time = self.processed_time
|
||||||
self.diarization_segments = [
|
self.diarization_segments = [
|
||||||
segment for segment in self.diarization_segments
|
segment for segment in self.diarization_segments
|
||||||
if current_time - segment.end < older_than
|
if current_time - segment.end < older_than
|
||||||
]
|
]
|
||||||
|
|
||||||
def on_error(self, error):
|
def on_error(self, error):
|
||||||
"""Handle an error in the stream."""
|
"""Handle an error in the stream."""
|
||||||
logger.debug(f"Error in diarization stream: {error}")
|
logger.debug(f"Error in diarization stream: {error}")
|
||||||
|
|
||||||
def on_completed(self):
|
def on_completed(self):
|
||||||
"""Handle the completion of the stream."""
|
"""Handle the completion of the stream."""
|
||||||
logger.debug("Diarization stream completed")
|
logger.debug("Diarization stream completed")
|
||||||
|
|
@ -96,7 +96,7 @@ class WebSocketAudioSource(AudioSource):
|
||||||
self._processing_thread = threading.Thread(target=self._process_chunks)
|
self._processing_thread = threading.Thread(target=self._process_chunks)
|
||||||
self._processing_thread.daemon = True
|
self._processing_thread.daemon = True
|
||||||
self._processing_thread.start()
|
self._processing_thread.start()
|
||||||
|
|
||||||
self._close_event.wait()
|
self._close_event.wait()
|
||||||
if self._processing_thread:
|
if self._processing_thread:
|
||||||
self._processing_thread.join(timeout=2.0)
|
self._processing_thread.join(timeout=2.0)
|
||||||
|
|
@ -106,30 +106,30 @@ class WebSocketAudioSource(AudioSource):
|
||||||
while not self._closed:
|
while not self._closed:
|
||||||
try:
|
try:
|
||||||
audio_chunk = self._queue.get(timeout=0.1)
|
audio_chunk = self._queue.get(timeout=0.1)
|
||||||
|
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
||||||
|
|
||||||
while len(self._buffer) >= self.block_size:
|
while len(self._buffer) >= self.block_size:
|
||||||
chunk = self._buffer[:self.block_size]
|
chunk = self._buffer[:self.block_size]
|
||||||
self._buffer = self._buffer[self.block_size:]
|
self._buffer = self._buffer[self.block_size:]
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
time_since_last = current_time - self._last_chunk_time
|
time_since_last = current_time - self._last_chunk_time
|
||||||
if time_since_last < self.block_duration:
|
if time_since_last < self.block_duration:
|
||||||
time.sleep(self.block_duration - time_since_last)
|
time.sleep(self.block_duration - time_since_last)
|
||||||
|
|
||||||
chunk_reshaped = chunk.reshape(1, -1)
|
chunk_reshaped = chunk.reshape(1, -1)
|
||||||
self.stream.on_next(chunk_reshaped)
|
self.stream.on_next(chunk_reshaped)
|
||||||
self._last_chunk_time = time.time()
|
self._last_chunk_time = time.time()
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
||||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||||
padded_chunk[:len(self._buffer)] = self._buffer
|
padded_chunk[:len(self._buffer)] = self._buffer
|
||||||
self._buffer = np.array([], dtype=np.float32)
|
self._buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||||
self.stream.on_next(chunk_reshaped)
|
self.stream.on_next(chunk_reshaped)
|
||||||
self._last_chunk_time = time.time()
|
self._last_chunk_time = time.time()
|
||||||
|
|
@ -137,14 +137,14 @@ class WebSocketAudioSource(AudioSource):
|
||||||
logger.error(f"Error in audio processing thread: {e}")
|
logger.error(f"Error in audio processing thread: {e}")
|
||||||
self.stream.on_error(e)
|
self.stream.on_error(e)
|
||||||
break
|
break
|
||||||
|
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
if len(self._buffer) > 0:
|
if len(self._buffer) > 0:
|
||||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||||
padded_chunk[:len(self._buffer)] = self._buffer
|
padded_chunk[:len(self._buffer)] = self._buffer
|
||||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||||
self.stream.on_next(chunk_reshaped)
|
self.stream.on_next(chunk_reshaped)
|
||||||
|
|
||||||
self.stream.on_completed()
|
self.stream.on_completed()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
@ -165,27 +165,27 @@ class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = SpeakerDiarizationConfig(
|
config = SpeakerDiarizationConfig(
|
||||||
segmentation=segmentation_model,
|
segmentation=segmentation_model,
|
||||||
embedding=embedding_model,
|
embedding=embedding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pipeline = SpeakerDiarization(config=config)
|
self.pipeline = SpeakerDiarization(config=config)
|
||||||
self.observer = DiarizationObserver()
|
self.observer = DiarizationObserver()
|
||||||
|
|
||||||
if use_microphone:
|
if use_microphone:
|
||||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||||
self.custom_source = None
|
self.custom_source = None
|
||||||
else:
|
else:
|
||||||
self.custom_source = WebSocketAudioSource(
|
self.custom_source = WebSocketAudioSource(
|
||||||
uri="websocket_source",
|
uri="websocket_source",
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
block_duration=block_duration
|
block_duration=block_duration
|
||||||
)
|
)
|
||||||
self.source = self.custom_source
|
self.source = self.custom_source
|
||||||
|
|
||||||
self.inference = StreamingInference(
|
self.inference = StreamingInference(
|
||||||
pipeline=self.pipeline,
|
pipeline=self.pipeline,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
|
|
@ -205,14 +205,14 @@ class DiartDiarization:
|
||||||
|
|
||||||
async def diarize(self):
|
async def diarize(self):
|
||||||
"""Return the current speaker segments from the diarization pipeline."""
|
"""Return the current speaker segments from the diarization pipeline."""
|
||||||
return self.observer.get_segments()
|
return self.observer.get_segments()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.close()
|
self.custom_source.close()
|
||||||
|
|
||||||
|
|
||||||
def concatenate_speakers(segments):
|
def concatenate_speakers(segments):
|
||||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
|
|
@ -223,7 +223,7 @@ def concatenate_speakers(segments):
|
||||||
segments_concatenated[-1]['end'] = segment.end
|
segments_concatenated[-1]['end'] = segment.end
|
||||||
# print("Segments concatenated:")
|
# print("Segments concatenated:")
|
||||||
# for entry in segments_concatenated:
|
# for entry in segments_concatenated:
|
||||||
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||||
return segments_concatenated
|
return segments_concatenated
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -281,4 +281,4 @@ def visualize_tokens(tokens):
|
||||||
conversation[-1]['text'] += token.text
|
conversation[-1]['text'] += token.text
|
||||||
print("Conversation:")
|
print("Conversation:")
|
||||||
for entry in conversation:
|
for entry in conversation:
|
||||||
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import wave
|
import wave
|
||||||
from queue import Empty, SimpleQueue
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -54,7 +52,7 @@ class SortformerDiarization:
|
||||||
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
||||||
"""
|
"""
|
||||||
self._load_model(model_name)
|
self._load_model(model_name)
|
||||||
|
|
||||||
def _load_model(self, model_name: str):
|
def _load_model(self, model_name: str):
|
||||||
"""Load and configure the Sortformer model for streaming."""
|
"""Load and configure the Sortformer model for streaming."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -63,12 +61,12 @@ class SortformerDiarization:
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.diar_model.to(device)
|
self.diar_model.to(device)
|
||||||
|
|
||||||
## to test
|
## to test
|
||||||
# for name, param in self.diar_model.named_parameters():
|
# for name, param in self.diar_model.named_parameters():
|
||||||
# if param.device != device:
|
# if param.device != device:
|
||||||
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
||||||
|
|
||||||
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
||||||
|
|
||||||
self.diar_model.sortformer_modules.chunk_len = 10
|
self.diar_model.sortformer_modules.chunk_len = 10
|
||||||
|
|
@ -80,16 +78,16 @@ class SortformerDiarization:
|
||||||
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||||
self.diar_model.sortformer_modules.log = False
|
self.diar_model.sortformer_modules.log = False
|
||||||
self.diar_model.sortformer_modules._check_streaming_parameters()
|
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load Sortformer model: {e}")
|
logger.error(f"Failed to load Sortformer model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
class SortformerDiarizationOnline:
|
class SortformerDiarizationOnline:
|
||||||
def __init__(self, shared_model, sample_rate: int = 16000):
|
def __init__(self, shared_model, sample_rate: int = 16000):
|
||||||
"""
|
"""
|
||||||
Initialize the streaming Sortformer diarization system.
|
Initialize the streaming Sortformer diarization system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_rate: Audio sample rate (default: 16000)
|
sample_rate: Audio sample rate (default: 16000)
|
||||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||||
|
|
@ -101,9 +99,9 @@ class SortformerDiarizationOnline:
|
||||||
self.segment_lock = threading.Lock()
|
self.segment_lock = threading.Lock()
|
||||||
self.global_time_offset = 0.0
|
self.global_time_offset = 0.0
|
||||||
self.debug = False
|
self.debug = False
|
||||||
|
|
||||||
self.diar_model = shared_model.diar_model
|
self.diar_model = shared_model.diar_model
|
||||||
|
|
||||||
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||||
window_size=0.025,
|
window_size=0.025,
|
||||||
normalize="NA",
|
normalize="NA",
|
||||||
|
|
@ -112,26 +110,26 @@ class SortformerDiarizationOnline:
|
||||||
pad_to=0
|
pad_to=0
|
||||||
)
|
)
|
||||||
self.audio2mel.to(self.diar_model.device)
|
self.audio2mel.to(self.diar_model.device)
|
||||||
|
|
||||||
self.chunk_duration_seconds = (
|
self.chunk_duration_seconds = (
|
||||||
self.diar_model.sortformer_modules.chunk_len *
|
self.diar_model.sortformer_modules.chunk_len *
|
||||||
self.diar_model.sortformer_modules.subsampling_factor *
|
self.diar_model.sortformer_modules.subsampling_factor *
|
||||||
self.diar_model.preprocessor._cfg.window_stride
|
self.diar_model.preprocessor._cfg.window_stride
|
||||||
)
|
)
|
||||||
|
|
||||||
self._init_streaming_state()
|
self._init_streaming_state()
|
||||||
|
|
||||||
self._previous_chunk_features = None
|
self._previous_chunk_features = None
|
||||||
self._chunk_index = 0
|
self._chunk_index = 0
|
||||||
self._len_prediction = None
|
self._len_prediction = None
|
||||||
|
|
||||||
# Audio buffer to store PCM chunks for debugging
|
# Audio buffer to store PCM chunks for debugging
|
||||||
self.audio_buffer = []
|
self.audio_buffer = []
|
||||||
|
|
||||||
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
||||||
self.audio_chunk_buffer = []
|
self.audio_chunk_buffer = []
|
||||||
self.accumulated_duration = 0.0
|
self.accumulated_duration = 0.0
|
||||||
|
|
||||||
logger.info("SortformerDiarization initialized successfully")
|
logger.info("SortformerDiarization initialized successfully")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -139,30 +137,30 @@ class SortformerDiarizationOnline:
|
||||||
"""Initialize the streaming state for the model."""
|
"""Initialize the streaming state for the model."""
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
device = self.diar_model.device
|
device = self.diar_model.device
|
||||||
|
|
||||||
self.streaming_state = StreamingSortformerState()
|
self.streaming_state = StreamingSortformerState()
|
||||||
self.streaming_state.spkcache = torch.zeros(
|
self.streaming_state.spkcache = torch.zeros(
|
||||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
self.streaming_state.spkcache_preds = torch.zeros(
|
self.streaming_state.spkcache_preds = torch.zeros(
|
||||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
self.streaming_state.fifo = torch.zeros(
|
self.streaming_state.fifo = torch.zeros(
|
||||||
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||||
|
|
||||||
def insert_silence(self, silence_duration: Optional[float]):
|
def insert_silence(self, silence_duration: Optional[float]):
|
||||||
"""
|
"""
|
||||||
Insert silence period by adjusting the global time offset.
|
Insert silence period by adjusting the global time offset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
silence_duration: Duration of silence in seconds
|
silence_duration: Duration of silence in seconds
|
||||||
"""
|
"""
|
||||||
|
|
@ -174,48 +172,48 @@ class SortformerDiarizationOnline:
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.audio_buffer.append(pcm_array.copy())
|
self.audio_buffer.append(pcm_array.copy())
|
||||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||||
|
|
||||||
|
|
||||||
async def diarize(self):
|
async def diarize(self):
|
||||||
"""
|
"""
|
||||||
Process audio data for diarization in streaming fashion.
|
Process audio data for diarization in streaming fashion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pcm_array: Audio data as numpy array
|
pcm_array: Audio data as numpy array
|
||||||
"""
|
"""
|
||||||
|
|
||||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||||
|
|
||||||
if not len(self.buffer_audio) >= threshold:
|
if not len(self.buffer_audio) >= threshold:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
audio = self.buffer_audio[:threshold]
|
audio = self.buffer_audio[:threshold]
|
||||||
self.buffer_audio = self.buffer_audio[threshold:]
|
self.buffer_audio = self.buffer_audio[threshold:]
|
||||||
|
|
||||||
device = self.diar_model.device
|
device = self.diar_model.device
|
||||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||||
|
|
||||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||||
audio_signal_chunk, audio_signal_length_chunk
|
audio_signal_chunk, audio_signal_length_chunk
|
||||||
)
|
)
|
||||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||||
|
|
||||||
if self._previous_chunk_features is not None:
|
if self._previous_chunk_features is not None:
|
||||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||||
else:
|
else:
|
||||||
total_features = processed_signal_chunk.to(device)
|
total_features = processed_signal_chunk.to(device)
|
||||||
|
|
||||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||||
|
|
||||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
left_offset = 8 if self._chunk_index > 0 else 0
|
left_offset = 8 if self._chunk_index > 0 else 0
|
||||||
right_offset = 8
|
right_offset = 8
|
||||||
|
|
||||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||||
processed_signal=chunk_feat_seq_t,
|
processed_signal=chunk_feat_seq_t,
|
||||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||||
|
|
@ -223,9 +221,9 @@ class SortformerDiarizationOnline:
|
||||||
total_preds=self.total_preds,
|
total_preds=self.total_preds,
|
||||||
left_offset=left_offset,
|
left_offset=left_offset,
|
||||||
right_offset=right_offset,
|
right_offset=right_offset,
|
||||||
)
|
)
|
||||||
new_segments = self._process_predictions()
|
new_segments = self._process_predictions()
|
||||||
|
|
||||||
self._chunk_index += 1
|
self._chunk_index += 1
|
||||||
return new_segments
|
return new_segments
|
||||||
|
|
||||||
|
|
@ -233,13 +231,13 @@ class SortformerDiarizationOnline:
|
||||||
"""Process model predictions and convert to speaker segments."""
|
"""Process model predictions and convert to speaker segments."""
|
||||||
preds_np = self.total_preds[0].cpu().numpy()
|
preds_np = self.total_preds[0].cpu().numpy()
|
||||||
active_speakers = np.argmax(preds_np, axis=1)
|
active_speakers = np.argmax(preds_np, axis=1)
|
||||||
|
|
||||||
if self._len_prediction is None:
|
if self._len_prediction is None:
|
||||||
self._len_prediction = len(active_speakers) #12
|
self._len_prediction = len(active_speakers) #12
|
||||||
|
|
||||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||||
|
|
||||||
new_segments = []
|
new_segments = []
|
||||||
|
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
|
|
@ -264,7 +262,7 @@ class SortformerDiarizationOnline:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return new_segments
|
return new_segments
|
||||||
|
|
||||||
def get_segments(self) -> List[SpeakerSegment]:
|
def get_segments(self) -> List[SpeakerSegment]:
|
||||||
"""Get a copy of the current speaker segments."""
|
"""Get a copy of the current speaker segments."""
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
|
|
@ -275,10 +273,10 @@ class SortformerDiarizationOnline:
|
||||||
logger.info("Closing SortformerDiarization")
|
logger.info("Closing SortformerDiarization")
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
self.diarization_segments.clear()
|
self.diarization_segments.clear()
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||||
wav_file.setnchannels(1) # mono audio
|
wav_file.setnchannels(1) # mono audio
|
||||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||||
|
|
@ -287,14 +285,13 @@ class SortformerDiarizationOnline:
|
||||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||||
|
|
||||||
|
|
||||||
from whisperlivekit.diarization.utils import extract_number
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""TEST ONLY."""
|
"""TEST ONLY."""
|
||||||
an4_audio = 'diarization_audio.wav'
|
an4_audio = 'diarization_audio.wav'
|
||||||
|
|
@ -304,24 +301,24 @@ if __name__ == '__main__':
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print("ground truth:")
|
print("ground truth:")
|
||||||
print("Speaker 0: 0:00 - 0:09")
|
print("Speaker 0: 0:00 - 0:09")
|
||||||
print("Speaker 1: 0:09 - 0:19")
|
print("Speaker 1: 0:09 - 0:19")
|
||||||
print("Speaker 2: 0:19 - 0:25")
|
print("Speaker 2: 0:19 - 0:25")
|
||||||
print("Speaker 0: 0:25 - 0:30")
|
print("Speaker 0: 0:25 - 0:30")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
diarization_backend = SortformerDiarization()
|
diarization_backend = SortformerDiarization()
|
||||||
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
|
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
|
||||||
chunk_size = 1600
|
chunk_size = 1600
|
||||||
|
|
||||||
for i in range(0, len(signal), chunk_size):
|
for i in range(0, len(signal), chunk_size):
|
||||||
chunk = signal[i:i+chunk_size]
|
chunk = signal[i:i+chunk_size]
|
||||||
new_segments = await diarization.diarize(chunk)
|
new_segments = await diarization.diarize(chunk)
|
||||||
print(f"Processed chunk {i // chunk_size + 1}")
|
print(f"Processed chunk {i // chunk_size + 1}")
|
||||||
print(new_segments)
|
print(new_segments)
|
||||||
|
|
||||||
segments = diarization.get_segments()
|
segments = diarization.get_segments()
|
||||||
print("\nDiarization results:")
|
print("\nDiarization results:")
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue