Refactor audio processor async pipeline
This commit is contained in:
parent
a282cbe75f
commit
d58365421f
1 changed files with 45 additions and 19 deletions
|
|
@ -6,14 +6,16 @@ from typing import Any, AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from whisperlivekit.core import (TranscriptionEngine,
|
from whisperlivekit.core import (
|
||||||
online_diarization_factory, online_factory,
|
TranscriptionEngine,
|
||||||
online_translation_factory)
|
online_diarization_factory,
|
||||||
from whisperlivekit.metrics_collector import SessionMetrics
|
online_factory,
|
||||||
|
online_translation_factory,
|
||||||
|
)
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
|
from whisperlivekit.metrics_collector import SessionMetrics
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
|
||||||
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
from whisperlivekit.timed_objects import ChangeSpeaker, FrontData, Silence, State
|
||||||
Segment, Silence, State, Transcript)
|
|
||||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
|
@ -57,6 +59,8 @@ class AudioProcessor:
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
"""Initialize the audio processor with configuration, models, and state."""
|
"""Initialize the audio processor with configuration, models, and state."""
|
||||||
|
# Extract per-session language override before passing to TranscriptionEngine
|
||||||
|
session_language = kwargs.pop('language', None)
|
||||||
|
|
||||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||||
models = kwargs['transcription_engine']
|
models = kwargs['transcription_engine']
|
||||||
|
|
@ -126,7 +130,7 @@ class AudioProcessor:
|
||||||
self.diarization: Optional[Any] = None
|
self.diarization: Optional[Any] = None
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.transcription = online_factory(self.args, models.asr)
|
self.transcription = online_factory(self.args, models.asr, language=session_language)
|
||||||
self.sep = self.transcription.asr.sep
|
self.sep = self.transcription.asr.sep
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||||
|
|
@ -175,7 +179,7 @@ class AudioProcessor:
|
||||||
self.metrics.n_silence_events += 1
|
self.metrics.n_silence_events += 1
|
||||||
if self.current_silence.duration is not None:
|
if self.current_silence.duration is not None:
|
||||||
self.metrics.total_silence_duration_s += self.current_silence.duration
|
self.metrics.total_silence_duration_s += self.current_silence.duration
|
||||||
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
if self.current_silence.duration and self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
||||||
self.state.new_tokens.append(self.current_silence)
|
self.state.new_tokens.append(self.current_silence)
|
||||||
# Push the completed silence as the end event (separate from the start event)
|
# Push the completed silence as the end event (separate from the start event)
|
||||||
await self._push_silence_event()
|
await self._push_silence_event()
|
||||||
|
|
@ -287,6 +291,7 @@ class AudioProcessor:
|
||||||
final_tokens = final_tokens or []
|
final_tokens = final_tokens or []
|
||||||
if final_tokens:
|
if final_tokens:
|
||||||
logger.info(f"Finish flushed {len(final_tokens)} tokens")
|
logger.info(f"Finish flushed {len(final_tokens)} tokens")
|
||||||
|
self.metrics.n_tokens_produced += len(final_tokens)
|
||||||
_buffer_transcript = self.transcription.get_buffer()
|
_buffer_transcript = self.transcription.get_buffer()
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.state.tokens.extend(final_tokens)
|
self.state.tokens.extend(final_tokens)
|
||||||
|
|
@ -307,8 +312,23 @@ class AudioProcessor:
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# item = await self.transcription_queue.get()
|
# Use a timeout so we periodically wake up and refresh the
|
||||||
item = await get_all_from_queue(self.transcription_queue)
|
# buffer state. Streaming backends (e.g. voxtral) may
|
||||||
|
# produce text tokens asynchronously; without a periodic
|
||||||
|
# drain, those tokens would sit unread until the next audio
|
||||||
|
# chunk arrives — causing the frontend to show nothing.
|
||||||
|
try:
|
||||||
|
item = await asyncio.wait_for(
|
||||||
|
get_all_from_queue(self.transcription_queue),
|
||||||
|
timeout=0.5,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# No new audio — just refresh buffer for streaming backends
|
||||||
|
_buffer_transcript = self.transcription.get_buffer()
|
||||||
|
async with self.lock:
|
||||||
|
self.state.buffer_transcription = _buffer_transcript
|
||||||
|
continue
|
||||||
|
|
||||||
if item is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||||
await self._finish_transcription()
|
await self._finish_transcription()
|
||||||
|
|
@ -326,7 +346,7 @@ class AudioProcessor:
|
||||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
|
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
|
||||||
self.transcription.start_silence
|
self.transcription.start_silence
|
||||||
)
|
)
|
||||||
asr_processing_logs += f" + Silence starting"
|
asr_processing_logs += " + Silence starting"
|
||||||
if item.has_ended:
|
if item.has_ended:
|
||||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||||
cumulative_pcm_duration_stream_time += item.duration
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
|
|
@ -404,7 +424,7 @@ class AudioProcessor:
|
||||||
item = await get_all_from_queue(self.diarization_queue)
|
item = await get_all_from_queue(self.diarization_queue)
|
||||||
if item is SENTINEL:
|
if item is SENTINEL:
|
||||||
break
|
break
|
||||||
elif type(item) is Silence:
|
elif isinstance(item, Silence):
|
||||||
if item.has_ended:
|
if item.has_ended:
|
||||||
self.diarization.insert_silence(item.duration)
|
self.diarization.insert_silence(item.duration)
|
||||||
continue
|
continue
|
||||||
|
|
@ -431,7 +451,11 @@ class AudioProcessor:
|
||||||
if item is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Translation processor received sentinel. Finishing.")
|
logger.debug("Translation processor received sentinel. Finishing.")
|
||||||
break
|
break
|
||||||
elif type(item) is Silence:
|
|
||||||
|
new_translation = None
|
||||||
|
new_translation_buffer = None
|
||||||
|
|
||||||
|
if isinstance(item, Silence):
|
||||||
if item.is_starting:
|
if item.is_starting:
|
||||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||||
if item.has_ended:
|
if item.has_ended:
|
||||||
|
|
@ -439,13 +463,14 @@ class AudioProcessor:
|
||||||
continue
|
continue
|
||||||
elif isinstance(item, ChangeSpeaker):
|
elif isinstance(item, ChangeSpeaker):
|
||||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
self.translation.insert_tokens(item)
|
self.translation.insert_tokens(item)
|
||||||
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
|
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
|
||||||
async with self.lock:
|
|
||||||
self.state.new_translation.append(new_translation)
|
if new_translation is not None:
|
||||||
self.state.new_translation_buffer = new_translation_buffer
|
async with self.lock:
|
||||||
|
self.state.new_translation.append(new_translation)
|
||||||
|
self.state.new_translation_buffer = new_translation_buffer
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in translation_processor: {e}")
|
logger.warning(f"Exception in translation_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
@ -465,7 +490,8 @@ class AudioProcessor:
|
||||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||||
diarization=self.args.diarization,
|
diarization=self.args.diarization,
|
||||||
translation=bool(self.translation),
|
translation=bool(self.translation),
|
||||||
current_silence=self.current_silence
|
current_silence=self.current_silence,
|
||||||
|
audio_time=self.total_pcm_samples / self.sample_rate if self.sample_rate else None,
|
||||||
)
|
)
|
||||||
state = await self.get_current_state()
|
state = await self.get_current_state()
|
||||||
|
|
||||||
|
|
@ -497,7 +523,7 @@ class AudioProcessor:
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue