- Add fallback chain for StorageView to numpy conversion - Prune old tokens/segments after 5min to bound memory
This commit is contained in:
parent
8bc0937c46
commit
451535d48f
2 changed files with 54 additions and 7 deletions
|
|
@ -282,10 +282,20 @@ class AlignAtt(AlignAttBase):
|
||||||
try:
|
try:
|
||||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
|
try:
|
||||||
arr = np.array(encoder_feature_ctranslate)
|
arr = np.asarray(encoder_feature_ctranslate, dtype=np.float32)
|
||||||
if arr.dtype == np.object_:
|
except (TypeError, ValueError):
|
||||||
arr = np.array(arr.tolist(), dtype=np.float32)
|
arr = np.array(encoder_feature_ctranslate)
|
||||||
|
if arr.dtype == np.object_:
|
||||||
|
try:
|
||||||
|
arr = np.stack([
|
||||||
|
np.asarray(item, dtype=np.float32) for item in arr.flat
|
||||||
|
])
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
arr = np.array(
|
||||||
|
[[float(x) for x in row] for row in arr.flat],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
encoder_feature = torch.as_tensor(arr, device=self.device)
|
encoder_feature = torch.as_tensor(arr, device=self.device)
|
||||||
else:
|
else:
|
||||||
mel_padded = log_mel_spectrogram(
|
mel_padded = log_mel_spectrogram(
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,14 @@ from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silenc
|
||||||
SilentSegment, SpeakerSegment,
|
SilentSegment, SpeakerSegment,
|
||||||
TimedText)
|
TimedText)
|
||||||
|
|
||||||
|
_DEFAULT_RETENTION_SECONDS: float = 300.0
|
||||||
|
|
||||||
|
|
||||||
class TokensAlignment:
|
class TokensAlignment:
|
||||||
|
|
||||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||||
self.state = state
|
self.state = state
|
||||||
self.diarization = args.diarization
|
self.diarization = args.diarization
|
||||||
self._tokens_index: int = 0
|
|
||||||
self._diarization_index: int = 0
|
|
||||||
self._translation_index: int = 0
|
|
||||||
|
|
||||||
self.all_tokens: List[ASRToken] = []
|
self.all_tokens: List[ASRToken] = []
|
||||||
self.all_diarization_segments: List[SpeakerSegment] = []
|
self.all_diarization_segments: List[SpeakerSegment] = []
|
||||||
|
|
@ -35,6 +34,8 @@ class TokensAlignment:
|
||||||
self.last_uncompleted_punc_segment: PuncSegment = None
|
self.last_uncompleted_punc_segment: PuncSegment = None
|
||||||
self.unvalidated_tokens: PuncSegment = []
|
self.unvalidated_tokens: PuncSegment = []
|
||||||
|
|
||||||
|
self._retention_seconds: float = _DEFAULT_RETENTION_SECONDS
|
||||||
|
|
||||||
def update(self) -> None:
|
def update(self) -> None:
|
||||||
"""Drain state buffers into the running alignment context."""
|
"""Drain state buffers into the running alignment context."""
|
||||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||||
|
|
@ -47,6 +48,39 @@ class TokensAlignment:
|
||||||
self.all_translation_segments.extend(self.new_translation)
|
self.all_translation_segments.extend(self.new_translation)
|
||||||
self.new_translation_buffer = self.state.new_translation_buffer
|
self.new_translation_buffer = self.state.new_translation_buffer
|
||||||
|
|
||||||
|
def _prune(self) -> None:
|
||||||
|
"""Drop tokens/segments older than ``_retention_seconds`` from the latest token."""
|
||||||
|
if not self.all_tokens:
|
||||||
|
return
|
||||||
|
|
||||||
|
latest = self.all_tokens[-1].end
|
||||||
|
cutoff = latest - self._retention_seconds
|
||||||
|
if cutoff <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _find_cutoff(items: list) -> int:
|
||||||
|
"""Return the index of the first item whose end >= cutoff."""
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
if item.end >= cutoff:
|
||||||
|
return i
|
||||||
|
return len(items)
|
||||||
|
|
||||||
|
idx = _find_cutoff(self.all_tokens)
|
||||||
|
if idx:
|
||||||
|
self.all_tokens = self.all_tokens[idx:]
|
||||||
|
|
||||||
|
idx = _find_cutoff(self.all_diarization_segments)
|
||||||
|
if idx:
|
||||||
|
self.all_diarization_segments = self.all_diarization_segments[idx:]
|
||||||
|
|
||||||
|
idx = _find_cutoff(self.all_translation_segments)
|
||||||
|
if idx:
|
||||||
|
self.all_translation_segments = self.all_translation_segments[idx:]
|
||||||
|
|
||||||
|
idx = _find_cutoff(self.validated_segments)
|
||||||
|
if idx:
|
||||||
|
self.validated_segments = self.validated_segments[idx:]
|
||||||
|
|
||||||
def add_translation(self, segment: Segment) -> None:
|
def add_translation(self, segment: Segment) -> None:
|
||||||
"""Append translated text segments that overlap with a segment."""
|
"""Append translated text segments that overlap with a segment."""
|
||||||
if segment.translation is None:
|
if segment.translation is None:
|
||||||
|
|
@ -217,4 +251,7 @@ class TokensAlignment:
|
||||||
))
|
))
|
||||||
if translation:
|
if translation:
|
||||||
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
||||||
|
|
||||||
|
self._prune()
|
||||||
|
|
||||||
return segments, diarization_buffer, self.new_translation_buffer.text
|
return segments, diarization_buffer, self.new_translation_buffer.text
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue