diff --git a/whisperlivekit/diff_protocol.py b/whisperlivekit/diff_protocol.py new file mode 100644 index 0000000..9baa144 --- /dev/null +++ b/whisperlivekit/diff_protocol.py @@ -0,0 +1,105 @@ +"""Diff-based WebSocket output protocol for WhisperLiveKit. + +Instead of sending the full FrontData state on every update, the DiffTracker +computes incremental diffs — only sending new/changed lines and volatile fields. + +Protocol +-------- +Opt-in via query parameter: ``ws://host:port/asr?mode=diff`` + +First message from server: + ``{"type": "snapshot", "seq": 1, ...full state...}`` + +Subsequent messages: + ``{"type": "diff", "seq": N, "new_lines": [...], ...}`` + +The client reconstructs state by: +1. On ``"snapshot"``: replace all state. +2. On ``"diff"``: + - If ``lines_pruned`` > 0: drop that many lines from the front. + - Append ``new_lines`` to the end. + - Replace ``buffer_*`` and ``remaining_time_*`` fields. + - Use ``n_lines`` to verify sync (total expected line count). +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List + +from whisperlivekit.timed_objects import FrontData + + +@dataclass +class DiffTracker: + """Tracks FrontData state and computes incremental diffs.""" + + seq: int = 0 + _prev_lines: List[Dict[str, Any]] = field(default_factory=list) + _sent_snapshot: bool = False + + def to_message(self, front_data: FrontData) -> Dict[str, Any]: + """Convert a FrontData into a diff or snapshot message. + + First call returns a full snapshot. Subsequent calls return diffs + containing only changed/new data. + """ + self.seq += 1 + full = front_data.to_dict() + current_lines = full["lines"] + + if not self._sent_snapshot: + self._sent_snapshot = True + self._prev_lines = current_lines[:] + return {"type": "snapshot", "seq": self.seq, **full} + + # Compute diff + msg: Dict[str, Any] = { + "type": "diff", + "seq": self.seq, + "status": full["status"], + "n_lines": len(current_lines), + "buffer_transcription": full["buffer_transcription"], + "buffer_diarization": full["buffer_diarization"], + "buffer_translation": full["buffer_translation"], + "remaining_time_transcription": full["remaining_time_transcription"], + "remaining_time_diarization": full["remaining_time_diarization"], + } + if full.get("error"): + msg["error"] = full["error"] + + # Detect front-pruning: find where current[0] appears in prev + prune_offset = 0 + if current_lines and self._prev_lines: + first_current = current_lines[0] + for i, prev_line in enumerate(self._prev_lines): + if prev_line == first_current: + prune_offset = i + break + else: + # current[0] not found in prev — treat all prev as pruned + prune_offset = len(self._prev_lines) + elif not current_lines: + prune_offset = len(self._prev_lines) + + if prune_offset > 0: + msg["lines_pruned"] = prune_offset + + # Find common prefix starting after pruned lines + common = 0 + remaining_prev = len(self._prev_lines) - prune_offset + min_len = min(remaining_prev, len(current_lines)) + while common < min_len and self._prev_lines[prune_offset + common] == current_lines[common]: + common += 1 + + # New or changed lines after the common prefix + new_lines = current_lines[common:] + if new_lines: + msg["new_lines"] = new_lines + + self._prev_lines = current_lines[:] + return msg + + def reset(self) -> None: + """Reset state so the next call produces a fresh snapshot.""" + self.seq = 0 + self._prev_lines = [] + self._sent_snapshot = False