This commit is contained in:
Quentin Fuxa 2025-09-14 17:03:00 +02:00
parent cd160caaa1
commit 99dc96c644
6 changed files with 100 additions and 29 deletions

View file

@ -257,12 +257,11 @@ class AudioProcessor:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s" asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.tokens: if self.tokens:
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |" asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs) logger.info(asr_processing_logs)
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0) self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue continue
logger.info(asr_processing_logs)
if isinstance(item, np.ndarray): if isinstance(item, np.ndarray):
pcm_array = item pcm_array = item
@ -301,7 +300,7 @@ class AudioProcessor:
new_tokens, buffer_text, new_end_buffer new_tokens, buffer_text, new_end_buffer
) )
if new_tokens and self.args.target_language and self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
await self.translation_queue.put(token) await self.translation_queue.put(token)
@ -326,13 +325,11 @@ class AudioProcessor:
logger.debug("Diarization processor received sentinel. Finishing.") logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done() self.diarization_queue.task_done()
break break
elif type(item) is Silence:
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration cumulative_pcm_duration_stream_time += item.duration
diarization_obj.insert_silence(item.duration) diarization_obj.insert_silence(item.duration)
continue continue
elif isinstance(item, np.ndarray):
if isinstance(item, np.ndarray):
pcm_array = item pcm_array = item
else: else:
raise Exception('item should be pcm_array') raise Exception('item should be pcm_array')
@ -365,14 +362,17 @@ class AudioProcessor:
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
try: try:
token = await self.translation_queue.get() #block until at least 1 token item = await self.translation_queue.get() #block until at least 1 token
if token is SENTINEL: if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.") logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done() self.translation_queue.task_done()
break break
elif type(item) is Silence:
online_translation.insert_silence(item.duration)
continue
# get all the available tokens for translation. The more words, the more precise # get all the available tokens for translation. The more words, the more precise
tokens_to_process = [token] tokens_to_process = [item]
additional_tokens = await get_all_from_queue(self.translation_queue) additional_tokens = await get_all_from_queue(self.translation_queue)
sentinel_found = False sentinel_found = False
@ -396,7 +396,7 @@ class AudioProcessor:
except Exception as e: except Exception as e:
logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and token is not SENTINEL: if 'token' in locals() and item is not SENTINEL:
self.translation_queue.task_done() self.translation_queue.task_done()
if 'additional_tokens' in locals(): if 'additional_tokens' in locals():
for _ in additional_tokens: for _ in additional_tokens:
@ -446,7 +446,7 @@ class AudioProcessor:
if not state.tokens and not buffer_transcription and not buffer_diarization: if not state.tokens and not buffer_transcription and not buffer_diarization:
response_status = "no_audio_detected" response_status = "no_audio_detected"
lines = [] lines = []
elif response_status == "active_transcription" and not lines: elif not lines:
lines = [Line( lines = [Line(
speaker=1, speaker=1,
start=state.get("end_buffer", 0), start=state.get("end_buffer", 0),
@ -638,6 +638,8 @@ class AudioProcessor:
await self.transcription_queue.put(silence_buffer) await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer) await self.diarization_queue.put(silence_buffer)
if self.translation_queue:
await self.translation_queue.put(silence_buffer)
if not self.silence: if not self.silence:
if self.args.transcription and self.transcription_queue: if self.args.transcription and self.transcription_queue:

View file

@ -39,7 +39,7 @@ def blank_to_silence(tokens):
) )
else: else:
if silence_token: #there was silence but no more if silence_token: #there was silence but no more
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION: if silence_token.duration() >= MIN_SILENCE_DURATION:
cleaned_tokens.append( cleaned_tokens.append(
silence_token silence_token
) )

View file

@ -123,14 +123,33 @@ def format_output(state, silence, current_time, args, debug, sep):
append_token_to_last_line(lines, sep, token, debug_info) append_token_to_last_line(lines, sep, token, debug_info)
if lines and translated_segments: if lines and translated_segments:
cts_idx = 0 # current_translated_segment_idx unassigned_translated_segments = []
for line in lines: for ts in translated_segments:
while cts_idx < len(translated_segments): assigned = False
ts = translated_segments[cts_idx] for line in lines:
if ts and ts.start and ts.start >= line.start and ts.end <= line.end: if ts and ts.overlaps_with(line):
line.translation += ts.text + ' ' if ts.is_within(line):
cts_idx += 1 line.translation += ts.text + ' '
else: assigned = True
break break
return lines, undiarized_text, buffer_transcription, '' else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
return lines, undiarized_text, buffer_transcription, ''

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, Any
from datetime import timedelta from datetime import timedelta
def format_time(seconds: float) -> str: def format_time(seconds: float) -> str:
@ -15,6 +15,21 @@ class TimedText:
speaker: Optional[int] = -1 speaker: Optional[int] = -1
probability: Optional[float] = None probability: Optional[float] = None
is_dummy: Optional[bool] = False is_dummy: Optional[bool] = False
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self)
def duration(self) -> float:
return self.end - self.start
def contains_time(self, time: float) -> bool:
return self.start <= time <= self.end
def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end
@dataclass @dataclass
class ASRToken(TimedText): class ASRToken(TimedText):
@ -41,6 +56,34 @@ class SpeakerSegment(TimedText):
class Translation(TimedText): class Translation(TimedText):
pass pass
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
"""
if not self.text or not self.contains_time(cut_time):
return self, None
words = self.text.split()
num_words = len(words)
if num_words == 0:
return self, None
duration_per_word = self.duration() / num_words
cut_word_index = int((cut_time - self.start) / duration_per_word)
if cut_word_index >= num_words:
cut_word_index = num_words -1
text0 = " ".join(words[:cut_word_index])
text1 = " ".join(words[cut_word_index:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass @dataclass
class Silence(): class Silence():
duration: float duration: float
@ -91,4 +134,4 @@ class State():
end_buffer: float end_buffer: float
end_attributed_speaker: float end_attributed_speaker: float
remaining_time_transcription: float remaining_time_transcription: float
remaining_time_diarization: float remaining_time_diarization: float

View file

@ -1,3 +1,4 @@
import logging
import ctranslate2 import ctranslate2
import torch import torch
import transformers import transformers
@ -6,11 +7,14 @@ import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code from whisperlivekit.translation.mapping_languages import get_nllb_code
from whisperlivekit.timed_objects import Translation from whisperlivekit.timed_objects import Translation
logger = logging.getLogger(__name__)
#In diarization case, we may want to translate just one speaker, or at least start the sentences there #In diarization case, we may want to translate just one speaker, or at least start the sentences there
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''} PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
# sentence is not finished.
@dataclass @dataclass
class TranslationModel(): class TranslationModel():
@ -109,7 +113,11 @@ class OnlineTranslation:
self.translation_remaining = self.translate_tokens(self.buffer) self.translation_remaining = self.translate_tokens(self.buffer)
self.len_processed_buffer = len(self.buffer) self.len_processed_buffer = len(self.buffer)
return self.validated + [self.translation_remaining] return self.validated + [self.translation_remaining]
def insert_silence(self, silence_duration: float):
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
self.buffer = []
self.validated += [self.translation_remaining]
if __name__ == '__main__': if __name__ == '__main__':
output_lang = 'fr' output_lang = 'fr'

View file

@ -438,7 +438,6 @@ label {
font-size: 13px; font-size: 13px;
border-radius: 30px; border-radius: 30px;
padding: 2px 10px; padding: 2px 10px;
display: none;
} }
.loading { .loading {