From e144abbbc7030bc26b9177739eae91869378e72f Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 11 Jan 2026 16:08:00 +0100 Subject: [PATCH] Refactor timed objects and data structures --- whisperlivekit/metrics.py | 2 +- whisperlivekit/metrics_collector.py | 1 - whisperlivekit/thread_safety.py | 2 +- whisperlivekit/timed_objects.py | 40 +++++++++++++++++------------ 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/whisperlivekit/metrics.py b/whisperlivekit/metrics.py index 8bbd9af..804be44 100644 --- a/whisperlivekit/metrics.py +++ b/whisperlivekit/metrics.py @@ -6,7 +6,7 @@ text normalization, and word-level timestamp accuracy metrics with greedy alignm import re import unicodedata -from typing import Dict, List, Optional +from typing import Dict, List def normalize_text(text: str) -> str: diff --git a/whisperlivekit/metrics_collector.py b/whisperlivekit/metrics_collector.py index 03db5dc..1c48447 100644 --- a/whisperlivekit/metrics_collector.py +++ b/whisperlivekit/metrics_collector.py @@ -78,7 +78,6 @@ class SessionMetrics: def log_summary(self) -> None: """Emit a structured log line summarising the session.""" - self.total_processing_time_s = sum(self.transcription_durations) d = self.to_dict() d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0 logger.info(f"SESSION_METRICS {d}") diff --git a/whisperlivekit/thread_safety.py b/whisperlivekit/thread_safety.py index 18e8303..aefa89e 100644 --- a/whisperlivekit/thread_safety.py +++ b/whisperlivekit/thread_safety.py @@ -20,8 +20,8 @@ Usage: export WHISPERLIVEKIT_LOCK_TIMEOUT=60 """ -import os import logging +import os import threading logger = logging.getLogger(__name__) diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 80b0276..3e40f95 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,12 +1,18 @@ from dataclasses import dataclass, field -from datetime import timedelta from typing import Any, Dict, List, Optional, Union PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} def format_time(seconds: float) -> str: - """Format seconds as HH:MM:SS.""" - return str(timedelta(seconds=int(seconds))) + """Format seconds as H:MM:SS.cc (centisecond precision).""" + total_cs = int(round(seconds * 100)) + cs = total_cs % 100 + total_s = total_cs // 100 + s = total_s % 60 + total_m = total_s // 60 + m = total_m % 60 + h = total_m // 60 + return f"{h}:{m:02d}:{s:02d}.{cs:02d}" @dataclass class Timed: @@ -18,10 +24,10 @@ class TimedText(Timed): text: Optional[str] = '' speaker: Optional[int] = -1 detected_language: Optional[str] = None - + def has_punctuation(self) -> bool: return any(char in PUNCTUATION_MARKS for char in self.text.strip()) - + def is_within(self, other: 'TimedText') -> bool: return other.contains_timespan(self) @@ -30,10 +36,10 @@ class TimedText(Timed): def contains_timespan(self, other: 'TimedText') -> bool: return self.start <= other.start and self.end >= other.end - + def __bool__(self) -> bool: return bool(self.text) - + def __str__(self) -> str: return str(self.text) @@ -103,7 +109,7 @@ class Silence(): return None self.duration = self.end - self.start return self.duration - + def is_silence(self) -> bool: return True @@ -127,9 +133,9 @@ class Segment(TimedText): """Return a normalized segment representing the provided tokens.""" if not tokens: return None - + start_token = tokens[0] - end_token = tokens[-1] + end_token = tokens[-1] if is_silence: return cls( start=start_token.start, @@ -176,7 +182,7 @@ class SilentSegment(Segment): self.text = '' -@dataclass +@dataclass class FrontData(): status: str = '' error: str = '' @@ -186,7 +192,7 @@ class FrontData(): buffer_translation: str = '' remaining_time_transcription: float = 0. remaining_time_diarization: float = 0. - + def to_dict(self) -> Dict[str, Any]: """Serialize the front-end data payload.""" _dict: Dict[str, Any] = { @@ -202,15 +208,15 @@ class FrontData(): _dict['error'] = self.error return _dict -@dataclass +@dataclass class ChangeSpeaker: speaker: int start: int -@dataclass +@dataclass class State(): """Unified state class for audio processing. - + Contains both persistent state (tokens, buffers) and temporary update buffers (new_* fields) that are consumed by TokensAlignment. """ @@ -221,10 +227,10 @@ class State(): end_attributed_speaker: float = 0.0 remaining_time_transcription: float = 0.0 remaining_time_diarization: float = 0.0 - + # Temporary update buffers (consumed by TokensAlignment.update()) new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list) new_translation: List[Any] = field(default_factory=list) new_diarization: List[Any] = field(default_factory=list) new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement - new_translation_buffer= TimedText() \ No newline at end of file + new_translation_buffer: TimedText = field(default_factory=TimedText)