Improve online ASR processor

This commit is contained in:
Quentin Fuxa 2026-01-17 09:35:00 +01:00
parent e144abbbc7
commit e1823dd99c
3 changed files with 22 additions and 26 deletions

View file

@ -44,13 +44,13 @@ class WhisperASR(ASRBase):
from whisperlivekit.whisper import load_model as load_whisper_model from whisperlivekit.whisper import load_model as load_whisper_model
if model_dir is not None: if model_dir is not None:
resolved_path = resolve_model_path(model_dir) resolved_path = resolve_model_path(model_dir)
if resolved_path.is_dir(): if resolved_path.is_dir():
model_info = detect_model_format(resolved_path) model_info = detect_model_format(resolved_path)
if not model_info.has_pytorch: if not model_info.has_pytorch:
raise FileNotFoundError( raise FileNotFoundError(
f"No supported PyTorch checkpoint found under {resolved_path}" f"No supported PyTorch checkpoint found under {resolved_path}"
) )
logger.debug(f"Loading Whisper model from custom path {resolved_path}") logger.debug(f"Loading Whisper model from custom path {resolved_path}")
return load_whisper_model(str(resolved_path), lora_path=self.lora_path) return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
@ -116,7 +116,7 @@ class FasterWhisperASR(ASRBase):
raise ValueError("Either model_size or model_dir must be set") raise ValueError("Either model_size or model_dir must be set")
device = "auto" # Allow CTranslate2 to decide available device device = "auto" # Allow CTranslate2 to decide available device
compute_type = "auto" # Allow CTranslate2 to decide faster compute type compute_type = "auto" # Allow CTranslate2 to decide faster compute type
model = WhisperModel( model = WhisperModel(
model_size_or_path, model_size_or_path,

View file

@ -28,8 +28,8 @@ class HypothesisBuffer:
def insert(self, new_tokens: List[ASRToken], offset: float): def insert(self, new_tokens: List[ASRToken], offset: float):
""" """
Insert new tokens (after applying a time offset) and compare them with the Insert new tokens (after applying a time offset) and compare them with the
already committed tokens. Only tokens that extend the committed hypothesis already committed tokens. Only tokens that extend the committed hypothesis
are added. are added.
""" """
# Apply the offset to each token. # Apply the offset to each token.
@ -98,7 +98,7 @@ class OnlineASRProcessor:
""" """
Processes incoming audio in a streaming fashion, calling the ASR system Processes incoming audio in a streaming fashion, calling the ASR system
periodically, and uses a hypothesis buffer to commit and trim recognized text. periodically, and uses a hypothesis buffer to commit and trim recognized text.
The processor supports two types of buffer trimming: The processor supports two types of buffer trimming:
- "sentence": trims at sentence boundaries (using a sentence tokenizer) - "sentence": trims at sentence boundaries (using a sentence tokenizer)
- "segment": trims at fixed segment durations. - "segment": trims at fixed segment durations.
@ -187,7 +187,7 @@ class OnlineASRProcessor:
def prompt(self) -> Tuple[str, str]: def prompt(self) -> Tuple[str, str]:
""" """
Returns a tuple: (prompt, context), where: Returns a tuple: (prompt, context), where:
- prompt is a 200-character suffix of committed text that falls - prompt is a 200-character suffix of committed text that falls
outside the current audio buffer. outside the current audio buffer.
- context is the committed text within the current audio buffer. - context is the committed text within the current audio buffer.
""" """
@ -213,7 +213,7 @@ class OnlineASRProcessor:
Get the unvalidated buffer in string format. Get the unvalidated buffer in string format.
""" """
return self.concatenate_tokens(self.transcript_buffer.buffer) return self.concatenate_tokens(self.transcript_buffer.buffer)
def process_iter(self) -> Tuple[List[ASRToken], float]: def process_iter(self) -> Tuple[List[ASRToken], float]:
""" """
@ -262,9 +262,6 @@ class OnlineASRProcessor:
logger.debug( logger.debug(
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds" f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
) )
if self.global_time_offset:
for token in committed_tokens:
token = token.with_offset(self.global_time_offset)
return committed_tokens, current_audio_processed_upto return committed_tokens, current_audio_processed_upto
def chunk_completed_sentence(self): def chunk_completed_sentence(self):
@ -273,19 +270,19 @@ class OnlineASRProcessor:
buffer at the end time of the penultimate sentence. buffer at the end time of the penultimate sentence.
Also ensures chunking happens if audio buffer exceeds a time limit. Also ensures chunking happens if audio buffer exceeds a time limit.
""" """
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
if not self.committed: if not self.committed:
if buffer_duration > self.buffer_trimming_sec: if buffer_duration > self.buffer_trimming_sec:
chunk_time = self.buffer_time_offset + (buffer_duration / 2) chunk_time = self.buffer_time_offset + (buffer_duration / 2)
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}") logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
self.chunk_at(chunk_time) self.chunk_at(chunk_time)
return return
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed)) logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
sentences = self.words_to_sentences(self.committed) sentences = self.words_to_sentences(self.committed)
for sentence in sentences: for sentence in sentences:
logger.debug(f"\tSentence: {sentence.text}") logger.debug(f"\tSentence: {sentence.text}")
chunk_done = False chunk_done = False
if len(sentences) >= 2: if len(sentences) >= 2:
while len(sentences) > 2: while len(sentences) > 2:
@ -294,7 +291,7 @@ class OnlineASRProcessor:
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}") logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
self.chunk_at(chunk_time) self.chunk_at(chunk_time)
chunk_done = True chunk_done = True
if not chunk_done and buffer_duration > self.buffer_trimming_sec: if not chunk_done and buffer_duration > self.buffer_trimming_sec:
last_committed_time = self.committed[-1].end last_committed_time = self.committed[-1].end
logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}") logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}")
@ -305,17 +302,17 @@ class OnlineASRProcessor:
Chunk the audio buffer based on segment-end timestamps reported by the ASR. Chunk the audio buffer based on segment-end timestamps reported by the ASR.
Also ensures chunking happens if audio buffer exceeds a time limit. Also ensures chunking happens if audio buffer exceeds a time limit.
""" """
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
if not self.committed: if not self.committed:
if buffer_duration > self.buffer_trimming_sec: if buffer_duration > self.buffer_trimming_sec:
chunk_time = self.buffer_time_offset + (buffer_duration / 2) chunk_time = self.buffer_time_offset + (buffer_duration / 2)
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}") logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
self.chunk_at(chunk_time) self.chunk_at(chunk_time)
return return
logger.debug("Processing committed tokens for segmenting") logger.debug("Processing committed tokens for segmenting")
ends = self.asr.segments_end_ts(res) ends = self.asr.segments_end_ts(res)
last_committed_time = self.committed[-1].end last_committed_time = self.committed[-1].end
chunk_done = False chunk_done = False
if len(ends) > 1: if len(ends) > 1:
logger.debug("Multiple segments available for chunking") logger.debug("Multiple segments available for chunking")
@ -331,13 +328,13 @@ class OnlineASRProcessor:
logger.debug("--- Last segment not within committed area") logger.debug("--- Last segment not within committed area")
else: else:
logger.debug("--- Not enough segments to chunk") logger.debug("--- Not enough segments to chunk")
if not chunk_done and buffer_duration > self.buffer_trimming_sec: if not chunk_done and buffer_duration > self.buffer_trimming_sec:
logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}") logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}")
self.chunk_at(last_committed_time) self.chunk_at(last_committed_time)
logger.debug("Segment chunking complete") logger.debug("Segment chunking complete")
def chunk_at(self, time: float): def chunk_at(self, time: float):
""" """
Trim both the hypothesis and audio buffer at the given time. Trim both the hypothesis and audio buffer at the given time.
@ -367,7 +364,7 @@ class OnlineASRProcessor:
if self.tokenize: if self.tokenize:
try: try:
sentence_texts = self.tokenize(full_text) sentence_texts = self.tokenize(full_text)
except Exception as e: except Exception:
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input. # Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
try: try:
sentence_texts = self.tokenize([full_text]) sentence_texts = self.tokenize([full_text])
@ -398,7 +395,7 @@ class OnlineASRProcessor:
) )
sentences.append(sentence) sentences.append(sentence)
return sentences return sentences
def finish(self) -> Tuple[List[ASRToken], float]: def finish(self) -> Tuple[List[ASRToken], float]:
""" """
Flush the remaining transcript when processing ends. Flush the remaining transcript when processing ends.

View file

@ -3,8 +3,7 @@ import logging
import platform import platform
import time import time
from whisperlivekit.backend_support import (faster_backend_available, from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.warmup import warmup_asr from whisperlivekit.warmup import warmup_asr
@ -39,7 +38,7 @@ def create_tokenizer(lan):
lan lan
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split() in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
): ):
from mosestokenizer import MosesSentenceSplitter from mosestokenizer import MosesSentenceSplitter
return MosesSentenceSplitter(lan) return MosesSentenceSplitter(lan)