diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index f59c800..c5d7923 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -282,10 +282,20 @@ class AlignAtt(AlignAttBase): try: encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device) except TypeError: - # Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32 - arr = np.array(encoder_feature_ctranslate) - if arr.dtype == np.object_: - arr = np.array(arr.tolist(), dtype=np.float32) + try: + arr = np.asarray(encoder_feature_ctranslate, dtype=np.float32) + except (TypeError, ValueError): + arr = np.array(encoder_feature_ctranslate) + if arr.dtype == np.object_: + try: + arr = np.stack([ + np.asarray(item, dtype=np.float32) for item in arr.flat + ]) + except (TypeError, ValueError): + arr = np.array( + [[float(x) for x in row] for row in arr.flat], + dtype=np.float32, + ) encoder_feature = torch.as_tensor(arr, device=self.device) else: mel_padded = log_mel_spectrogram( diff --git a/whisperlivekit/tokens_alignment.py b/whisperlivekit/tokens_alignment.py index 7760ac6..74d451e 100644 --- a/whisperlivekit/tokens_alignment.py +++ b/whisperlivekit/tokens_alignment.py @@ -5,15 +5,14 @@ from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silenc SilentSegment, SpeakerSegment, TimedText) +_DEFAULT_RETENTION_SECONDS: float = 300.0 + class TokensAlignment: def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None: self.state = state self.diarization = args.diarization - self._tokens_index: int = 0 - self._diarization_index: int = 0 - self._translation_index: int = 0 self.all_tokens: List[ASRToken] = [] self.all_diarization_segments: List[SpeakerSegment] = [] @@ -35,6 +34,8 @@ class TokensAlignment: self.last_uncompleted_punc_segment: PuncSegment = None self.unvalidated_tokens: PuncSegment = [] + self._retention_seconds: float = _DEFAULT_RETENTION_SECONDS + def update(self) -> None: """Drain state buffers into the running alignment context.""" self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] @@ -47,6 +48,39 @@ class TokensAlignment: self.all_translation_segments.extend(self.new_translation) self.new_translation_buffer = self.state.new_translation_buffer + def _prune(self) -> None: + """Drop tokens/segments older than ``_retention_seconds`` from the latest token.""" + if not self.all_tokens: + return + + latest = self.all_tokens[-1].end + cutoff = latest - self._retention_seconds + if cutoff <= 0: + return + + def _find_cutoff(items: list) -> int: + """Return the index of the first item whose end >= cutoff.""" + for i, item in enumerate(items): + if item.end >= cutoff: + return i + return len(items) + + idx = _find_cutoff(self.all_tokens) + if idx: + self.all_tokens = self.all_tokens[idx:] + + idx = _find_cutoff(self.all_diarization_segments) + if idx: + self.all_diarization_segments = self.all_diarization_segments[idx:] + + idx = _find_cutoff(self.all_translation_segments) + if idx: + self.all_translation_segments = self.all_translation_segments[idx:] + + idx = _find_cutoff(self.validated_segments) + if idx: + self.validated_segments = self.validated_segments[idx:] + def add_translation(self, segment: Segment) -> None: """Append translated text segments that overlap with a segment.""" if segment.translation is None: @@ -217,4 +251,7 @@ class TokensAlignment: )) if translation: [self.add_translation(segment) for segment in segments if not segment.is_silence()] + + self._prune() + return segments, diarization_buffer, self.new_translation_buffer.text