- Add fallback chain for StorageView to numpy conversion - Prune old tokens/segments after 5min to bound memory
This commit is contained in:
parent
8bc0937c46
commit
451535d48f
2 changed files with 54 additions and 7 deletions
|
|
@ -282,10 +282,20 @@ class AlignAtt(AlignAttBase):
|
|||
try:
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError:
|
||||
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
|
||||
arr = np.array(encoder_feature_ctranslate)
|
||||
if arr.dtype == np.object_:
|
||||
arr = np.array(arr.tolist(), dtype=np.float32)
|
||||
try:
|
||||
arr = np.asarray(encoder_feature_ctranslate, dtype=np.float32)
|
||||
except (TypeError, ValueError):
|
||||
arr = np.array(encoder_feature_ctranslate)
|
||||
if arr.dtype == np.object_:
|
||||
try:
|
||||
arr = np.stack([
|
||||
np.asarray(item, dtype=np.float32) for item in arr.flat
|
||||
])
|
||||
except (TypeError, ValueError):
|
||||
arr = np.array(
|
||||
[[float(x) for x in row] for row in arr.flat],
|
||||
dtype=np.float32,
|
||||
)
|
||||
encoder_feature = torch.as_tensor(arr, device=self.device)
|
||||
else:
|
||||
mel_padded = log_mel_spectrogram(
|
||||
|
|
|
|||
|
|
@ -5,15 +5,14 @@ from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silenc
|
|||
SilentSegment, SpeakerSegment,
|
||||
TimedText)
|
||||
|
||||
_DEFAULT_RETENTION_SECONDS: float = 300.0
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||
self.state = state
|
||||
self.diarization = args.diarization
|
||||
self._tokens_index: int = 0
|
||||
self._diarization_index: int = 0
|
||||
self._translation_index: int = 0
|
||||
|
||||
self.all_tokens: List[ASRToken] = []
|
||||
self.all_diarization_segments: List[SpeakerSegment] = []
|
||||
|
|
@ -35,6 +34,8 @@ class TokensAlignment:
|
|||
self.last_uncompleted_punc_segment: PuncSegment = None
|
||||
self.unvalidated_tokens: PuncSegment = []
|
||||
|
||||
self._retention_seconds: float = _DEFAULT_RETENTION_SECONDS
|
||||
|
||||
def update(self) -> None:
|
||||
"""Drain state buffers into the running alignment context."""
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
|
|
@ -47,6 +48,39 @@ class TokensAlignment:
|
|||
self.all_translation_segments.extend(self.new_translation)
|
||||
self.new_translation_buffer = self.state.new_translation_buffer
|
||||
|
||||
def _prune(self) -> None:
|
||||
"""Drop tokens/segments older than ``_retention_seconds`` from the latest token."""
|
||||
if not self.all_tokens:
|
||||
return
|
||||
|
||||
latest = self.all_tokens[-1].end
|
||||
cutoff = latest - self._retention_seconds
|
||||
if cutoff <= 0:
|
||||
return
|
||||
|
||||
def _find_cutoff(items: list) -> int:
|
||||
"""Return the index of the first item whose end >= cutoff."""
|
||||
for i, item in enumerate(items):
|
||||
if item.end >= cutoff:
|
||||
return i
|
||||
return len(items)
|
||||
|
||||
idx = _find_cutoff(self.all_tokens)
|
||||
if idx:
|
||||
self.all_tokens = self.all_tokens[idx:]
|
||||
|
||||
idx = _find_cutoff(self.all_diarization_segments)
|
||||
if idx:
|
||||
self.all_diarization_segments = self.all_diarization_segments[idx:]
|
||||
|
||||
idx = _find_cutoff(self.all_translation_segments)
|
||||
if idx:
|
||||
self.all_translation_segments = self.all_translation_segments[idx:]
|
||||
|
||||
idx = _find_cutoff(self.validated_segments)
|
||||
if idx:
|
||||
self.validated_segments = self.validated_segments[idx:]
|
||||
|
||||
def add_translation(self, segment: Segment) -> None:
|
||||
"""Append translated text segments that overlap with a segment."""
|
||||
if segment.translation is None:
|
||||
|
|
@ -217,4 +251,7 @@ class TokensAlignment:
|
|||
))
|
||||
if translation:
|
||||
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
||||
|
||||
self._prune()
|
||||
|
||||
return segments, diarization_buffer, self.new_translation_buffer.text
|
||||
|
|
|
|||
Loading…
Reference in a new issue