Fix ctranslate2 encoder conversion (#345) and memory leak in TokensAlignment (#344)

- Add fallback chain for StorageView to numpy conversion
- Prune old tokens/segments after 5min to bound memory
This commit is contained in:
Quentin Fuxa 2026-03-10 22:37:00 +01:00
parent 8bc0937c46
commit 451535d48f
2 changed files with 54 additions and 7 deletions

View file

@ -282,10 +282,20 @@ class AlignAtt(AlignAttBase):
try:
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
except TypeError:
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
arr = np.array(encoder_feature_ctranslate)
if arr.dtype == np.object_:
arr = np.array(arr.tolist(), dtype=np.float32)
try:
arr = np.asarray(encoder_feature_ctranslate, dtype=np.float32)
except (TypeError, ValueError):
arr = np.array(encoder_feature_ctranslate)
if arr.dtype == np.object_:
try:
arr = np.stack([
np.asarray(item, dtype=np.float32) for item in arr.flat
])
except (TypeError, ValueError):
arr = np.array(
[[float(x) for x in row] for row in arr.flat],
dtype=np.float32,
)
encoder_feature = torch.as_tensor(arr, device=self.device)
else:
mel_padded = log_mel_spectrogram(

View file

@ -5,15 +5,14 @@ from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silenc
SilentSegment, SpeakerSegment,
TimedText)
_DEFAULT_RETENTION_SECONDS: float = 300.0
class TokensAlignment:
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
self.diarization = args.diarization
self._tokens_index: int = 0
self._diarization_index: int = 0
self._translation_index: int = 0
self.all_tokens: List[ASRToken] = []
self.all_diarization_segments: List[SpeakerSegment] = []
@ -35,6 +34,8 @@ class TokensAlignment:
self.last_uncompleted_punc_segment: PuncSegment = None
self.unvalidated_tokens: PuncSegment = []
self._retention_seconds: float = _DEFAULT_RETENTION_SECONDS
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
@ -47,6 +48,39 @@ class TokensAlignment:
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def _prune(self) -> None:
"""Drop tokens/segments older than ``_retention_seconds`` from the latest token."""
if not self.all_tokens:
return
latest = self.all_tokens[-1].end
cutoff = latest - self._retention_seconds
if cutoff <= 0:
return
def _find_cutoff(items: list) -> int:
"""Return the index of the first item whose end >= cutoff."""
for i, item in enumerate(items):
if item.end >= cutoff:
return i
return len(items)
idx = _find_cutoff(self.all_tokens)
if idx:
self.all_tokens = self.all_tokens[idx:]
idx = _find_cutoff(self.all_diarization_segments)
if idx:
self.all_diarization_segments = self.all_diarization_segments[idx:]
idx = _find_cutoff(self.all_translation_segments)
if idx:
self.all_translation_segments = self.all_translation_segments[idx:]
idx = _find_cutoff(self.validated_segments)
if idx:
self.validated_segments = self.validated_segments[idx:]
def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment."""
if segment.translation is None:
@ -217,4 +251,7 @@ class TokensAlignment:
))
if translation:
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
self._prune()
return segments, diarization_buffer, self.new_translation_buffer.text