from time import time from typing import Any, List, Optional, Tuple, Union from whisperlivekit.timed_objects import ( ASRToken, PuncSegment, Segment, Silence, SilentSegment, SpeakerSegment, TimedText, ) _DEFAULT_RETENTION_SECONDS: float = 300.0 class TokensAlignment: def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None: self.state = state self.diarization = args.diarization self.all_tokens: List[ASRToken] = [] self.all_diarization_segments: List[SpeakerSegment] = [] self.all_translation_segments: List[Any] = [] self.new_tokens: List[ASRToken] = [] self.new_diarization: List[SpeakerSegment] = [] self.new_translation: List[Any] = [] self.new_translation_buffer: Union[TimedText, str] = TimedText() self.new_tokens_buffer: List[Any] = [] self.sep: str = sep if sep is not None else ' ' self.beg_loop: Optional[float] = None self.validated_segments: List[Segment] = [] self.current_line_tokens: List[ASRToken] = [] self.diarization_buffer: List[ASRToken] = [] self.last_punctuation = None self.last_uncompleted_punc_segment: PuncSegment = None self.unvalidated_tokens: PuncSegment = [] self._retention_seconds: float = _DEFAULT_RETENTION_SECONDS def update(self) -> None: """Drain state buffers into the running alignment context.""" self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] self.new_diarization, self.state.new_diarization = self.state.new_diarization, [] self.new_translation, self.state.new_translation = self.state.new_translation, [] self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, [] self.all_tokens.extend(self.new_tokens) self.all_diarization_segments.extend(self.new_diarization) self.all_translation_segments.extend(self.new_translation) self.new_translation_buffer = self.state.new_translation_buffer def _prune(self) -> None: """Drop tokens/segments older than ``_retention_seconds`` from the latest token.""" if not self.all_tokens: return latest = self.all_tokens[-1].end cutoff = latest - self._retention_seconds if cutoff <= 0: return def _find_cutoff(items: list) -> int: """Return the index of the first item whose end >= cutoff.""" for i, item in enumerate(items): if item.end >= cutoff: return i return len(items) idx = _find_cutoff(self.all_tokens) if idx: self.all_tokens = self.all_tokens[idx:] idx = _find_cutoff(self.all_diarization_segments) if idx: self.all_diarization_segments = self.all_diarization_segments[idx:] idx = _find_cutoff(self.all_translation_segments) if idx: self.all_translation_segments = self.all_translation_segments[idx:] idx = _find_cutoff(self.validated_segments) if idx: self.validated_segments = self.validated_segments[idx:] def add_translation(self, segment: Segment) -> None: """Append translated text segments that overlap with a segment.""" if segment.translation is None: segment.translation = '' for ts in self.all_translation_segments: if ts.is_within(segment): if ts.text: segment.translation += ts.text + self.sep elif segment.translation: break def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]: """Group tokens into segments split by punctuation and explicit silence.""" segments = [] segment_start_idx = 0 for i, token in enumerate(self.all_tokens): if token.is_silence(): previous_segment = PuncSegment.from_tokens( tokens=self.all_tokens[segment_start_idx: i], ) if previous_segment: segments.append(previous_segment) segment = PuncSegment.from_tokens( tokens=[token], is_silence=True ) segments.append(segment) segment_start_idx = i+1 else: if token.has_punctuation(): segment = PuncSegment.from_tokens( tokens=self.all_tokens[segment_start_idx: i+1], ) segments.append(segment) segment_start_idx = i+1 final_segment = PuncSegment.from_tokens( tokens=self.all_tokens[segment_start_idx:], ) if final_segment: segments.append(final_segment) return segments def compute_new_punctuations_segments(self) -> List[PuncSegment]: new_punc_segments = [] segment_start_idx = 0 self.unvalidated_tokens += self.new_tokens for i, token in enumerate(self.unvalidated_tokens): if token.is_silence(): previous_segment = PuncSegment.from_tokens( tokens=self.unvalidated_tokens[segment_start_idx: i], ) if previous_segment: new_punc_segments.append(previous_segment) segment = PuncSegment.from_tokens( tokens=[token], is_silence=True ) new_punc_segments.append(segment) segment_start_idx = i+1 else: if token.has_punctuation(): segment = PuncSegment.from_tokens( tokens=self.unvalidated_tokens[segment_start_idx: i+1], ) new_punc_segments.append(segment) segment_start_idx = i+1 self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:] return new_punc_segments def concatenate_diar_segments(self) -> List[SpeakerSegment]: """Merge consecutive diarization slices that share the same speaker.""" if not self.all_diarization_segments: return [] merged = [self.all_diarization_segments[0]] for segment in self.all_diarization_segments[1:]: if segment.speaker == merged[-1].speaker: merged[-1].end = segment.end else: merged.append(segment) return merged @staticmethod def intersection_duration(seg1: TimedText, seg2: TimedText) -> float: """Return the overlap duration between two timed segments.""" start = max(seg1.start, seg2.start) end = min(seg1.end, seg2.end) return max(0, end - start) def get_lines_diarization(self) -> Tuple[List[Segment], str]: """Build segments when diarization is enabled and track overflow buffer.""" diarization_buffer = '' punctuation_segments = self.compute_punctuations_segments() diarization_segments = self.concatenate_diar_segments() for punctuation_segment in punctuation_segments: if not punctuation_segment.is_silence(): if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end: diarization_buffer += punctuation_segment.text else: max_overlap = 0.0 max_overlap_speaker = 1 for diarization_segment in diarization_segments: intersec = self.intersection_duration(punctuation_segment, diarization_segment) if intersec > max_overlap: max_overlap = intersec max_overlap_speaker = diarization_segment.speaker + 1 punctuation_segment.speaker = max_overlap_speaker segments = [] if punctuation_segments: segments = [punctuation_segments[0]] for segment in punctuation_segments[1:]: if segment.speaker == segments[-1].speaker: if segments[-1].text: segments[-1].text += segment.text segments[-1].end = segment.end else: segments.append(segment) return segments, diarization_buffer def get_lines( self, diarization: bool = False, translation: bool = False, current_silence: Optional[Silence] = None, audio_time: Optional[float] = None, ) -> Tuple[List[Segment], str, Union[str, TimedText]]: """Return the formatted segments plus buffers, optionally with diarization/translation. Args: audio_time: Current audio stream position in seconds. Used as fallback for ongoing silence end time instead of wall-clock (which breaks when audio is fed faster or slower than real-time). """ # Fallback for ongoing silence: prefer audio stream time over wall-clock _silence_now = audio_time if audio_time is not None else (time() - self.beg_loop) if diarization: segments, diarization_buffer = self.get_lines_diarization() else: diarization_buffer = '' for token in self.new_tokens: if isinstance(token, Silence): if self.current_line_tokens: self.validated_segments.append(Segment.from_tokens(self.current_line_tokens)) self.current_line_tokens = [] end_silence = token.end if token.has_ended else _silence_now if self.validated_segments and self.validated_segments[-1].is_silence(): self.validated_segments[-1].end = end_silence else: self.validated_segments.append(SilentSegment( start=token.start, end=end_silence )) else: self.current_line_tokens.append(token) segments = list(self.validated_segments) if self.current_line_tokens: segments.append(Segment.from_tokens(self.current_line_tokens)) if current_silence: end_silence = current_silence.end if current_silence.has_ended else _silence_now if segments and segments[-1].is_silence(): segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence) else: segments.append(SilentSegment( start=current_silence.start, end=end_silence )) if translation: [self.add_translation(segment) for segment in segments if not segment.is_silence()] self._prune() return segments, diarization_buffer, self.new_translation_buffer.text