Refactor timed objects and data structures
This commit is contained in:
parent
83362c89c4
commit
e144abbbc7
4 changed files with 25 additions and 20 deletions
|
|
@ -6,7 +6,7 @@ text normalization, and word-level timestamp accuracy metrics with greedy alignm
|
|||
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ class SessionMetrics:
|
|||
|
||||
def log_summary(self) -> None:
|
||||
"""Emit a structured log line summarising the session."""
|
||||
self.total_processing_time_s = sum(self.transcription_durations)
|
||||
d = self.to_dict()
|
||||
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
|
||||
logger.info(f"SESSION_METRICS {d}")
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ Usage:
|
|||
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
"""Format seconds as H:MM:SS.cc (centisecond precision)."""
|
||||
total_cs = int(round(seconds * 100))
|
||||
cs = total_cs % 100
|
||||
total_s = total_cs // 100
|
||||
s = total_s % 60
|
||||
total_m = total_s // 60
|
||||
m = total_m % 60
|
||||
h = total_m // 60
|
||||
return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
|
||||
|
||||
@dataclass
|
||||
class Timed:
|
||||
|
|
@ -18,10 +24,10 @@ class TimedText(Timed):
|
|||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
|
||||
def has_punctuation(self) -> bool:
|
||||
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
|
||||
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
|
||||
|
|
@ -30,10 +36,10 @@ class TimedText(Timed):
|
|||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.text)
|
||||
|
||||
|
|
@ -103,7 +109,7 @@ class Silence():
|
|||
return None
|
||||
self.duration = self.end - self.start
|
||||
return self.duration
|
||||
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return True
|
||||
|
||||
|
|
@ -127,9 +133,9 @@ class Segment(TimedText):
|
|||
"""Return a normalized segment representing the provided tokens."""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
|
||||
start_token = tokens[0]
|
||||
end_token = tokens[-1]
|
||||
end_token = tokens[-1]
|
||||
if is_silence:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
|
|
@ -176,7 +182,7 @@ class SilentSegment(Segment):
|
|||
self.text = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class FrontData():
|
||||
status: str = ''
|
||||
error: str = ''
|
||||
|
|
@ -186,7 +192,7 @@ class FrontData():
|
|||
buffer_translation: str = ''
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the front-end data payload."""
|
||||
_dict: Dict[str, Any] = {
|
||||
|
|
@ -202,15 +208,15 @@ class FrontData():
|
|||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
start: int
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class State():
|
||||
"""Unified state class for audio processing.
|
||||
|
||||
|
||||
Contains both persistent state (tokens, buffers) and temporary update buffers
|
||||
(new_* fields) that are consumed by TokensAlignment.
|
||||
"""
|
||||
|
|
@ -221,10 +227,10 @@ class State():
|
|||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
|
||||
|
||||
# Temporary update buffers (consumed by TokensAlignment.update())
|
||||
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
|
||||
new_translation: List[Any] = field(default_factory=list)
|
||||
new_diarization: List[Any] = field(default_factory=list)
|
||||
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
|
||||
new_translation_buffer= TimedText()
|
||||
new_translation_buffer: TimedText = field(default_factory=TimedText)
|
||||
|
|
|
|||
Loading…
Reference in a new issue