fixes #224
This commit is contained in:
parent
cd160caaa1
commit
99dc96c644
6 changed files with 100 additions and 29 deletions
|
|
@ -257,12 +257,11 @@ class AudioProcessor:
|
||||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||||
if self.tokens:
|
if self.tokens:
|
||||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||||
logger.info(asr_processing_logs)
|
logger.info(asr_processing_logs)
|
||||||
|
|
||||||
if type(item) is Silence:
|
|
||||||
cumulative_pcm_duration_stream_time += item.duration
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||||
continue
|
continue
|
||||||
|
logger.info(asr_processing_logs)
|
||||||
|
|
||||||
if isinstance(item, np.ndarray):
|
if isinstance(item, np.ndarray):
|
||||||
pcm_array = item
|
pcm_array = item
|
||||||
|
|
@ -301,7 +300,7 @@ class AudioProcessor:
|
||||||
new_tokens, buffer_text, new_end_buffer
|
new_tokens, buffer_text, new_end_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_tokens and self.args.target_language and self.translation_queue:
|
if self.translation_queue:
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
await self.translation_queue.put(token)
|
await self.translation_queue.put(token)
|
||||||
|
|
||||||
|
|
@ -326,13 +325,11 @@ class AudioProcessor:
|
||||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
break
|
break
|
||||||
|
elif type(item) is Silence:
|
||||||
if type(item) is Silence:
|
|
||||||
cumulative_pcm_duration_stream_time += item.duration
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
diarization_obj.insert_silence(item.duration)
|
diarization_obj.insert_silence(item.duration)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(item, np.ndarray):
|
||||||
if isinstance(item, np.ndarray):
|
|
||||||
pcm_array = item
|
pcm_array = item
|
||||||
else:
|
else:
|
||||||
raise Exception('item should be pcm_array')
|
raise Exception('item should be pcm_array')
|
||||||
|
|
@ -365,14 +362,17 @@ class AudioProcessor:
|
||||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
token = await self.translation_queue.get() #block until at least 1 token
|
item = await self.translation_queue.get() #block until at least 1 token
|
||||||
if token is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Translation processor received sentinel. Finishing.")
|
logger.debug("Translation processor received sentinel. Finishing.")
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
break
|
break
|
||||||
|
elif type(item) is Silence:
|
||||||
|
online_translation.insert_silence(item.duration)
|
||||||
|
continue
|
||||||
|
|
||||||
# get all the available tokens for translation. The more words, the more precise
|
# get all the available tokens for translation. The more words, the more precise
|
||||||
tokens_to_process = [token]
|
tokens_to_process = [item]
|
||||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||||
|
|
||||||
sentinel_found = False
|
sentinel_found = False
|
||||||
|
|
@ -396,7 +396,7 @@ class AudioProcessor:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in translation_processor: {e}")
|
logger.warning(f"Exception in translation_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
if 'token' in locals() and token is not SENTINEL:
|
if 'token' in locals() and item is not SENTINEL:
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
if 'additional_tokens' in locals():
|
if 'additional_tokens' in locals():
|
||||||
for _ in additional_tokens:
|
for _ in additional_tokens:
|
||||||
|
|
@ -446,7 +446,7 @@ class AudioProcessor:
|
||||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||||
response_status = "no_audio_detected"
|
response_status = "no_audio_detected"
|
||||||
lines = []
|
lines = []
|
||||||
elif response_status == "active_transcription" and not lines:
|
elif not lines:
|
||||||
lines = [Line(
|
lines = [Line(
|
||||||
speaker=1,
|
speaker=1,
|
||||||
start=state.get("end_buffer", 0),
|
start=state.get("end_buffer", 0),
|
||||||
|
|
@ -638,6 +638,8 @@ class AudioProcessor:
|
||||||
await self.transcription_queue.put(silence_buffer)
|
await self.transcription_queue.put(silence_buffer)
|
||||||
if self.args.diarization and self.diarization_queue:
|
if self.args.diarization and self.diarization_queue:
|
||||||
await self.diarization_queue.put(silence_buffer)
|
await self.diarization_queue.put(silence_buffer)
|
||||||
|
if self.translation_queue:
|
||||||
|
await self.translation_queue.put(silence_buffer)
|
||||||
|
|
||||||
if not self.silence:
|
if not self.silence:
|
||||||
if self.args.transcription and self.transcription_queue:
|
if self.args.transcription and self.transcription_queue:
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ def blank_to_silence(tokens):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if silence_token: #there was silence but no more
|
if silence_token: #there was silence but no more
|
||||||
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||||
cleaned_tokens.append(
|
cleaned_tokens.append(
|
||||||
silence_token
|
silence_token
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -123,14 +123,33 @@ def format_output(state, silence, current_time, args, debug, sep):
|
||||||
|
|
||||||
append_token_to_last_line(lines, sep, token, debug_info)
|
append_token_to_last_line(lines, sep, token, debug_info)
|
||||||
if lines and translated_segments:
|
if lines and translated_segments:
|
||||||
cts_idx = 0 # current_translated_segment_idx
|
unassigned_translated_segments = []
|
||||||
for line in lines:
|
for ts in translated_segments:
|
||||||
while cts_idx < len(translated_segments):
|
assigned = False
|
||||||
ts = translated_segments[cts_idx]
|
for line in lines:
|
||||||
if ts and ts.start and ts.start >= line.start and ts.end <= line.end:
|
if ts and ts.overlaps_with(line):
|
||||||
line.translation += ts.text + ' '
|
if ts.is_within(line):
|
||||||
cts_idx += 1
|
line.translation += ts.text + ' '
|
||||||
else:
|
assigned = True
|
||||||
break
|
break
|
||||||
return lines, undiarized_text, buffer_transcription, ''
|
else:
|
||||||
|
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||||
|
if ts0 and line.overlaps_with(ts0):
|
||||||
|
line.translation += ts0.text + ' '
|
||||||
|
if ts1:
|
||||||
|
unassigned_translated_segments.append(ts1)
|
||||||
|
assigned = True
|
||||||
|
break
|
||||||
|
if not assigned:
|
||||||
|
unassigned_translated_segments.append(ts)
|
||||||
|
|
||||||
|
if unassigned_translated_segments:
|
||||||
|
for line in lines:
|
||||||
|
remaining_segments = []
|
||||||
|
for ts in unassigned_translated_segments:
|
||||||
|
if ts and ts.overlaps_with(line):
|
||||||
|
line.translation += ts.text + ' '
|
||||||
|
else:
|
||||||
|
remaining_segments.append(ts)
|
||||||
|
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||||
|
return lines, undiarized_text, buffer_transcription, ''
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, Any
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
def format_time(seconds: float) -> str:
|
||||||
|
|
@ -15,6 +15,21 @@ class TimedText:
|
||||||
speaker: Optional[int] = -1
|
speaker: Optional[int] = -1
|
||||||
probability: Optional[float] = None
|
probability: Optional[float] = None
|
||||||
is_dummy: Optional[bool] = False
|
is_dummy: Optional[bool] = False
|
||||||
|
|
||||||
|
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||||
|
return not (self.end <= other.start or other.end <= self.start)
|
||||||
|
|
||||||
|
def is_within(self, other: 'TimedText') -> bool:
|
||||||
|
return other.contains_timespan(self)
|
||||||
|
|
||||||
|
def duration(self) -> float:
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
def contains_time(self, time: float) -> bool:
|
||||||
|
return self.start <= time <= self.end
|
||||||
|
|
||||||
|
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||||
|
return self.start <= other.start and self.end >= other.end
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
|
|
@ -41,6 +56,34 @@ class SpeakerSegment(TimedText):
|
||||||
class Translation(TimedText):
|
class Translation(TimedText):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def approximate_cut_at(self, cut_time):
|
||||||
|
"""
|
||||||
|
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||||
|
"""
|
||||||
|
if not self.text or not self.contains_time(cut_time):
|
||||||
|
return self, None
|
||||||
|
|
||||||
|
words = self.text.split()
|
||||||
|
num_words = len(words)
|
||||||
|
if num_words == 0:
|
||||||
|
return self, None
|
||||||
|
|
||||||
|
duration_per_word = self.duration() / num_words
|
||||||
|
|
||||||
|
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||||
|
|
||||||
|
if cut_word_index >= num_words:
|
||||||
|
cut_word_index = num_words -1
|
||||||
|
|
||||||
|
text0 = " ".join(words[:cut_word_index])
|
||||||
|
text1 = " ".join(words[cut_word_index:])
|
||||||
|
|
||||||
|
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||||
|
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||||
|
|
||||||
|
return segment0, segment1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Silence():
|
class Silence():
|
||||||
duration: float
|
duration: float
|
||||||
|
|
@ -91,4 +134,4 @@ class State():
|
||||||
end_buffer: float
|
end_buffer: float
|
||||||
end_attributed_speaker: float
|
end_attributed_speaker: float
|
||||||
remaining_time_transcription: float
|
remaining_time_transcription: float
|
||||||
remaining_time_diarization: float
|
remaining_time_diarization: float
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
@ -6,11 +7,14 @@ import huggingface_hub
|
||||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||||
from whisperlivekit.timed_objects import Translation
|
from whisperlivekit.timed_objects import Translation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||||
|
|
||||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||||
|
|
||||||
|
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
||||||
|
# sentence is not finished.
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TranslationModel():
|
class TranslationModel():
|
||||||
|
|
@ -109,7 +113,11 @@ class OnlineTranslation:
|
||||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||||
self.len_processed_buffer = len(self.buffer)
|
self.len_processed_buffer = len(self.buffer)
|
||||||
return self.validated + [self.translation_remaining]
|
return self.validated + [self.translation_remaining]
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration: float):
|
||||||
|
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||||
|
self.buffer = []
|
||||||
|
self.validated += [self.translation_remaining]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
output_lang = 'fr'
|
output_lang = 'fr'
|
||||||
|
|
|
||||||
|
|
@ -438,7 +438,6 @@ label {
|
||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
border-radius: 30px;
|
border-radius: 30px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
display: none;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.loading {
|
.loading {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue