Refactor timed objects and data structures
This commit is contained in:
parent
83362c89c4
commit
e144abbbc7
4 changed files with 25 additions and 20 deletions
|
|
@ -6,7 +6,7 @@ text normalization, and word-level timestamp accuracy metrics with greedy alignm
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(text: str) -> str:
|
def normalize_text(text: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,6 @@ class SessionMetrics:
|
||||||
|
|
||||||
def log_summary(self) -> None:
|
def log_summary(self) -> None:
|
||||||
"""Emit a structured log line summarising the session."""
|
"""Emit a structured log line summarising the session."""
|
||||||
self.total_processing_time_s = sum(self.transcription_durations)
|
|
||||||
d = self.to_dict()
|
d = self.to_dict()
|
||||||
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
|
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
|
||||||
logger.info(f"SESSION_METRICS {d}")
|
logger.info(f"SESSION_METRICS {d}")
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,8 @@ Usage:
|
||||||
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,18 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
def format_time(seconds: float) -> str:
|
||||||
"""Format seconds as HH:MM:SS."""
|
"""Format seconds as H:MM:SS.cc (centisecond precision)."""
|
||||||
return str(timedelta(seconds=int(seconds)))
|
total_cs = int(round(seconds * 100))
|
||||||
|
cs = total_cs % 100
|
||||||
|
total_s = total_cs // 100
|
||||||
|
s = total_s % 60
|
||||||
|
total_m = total_s // 60
|
||||||
|
m = total_m % 60
|
||||||
|
h = total_m // 60
|
||||||
|
return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Timed:
|
class Timed:
|
||||||
|
|
@ -18,10 +24,10 @@ class TimedText(Timed):
|
||||||
text: Optional[str] = ''
|
text: Optional[str] = ''
|
||||||
speaker: Optional[int] = -1
|
speaker: Optional[int] = -1
|
||||||
detected_language: Optional[str] = None
|
detected_language: Optional[str] = None
|
||||||
|
|
||||||
def has_punctuation(self) -> bool:
|
def has_punctuation(self) -> bool:
|
||||||
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
|
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
|
||||||
|
|
||||||
def is_within(self, other: 'TimedText') -> bool:
|
def is_within(self, other: 'TimedText') -> bool:
|
||||||
return other.contains_timespan(self)
|
return other.contains_timespan(self)
|
||||||
|
|
||||||
|
|
@ -30,10 +36,10 @@ class TimedText(Timed):
|
||||||
|
|
||||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||||
return self.start <= other.start and self.end >= other.end
|
return self.start <= other.start and self.end >= other.end
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return bool(self.text)
|
return bool(self.text)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str(self.text)
|
return str(self.text)
|
||||||
|
|
||||||
|
|
@ -103,7 +109,7 @@ class Silence():
|
||||||
return None
|
return None
|
||||||
self.duration = self.end - self.start
|
self.duration = self.end - self.start
|
||||||
return self.duration
|
return self.duration
|
||||||
|
|
||||||
def is_silence(self) -> bool:
|
def is_silence(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -127,9 +133,9 @@ class Segment(TimedText):
|
||||||
"""Return a normalized segment representing the provided tokens."""
|
"""Return a normalized segment representing the provided tokens."""
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
start_token = tokens[0]
|
start_token = tokens[0]
|
||||||
end_token = tokens[-1]
|
end_token = tokens[-1]
|
||||||
if is_silence:
|
if is_silence:
|
||||||
return cls(
|
return cls(
|
||||||
start=start_token.start,
|
start=start_token.start,
|
||||||
|
|
@ -176,7 +182,7 @@ class SilentSegment(Segment):
|
||||||
self.text = ''
|
self.text = ''
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrontData():
|
class FrontData():
|
||||||
status: str = ''
|
status: str = ''
|
||||||
error: str = ''
|
error: str = ''
|
||||||
|
|
@ -186,7 +192,7 @@ class FrontData():
|
||||||
buffer_translation: str = ''
|
buffer_translation: str = ''
|
||||||
remaining_time_transcription: float = 0.
|
remaining_time_transcription: float = 0.
|
||||||
remaining_time_diarization: float = 0.
|
remaining_time_diarization: float = 0.
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Serialize the front-end data payload."""
|
"""Serialize the front-end data payload."""
|
||||||
_dict: Dict[str, Any] = {
|
_dict: Dict[str, Any] = {
|
||||||
|
|
@ -202,15 +208,15 @@ class FrontData():
|
||||||
_dict['error'] = self.error
|
_dict['error'] = self.error
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChangeSpeaker:
|
class ChangeSpeaker:
|
||||||
speaker: int
|
speaker: int
|
||||||
start: int
|
start: int
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State():
|
class State():
|
||||||
"""Unified state class for audio processing.
|
"""Unified state class for audio processing.
|
||||||
|
|
||||||
Contains both persistent state (tokens, buffers) and temporary update buffers
|
Contains both persistent state (tokens, buffers) and temporary update buffers
|
||||||
(new_* fields) that are consumed by TokensAlignment.
|
(new_* fields) that are consumed by TokensAlignment.
|
||||||
"""
|
"""
|
||||||
|
|
@ -221,10 +227,10 @@ class State():
|
||||||
end_attributed_speaker: float = 0.0
|
end_attributed_speaker: float = 0.0
|
||||||
remaining_time_transcription: float = 0.0
|
remaining_time_transcription: float = 0.0
|
||||||
remaining_time_diarization: float = 0.0
|
remaining_time_diarization: float = 0.0
|
||||||
|
|
||||||
# Temporary update buffers (consumed by TokensAlignment.update())
|
# Temporary update buffers (consumed by TokensAlignment.update())
|
||||||
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
|
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
|
||||||
new_translation: List[Any] = field(default_factory=list)
|
new_translation: List[Any] = field(default_factory=list)
|
||||||
new_diarization: List[Any] = field(default_factory=list)
|
new_diarization: List[Any] = field(default_factory=list)
|
||||||
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
|
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
|
||||||
new_translation_buffer= TimedText()
|
new_translation_buffer: TimedText = field(default_factory=TimedText)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue