clean simulwhisper backend and online
This commit is contained in:
parent
197293e25e
commit
b05297a96d
2 changed files with 5 additions and 97 deletions
|
|
@ -3,16 +3,15 @@ import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
import logging
|
import logging
|
||||||
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"""SimulStreaming dependencies are not available.
|
"""SimulStreaming dependencies are not available.
|
||||||
|
|
@ -28,23 +27,14 @@ class SimulStreamingOnlineProcessor:
|
||||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||||
confidence_validation = False,
|
confidence_validation = False,
|
||||||
logfile=sys.stderr,
|
logfile=sys.stderr,
|
||||||
):
|
):
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise ImportError("SimulStreaming dependencies are not available.")
|
|
||||||
|
|
||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.tokenize = tokenize_method
|
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.confidence_validation = confidence_validation
|
self.confidence_validation = confidence_validation
|
||||||
self.init()
|
|
||||||
|
|
||||||
# buffer does not work yet
|
# buffer does not work yet
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||||
|
|
||||||
def init(self, offset: Optional[float] = None):
|
|
||||||
"""Initialize or reset the processing state."""
|
|
||||||
self.audio_chunks = []
|
self.audio_chunks = []
|
||||||
self.offset = offset if offset is not None else 0.0
|
self.offset = 0.0
|
||||||
self.is_last = False
|
self.is_last = False
|
||||||
self.beg = self.offset
|
self.beg = self.offset
|
||||||
self.end = self.offset
|
self.end = self.offset
|
||||||
|
|
@ -56,14 +46,8 @@ class SimulStreamingOnlineProcessor:
|
||||||
self.buffer_content = ""
|
self.buffer_content = ""
|
||||||
self.processed_audio_duration = 0.0
|
self.processed_audio_duration = 0.0
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
|
||||||
"""Returns the absolute end time of the current audio buffer."""
|
|
||||||
return self.end
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
if torch is None:
|
|
||||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
|
||||||
|
|
||||||
# Convert numpy array to torch tensor
|
# Convert numpy array to torch tensor
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
|
|
@ -79,13 +63,6 @@ class SimulStreamingOnlineProcessor:
|
||||||
else:
|
else:
|
||||||
self.end = self.offset + self.cumulative_audio_duration
|
self.end = self.offset + self.cumulative_audio_duration
|
||||||
|
|
||||||
def prompt(self) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Returns a tuple: (prompt, context).
|
|
||||||
SimulStreaming handles prompting internally, so we return empty strings.
|
|
||||||
"""
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
def get_buffer(self):
|
def get_buffer(self):
|
||||||
"""
|
"""
|
||||||
Get the unvalidated buffer content.
|
Get the unvalidated buffer content.
|
||||||
|
|
@ -150,7 +127,6 @@ class SimulStreamingOnlineProcessor:
|
||||||
self.asr.model.insert_audio(audio)
|
self.asr.model.insert_audio(audio)
|
||||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
||||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||||
text = self.asr.model.tokenizer.decode(tokens)
|
|
||||||
|
|
||||||
new_tokens = []
|
new_tokens = []
|
||||||
for ts_word in ts_words:
|
for ts_word in ts_words:
|
||||||
|
|
@ -172,55 +148,6 @@ class SimulStreamingOnlineProcessor:
|
||||||
logger.exception(f"SimulStreaming processing error: {e}")
|
logger.exception(f"SimulStreaming processing error: {e}")
|
||||||
return [], self.end
|
return [], self.end
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
logger.debug("SimulStreaming finish() called")
|
|
||||||
self.is_last = True
|
|
||||||
final_tokens, final_time = self.process_iter()
|
|
||||||
self.is_last = False
|
|
||||||
return final_tokens, final_time
|
|
||||||
|
|
||||||
def concatenate_tokens(
|
|
||||||
self,
|
|
||||||
tokens: List[ASRToken],
|
|
||||||
sep: Optional[str] = None,
|
|
||||||
offset: float = 0
|
|
||||||
) -> Transcript:
|
|
||||||
"""Concatenate tokens into a Transcript object."""
|
|
||||||
sep = sep if sep is not None else self.asr.sep
|
|
||||||
text = sep.join(token.text for token in tokens)
|
|
||||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
|
||||||
if tokens:
|
|
||||||
start = offset + tokens[0].start
|
|
||||||
end = offset + tokens[-1].end
|
|
||||||
else:
|
|
||||||
start = None
|
|
||||||
end = None
|
|
||||||
return Transcript(start, end, text, probability=probability)
|
|
||||||
|
|
||||||
def chunk_at(self, time: float):
|
|
||||||
"""
|
|
||||||
useless but kept for compatibility
|
|
||||||
"""
|
|
||||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
|
||||||
"""
|
|
||||||
Create simple sentences.
|
|
||||||
"""
|
|
||||||
if not tokens:
|
|
||||||
return []
|
|
||||||
|
|
||||||
full_text = " ".join(token.text for token in tokens)
|
|
||||||
sentence = Sentence(
|
|
||||||
start=tokens[0].start,
|
|
||||||
end=tokens[-1].end,
|
|
||||||
text=full_text
|
|
||||||
)
|
|
||||||
return [sentence]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingASR():
|
class SimulStreamingASR():
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
sep = ""
|
sep = ""
|
||||||
|
|
@ -247,7 +174,7 @@ class SimulStreamingASR():
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
self.model_path = model_dir
|
self.model_path = model_dir
|
||||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
elif modelsize is not None:
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
'tiny': './tiny.pt',
|
'tiny': './tiny.pt',
|
||||||
'base': './base.pt',
|
'base': './base.pt',
|
||||||
|
|
@ -297,13 +224,6 @@ class SimulStreamingASR():
|
||||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
logger.error(f"Failed to load SimulStreaming model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def segments_end_ts(self, result) -> List[float]:
|
|
||||||
"""Get segment end timestamps."""
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
num_tokens = len(result)
|
|
||||||
return [num_tokens * 0.1] # rough estimate
|
|
||||||
return [1.0]
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
def set_translate_task(self):
|
||||||
"""Set up translation task."""
|
"""Set up translation task."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# simulStreaming imports - we check if the files are here
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
OnlineProcessorInterface = None
|
|
||||||
torch = None
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisBuffer:
|
class HypothesisBuffer:
|
||||||
"""
|
"""
|
||||||
Buffer to store and process ASR hypothesis tokens.
|
Buffer to store and process ASR hypothesis tokens.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue