From 586540ae365f35739baaad761e6b1dfd41364d08 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 22 Feb 2026 16:19:00 +0100 Subject: [PATCH] Add test harness and test client --- whisperlivekit/test_client.py | 393 +++++++++++++++++ whisperlivekit/test_data.py | 365 ++++++++++++++++ whisperlivekit/test_harness.py | 745 +++++++++++++++++++++++++++++++++ 3 files changed, 1503 insertions(+) create mode 100644 whisperlivekit/test_client.py create mode 100644 whisperlivekit/test_data.py create mode 100644 whisperlivekit/test_harness.py diff --git a/whisperlivekit/test_client.py b/whisperlivekit/test_client.py new file mode 100644 index 0000000..b28a410 --- /dev/null +++ b/whisperlivekit/test_client.py @@ -0,0 +1,393 @@ +"""Headless test client for WhisperLiveKit. + +Feeds audio files to the transcription pipeline via WebSocket +and collects results — no browser or microphone needed. + +Usage: + # Against a running server (server must be started with --pcm-input): + python -m whisperlivekit.test_client audio.wav + + # Custom server URL and speed: + python -m whisperlivekit.test_client audio.wav --url ws://localhost:9090/asr --speed 0 + + # Output raw JSON responses: + python -m whisperlivekit.test_client audio.wav --json + + # Programmatic usage: + from whisperlivekit.test_client import transcribe_audio + result = asyncio.run(transcribe_audio("audio.wav")) + print(result.text) +""" + +import argparse +import asyncio +import json +import logging +import subprocess +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 16000 +BYTES_PER_SAMPLE = 2 # s16le + + +@dataclass +class TranscriptionResult: + """Collected transcription results from a session.""" + + responses: List[dict] = field(default_factory=list) + audio_duration: float = 0.0 + + @property + def text(self) -> str: + """Full transcription text from the last response (committed lines + buffer).""" + if not self.responses: + return "" + for resp in reversed(self.responses): + lines = resp.get("lines", []) + buffer = resp.get("buffer_transcription", "") + if lines or buffer: + parts = [line["text"] for line in lines if line.get("text")] + if buffer: + parts.append(buffer) + return " ".join(parts) + return "" + + @property + def committed_text(self) -> str: + """Only the committed (finalized) transcription lines, no buffer.""" + if not self.responses: + return "" + for resp in reversed(self.responses): + lines = resp.get("lines", []) + if lines: + return " ".join(line["text"] for line in lines if line.get("text")) + return "" + + @property + def lines(self) -> List[dict]: + """Committed lines from the last response.""" + for resp in reversed(self.responses): + if resp.get("lines"): + return resp["lines"] + return [] + + @property + def n_updates(self) -> int: + """Number of non-empty updates received.""" + return sum( + 1 for r in self.responses + if r.get("lines") or r.get("buffer_transcription") + ) + + +def reconstruct_state(msg: dict, lines: List[dict]) -> dict: + """Reconstruct full state from a diff or snapshot message. + + Mutates ``lines`` in-place (prune front, append new) and returns + a full-state dict compatible with TranscriptionResult. + """ + if msg.get("type") == "snapshot": + lines.clear() + lines.extend(msg.get("lines", [])) + return msg + + # Apply diff + n_pruned = msg.get("lines_pruned", 0) + if n_pruned > 0: + del lines[:n_pruned] + new_lines = msg.get("new_lines", []) + lines.extend(new_lines) + + return { + "status": msg.get("status", ""), + "lines": lines[:], # snapshot copy + "buffer_transcription": msg.get("buffer_transcription", ""), + "buffer_diarization": msg.get("buffer_diarization", ""), + "buffer_translation": msg.get("buffer_translation", ""), + "remaining_time_transcription": msg.get("remaining_time_transcription", 0), + "remaining_time_diarization": msg.get("remaining_time_diarization", 0), + } + + +def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes: + """Load an audio file and convert to PCM s16le mono via ffmpeg. + + Supports any format ffmpeg can decode (wav, mp3, flac, ogg, m4a, ...). + """ + cmd = [ + "ffmpeg", "-i", str(audio_path), + "-f", "s16le", "-acodec", "pcm_s16le", + "-ar", str(sample_rate), "-ac", "1", + "-loglevel", "error", + "pipe:1", + ] + proc = subprocess.run(cmd, capture_output=True) + if proc.returncode != 0: + raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}") + if not proc.stdout: + raise RuntimeError(f"ffmpeg produced no output for {audio_path}") + return proc.stdout + + +async def transcribe_audio( + audio_path: str, + url: str = "ws://localhost:8000/asr", + chunk_duration: float = 0.5, + speed: float = 1.0, + timeout: float = 60.0, + on_response: Optional[callable] = None, + mode: str = "full", +) -> TranscriptionResult: + """Feed an audio file to a running WhisperLiveKit server and collect results. + + Args: + audio_path: Path to an audio file (any format ffmpeg supports). + url: WebSocket URL of the /asr endpoint. + chunk_duration: Duration of each audio chunk sent (seconds). + speed: Playback speed multiplier (1.0 = real-time, 0 = as fast as possible). + timeout: Max seconds to wait for the server after audio finishes. + on_response: Optional callback invoked with each response dict as it arrives. + mode: Output mode — "full" (default) or "diff" for incremental updates. + + Returns: + TranscriptionResult with collected responses and convenience accessors. + """ + import websockets + + result = TranscriptionResult() + + # Convert audio to PCM for both modes (we need duration either way) + pcm_data = load_audio_pcm(audio_path) + result.audio_duration = len(pcm_data) / (SAMPLE_RATE * BYTES_PER_SAMPLE) + logger.info("Loaded %s: %.1fs of audio", audio_path, result.audio_duration) + + chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE) + + # Append mode query parameter if using diff mode + connect_url = url + if mode == "diff": + sep = "&" if "?" in url else "?" + connect_url = f"{url}{sep}mode=diff" + + async with websockets.connect(connect_url) as ws: + # Server sends config on connect + config_raw = await ws.recv() + config_msg = json.loads(config_raw) + is_pcm = config_msg.get("useAudioWorklet", False) + logger.info("Server config: %s", config_msg) + + if not is_pcm: + logger.warning( + "Server is not in PCM mode. Start the server with --pcm-input " + "for the test client. Attempting raw file streaming instead." + ) + + done_event = asyncio.Event() + diff_lines: List[dict] = [] # running state for diff mode reconstruction + + async def send_audio(): + if is_pcm: + offset = 0 + n_chunks = 0 + while offset < len(pcm_data): + end = min(offset + chunk_bytes, len(pcm_data)) + await ws.send(pcm_data[offset:end]) + offset = end + n_chunks += 1 + if speed > 0: + await asyncio.sleep(chunk_duration / speed) + logger.info("Sent %d PCM chunks (%.1fs)", n_chunks, result.audio_duration) + else: + # Non-PCM: send raw file bytes for server-side ffmpeg decoding + file_bytes = Path(audio_path).read_bytes() + raw_chunk_size = 32000 + offset = 0 + while offset < len(file_bytes): + end = min(offset + raw_chunk_size, len(file_bytes)) + await ws.send(file_bytes[offset:end]) + offset = end + if speed > 0: + await asyncio.sleep(0.5 / speed) + logger.info("Sent %d bytes of raw audio", len(file_bytes)) + + # Signal end of audio + await ws.send(b"") + logger.info("End-of-audio signal sent") + + async def receive_results(): + try: + async for raw_msg in ws: + data = json.loads(raw_msg) + if data.get("type") == "ready_to_stop": + logger.info("Server signaled ready_to_stop") + done_event.set() + return + # In diff mode, reconstruct full state for uniform API + if mode == "diff" and data.get("type") in ("snapshot", "diff"): + data = reconstruct_state(data, diff_lines) + result.responses.append(data) + if on_response: + on_response(data) + except Exception as e: + logger.debug("Receiver ended: %s", e) + done_event.set() + + send_task = asyncio.create_task(send_audio()) + recv_task = asyncio.create_task(receive_results()) + + # Total wait = time to send + time for server to process + timeout margin + send_time = result.audio_duration / speed if speed > 0 else 1.0 + total_timeout = send_time + timeout + + try: + await asyncio.wait_for( + asyncio.gather(send_task, recv_task), + timeout=total_timeout, + ) + except asyncio.TimeoutError: + logger.warning("Timed out after %.0fs", total_timeout) + send_task.cancel() + recv_task.cancel() + try: + await asyncio.gather(send_task, recv_task, return_exceptions=True) + except Exception: + pass + + logger.info( + "Session complete: %d responses, %d updates", + len(result.responses), result.n_updates, + ) + return result + + +def _print_result(result: TranscriptionResult, output_json: bool = False) -> None: + """Print transcription results to stdout.""" + if output_json: + for resp in result.responses: + print(json.dumps(resp)) + return + + if result.lines: + for line in result.lines: + speaker = line.get("speaker", "") + text = line.get("text", "") + start = line.get("start", "") + end = line.get("end", "") + prefix = f"[{start} -> {end}]" + if speaker and speaker != 1: + prefix += f" Speaker {speaker}" + print(f"{prefix} {text}") + + buffer = "" + if result.responses: + buffer = result.responses[-1].get("buffer_transcription", "") + if buffer: + print(f"[buffer] {buffer}") + + if not result.lines and not buffer: + print("(no transcription received)") + + print( + f"\n--- {len(result.responses)} responses | " + f"{result.n_updates} updates | " + f"{result.audio_duration:.1f}s audio ---" + ) + + +def main(): + parser = argparse.ArgumentParser( + prog="whisperlivekit-test-client", + description=( + "Headless test client for WhisperLiveKit. " + "Feeds audio files via WebSocket and prints the transcription." + ), + ) + parser.add_argument("audio", help="Path to audio file (wav, mp3, flac, ...)") + parser.add_argument( + "--url", default="ws://localhost:8000/asr", + help="WebSocket endpoint URL (default: ws://localhost:8000/asr)", + ) + parser.add_argument( + "--speed", type=float, default=1.0, + help="Playback speed multiplier (1.0 = real-time, 0 = fastest, default: 1.0)", + ) + parser.add_argument( + "--chunk-duration", type=float, default=0.5, + help="Chunk duration in seconds (default: 0.5)", + ) + parser.add_argument( + "--timeout", type=float, default=60.0, + help="Max seconds to wait for server after audio ends (default: 60)", + ) + parser.add_argument( + "--language", "-l", default=None, + help="Override transcription language for this session (e.g. en, fr, auto)", + ) + parser.add_argument("--json", action="store_true", help="Output raw JSON responses") + parser.add_argument( + "--diff", action="store_true", + help="Use diff protocol (only receive incremental changes from server)", + ) + parser.add_argument( + "--live", action="store_true", + help="Print transcription updates as they arrive", + ) + parser.add_argument("--verbose", "-v", action="store_true") + + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + audio_path = Path(args.audio) + if not audio_path.exists(): + print(f"Error: file not found: {audio_path}", file=sys.stderr) + sys.exit(1) + + live_callback = None + if args.live: + def live_callback(data): + lines = data.get("lines", []) + buf = data.get("buffer_transcription", "") + parts = [l["text"] for l in lines if l.get("text")] + if buf: + parts.append(f"[{buf}]") + if parts: + print("\r" + " ".join(parts), end="", flush=True) + + # Build URL with query parameters for language and mode + url = args.url + params = [] + if args.language: + params.append(f"language={args.language}") + if args.diff: + params.append("mode=diff") + if params: + sep = "&" if "?" in url else "?" + url = f"{url}{sep}{'&'.join(params)}" + + result = asyncio.run(transcribe_audio( + audio_path=str(audio_path), + url=url, + chunk_duration=args.chunk_duration, + speed=args.speed, + timeout=args.timeout, + on_response=live_callback, + mode="diff" if args.diff else "full", + )) + + if args.live: + print() # newline after live output + + _print_result(result, output_json=args.json) + + +if __name__ == "__main__": + main() diff --git a/whisperlivekit/test_data.py b/whisperlivekit/test_data.py new file mode 100644 index 0000000..4943d15 --- /dev/null +++ b/whisperlivekit/test_data.py @@ -0,0 +1,365 @@ +"""Standard test audio samples for evaluating the WhisperLiveKit pipeline. + +Downloads curated samples from public ASR datasets (LibriSpeech, AMI) +and caches them locally. Each sample includes the audio file path, +ground truth transcript, speaker info, and timing metadata. + +Usage:: + + from whisperlivekit.test_data import get_samples, get_sample + + # Download all standard test samples (first call downloads, then cached) + samples = get_samples() + + for s in samples: + print(f"{s.name}: {s.duration:.1f}s, {s.n_speakers} speaker(s)") + print(f" Reference: {s.reference[:60]}...") + + # Use with TestHarness + from whisperlivekit.test_harness import TestHarness + + async with TestHarness(model_size="base", lan="en") as h: + sample = get_sample("librispeech_short") + await h.feed(sample.path, speed=0) + result = await h.finish() + print(f"WER: {result.wer(sample.reference):.2%}") + +Requires: pip install whisperlivekit[test] (installs 'datasets' and 'librosa') +""" + +import json +import logging +import wave +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List + +import numpy as np + +logger = logging.getLogger(__name__) + +CACHE_DIR = Path.home() / ".cache" / "whisperlivekit" / "test_data" +METADATA_FILE = "metadata.json" + + +@dataclass +class TestSample: + """A test audio sample with ground truth metadata.""" + + name: str + path: str # absolute path to WAV file + reference: str # ground truth transcript + duration: float # audio duration in seconds + sample_rate: int = 16000 + n_speakers: int = 1 + language: str = "en" + source: str = "" # dataset name + # Per-utterance ground truth for multi-speaker: [(start, end, speaker, text), ...] + utterances: List[Dict] = field(default_factory=list) + + @property + def has_timestamps(self) -> bool: + return len(self.utterances) > 0 + + +def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None: + """Save numpy audio array as 16-bit PCM WAV.""" + # Ensure mono + if audio.ndim > 1: + audio = audio.mean(axis=-1) + # Normalize to int16 range + if audio.dtype in (np.float32, np.float64): + audio = np.clip(audio, -1.0, 1.0) + audio = (audio * 32767).astype(np.int16) + elif audio.dtype != np.int16: + audio = audio.astype(np.int16) + + path.parent.mkdir(parents=True, exist_ok=True) + with wave.open(str(path), "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio.tobytes()) + + +def _load_metadata() -> Dict: + """Load cached metadata if it exists.""" + meta_path = CACHE_DIR / METADATA_FILE + if meta_path.exists(): + return json.loads(meta_path.read_text()) + return {} + + +def _save_metadata(meta: Dict) -> None: + CACHE_DIR.mkdir(parents=True, exist_ok=True) + (CACHE_DIR / METADATA_FILE).write_text(json.dumps(meta, indent=2)) + + +def _ensure_datasets(): + """Check that the datasets library is available.""" + try: + import datasets # noqa: F401 + return True + except ImportError: + raise ImportError( + "The 'datasets' package is required for test data download. " + "Install it with: pip install whisperlivekit[test]" + ) + + +def _decode_audio(audio_bytes: bytes) -> tuple: + """Decode audio bytes using soundfile (avoids torchcodec dependency). + + Returns: + (audio_array, sample_rate) — float32 numpy array and int sample rate. + """ + import io + + import soundfile as sf + audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32") + return np.array(audio_array, dtype=np.float32), sr + + +# --------------------------------------------------------------------------- +# Dataset-specific download functions +# --------------------------------------------------------------------------- + +def _download_librispeech_samples(n_samples: int = 3) -> List[Dict]: + """Download short samples from LibriSpeech test-clean.""" + _ensure_datasets() + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info("Downloading LibriSpeech test-clean samples (streaming)...") + ds = load_dataset( + "openslr/librispeech_asr", + "clean", + split="test", + streaming=True, + ) + ds = ds.cast_column("audio", Audio(decode=False)) + + samples = [] + for i, item in enumerate(ds): + if i >= n_samples: + break + + audio_array, sr = _decode_audio(item["audio"]["bytes"]) + duration = len(audio_array) / sr + text = item["text"] + sample_id = item.get("id", f"librispeech_{i}") + + # Save WAV + wav_name = f"librispeech_{i}.wav" + wav_path = CACHE_DIR / wav_name + _save_wav(wav_path, audio_array, sr) + + # Name: first sample is "librispeech_short", rest are numbered + name = "librispeech_short" if i == 0 else f"librispeech_{i}" + + samples.append({ + "name": name, + "file": wav_name, + "reference": text, + "duration": round(duration, 2), + "sample_rate": sr, + "n_speakers": 1, + "language": "en", + "source": "openslr/librispeech_asr (test-clean)", + "source_id": str(sample_id), + "utterances": [], + }) + logger.info( + " [%d] %.1fs - %s", + i, duration, text[:60] + ("..." if len(text) > 60 else ""), + ) + + return samples + + +def _download_ami_sample() -> List[Dict]: + """Download one AMI meeting segment with multiple speakers.""" + _ensure_datasets() + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info("Downloading AMI meeting test sample (streaming)...") + + # Use the edinburghcstr/ami version which has pre-segmented utterances + # with speaker_id, begin_time, end_time, text + ds = load_dataset( + "edinburghcstr/ami", + "ihm", + split="test", + streaming=True, + ) + ds = ds.cast_column("audio", Audio(decode=False)) + + # Collect utterances from one meeting + meeting_utterances = [] + meeting_id = None + audio_arrays = [] + sample_rate = None + + for item in ds: + mid = item.get("meeting_id", "unknown") + + # Take the first meeting only + if meeting_id is None: + meeting_id = mid + elif mid != meeting_id: + # We've moved to a different meeting, stop + break + + audio_array, sr = _decode_audio(item["audio"]["bytes"]) + sample_rate = sr + + meeting_utterances.append({ + "start": round(item.get("begin_time", 0.0), 2), + "end": round(item.get("end_time", 0.0), 2), + "speaker": item.get("speaker_id", "unknown"), + "text": item.get("text", ""), + }) + audio_arrays.append(audio_array) + + # Limit to reasonable size (~60s of utterances) + total_dur = sum(u["end"] - u["start"] for u in meeting_utterances) + if total_dur > 60: + break + + if not audio_arrays: + logger.warning("No AMI samples found") + return [] + + # Concatenate all utterance audio + full_audio = np.concatenate(audio_arrays) + duration = len(full_audio) / sample_rate + + # Build reference text + speakers = set(u["speaker"] for u in meeting_utterances) + reference = " ".join(u["text"] for u in meeting_utterances if u["text"]) + + wav_name = "ami_meeting.wav" + wav_path = CACHE_DIR / wav_name + _save_wav(wav_path, full_audio, sample_rate) + + logger.info( + " AMI meeting %s: %.1fs, %d speakers, %d utterances", + meeting_id, duration, len(speakers), len(meeting_utterances), + ) + + return [{ + "name": "ami_meeting", + "file": wav_name, + "reference": reference, + "duration": round(duration, 2), + "sample_rate": sample_rate, + "n_speakers": len(speakers), + "language": "en", + "source": f"edinburghcstr/ami (ihm, meeting {meeting_id})", + "source_id": meeting_id, + "utterances": meeting_utterances, + }] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def download_test_samples(force: bool = False) -> List[TestSample]: + """Download standard test audio samples. + + Downloads samples from LibriSpeech (clean single-speaker) and + AMI (multi-speaker meetings) on first call. Subsequent calls + return cached data. + + Args: + force: Re-download even if cached. + + Returns: + List of TestSample objects ready for use with TestHarness. + """ + meta = _load_metadata() + + if meta.get("samples") and not force: + # Check all files still exist + all_exist = all( + (CACHE_DIR / s["file"]).exists() + for s in meta["samples"] + ) + if all_exist: + return _meta_to_samples(meta["samples"]) + + logger.info("Downloading test samples to %s ...", CACHE_DIR) + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + all_samples = [] + + try: + all_samples.extend(_download_librispeech_samples(n_samples=3)) + except Exception as e: + logger.warning("Failed to download LibriSpeech samples: %s", e) + + try: + all_samples.extend(_download_ami_sample()) + except Exception as e: + logger.warning("Failed to download AMI sample: %s", e) + + if not all_samples: + raise RuntimeError( + "Failed to download any test samples. " + "Check your internet connection and ensure 'datasets' is installed: " + "pip install whisperlivekit[test]" + ) + + _save_metadata({"samples": all_samples}) + logger.info("Downloaded %d test samples to %s", len(all_samples), CACHE_DIR) + + return _meta_to_samples(all_samples) + + +def get_samples() -> List[TestSample]: + """Get standard test samples (downloads on first call).""" + return download_test_samples() + + +def get_sample(name: str) -> TestSample: + """Get a specific test sample by name. + + Available names: 'librispeech_short', 'librispeech_1', 'librispeech_2', + 'ami_meeting'. + + Raises: + KeyError: If the sample name is not found. + """ + samples = get_samples() + for s in samples: + if s.name == name: + return s + available = [s.name for s in samples] + raise KeyError(f"Sample '{name}' not found. Available: {available}") + + +def list_sample_names() -> List[str]: + """List names of available test samples (downloads if needed).""" + return [s.name for s in get_samples()] + + +def _meta_to_samples(meta_list: List[Dict]) -> List[TestSample]: + """Convert metadata dicts to TestSample objects.""" + samples = [] + for m in meta_list: + samples.append(TestSample( + name=m["name"], + path=str(CACHE_DIR / m["file"]), + reference=m["reference"], + duration=m["duration"], + sample_rate=m.get("sample_rate", 16000), + n_speakers=m.get("n_speakers", 1), + language=m.get("language", "en"), + source=m.get("source", ""), + utterances=m.get("utterances", []), + )) + return samples diff --git a/whisperlivekit/test_harness.py b/whisperlivekit/test_harness.py new file mode 100644 index 0000000..1353652 --- /dev/null +++ b/whisperlivekit/test_harness.py @@ -0,0 +1,745 @@ +"""In-process testing harness for the full WhisperLiveKit pipeline. + +Wraps AudioProcessor to provide a controllable, observable interface +for testing transcription, diarization, silence detection, and timing +without needing a running server or WebSocket connection. + +Designed for use by AI agents: feed audio with timeline control, +inspect state at any point, pause/resume to test silence detection, +cut to test abrupt termination. + +Usage:: + + import asyncio + from whisperlivekit.test_harness import TestHarness + + async def main(): + async with TestHarness(model_size="base", lan="en") as h: + # Load audio with timeline control + player = h.load_audio("interview.wav") + + # Play first 5 seconds at real-time speed + await player.play(5.0, speed=1.0) + print(h.state.text) # Check what's transcribed so far + + # Pause for 7 seconds (triggers silence detection) + await h.pause(7.0, speed=1.0) + assert h.state.has_silence + + # Resume playback + await player.play(5.0, speed=1.0) + + # Finish and evaluate + result = await h.finish() + print(f"WER: {result.wer('expected transcription'):.2%}") + print(f"Speakers: {result.speakers}") + print(f"Silence segments: {len(result.silence_segments)}") + + # Inspect historical state at specific audio position + snap = h.snapshot_at(3.0) + print(f"At 3s: '{snap.text}'") + + asyncio.run(main()) +""" + +import asyncio +import logging +import subprocess +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +from whisperlivekit.timed_objects import FrontData + +logger = logging.getLogger(__name__) + +# Engine cache: avoids reloading models when switching backends in tests. +# Key is a frozen config tuple, value is the TranscriptionEngine instance. +_engine_cache: Dict[Tuple, "Any"] = {} + +SAMPLE_RATE = 16000 +BYTES_PER_SAMPLE = 2 # s16le + + +def _parse_time(time_str: str) -> float: + """Parse 'H:MM:SS.cc' timestamp string to seconds.""" + parts = time_str.split(":") + if len(parts) == 3: + return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2]) + if len(parts) == 2: + return int(parts[0]) * 60 + float(parts[1]) + return float(parts[0]) + + +def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes: + """Load any audio file and convert to PCM s16le mono via ffmpeg.""" + cmd = [ + "ffmpeg", "-i", str(audio_path), + "-f", "s16le", "-acodec", "pcm_s16le", + "-ar", str(sample_rate), "-ac", "1", + "-loglevel", "error", + "pipe:1", + ] + proc = subprocess.run(cmd, capture_output=True) + if proc.returncode != 0: + raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}") + if not proc.stdout: + raise RuntimeError(f"ffmpeg produced no output for {audio_path}") + return proc.stdout + + +# --------------------------------------------------------------------------- +# TestState — observable transcription state +# --------------------------------------------------------------------------- + +@dataclass +class TestState: + """Observable transcription state at a point in time. + + Provides accessors for inspecting lines, buffers, speakers, timestamps, + silence segments, and computing evaluation metrics like WER. + + All time-based queries accept seconds as floats. + """ + + lines: List[Dict[str, Any]] = field(default_factory=list) + buffer_transcription: str = "" + buffer_diarization: str = "" + buffer_translation: str = "" + remaining_time_transcription: float = 0.0 + remaining_time_diarization: float = 0.0 + audio_position: float = 0.0 + status: str = "" + error: str = "" + + @classmethod + def from_front_data(cls, front_data: FrontData, audio_position: float = 0.0) -> "TestState": + d = front_data.to_dict() + return cls( + lines=d.get("lines", []), + buffer_transcription=d.get("buffer_transcription", ""), + buffer_diarization=d.get("buffer_diarization", ""), + buffer_translation=d.get("buffer_translation", ""), + remaining_time_transcription=d.get("remaining_time_transcription", 0), + remaining_time_diarization=d.get("remaining_time_diarization", 0), + audio_position=audio_position, + status=d.get("status", ""), + error=d.get("error", ""), + ) + + # ── Text accessors ── + + @property + def text(self) -> str: + """Full transcription: committed lines + buffer.""" + parts = [l["text"] for l in self.lines if l.get("text")] + if self.buffer_transcription: + parts.append(self.buffer_transcription) + return " ".join(parts) + + @property + def committed_text(self) -> str: + """Only committed (finalized) lines, no buffer.""" + return " ".join(l["text"] for l in self.lines if l.get("text")) + + @property + def committed_word_count(self) -> int: + """Number of words in committed lines.""" + t = self.committed_text + return len(t.split()) if t.strip() else 0 + + @property + def buffer_word_count(self) -> int: + """Number of words in the unconfirmed buffer.""" + return len(self.buffer_transcription.split()) if self.buffer_transcription.strip() else 0 + + # ── Speaker accessors ── + + @property + def speakers(self) -> Set[int]: + """Set of speaker IDs (excluding silence marker -2).""" + return {l["speaker"] for l in self.lines if l.get("speaker", 0) > 0} + + @property + def n_speakers(self) -> int: + return len(self.speakers) + + def speaker_at(self, time_s: float) -> Optional[int]: + """Speaker ID at the given timestamp, or None if no segment covers it.""" + line = self.line_at(time_s) + return line["speaker"] if line else None + + def speakers_in(self, start_s: float, end_s: float) -> Set[int]: + """All speaker IDs active in the time range (excluding silence -2).""" + return { + l.get("speaker") + for l in self.lines_between(start_s, end_s) + if l.get("speaker", 0) > 0 + } + + @property + def speaker_timeline(self) -> List[Dict[str, Any]]: + """Timeline: [{"start": float, "end": float, "speaker": int}] for all lines.""" + return [ + { + "start": _parse_time(l.get("start", "0:00:00")), + "end": _parse_time(l.get("end", "0:00:00")), + "speaker": l.get("speaker", -1), + } + for l in self.lines + ] + + @property + def n_speaker_changes(self) -> int: + """Number of speaker transitions (excluding silence segments).""" + speech = [s for s in self.speaker_timeline if s["speaker"] != -2] + return sum( + 1 for i in range(1, len(speech)) + if speech[i]["speaker"] != speech[i - 1]["speaker"] + ) + + # ── Silence accessors ── + + @property + def has_silence(self) -> bool: + """Whether any silence segment (speaker=-2) exists.""" + return any(l.get("speaker") == -2 for l in self.lines) + + @property + def silence_segments(self) -> List[Dict[str, Any]]: + """All silence segments (raw line dicts).""" + return [l for l in self.lines if l.get("speaker") == -2] + + def silence_at(self, time_s: float) -> bool: + """True if time_s falls within a silence segment.""" + line = self.line_at(time_s) + return line is not None and line.get("speaker") == -2 + + # ── Line / segment accessors ── + + @property + def speech_lines(self) -> List[Dict[str, Any]]: + """Lines excluding silence segments.""" + return [l for l in self.lines if l.get("speaker", 0) != -2 and l.get("text")] + + def line_at(self, time_s: float) -> Optional[Dict[str, Any]]: + """Find the line covering the given timestamp (seconds).""" + for line in self.lines: + start = _parse_time(line.get("start", "0:00:00")) + end = _parse_time(line.get("end", "0:00:00")) + if start <= time_s <= end: + return line + return None + + def text_at(self, time_s: float) -> Optional[str]: + """Text of the segment covering the given timestamp.""" + line = self.line_at(time_s) + return line["text"] if line else None + + def lines_between(self, start_s: float, end_s: float) -> List[Dict[str, Any]]: + """All lines overlapping the time range [start_s, end_s].""" + result = [] + for line in self.lines: + ls = _parse_time(line.get("start", "0:00:00")) + le = _parse_time(line.get("end", "0:00:00")) + if le >= start_s and ls <= end_s: + result.append(line) + return result + + def text_between(self, start_s: float, end_s: float) -> str: + """Concatenated text of all lines overlapping the time range.""" + return " ".join( + l["text"] for l in self.lines_between(start_s, end_s) + if l.get("text") + ) + + # ── Evaluation ── + + def wer(self, reference: str) -> float: + """Word Error Rate of committed text against reference. + + Returns: + WER as a float (0.0 = perfect, 1.0 = 100% error rate). + """ + from whisperlivekit.metrics import compute_wer + result = compute_wer(reference, self.committed_text) + return result["wer"] + + def wer_detailed(self, reference: str) -> Dict: + """Full WER breakdown: substitutions, insertions, deletions, etc.""" + from whisperlivekit.metrics import compute_wer + return compute_wer(reference, self.committed_text) + + # ── Timing validation ── + + @property + def timestamps(self) -> List[Dict[str, Any]]: + """All line timestamps as [{"start": float, "end": float, "speaker": int, "text": str}].""" + result = [] + for line in self.lines: + result.append({ + "start": _parse_time(line.get("start", "0:00:00")), + "end": _parse_time(line.get("end", "0:00:00")), + "speaker": line.get("speaker", -1), + "text": line.get("text", ""), + }) + return result + + @property + def timing_valid(self) -> bool: + """All timestamps have start <= end and no negative values.""" + for ts in self.timestamps: + if ts["start"] < 0 or ts["end"] < 0: + return False + if ts["end"] < ts["start"]: + return False + return True + + @property + def timing_monotonic(self) -> bool: + """Line start times are non-decreasing.""" + stamps = self.timestamps + for i in range(1, len(stamps)): + if stamps[i]["start"] < stamps[i - 1]["start"]: + return False + return True + + def timing_errors(self) -> List[str]: + """Human-readable list of timing issues found.""" + errors = [] + stamps = self.timestamps + for i, ts in enumerate(stamps): + if ts["start"] < 0: + errors.append(f"Line {i}: negative start {ts['start']:.2f}s") + if ts["end"] < 0: + errors.append(f"Line {i}: negative end {ts['end']:.2f}s") + if ts["end"] < ts["start"]: + errors.append( + f"Line {i}: end ({ts['end']:.2f}s) < start ({ts['start']:.2f}s)" + ) + for i in range(1, len(stamps)): + if stamps[i]["start"] < stamps[i - 1]["start"]: + errors.append( + f"Line {i}: start ({stamps[i]['start']:.2f}s) < previous start " + f"({stamps[i-1]['start']:.2f}s) — non-monotonic" + ) + return errors + + +# --------------------------------------------------------------------------- +# AudioPlayer — timeline control for a loaded audio file +# --------------------------------------------------------------------------- + +class AudioPlayer: + """Controls playback of a loaded audio file through the pipeline. + + Tracks position in the audio, enabling play/pause/resume patterns:: + + player = h.load_audio("speech.wav") + await player.play(3.0) # Play first 3 seconds + await h.pause(7.0) # 7s silence (triggers detection) + await player.play(5.0) # Play next 5 seconds + await player.play() # Play all remaining audio + + Args: + harness: The TestHarness instance. + pcm_data: Raw PCM s16le 16kHz mono bytes. + sample_rate: Audio sample rate (default 16000). + """ + + def __init__(self, harness: "TestHarness", pcm_data: bytes, sample_rate: int = SAMPLE_RATE): + self._harness = harness + self._pcm = pcm_data + self._sr = sample_rate + self._bps = sample_rate * BYTES_PER_SAMPLE # bytes per second + self._pos = 0 # current position in bytes + + @property + def position(self) -> float: + """Current playback position in seconds.""" + return self._pos / self._bps + + @property + def duration(self) -> float: + """Total audio duration in seconds.""" + return len(self._pcm) / self._bps + + @property + def remaining(self) -> float: + """Remaining audio in seconds.""" + return max(0.0, (len(self._pcm) - self._pos) / self._bps) + + @property + def done(self) -> bool: + """True if all audio has been played.""" + return self._pos >= len(self._pcm) + + async def play( + self, + duration_s: Optional[float] = None, + speed: float = 1.0, + chunk_duration: float = 0.5, + ) -> None: + """Play audio from the current position. + + Args: + duration_s: Seconds of audio to play. None = all remaining. + speed: 1.0 = real-time, 0 = instant, >1 = faster. + chunk_duration: Size of each chunk fed to the pipeline (seconds). + """ + if duration_s is None: + end_pos = len(self._pcm) + else: + end_pos = min(self._pos + int(duration_s * self._bps), len(self._pcm)) + + # Align to sample boundary + end_pos = (end_pos // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE + + if end_pos <= self._pos: + return + + segment = self._pcm[self._pos:end_pos] + self._pos = end_pos + await self._harness.feed_pcm(segment, speed=speed, chunk_duration=chunk_duration) + + async def play_until( + self, + time_s: float, + speed: float = 1.0, + chunk_duration: float = 0.5, + ) -> None: + """Play until reaching time_s in the audio timeline.""" + target = min(int(time_s * self._bps), len(self._pcm)) + target = (target // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE + + if target <= self._pos: + return + + segment = self._pcm[self._pos:target] + self._pos = target + await self._harness.feed_pcm(segment, speed=speed, chunk_duration=chunk_duration) + + def seek(self, time_s: float) -> None: + """Move the playback cursor without feeding audio.""" + pos = int(time_s * self._bps) + pos = (pos // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE + self._pos = max(0, min(pos, len(self._pcm))) + + def reset(self) -> None: + """Reset to the beginning of the audio.""" + self._pos = 0 + + +# --------------------------------------------------------------------------- +# TestHarness — pipeline controller +# --------------------------------------------------------------------------- + +class TestHarness: + """In-process testing harness for the full WhisperLiveKit pipeline. + + Use as an async context manager. Provides methods to feed audio, + pause/resume, inspect state, and evaluate results. + + Methods: + load_audio(path) → AudioPlayer with play/seek controls + feed(path, speed) → feed entire audio file (simple mode) + pause(duration) → inject silence (triggers detection if > 5s) + drain(seconds) → let pipeline catch up + finish() → flush and return final state + cut() → abrupt stop, return partial state + wait_for(pred) → wait for condition on state + + State inspection: + .state → current TestState + .history → all historical states + .snapshot_at(t) → state at audio position t + .metrics → SessionMetrics (latency, RTF, etc.) + + Args: + All keyword arguments passed to AudioProcessor. + Common: model_size, lan, backend, diarization, vac. + """ + + def __init__(self, **kwargs: Any): + kwargs.setdefault("pcm_input", True) + self._engine_kwargs = kwargs + self._processor = None + self._results_gen = None + self._collect_task = None + self._state = TestState() + self._audio_position = 0.0 + self._history: List[TestState] = [] + self._on_update: Optional[Callable[[TestState], None]] = None + + async def __aenter__(self) -> "TestHarness": + from whisperlivekit.audio_processor import AudioProcessor + from whisperlivekit.core import TranscriptionEngine + + # Cache engines by config to avoid reloading models when switching + # backends between tests. The singleton is reset only when the + # requested config doesn't match any cached engine. + cache_key = tuple(sorted(self._engine_kwargs.items())) + + if cache_key not in _engine_cache: + TranscriptionEngine.reset() + _engine_cache[cache_key] = TranscriptionEngine(**self._engine_kwargs) + + engine = _engine_cache[cache_key] + + self._processor = AudioProcessor(transcription_engine=engine) + self._results_gen = await self._processor.create_tasks() + self._collect_task = asyncio.create_task(self._collect_results()) + return self + + async def __aexit__(self, *exc: Any) -> None: + if self._processor: + await self._processor.cleanup() + if self._collect_task and not self._collect_task.done(): + self._collect_task.cancel() + try: + await self._collect_task + except asyncio.CancelledError: + pass + + async def _collect_results(self) -> None: + """Background task: consume results from the pipeline.""" + try: + async for front_data in self._results_gen: + self._state = TestState.from_front_data(front_data, self._audio_position) + self._history.append(self._state) + if self._on_update: + self._on_update(self._state) + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning("Result collector ended: %s", e) + + # ── Properties ── + + @property + def state(self) -> TestState: + """Current transcription state (updated live as results arrive).""" + return self._state + + @property + def history(self) -> List[TestState]: + """All states received so far, in order.""" + return self._history + + @property + def audio_position(self) -> float: + """How many seconds of audio have been fed so far.""" + return self._audio_position + + @property + def metrics(self): + """Pipeline's SessionMetrics (latency, RTF, token counts, etc.).""" + if self._processor: + return self._processor.metrics + return None + + def on_update(self, callback: Callable[[TestState], None]) -> None: + """Register a callback invoked on each new state update.""" + self._on_update = callback + + # ── Audio loading and feeding ── + + def load_audio(self, source) -> AudioPlayer: + """Load audio and return a player with timeline control. + + Args: + source: Path to audio file (str), or a TestSample with .path attribute. + + Returns: + AudioPlayer with play/play_until/seek/reset methods. + """ + path = source.path if hasattr(source, "path") else str(source) + pcm = load_audio_pcm(path) + return AudioPlayer(self, pcm) + + async def feed( + self, + audio_path: str, + speed: float = 1.0, + chunk_duration: float = 0.5, + ) -> None: + """Feed an entire audio file to the pipeline (simple mode). + + For timeline control (play/pause/resume), use load_audio() instead. + + Args: + audio_path: Path to any audio file ffmpeg can decode. + speed: Playback speed (1.0 = real-time, 0 = instant). + chunk_duration: Size of each PCM chunk in seconds. + """ + pcm = load_audio_pcm(audio_path) + await self.feed_pcm(pcm, speed=speed, chunk_duration=chunk_duration) + + async def feed_pcm( + self, + pcm_data: bytes, + speed: float = 1.0, + chunk_duration: float = 0.5, + ) -> None: + """Feed raw PCM s16le 16kHz mono bytes to the pipeline. + + Args: + pcm_data: Raw PCM bytes. + speed: Playback speed multiplier. + chunk_duration: Duration of each chunk sent (seconds). + """ + chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE) + offset = 0 + while offset < len(pcm_data): + end = min(offset + chunk_bytes, len(pcm_data)) + await self._processor.process_audio(pcm_data[offset:end]) + chunk_seconds = (end - offset) / (SAMPLE_RATE * BYTES_PER_SAMPLE) + self._audio_position += chunk_seconds + offset = end + if speed > 0: + await asyncio.sleep(chunk_duration / speed) + + # ── Pause / silence ── + + async def pause(self, duration_s: float, speed: float = 1.0) -> None: + """Inject silence to simulate a pause in speech. + + Pauses > 5s trigger silence segment detection (MIN_DURATION_REAL_SILENCE). + Pauses < 5s are treated as brief gaps and produce no silence segment + (provided speech resumes afterward). + + Args: + duration_s: Duration of silence in seconds. + speed: Playback speed (1.0 = real-time, 0 = instant). + """ + silent_pcm = bytes(int(duration_s * SAMPLE_RATE * BYTES_PER_SAMPLE)) + await self.feed_pcm(silent_pcm, speed=speed) + + async def silence(self, duration_s: float, speed: float = 1.0) -> None: + """Alias for pause(). Inject silence for the given duration.""" + await self.pause(duration_s, speed=speed) + + # ── Waiting ── + + async def wait_for( + self, + predicate: Callable[[TestState], bool], + timeout: float = 30.0, + poll_interval: float = 0.1, + ) -> TestState: + """Wait until predicate(state) returns True. + + Raises: + TimeoutError: If the condition is not met within timeout. + """ + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + if predicate(self._state): + return self._state + await asyncio.sleep(poll_interval) + raise TimeoutError( + f"Condition not met within {timeout}s. " + f"Current state: {len(self._state.lines)} lines, " + f"buffer='{self._state.buffer_transcription[:50]}', " + f"audio_pos={self._audio_position:.1f}s" + ) + + async def wait_for_text(self, timeout: float = 30.0) -> TestState: + """Wait until any transcription text appears.""" + return await self.wait_for(lambda s: s.text.strip(), timeout=timeout) + + async def wait_for_lines(self, n: int = 1, timeout: float = 30.0) -> TestState: + """Wait until at least n committed speech lines exist.""" + return await self.wait_for(lambda s: len(s.speech_lines) >= n, timeout=timeout) + + async def wait_for_silence(self, timeout: float = 30.0) -> TestState: + """Wait until a silence segment is detected.""" + return await self.wait_for(lambda s: s.has_silence, timeout=timeout) + + async def wait_for_speakers(self, n: int = 2, timeout: float = 30.0) -> TestState: + """Wait until at least n distinct speakers are detected.""" + return await self.wait_for(lambda s: s.n_speakers >= n, timeout=timeout) + + async def drain(self, seconds: float = 2.0) -> None: + """Let the pipeline process without feeding audio. + + Useful after feeding audio to allow the ASR backend to catch up. + """ + await asyncio.sleep(seconds) + + # ── Finishing ── + + async def finish(self, timeout: float = 30.0) -> TestState: + """Signal end of audio and wait for pipeline to flush all results. + + Returns: + Final TestState with all committed lines and empty buffer. + """ + await self._processor.process_audio(b"") + if self._collect_task: + try: + await asyncio.wait_for(self._collect_task, timeout=timeout) + except asyncio.TimeoutError: + logger.warning("Timed out waiting for pipeline to finish after %.0fs", timeout) + except asyncio.CancelledError: + pass + return self._state + + async def cut(self, timeout: float = 5.0) -> TestState: + """Abrupt audio stop — signal EOF and return current state quickly. + + Simulates user closing the connection mid-speech. Sends EOF but + uses a short timeout, so partial results are returned even if + the pipeline hasn't fully flushed. + + Returns: + TestState with whatever has been processed so far. + """ + await self._processor.process_audio(b"") + if self._collect_task: + try: + await asyncio.wait_for(self._collect_task, timeout=timeout) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + return self._state + + # ── History inspection ── + + def snapshot_at(self, audio_time: float) -> Optional[TestState]: + """Find the historical state closest to when audio_time was reached. + + Args: + audio_time: Audio position in seconds. + + Returns: + The TestState captured at that point, or None if no history. + """ + if not self._history: + return None + best = None + best_diff = float("inf") + for s in self._history: + diff = abs(s.audio_position - audio_time) + if diff < best_diff: + best_diff = diff + best = s + return best + + # ── Debug ── + + def print_state(self) -> None: + """Print current state to stdout for debugging.""" + s = self._state + print(f"--- Audio: {self._audio_position:.1f}s | Status: {s.status} ---") + for line in s.lines: + speaker = line.get("speaker", "?") + text = line.get("text", "") + start = line.get("start", "") + end = line.get("end", "") + tag = "SILENCE" if speaker == -2 else f"Speaker {speaker}" + print(f" [{start} -> {end}] {tag}: {text}") + if s.buffer_transcription: + print(f" [buffer] {s.buffer_transcription}") + if s.buffer_diarization: + print(f" [diar buffer] {s.buffer_diarization}") + print(f" Speakers: {s.speakers or 'none'} | Silence: {s.has_silence}") + print()