From a153e11fe06b7ad08118c5dc7cde5301f020cb4e Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 28 Sep 2025 11:04:00 +0200 Subject: [PATCH] update when self.diarization_before_transcription --- whisperlivekit/audio_processor.py | 59 +++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 97a1c76..f2cd342 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -16,6 +16,18 @@ logger.setLevel(logging.DEBUG) SENTINEL = object() # unique sentinel object for end of stream marker +def cut_at(cumulative_pcm, cut_sec): + cumulative_len = 0 + cut_sample = int(cut_sec * 16000) + + for ind, pcm_array in enumerate(cumulative_pcm): + if (cumulative_len + len(pcm_array)) >= cut_sample: + cut_chunk = cut_sample - cumulative_len + before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]]) + after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:] + return before, after + cumulative_len += len(pcm_array) + return np.concatenate(cumulative_pcm), [] async def get_all_from_queue(queue): items = [] @@ -62,14 +74,18 @@ class AudioProcessor: self.end_buffer = 0 self.end_attributed_speaker = 0 self.lock = asyncio.Lock() - self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio + self.beg_loop = 0.0 #to deal with a potential little lag at the websocket initialization, this is now set in process_audio self.sep = " " # Default separator self.last_response_content = FrontData() self.last_detected_speaker = None self.speaker_languages = {} - self.cumulative_pcm_len = 0 self.diarization_before_transcription = False + if self.diarization_before_transcription: + self.cumulative_pcm = [] + self.last_start = 0.0 + self.last_end = 0.0 + # Models and processing self.asr = models.asr self.vac_model = models.vac_model @@ -296,7 +312,9 @@ class AudioProcessor: async def diarization_processor(self, diarization_obj): """Process audio chunks for speaker diarization.""" - self.current_speaker = 0 + if self.diarization_before_transcription: + self.current_speaker = 0 + await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0)) while True: try: item = await self.diarization_queue.get() @@ -312,26 +330,39 @@ class AudioProcessor: else: raise Exception('item should be pcm_array') + + # Process diarization await diarization_obj.diarize(pcm_array) segments = diarization_obj.get_segments() - if self.diarization_before_transcription: - if segments and segments[-1].speaker != self.current_speaker: - self.current_speaker = segments[-1].speaker - cut_at = int(segments[-1].start*16000 - (self.cumulative_pcm_len)) - await self.transcription_queue.put(pcm_array[cut_at:]) - await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=cut_at)) - await self.transcription_queue.put(pcm_array[:cut_at]) - else: - await self.transcription_queue.put(pcm_array) - else: + self.cumulative_pcm.append(pcm_array) + if self.segments: + last_segment = segments[-1] + if last_segment.speaker != self.current_speaker: + cut_sec = last_segment.start - self.last_end + to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) + await self.transcription_queue.put(to_transcript) + + self.current_speaker = last_segment.speaker + await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=last_segment.start)) + + cut_sec = last_segment.end - last_segment.start + to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) + await self.transcription_queue.put(to_transcript) + self.last_start = last_segment.start + self.last_end = last_segment.end + else: + cut_sec = last_segment.end - self.last_end + to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) + await self.transcription_queue.put(to_transcript) + self.last_end = last_segment.end + elif not self.diarization_before_transcription: async with self.lock: self.tokens = diarization_obj.assign_speakers_to_tokens( self.tokens, use_punctuation_split=self.args.punctuation_split ) - self.cumulative_pcm_len += len(pcm_array) if len(self.tokens) > 0: self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker) self.diarization_queue.task_done()