Refactor timed objects and data structures

This commit is contained in:
Quentin Fuxa 2026-01-11 16:08:00 +01:00
parent 83362c89c4
commit e144abbbc7
4 changed files with 25 additions and 20 deletions

View file

@ -6,7 +6,7 @@ text normalization, and word-level timestamp accuracy metrics with greedy alignm
import re import re
import unicodedata import unicodedata
from typing import Dict, List, Optional from typing import Dict, List
def normalize_text(text: str) -> str: def normalize_text(text: str) -> str:

View file

@ -78,7 +78,6 @@ class SessionMetrics:
def log_summary(self) -> None: def log_summary(self) -> None:
"""Emit a structured log line summarising the session.""" """Emit a structured log line summarising the session."""
self.total_processing_time_s = sum(self.transcription_durations)
d = self.to_dict() d = self.to_dict()
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0 d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
logger.info(f"SESSION_METRICS {d}") logger.info(f"SESSION_METRICS {d}")

View file

@ -20,8 +20,8 @@ Usage:
export WHISPERLIVEKIT_LOCK_TIMEOUT=60 export WHISPERLIVEKIT_LOCK_TIMEOUT=60
""" """
import os
import logging import logging
import os
import threading import threading
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,12 +1,18 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''} PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
def format_time(seconds: float) -> str: def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS.""" """Format seconds as H:MM:SS.cc (centisecond precision)."""
return str(timedelta(seconds=int(seconds))) total_cs = int(round(seconds * 100))
cs = total_cs % 100
total_s = total_cs // 100
s = total_s % 60
total_m = total_s // 60
m = total_m % 60
h = total_m // 60
return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
@dataclass @dataclass
class Timed: class Timed:
@ -18,10 +24,10 @@ class TimedText(Timed):
text: Optional[str] = '' text: Optional[str] = ''
speaker: Optional[int] = -1 speaker: Optional[int] = -1
detected_language: Optional[str] = None detected_language: Optional[str] = None
def has_punctuation(self) -> bool: def has_punctuation(self) -> bool:
return any(char in PUNCTUATION_MARKS for char in self.text.strip()) return any(char in PUNCTUATION_MARKS for char in self.text.strip())
def is_within(self, other: 'TimedText') -> bool: def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self) return other.contains_timespan(self)
@ -30,10 +36,10 @@ class TimedText(Timed):
def contains_timespan(self, other: 'TimedText') -> bool: def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end return self.start <= other.start and self.end >= other.end
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self.text) return bool(self.text)
def __str__(self) -> str: def __str__(self) -> str:
return str(self.text) return str(self.text)
@ -103,7 +109,7 @@ class Silence():
return None return None
self.duration = self.end - self.start self.duration = self.end - self.start
return self.duration return self.duration
def is_silence(self) -> bool: def is_silence(self) -> bool:
return True return True
@ -127,9 +133,9 @@ class Segment(TimedText):
"""Return a normalized segment representing the provided tokens.""" """Return a normalized segment representing the provided tokens."""
if not tokens: if not tokens:
return None return None
start_token = tokens[0] start_token = tokens[0]
end_token = tokens[-1] end_token = tokens[-1]
if is_silence: if is_silence:
return cls( return cls(
start=start_token.start, start=start_token.start,
@ -176,7 +182,7 @@ class SilentSegment(Segment):
self.text = '' self.text = ''
@dataclass @dataclass
class FrontData(): class FrontData():
status: str = '' status: str = ''
error: str = '' error: str = ''
@ -186,7 +192,7 @@ class FrontData():
buffer_translation: str = '' buffer_translation: str = ''
remaining_time_transcription: float = 0. remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0. remaining_time_diarization: float = 0.
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Serialize the front-end data payload.""" """Serialize the front-end data payload."""
_dict: Dict[str, Any] = { _dict: Dict[str, Any] = {
@ -202,15 +208,15 @@ class FrontData():
_dict['error'] = self.error _dict['error'] = self.error
return _dict return _dict
@dataclass @dataclass
class ChangeSpeaker: class ChangeSpeaker:
speaker: int speaker: int
start: int start: int
@dataclass @dataclass
class State(): class State():
"""Unified state class for audio processing. """Unified state class for audio processing.
Contains both persistent state (tokens, buffers) and temporary update buffers Contains both persistent state (tokens, buffers) and temporary update buffers
(new_* fields) that are consumed by TokensAlignment. (new_* fields) that are consumed by TokensAlignment.
""" """
@ -221,10 +227,10 @@ class State():
end_attributed_speaker: float = 0.0 end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0 remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0 remaining_time_diarization: float = 0.0
# Temporary update buffers (consumed by TokensAlignment.update()) # Temporary update buffers (consumed by TokensAlignment.update())
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list) new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
new_translation: List[Any] = field(default_factory=list) new_translation: List[Any] = field(default_factory=list)
new_diarization: List[Any] = field(default_factory=list) new_diarization: List[Any] = field(default_factory=list)
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
new_translation_buffer= TimedText() new_translation_buffer: TimedText = field(default_factory=TimedText)