WhisperLiveKit/whisperlivekit/simul_whisper/backend.py

217 lines
8.2 KiB
Python

import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
import os
logger = logging.getLogger(__name__)
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
):
self.asr = asr
self.logfile = logfile
self.is_last = False
self.beg = 0.0
self.end = 0.0
self.cumulative_audio_duration = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = PaddedAlignAttWhisper(
cfg=asr.cfg,
loaded_model=asr.whisper_model)
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.end = audio_stream_end_time
else:
self.end = self.cumulative_audio_duration
self.model.insert_audio(audio_tensor)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
# From the simulstreaming repo. self.model to self.asr.model
pr = generation["progress"]
if "result" not in generation:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
else:
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
if t != stt:
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
if b is None:
b = f
e = f
out = (b*0.02, e*0.02, sw)
ret.append(out)
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
return ret
def process_iter(self) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
return new_tokens, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
def load_model(self, modelsize):
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.whisper_model = load_model(name=model_name, download_root=model_path)
def set_translate_task(self):
"""Set up translation task."""
return tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
task="translate"
)
# def warmup(self, audio, init_prompt=""):
# """Warmup the SimulStreaming model."""
# try:
# if isinstance(audio, np.ndarray):
# audio = torch.from_numpy(audio).float()
# self.model.insert_audio(audio)
# self.model.infer(True)
# self.model.refresh_segment(complete=True)
# logger.info("SimulStreaming model warmed up successfully")
# except Exception as e:
# logger.exception(f"SimulStreaming warmup failed: {e}")