Improve online ASR processor
This commit is contained in:
parent
e144abbbc7
commit
e1823dd99c
3 changed files with 22 additions and 26 deletions
|
|
@ -44,13 +44,13 @@ class WhisperASR(ASRBase):
|
||||||
from whisperlivekit.whisper import load_model as load_whisper_model
|
from whisperlivekit.whisper import load_model as load_whisper_model
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
resolved_path = resolve_model_path(model_dir)
|
resolved_path = resolve_model_path(model_dir)
|
||||||
if resolved_path.is_dir():
|
if resolved_path.is_dir():
|
||||||
model_info = detect_model_format(resolved_path)
|
model_info = detect_model_format(resolved_path)
|
||||||
if not model_info.has_pytorch:
|
if not model_info.has_pytorch:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||||
)
|
)
|
||||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||||
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
|
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
|
||||||
|
|
||||||
|
|
@ -116,7 +116,7 @@ class FasterWhisperASR(ASRBase):
|
||||||
raise ValueError("Either model_size or model_dir must be set")
|
raise ValueError("Either model_size or model_dir must be set")
|
||||||
device = "auto" # Allow CTranslate2 to decide available device
|
device = "auto" # Allow CTranslate2 to decide available device
|
||||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||||
|
|
||||||
|
|
||||||
model = WhisperModel(
|
model = WhisperModel(
|
||||||
model_size_or_path,
|
model_size_or_path,
|
||||||
|
|
|
||||||
|
|
@ -28,8 +28,8 @@ class HypothesisBuffer:
|
||||||
|
|
||||||
def insert(self, new_tokens: List[ASRToken], offset: float):
|
def insert(self, new_tokens: List[ASRToken], offset: float):
|
||||||
"""
|
"""
|
||||||
Insert new tokens (after applying a time offset) and compare them with the
|
Insert new tokens (after applying a time offset) and compare them with the
|
||||||
already committed tokens. Only tokens that extend the committed hypothesis
|
already committed tokens. Only tokens that extend the committed hypothesis
|
||||||
are added.
|
are added.
|
||||||
"""
|
"""
|
||||||
# Apply the offset to each token.
|
# Apply the offset to each token.
|
||||||
|
|
@ -98,7 +98,7 @@ class OnlineASRProcessor:
|
||||||
"""
|
"""
|
||||||
Processes incoming audio in a streaming fashion, calling the ASR system
|
Processes incoming audio in a streaming fashion, calling the ASR system
|
||||||
periodically, and uses a hypothesis buffer to commit and trim recognized text.
|
periodically, and uses a hypothesis buffer to commit and trim recognized text.
|
||||||
|
|
||||||
The processor supports two types of buffer trimming:
|
The processor supports two types of buffer trimming:
|
||||||
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
|
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
|
||||||
- "segment": trims at fixed segment durations.
|
- "segment": trims at fixed segment durations.
|
||||||
|
|
@ -187,7 +187,7 @@ class OnlineASRProcessor:
|
||||||
def prompt(self) -> Tuple[str, str]:
|
def prompt(self) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Returns a tuple: (prompt, context), where:
|
Returns a tuple: (prompt, context), where:
|
||||||
- prompt is a 200-character suffix of committed text that falls
|
- prompt is a 200-character suffix of committed text that falls
|
||||||
outside the current audio buffer.
|
outside the current audio buffer.
|
||||||
- context is the committed text within the current audio buffer.
|
- context is the committed text within the current audio buffer.
|
||||||
"""
|
"""
|
||||||
|
|
@ -213,7 +213,7 @@ class OnlineASRProcessor:
|
||||||
Get the unvalidated buffer in string format.
|
Get the unvalidated buffer in string format.
|
||||||
"""
|
"""
|
||||||
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||||
|
|
||||||
|
|
||||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -262,9 +262,6 @@ class OnlineASRProcessor:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||||
)
|
)
|
||||||
if self.global_time_offset:
|
|
||||||
for token in committed_tokens:
|
|
||||||
token = token.with_offset(self.global_time_offset)
|
|
||||||
return committed_tokens, current_audio_processed_upto
|
return committed_tokens, current_audio_processed_upto
|
||||||
|
|
||||||
def chunk_completed_sentence(self):
|
def chunk_completed_sentence(self):
|
||||||
|
|
@ -273,19 +270,19 @@ class OnlineASRProcessor:
|
||||||
buffer at the end time of the penultimate sentence.
|
buffer at the end time of the penultimate sentence.
|
||||||
Also ensures chunking happens if audio buffer exceeds a time limit.
|
Also ensures chunking happens if audio buffer exceeds a time limit.
|
||||||
"""
|
"""
|
||||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||||
if not self.committed:
|
if not self.committed:
|
||||||
if buffer_duration > self.buffer_trimming_sec:
|
if buffer_duration > self.buffer_trimming_sec:
|
||||||
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
||||||
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
||||||
self.chunk_at(chunk_time)
|
self.chunk_at(chunk_time)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
||||||
sentences = self.words_to_sentences(self.committed)
|
sentences = self.words_to_sentences(self.committed)
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
logger.debug(f"\tSentence: {sentence.text}")
|
logger.debug(f"\tSentence: {sentence.text}")
|
||||||
|
|
||||||
chunk_done = False
|
chunk_done = False
|
||||||
if len(sentences) >= 2:
|
if len(sentences) >= 2:
|
||||||
while len(sentences) > 2:
|
while len(sentences) > 2:
|
||||||
|
|
@ -294,7 +291,7 @@ class OnlineASRProcessor:
|
||||||
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
||||||
self.chunk_at(chunk_time)
|
self.chunk_at(chunk_time)
|
||||||
chunk_done = True
|
chunk_done = True
|
||||||
|
|
||||||
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
||||||
last_committed_time = self.committed[-1].end
|
last_committed_time = self.committed[-1].end
|
||||||
logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}")
|
logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}")
|
||||||
|
|
@ -305,17 +302,17 @@ class OnlineASRProcessor:
|
||||||
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
||||||
Also ensures chunking happens if audio buffer exceeds a time limit.
|
Also ensures chunking happens if audio buffer exceeds a time limit.
|
||||||
"""
|
"""
|
||||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||||
if not self.committed:
|
if not self.committed:
|
||||||
if buffer_duration > self.buffer_trimming_sec:
|
if buffer_duration > self.buffer_trimming_sec:
|
||||||
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
||||||
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
||||||
self.chunk_at(chunk_time)
|
self.chunk_at(chunk_time)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("Processing committed tokens for segmenting")
|
logger.debug("Processing committed tokens for segmenting")
|
||||||
ends = self.asr.segments_end_ts(res)
|
ends = self.asr.segments_end_ts(res)
|
||||||
last_committed_time = self.committed[-1].end
|
last_committed_time = self.committed[-1].end
|
||||||
chunk_done = False
|
chunk_done = False
|
||||||
if len(ends) > 1:
|
if len(ends) > 1:
|
||||||
logger.debug("Multiple segments available for chunking")
|
logger.debug("Multiple segments available for chunking")
|
||||||
|
|
@ -331,13 +328,13 @@ class OnlineASRProcessor:
|
||||||
logger.debug("--- Last segment not within committed area")
|
logger.debug("--- Last segment not within committed area")
|
||||||
else:
|
else:
|
||||||
logger.debug("--- Not enough segments to chunk")
|
logger.debug("--- Not enough segments to chunk")
|
||||||
|
|
||||||
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
||||||
logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}")
|
logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}")
|
||||||
self.chunk_at(last_committed_time)
|
self.chunk_at(last_committed_time)
|
||||||
|
|
||||||
logger.debug("Segment chunking complete")
|
logger.debug("Segment chunking complete")
|
||||||
|
|
||||||
def chunk_at(self, time: float):
|
def chunk_at(self, time: float):
|
||||||
"""
|
"""
|
||||||
Trim both the hypothesis and audio buffer at the given time.
|
Trim both the hypothesis and audio buffer at the given time.
|
||||||
|
|
@ -367,7 +364,7 @@ class OnlineASRProcessor:
|
||||||
if self.tokenize:
|
if self.tokenize:
|
||||||
try:
|
try:
|
||||||
sentence_texts = self.tokenize(full_text)
|
sentence_texts = self.tokenize(full_text)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
|
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
|
||||||
try:
|
try:
|
||||||
sentence_texts = self.tokenize([full_text])
|
sentence_texts = self.tokenize([full_text])
|
||||||
|
|
@ -398,7 +395,7 @@ class OnlineASRProcessor:
|
||||||
)
|
)
|
||||||
sentences.append(sentence)
|
sentences.append(sentence)
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
Flush the remaining transcript when processing ends.
|
Flush the remaining transcript when processing ends.
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,7 @@ import logging
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from whisperlivekit.backend_support import (faster_backend_available,
|
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||||
mlx_backend_available)
|
|
||||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||||
from whisperlivekit.warmup import warmup_asr
|
from whisperlivekit.warmup import warmup_asr
|
||||||
|
|
||||||
|
|
@ -39,7 +38,7 @@ def create_tokenizer(lan):
|
||||||
lan
|
lan
|
||||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||||
):
|
):
|
||||||
from mosestokenizer import MosesSentenceSplitter
|
from mosestokenizer import MosesSentenceSplitter
|
||||||
|
|
||||||
return MosesSentenceSplitter(lan)
|
return MosesSentenceSplitter(lan)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue