in buffer while language not detected »
This commit is contained in:
parent
a5503308c5
commit
674b20d3af
5 changed files with 86 additions and 78 deletions
|
|
@ -4,7 +4,7 @@ from time import time, sleep
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State
|
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript
|
||||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||||
from whisperlivekit.results_formater import format_output
|
from whisperlivekit.results_formater import format_output
|
||||||
|
|
@ -58,7 +58,7 @@ class AudioProcessor:
|
||||||
self.silence_duration = 0.0
|
self.silence_duration = 0.0
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.translated_segments = []
|
self.translated_segments = []
|
||||||
self.buffer_transcription = ""
|
self.buffer_transcription = Transcript()
|
||||||
self.buffer_diarization = ""
|
self.buffer_diarization = ""
|
||||||
self.end_buffer = 0
|
self.end_buffer = 0
|
||||||
self.end_attributed_speaker = 0
|
self.end_attributed_speaker = 0
|
||||||
|
|
@ -114,20 +114,6 @@ class AudioProcessor:
|
||||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
async def update_transcription(self, new_tokens, buffer, end_buffer):
|
|
||||||
"""Thread-safe update of transcription with new data."""
|
|
||||||
async with self.lock:
|
|
||||||
self.tokens.extend(new_tokens)
|
|
||||||
self.buffer_transcription = buffer
|
|
||||||
self.end_buffer = end_buffer
|
|
||||||
|
|
||||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
|
||||||
"""Thread-safe update of diarization with new data."""
|
|
||||||
async with self.lock:
|
|
||||||
self.end_attributed_speaker = end_attributed_speaker
|
|
||||||
if buffer_diarization:
|
|
||||||
self.buffer_diarization = buffer_diarization
|
|
||||||
|
|
||||||
async def add_dummy_token(self):
|
async def add_dummy_token(self):
|
||||||
"""Placeholder token when no transcription is available."""
|
"""Placeholder token when no transcription is available."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
|
|
@ -168,7 +154,7 @@ class AudioProcessor:
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.translated_segments = []
|
self.translated_segments = []
|
||||||
self.buffer_transcription = self.buffer_diarization = ""
|
self.buffer_transcription = self.buffer_diarization = Transcript()
|
||||||
self.end_buffer = self.end_attributed_speaker = 0
|
self.end_buffer = self.end_attributed_speaker = 0
|
||||||
self.beg_loop = time()
|
self.beg_loop = time()
|
||||||
|
|
||||||
|
|
@ -264,30 +250,28 @@ class AudioProcessor:
|
||||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
|
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
|
||||||
|
|
||||||
# Get buffer information
|
_buffer_transcript = self.online.get_buffer()
|
||||||
_buffer_transcript_obj = self.online.get_buffer()
|
buffer_text = _buffer_transcript.text
|
||||||
buffer_text = _buffer_transcript_obj.text
|
|
||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||||
if buffer_text.startswith(validated_text):
|
if buffer_text.startswith(validated_text):
|
||||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||||
|
|
||||||
candidate_end_times = [self.end_buffer]
|
candidate_end_times = [self.end_buffer]
|
||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
candidate_end_times.append(new_tokens[-1].end)
|
candidate_end_times.append(new_tokens[-1].end)
|
||||||
|
|
||||||
if _buffer_transcript_obj.end is not None:
|
if _buffer_transcript.end is not None:
|
||||||
candidate_end_times.append(_buffer_transcript_obj.end)
|
candidate_end_times.append(_buffer_transcript.end)
|
||||||
|
|
||||||
candidate_end_times.append(current_audio_processed_upto)
|
candidate_end_times.append(current_audio_processed_upto)
|
||||||
|
|
||||||
new_end_buffer = max(candidate_end_times)
|
async with self.lock:
|
||||||
|
self.tokens.extend(new_tokens)
|
||||||
await self.update_transcription(
|
self.buffer_transcription = _buffer_transcript
|
||||||
new_tokens, buffer_text, new_end_buffer
|
self.end_buffer = max(candidate_end_times)
|
||||||
)
|
|
||||||
|
|
||||||
if self.translation_queue:
|
if self.translation_queue:
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
|
|
@ -438,8 +422,8 @@ class AudioProcessor:
|
||||||
sep=self.sep
|
sep=self.sep
|
||||||
)
|
)
|
||||||
if end_w_silence:
|
if end_w_silence:
|
||||||
buffer_transcription = ''
|
buffer_transcription = Transcript()
|
||||||
buffer_diarization = ''
|
buffer_diarization = Transcript()
|
||||||
else:
|
else:
|
||||||
buffer_transcription = state.buffer_transcription
|
buffer_transcription = state.buffer_transcription
|
||||||
buffer_diarization = state.buffer_diarization
|
buffer_diarization = state.buffer_diarization
|
||||||
|
|
@ -449,8 +433,13 @@ class AudioProcessor:
|
||||||
combined = self.sep.join(undiarized_text)
|
combined = self.sep.join(undiarized_text)
|
||||||
if buffer_transcription:
|
if buffer_transcription:
|
||||||
combined += self.sep
|
combined += self.sep
|
||||||
await self.update_diarization(state.end_attributed_speaker, combined)
|
|
||||||
buffer_diarization = combined
|
async with self.lock:
|
||||||
|
self.end_attributed_speaker = state.end_attributed_speaker
|
||||||
|
if buffer_diarization:
|
||||||
|
self.buffer_diarization = buffer_diarization
|
||||||
|
|
||||||
|
buffer_diarization.text = combined
|
||||||
|
|
||||||
response_status = "active_transcription"
|
response_status = "active_transcription"
|
||||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||||
|
|
@ -466,8 +455,8 @@ class AudioProcessor:
|
||||||
response = FrontData(
|
response = FrontData(
|
||||||
status=response_status,
|
status=response_status,
|
||||||
lines=lines,
|
lines=lines,
|
||||||
buffer_transcription=buffer_transcription,
|
buffer_transcription=buffer_transcription.text,
|
||||||
buffer_diarization=buffer_diarization,
|
buffer_diarization=buffer_transcription.text,
|
||||||
remaining_time_transcription=state.remaining_time_transcription,
|
remaining_time_transcription=state.remaining_time_transcription,
|
||||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -150,4 +150,8 @@ def format_output(state, silence, current_time, args, debug, sep):
|
||||||
else:
|
else:
|
||||||
remaining_segments.append(ts)
|
remaining_segments.append(ts)
|
||||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||||
|
|
||||||
|
if state.buffer_transcription and lines:
|
||||||
|
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||||
|
|
||||||
return lines, undiarized_text, end_w_silence
|
return lines, undiarized_text, end_w_silence
|
||||||
|
|
|
||||||
|
|
@ -52,12 +52,7 @@ class SimulStreamingOnlineProcessor:
|
||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.end = 0.0
|
self.end = 0.0
|
||||||
self.buffer = Transcript(
|
self.buffer = []
|
||||||
start=None,
|
|
||||||
end=None,
|
|
||||||
text='',
|
|
||||||
probability=None
|
|
||||||
)
|
|
||||||
self.committed: List[ASRToken] = []
|
self.committed: List[ASRToken] = []
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
self.load_new_backend()
|
self.load_new_backend()
|
||||||
|
|
@ -103,8 +98,9 @@ class SimulStreamingOnlineProcessor:
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
|
|
||||||
def get_buffer(self):
|
def get_buffer(self):
|
||||||
return self.buffer
|
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||||
|
return concat_buffer
|
||||||
|
|
||||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
Process accumulated audio chunks using SimulStreaming.
|
Process accumulated audio chunks using SimulStreaming.
|
||||||
|
|
@ -112,9 +108,10 @@ class SimulStreamingOnlineProcessor:
|
||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
new_tokens = self.model.infer(is_last=is_last)
|
timestamped_words, timestamped_buffer_language = self.model.infer(is_last=is_last)
|
||||||
self.committed.extend(new_tokens)
|
self.buffer = timestamped_buffer_language
|
||||||
return new_tokens, self.end
|
self.committed.extend(timestamped_words)
|
||||||
|
return timestamped_words, self.end
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,7 @@ class PaddedAlignAttWhisper:
|
||||||
)
|
)
|
||||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
|
# self.create_tokenizer('en')
|
||||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
self.global_time_offset = 0.0
|
self.global_time_offset = 0.0
|
||||||
self.reset_tokenizer_to_auto_next_call = False
|
self.reset_tokenizer_to_auto_next_call = False
|
||||||
|
|
@ -433,21 +434,18 @@ class PaddedAlignAttWhisper:
|
||||||
end_encode = time()
|
end_encode = time()
|
||||||
# print('Encoder duration:', end_encode-beg_encode)
|
# print('Encoder duration:', end_encode-beg_encode)
|
||||||
|
|
||||||
# if self.cfg.language == "auto" and self.detected_language is None:
|
if self.cfg.language == "auto" and self.detected_language is None:
|
||||||
# seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
|
seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
|
||||||
# if seconds_since_start >= 3.0:
|
if seconds_since_start >= 3.0:
|
||||||
# language_tokens, language_probs = self.lang_id(encoder_feature)
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
# logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
# top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
# logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
self.create_tokenizer(top_lan)
|
||||||
# #self.tokenizer.language = top_lan
|
self.refresh_segment(complete=True)
|
||||||
# #self.tokenizer.__post_init__()
|
self.detected_language = top_lan
|
||||||
# self.create_tokenizer(top_lan)
|
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||||
# self.detected_language = top_lan
|
else:
|
||||||
# self.init_tokens()
|
logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
|
||||||
# logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
|
||||||
# else:
|
|
||||||
# logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
|
|
||||||
|
|
||||||
self.trim_context()
|
self.trim_context()
|
||||||
current_tokens = self._current_tokens()
|
current_tokens = self._current_tokens()
|
||||||
|
|
@ -495,19 +493,6 @@ class PaddedAlignAttWhisper:
|
||||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||||
self.debug_print_tokens(current_tokens)
|
self.debug_print_tokens(current_tokens)
|
||||||
|
|
||||||
# # Early stop on sentence-ending punctuation when language is auto
|
|
||||||
# if not completed and self.cfg.language == "auto":
|
|
||||||
# last_token_id = current_tokens[0, -1].item()
|
|
||||||
# last_token_text = self.tokenizer.decode([last_token_id]).strip()
|
|
||||||
# if last_token_text in PUNCTUATION_MARKS:
|
|
||||||
# logger.debug(f"Punctuation boundary '{last_token_text}' hit; stopping early to allow language re-check.")
|
|
||||||
# punctuation_stop = True
|
|
||||||
# # Ensure next call starts with auto language (re-detect for new sentence)
|
|
||||||
# self.reset_tokenizer_to_auto_next_call = True
|
|
||||||
# self.detected_language = None
|
|
||||||
# self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
|
|
||||||
# break
|
|
||||||
|
|
||||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||||
for i, attn_mat in enumerate(self.dec_attns):
|
for i, attn_mat in enumerate(self.dec_attns):
|
||||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||||
|
|
@ -617,10 +602,17 @@ class PaddedAlignAttWhisper:
|
||||||
timestamp_entry = ASRToken(
|
timestamp_entry = ASRToken(
|
||||||
start=current_timestamp,
|
start=current_timestamp,
|
||||||
end=current_timestamp + 0.1,
|
end=current_timestamp + 0.1,
|
||||||
text=word,
|
text= word,
|
||||||
probability=0.95
|
probability=0.95,
|
||||||
|
language=self.detected_language
|
||||||
).with_offset(
|
).with_offset(
|
||||||
self.global_time_offset
|
self.global_time_offset
|
||||||
)
|
)
|
||||||
timestamped_words.append(timestamp_entry)
|
timestamped_words.append(timestamp_entry)
|
||||||
return timestamped_words
|
|
||||||
|
if self.detected_language is None and self.cfg.language == "auto":
|
||||||
|
timestamped_buffer_language, timestamped_words = timestamped_words, []
|
||||||
|
else:
|
||||||
|
timestamped_buffer_language = []
|
||||||
|
|
||||||
|
return timestamped_words, timestamped_buffer_language
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, List
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||||
|
|
@ -17,6 +17,7 @@ class TimedText:
|
||||||
speaker: Optional[int] = -1
|
speaker: Optional[int] = -1
|
||||||
probability: Optional[float] = None
|
probability: Optional[float] = None
|
||||||
is_dummy: Optional[bool] = False
|
is_dummy: Optional[bool] = False
|
||||||
|
language: str = None
|
||||||
|
|
||||||
def is_punctuation(self):
|
def is_punctuation(self):
|
||||||
return self.text.strip() in PUNCTUATION_MARKS
|
return self.text.strip() in PUNCTUATION_MARKS
|
||||||
|
|
@ -35,6 +36,10 @@ class TimedText:
|
||||||
|
|
||||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||||
return self.start <= other.start and self.end >= other.end
|
return self.start <= other.start and self.end >= other.end
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.text)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
|
|
@ -48,7 +53,28 @@ class Sentence(TimedText):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Transcript(TimedText):
|
class Transcript(TimedText):
|
||||||
pass
|
"""
|
||||||
|
represents a concatenation of several ASRToken
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokens(
|
||||||
|
cls,
|
||||||
|
tokens: List[ASRToken],
|
||||||
|
sep: Optional[str] = None,
|
||||||
|
offset: float = 0
|
||||||
|
) -> "Transcript":
|
||||||
|
sep = sep if sep is not None else ' '
|
||||||
|
text = sep.join(token.text for token in tokens)
|
||||||
|
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||||
|
if tokens:
|
||||||
|
start = offset + tokens[0].start
|
||||||
|
end = offset + tokens[-1].end
|
||||||
|
else:
|
||||||
|
start = None
|
||||||
|
end = None
|
||||||
|
return cls(start, end, text, probability=probability)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpeakerSegment(TimedText):
|
class SpeakerSegment(TimedText):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue