Add test harness and test client
This commit is contained in:
parent
cd8df8e1aa
commit
586540ae36
3 changed files with 1503 additions and 0 deletions
393
whisperlivekit/test_client.py
Normal file
393
whisperlivekit/test_client.py
Normal file
|
|
@ -0,0 +1,393 @@
|
||||||
|
"""Headless test client for WhisperLiveKit.
|
||||||
|
|
||||||
|
Feeds audio files to the transcription pipeline via WebSocket
|
||||||
|
and collects results — no browser or microphone needed.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Against a running server (server must be started with --pcm-input):
|
||||||
|
python -m whisperlivekit.test_client audio.wav
|
||||||
|
|
||||||
|
# Custom server URL and speed:
|
||||||
|
python -m whisperlivekit.test_client audio.wav --url ws://localhost:9090/asr --speed 0
|
||||||
|
|
||||||
|
# Output raw JSON responses:
|
||||||
|
python -m whisperlivekit.test_client audio.wav --json
|
||||||
|
|
||||||
|
# Programmatic usage:
|
||||||
|
from whisperlivekit.test_client import transcribe_audio
|
||||||
|
result = asyncio.run(transcribe_audio("audio.wav"))
|
||||||
|
print(result.text)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
BYTES_PER_SAMPLE = 2 # s16le
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptionResult:
|
||||||
|
"""Collected transcription results from a session."""
|
||||||
|
|
||||||
|
responses: List[dict] = field(default_factory=list)
|
||||||
|
audio_duration: float = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
"""Full transcription text from the last response (committed lines + buffer)."""
|
||||||
|
if not self.responses:
|
||||||
|
return ""
|
||||||
|
for resp in reversed(self.responses):
|
||||||
|
lines = resp.get("lines", [])
|
||||||
|
buffer = resp.get("buffer_transcription", "")
|
||||||
|
if lines or buffer:
|
||||||
|
parts = [line["text"] for line in lines if line.get("text")]
|
||||||
|
if buffer:
|
||||||
|
parts.append(buffer)
|
||||||
|
return " ".join(parts)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def committed_text(self) -> str:
|
||||||
|
"""Only the committed (finalized) transcription lines, no buffer."""
|
||||||
|
if not self.responses:
|
||||||
|
return ""
|
||||||
|
for resp in reversed(self.responses):
|
||||||
|
lines = resp.get("lines", [])
|
||||||
|
if lines:
|
||||||
|
return " ".join(line["text"] for line in lines if line.get("text"))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lines(self) -> List[dict]:
|
||||||
|
"""Committed lines from the last response."""
|
||||||
|
for resp in reversed(self.responses):
|
||||||
|
if resp.get("lines"):
|
||||||
|
return resp["lines"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_updates(self) -> int:
|
||||||
|
"""Number of non-empty updates received."""
|
||||||
|
return sum(
|
||||||
|
1 for r in self.responses
|
||||||
|
if r.get("lines") or r.get("buffer_transcription")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_state(msg: dict, lines: List[dict]) -> dict:
|
||||||
|
"""Reconstruct full state from a diff or snapshot message.
|
||||||
|
|
||||||
|
Mutates ``lines`` in-place (prune front, append new) and returns
|
||||||
|
a full-state dict compatible with TranscriptionResult.
|
||||||
|
"""
|
||||||
|
if msg.get("type") == "snapshot":
|
||||||
|
lines.clear()
|
||||||
|
lines.extend(msg.get("lines", []))
|
||||||
|
return msg
|
||||||
|
|
||||||
|
# Apply diff
|
||||||
|
n_pruned = msg.get("lines_pruned", 0)
|
||||||
|
if n_pruned > 0:
|
||||||
|
del lines[:n_pruned]
|
||||||
|
new_lines = msg.get("new_lines", [])
|
||||||
|
lines.extend(new_lines)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": msg.get("status", ""),
|
||||||
|
"lines": lines[:], # snapshot copy
|
||||||
|
"buffer_transcription": msg.get("buffer_transcription", ""),
|
||||||
|
"buffer_diarization": msg.get("buffer_diarization", ""),
|
||||||
|
"buffer_translation": msg.get("buffer_translation", ""),
|
||||||
|
"remaining_time_transcription": msg.get("remaining_time_transcription", 0),
|
||||||
|
"remaining_time_diarization": msg.get("remaining_time_diarization", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes:
|
||||||
|
"""Load an audio file and convert to PCM s16le mono via ffmpeg.
|
||||||
|
|
||||||
|
Supports any format ffmpeg can decode (wav, mp3, flac, ogg, m4a, ...).
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-i", str(audio_path),
|
||||||
|
"-f", "s16le", "-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sample_rate), "-ac", "1",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"pipe:1",
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd, capture_output=True)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}")
|
||||||
|
if not proc.stdout:
|
||||||
|
raise RuntimeError(f"ffmpeg produced no output for {audio_path}")
|
||||||
|
return proc.stdout
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_audio(
|
||||||
|
audio_path: str,
|
||||||
|
url: str = "ws://localhost:8000/asr",
|
||||||
|
chunk_duration: float = 0.5,
|
||||||
|
speed: float = 1.0,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
on_response: Optional[callable] = None,
|
||||||
|
mode: str = "full",
|
||||||
|
) -> TranscriptionResult:
|
||||||
|
"""Feed an audio file to a running WhisperLiveKit server and collect results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to an audio file (any format ffmpeg supports).
|
||||||
|
url: WebSocket URL of the /asr endpoint.
|
||||||
|
chunk_duration: Duration of each audio chunk sent (seconds).
|
||||||
|
speed: Playback speed multiplier (1.0 = real-time, 0 = as fast as possible).
|
||||||
|
timeout: Max seconds to wait for the server after audio finishes.
|
||||||
|
on_response: Optional callback invoked with each response dict as it arrives.
|
||||||
|
mode: Output mode — "full" (default) or "diff" for incremental updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TranscriptionResult with collected responses and convenience accessors.
|
||||||
|
"""
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
result = TranscriptionResult()
|
||||||
|
|
||||||
|
# Convert audio to PCM for both modes (we need duration either way)
|
||||||
|
pcm_data = load_audio_pcm(audio_path)
|
||||||
|
result.audio_duration = len(pcm_data) / (SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||||
|
logger.info("Loaded %s: %.1fs of audio", audio_path, result.audio_duration)
|
||||||
|
|
||||||
|
chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||||
|
|
||||||
|
# Append mode query parameter if using diff mode
|
||||||
|
connect_url = url
|
||||||
|
if mode == "diff":
|
||||||
|
sep = "&" if "?" in url else "?"
|
||||||
|
connect_url = f"{url}{sep}mode=diff"
|
||||||
|
|
||||||
|
async with websockets.connect(connect_url) as ws:
|
||||||
|
# Server sends config on connect
|
||||||
|
config_raw = await ws.recv()
|
||||||
|
config_msg = json.loads(config_raw)
|
||||||
|
is_pcm = config_msg.get("useAudioWorklet", False)
|
||||||
|
logger.info("Server config: %s", config_msg)
|
||||||
|
|
||||||
|
if not is_pcm:
|
||||||
|
logger.warning(
|
||||||
|
"Server is not in PCM mode. Start the server with --pcm-input "
|
||||||
|
"for the test client. Attempting raw file streaming instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
done_event = asyncio.Event()
|
||||||
|
diff_lines: List[dict] = [] # running state for diff mode reconstruction
|
||||||
|
|
||||||
|
async def send_audio():
|
||||||
|
if is_pcm:
|
||||||
|
offset = 0
|
||||||
|
n_chunks = 0
|
||||||
|
while offset < len(pcm_data):
|
||||||
|
end = min(offset + chunk_bytes, len(pcm_data))
|
||||||
|
await ws.send(pcm_data[offset:end])
|
||||||
|
offset = end
|
||||||
|
n_chunks += 1
|
||||||
|
if speed > 0:
|
||||||
|
await asyncio.sleep(chunk_duration / speed)
|
||||||
|
logger.info("Sent %d PCM chunks (%.1fs)", n_chunks, result.audio_duration)
|
||||||
|
else:
|
||||||
|
# Non-PCM: send raw file bytes for server-side ffmpeg decoding
|
||||||
|
file_bytes = Path(audio_path).read_bytes()
|
||||||
|
raw_chunk_size = 32000
|
||||||
|
offset = 0
|
||||||
|
while offset < len(file_bytes):
|
||||||
|
end = min(offset + raw_chunk_size, len(file_bytes))
|
||||||
|
await ws.send(file_bytes[offset:end])
|
||||||
|
offset = end
|
||||||
|
if speed > 0:
|
||||||
|
await asyncio.sleep(0.5 / speed)
|
||||||
|
logger.info("Sent %d bytes of raw audio", len(file_bytes))
|
||||||
|
|
||||||
|
# Signal end of audio
|
||||||
|
await ws.send(b"")
|
||||||
|
logger.info("End-of-audio signal sent")
|
||||||
|
|
||||||
|
async def receive_results():
|
||||||
|
try:
|
||||||
|
async for raw_msg in ws:
|
||||||
|
data = json.loads(raw_msg)
|
||||||
|
if data.get("type") == "ready_to_stop":
|
||||||
|
logger.info("Server signaled ready_to_stop")
|
||||||
|
done_event.set()
|
||||||
|
return
|
||||||
|
# In diff mode, reconstruct full state for uniform API
|
||||||
|
if mode == "diff" and data.get("type") in ("snapshot", "diff"):
|
||||||
|
data = reconstruct_state(data, diff_lines)
|
||||||
|
result.responses.append(data)
|
||||||
|
if on_response:
|
||||||
|
on_response(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Receiver ended: %s", e)
|
||||||
|
done_event.set()
|
||||||
|
|
||||||
|
send_task = asyncio.create_task(send_audio())
|
||||||
|
recv_task = asyncio.create_task(receive_results())
|
||||||
|
|
||||||
|
# Total wait = time to send + time for server to process + timeout margin
|
||||||
|
send_time = result.audio_duration / speed if speed > 0 else 1.0
|
||||||
|
total_timeout = send_time + timeout
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(send_task, recv_task),
|
||||||
|
timeout=total_timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Timed out after %.0fs", total_timeout)
|
||||||
|
send_task.cancel()
|
||||||
|
recv_task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.gather(send_task, recv_task, return_exceptions=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Session complete: %d responses, %d updates",
|
||||||
|
len(result.responses), result.n_updates,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _print_result(result: TranscriptionResult, output_json: bool = False) -> None:
|
||||||
|
"""Print transcription results to stdout."""
|
||||||
|
if output_json:
|
||||||
|
for resp in result.responses:
|
||||||
|
print(json.dumps(resp))
|
||||||
|
return
|
||||||
|
|
||||||
|
if result.lines:
|
||||||
|
for line in result.lines:
|
||||||
|
speaker = line.get("speaker", "")
|
||||||
|
text = line.get("text", "")
|
||||||
|
start = line.get("start", "")
|
||||||
|
end = line.get("end", "")
|
||||||
|
prefix = f"[{start} -> {end}]"
|
||||||
|
if speaker and speaker != 1:
|
||||||
|
prefix += f" Speaker {speaker}"
|
||||||
|
print(f"{prefix} {text}")
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
if result.responses:
|
||||||
|
buffer = result.responses[-1].get("buffer_transcription", "")
|
||||||
|
if buffer:
|
||||||
|
print(f"[buffer] {buffer}")
|
||||||
|
|
||||||
|
if not result.lines and not buffer:
|
||||||
|
print("(no transcription received)")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n--- {len(result.responses)} responses | "
|
||||||
|
f"{result.n_updates} updates | "
|
||||||
|
f"{result.audio_duration:.1f}s audio ---"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="whisperlivekit-test-client",
|
||||||
|
description=(
|
||||||
|
"Headless test client for WhisperLiveKit. "
|
||||||
|
"Feeds audio files via WebSocket and prints the transcription."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("audio", help="Path to audio file (wav, mp3, flac, ...)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--url", default="ws://localhost:8000/asr",
|
||||||
|
help="WebSocket endpoint URL (default: ws://localhost:8000/asr)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speed", type=float, default=1.0,
|
||||||
|
help="Playback speed multiplier (1.0 = real-time, 0 = fastest, default: 1.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-duration", type=float, default=0.5,
|
||||||
|
help="Chunk duration in seconds (default: 0.5)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout", type=float, default=60.0,
|
||||||
|
help="Max seconds to wait for server after audio ends (default: 60)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language", "-l", default=None,
|
||||||
|
help="Override transcription language for this session (e.g. en, fr, auto)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--json", action="store_true", help="Output raw JSON responses")
|
||||||
|
parser.add_argument(
|
||||||
|
"--diff", action="store_true",
|
||||||
|
help="Use diff protocol (only receive incremental changes from server)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--live", action="store_true",
|
||||||
|
help="Print transcription updates as they arrive",
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if args.verbose else logging.WARNING,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_path = Path(args.audio)
|
||||||
|
if not audio_path.exists():
|
||||||
|
print(f"Error: file not found: {audio_path}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
live_callback = None
|
||||||
|
if args.live:
|
||||||
|
def live_callback(data):
|
||||||
|
lines = data.get("lines", [])
|
||||||
|
buf = data.get("buffer_transcription", "")
|
||||||
|
parts = [l["text"] for l in lines if l.get("text")]
|
||||||
|
if buf:
|
||||||
|
parts.append(f"[{buf}]")
|
||||||
|
if parts:
|
||||||
|
print("\r" + " ".join(parts), end="", flush=True)
|
||||||
|
|
||||||
|
# Build URL with query parameters for language and mode
|
||||||
|
url = args.url
|
||||||
|
params = []
|
||||||
|
if args.language:
|
||||||
|
params.append(f"language={args.language}")
|
||||||
|
if args.diff:
|
||||||
|
params.append("mode=diff")
|
||||||
|
if params:
|
||||||
|
sep = "&" if "?" in url else "?"
|
||||||
|
url = f"{url}{sep}{'&'.join(params)}"
|
||||||
|
|
||||||
|
result = asyncio.run(transcribe_audio(
|
||||||
|
audio_path=str(audio_path),
|
||||||
|
url=url,
|
||||||
|
chunk_duration=args.chunk_duration,
|
||||||
|
speed=args.speed,
|
||||||
|
timeout=args.timeout,
|
||||||
|
on_response=live_callback,
|
||||||
|
mode="diff" if args.diff else "full",
|
||||||
|
))
|
||||||
|
|
||||||
|
if args.live:
|
||||||
|
print() # newline after live output
|
||||||
|
|
||||||
|
_print_result(result, output_json=args.json)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
365
whisperlivekit/test_data.py
Normal file
365
whisperlivekit/test_data.py
Normal file
|
|
@ -0,0 +1,365 @@
|
||||||
|
"""Standard test audio samples for evaluating the WhisperLiveKit pipeline.
|
||||||
|
|
||||||
|
Downloads curated samples from public ASR datasets (LibriSpeech, AMI)
|
||||||
|
and caches them locally. Each sample includes the audio file path,
|
||||||
|
ground truth transcript, speaker info, and timing metadata.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from whisperlivekit.test_data import get_samples, get_sample
|
||||||
|
|
||||||
|
# Download all standard test samples (first call downloads, then cached)
|
||||||
|
samples = get_samples()
|
||||||
|
|
||||||
|
for s in samples:
|
||||||
|
print(f"{s.name}: {s.duration:.1f}s, {s.n_speakers} speaker(s)")
|
||||||
|
print(f" Reference: {s.reference[:60]}...")
|
||||||
|
|
||||||
|
# Use with TestHarness
|
||||||
|
from whisperlivekit.test_harness import TestHarness
|
||||||
|
|
||||||
|
async with TestHarness(model_size="base", lan="en") as h:
|
||||||
|
sample = get_sample("librispeech_short")
|
||||||
|
await h.feed(sample.path, speed=0)
|
||||||
|
result = await h.finish()
|
||||||
|
print(f"WER: {result.wer(sample.reference):.2%}")
|
||||||
|
|
||||||
|
Requires: pip install whisperlivekit[test] (installs 'datasets' and 'librosa')
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import wave
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CACHE_DIR = Path.home() / ".cache" / "whisperlivekit" / "test_data"
|
||||||
|
METADATA_FILE = "metadata.json"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestSample:
|
||||||
|
"""A test audio sample with ground truth metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
path: str # absolute path to WAV file
|
||||||
|
reference: str # ground truth transcript
|
||||||
|
duration: float # audio duration in seconds
|
||||||
|
sample_rate: int = 16000
|
||||||
|
n_speakers: int = 1
|
||||||
|
language: str = "en"
|
||||||
|
source: str = "" # dataset name
|
||||||
|
# Per-utterance ground truth for multi-speaker: [(start, end, speaker, text), ...]
|
||||||
|
utterances: List[Dict] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_timestamps(self) -> bool:
|
||||||
|
return len(self.utterances) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
|
||||||
|
"""Save numpy audio array as 16-bit PCM WAV."""
|
||||||
|
# Ensure mono
|
||||||
|
if audio.ndim > 1:
|
||||||
|
audio = audio.mean(axis=-1)
|
||||||
|
# Normalize to int16 range
|
||||||
|
if audio.dtype in (np.float32, np.float64):
|
||||||
|
audio = np.clip(audio, -1.0, 1.0)
|
||||||
|
audio = (audio * 32767).astype(np.int16)
|
||||||
|
elif audio.dtype != np.int16:
|
||||||
|
audio = audio.astype(np.int16)
|
||||||
|
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with wave.open(str(path), "w") as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
wf.writeframes(audio.tobytes())
|
||||||
|
|
||||||
|
|
||||||
|
def _load_metadata() -> Dict:
|
||||||
|
"""Load cached metadata if it exists."""
|
||||||
|
meta_path = CACHE_DIR / METADATA_FILE
|
||||||
|
if meta_path.exists():
|
||||||
|
return json.loads(meta_path.read_text())
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _save_metadata(meta: Dict) -> None:
|
||||||
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
(CACHE_DIR / METADATA_FILE).write_text(json.dumps(meta, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_datasets():
|
||||||
|
"""Check that the datasets library is available."""
|
||||||
|
try:
|
||||||
|
import datasets # noqa: F401
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'datasets' package is required for test data download. "
|
||||||
|
"Install it with: pip install whisperlivekit[test]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_audio(audio_bytes: bytes) -> tuple:
|
||||||
|
"""Decode audio bytes using soundfile (avoids torchcodec dependency).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(audio_array, sample_rate) — float32 numpy array and int sample rate.
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
|
||||||
|
return np.array(audio_array, dtype=np.float32), sr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset-specific download functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _download_librispeech_samples(n_samples: int = 3) -> List[Dict]:
|
||||||
|
"""Download short samples from LibriSpeech test-clean."""
|
||||||
|
_ensure_datasets()
|
||||||
|
import datasets.config
|
||||||
|
datasets.config.TORCHCODEC_AVAILABLE = False
|
||||||
|
from datasets import Audio, load_dataset
|
||||||
|
|
||||||
|
logger.info("Downloading LibriSpeech test-clean samples (streaming)...")
|
||||||
|
ds = load_dataset(
|
||||||
|
"openslr/librispeech_asr",
|
||||||
|
"clean",
|
||||||
|
split="test",
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
ds = ds.cast_column("audio", Audio(decode=False))
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for i, item in enumerate(ds):
|
||||||
|
if i >= n_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
||||||
|
duration = len(audio_array) / sr
|
||||||
|
text = item["text"]
|
||||||
|
sample_id = item.get("id", f"librispeech_{i}")
|
||||||
|
|
||||||
|
# Save WAV
|
||||||
|
wav_name = f"librispeech_{i}.wav"
|
||||||
|
wav_path = CACHE_DIR / wav_name
|
||||||
|
_save_wav(wav_path, audio_array, sr)
|
||||||
|
|
||||||
|
# Name: first sample is "librispeech_short", rest are numbered
|
||||||
|
name = "librispeech_short" if i == 0 else f"librispeech_{i}"
|
||||||
|
|
||||||
|
samples.append({
|
||||||
|
"name": name,
|
||||||
|
"file": wav_name,
|
||||||
|
"reference": text,
|
||||||
|
"duration": round(duration, 2),
|
||||||
|
"sample_rate": sr,
|
||||||
|
"n_speakers": 1,
|
||||||
|
"language": "en",
|
||||||
|
"source": "openslr/librispeech_asr (test-clean)",
|
||||||
|
"source_id": str(sample_id),
|
||||||
|
"utterances": [],
|
||||||
|
})
|
||||||
|
logger.info(
|
||||||
|
" [%d] %.1fs - %s",
|
||||||
|
i, duration, text[:60] + ("..." if len(text) > 60 else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def _download_ami_sample() -> List[Dict]:
|
||||||
|
"""Download one AMI meeting segment with multiple speakers."""
|
||||||
|
_ensure_datasets()
|
||||||
|
import datasets.config
|
||||||
|
datasets.config.TORCHCODEC_AVAILABLE = False
|
||||||
|
from datasets import Audio, load_dataset
|
||||||
|
|
||||||
|
logger.info("Downloading AMI meeting test sample (streaming)...")
|
||||||
|
|
||||||
|
# Use the edinburghcstr/ami version which has pre-segmented utterances
|
||||||
|
# with speaker_id, begin_time, end_time, text
|
||||||
|
ds = load_dataset(
|
||||||
|
"edinburghcstr/ami",
|
||||||
|
"ihm",
|
||||||
|
split="test",
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
ds = ds.cast_column("audio", Audio(decode=False))
|
||||||
|
|
||||||
|
# Collect utterances from one meeting
|
||||||
|
meeting_utterances = []
|
||||||
|
meeting_id = None
|
||||||
|
audio_arrays = []
|
||||||
|
sample_rate = None
|
||||||
|
|
||||||
|
for item in ds:
|
||||||
|
mid = item.get("meeting_id", "unknown")
|
||||||
|
|
||||||
|
# Take the first meeting only
|
||||||
|
if meeting_id is None:
|
||||||
|
meeting_id = mid
|
||||||
|
elif mid != meeting_id:
|
||||||
|
# We've moved to a different meeting, stop
|
||||||
|
break
|
||||||
|
|
||||||
|
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
||||||
|
sample_rate = sr
|
||||||
|
|
||||||
|
meeting_utterances.append({
|
||||||
|
"start": round(item.get("begin_time", 0.0), 2),
|
||||||
|
"end": round(item.get("end_time", 0.0), 2),
|
||||||
|
"speaker": item.get("speaker_id", "unknown"),
|
||||||
|
"text": item.get("text", ""),
|
||||||
|
})
|
||||||
|
audio_arrays.append(audio_array)
|
||||||
|
|
||||||
|
# Limit to reasonable size (~60s of utterances)
|
||||||
|
total_dur = sum(u["end"] - u["start"] for u in meeting_utterances)
|
||||||
|
if total_dur > 60:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not audio_arrays:
|
||||||
|
logger.warning("No AMI samples found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Concatenate all utterance audio
|
||||||
|
full_audio = np.concatenate(audio_arrays)
|
||||||
|
duration = len(full_audio) / sample_rate
|
||||||
|
|
||||||
|
# Build reference text
|
||||||
|
speakers = set(u["speaker"] for u in meeting_utterances)
|
||||||
|
reference = " ".join(u["text"] for u in meeting_utterances if u["text"])
|
||||||
|
|
||||||
|
wav_name = "ami_meeting.wav"
|
||||||
|
wav_path = CACHE_DIR / wav_name
|
||||||
|
_save_wav(wav_path, full_audio, sample_rate)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
" AMI meeting %s: %.1fs, %d speakers, %d utterances",
|
||||||
|
meeting_id, duration, len(speakers), len(meeting_utterances),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [{
|
||||||
|
"name": "ami_meeting",
|
||||||
|
"file": wav_name,
|
||||||
|
"reference": reference,
|
||||||
|
"duration": round(duration, 2),
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"n_speakers": len(speakers),
|
||||||
|
"language": "en",
|
||||||
|
"source": f"edinburghcstr/ami (ihm, meeting {meeting_id})",
|
||||||
|
"source_id": meeting_id,
|
||||||
|
"utterances": meeting_utterances,
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def download_test_samples(force: bool = False) -> List[TestSample]:
|
||||||
|
"""Download standard test audio samples.
|
||||||
|
|
||||||
|
Downloads samples from LibriSpeech (clean single-speaker) and
|
||||||
|
AMI (multi-speaker meetings) on first call. Subsequent calls
|
||||||
|
return cached data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: Re-download even if cached.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TestSample objects ready for use with TestHarness.
|
||||||
|
"""
|
||||||
|
meta = _load_metadata()
|
||||||
|
|
||||||
|
if meta.get("samples") and not force:
|
||||||
|
# Check all files still exist
|
||||||
|
all_exist = all(
|
||||||
|
(CACHE_DIR / s["file"]).exists()
|
||||||
|
for s in meta["samples"]
|
||||||
|
)
|
||||||
|
if all_exist:
|
||||||
|
return _meta_to_samples(meta["samples"])
|
||||||
|
|
||||||
|
logger.info("Downloading test samples to %s ...", CACHE_DIR)
|
||||||
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
all_samples = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_samples.extend(_download_librispeech_samples(n_samples=3))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to download LibriSpeech samples: %s", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_samples.extend(_download_ami_sample())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to download AMI sample: %s", e)
|
||||||
|
|
||||||
|
if not all_samples:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to download any test samples. "
|
||||||
|
"Check your internet connection and ensure 'datasets' is installed: "
|
||||||
|
"pip install whisperlivekit[test]"
|
||||||
|
)
|
||||||
|
|
||||||
|
_save_metadata({"samples": all_samples})
|
||||||
|
logger.info("Downloaded %d test samples to %s", len(all_samples), CACHE_DIR)
|
||||||
|
|
||||||
|
return _meta_to_samples(all_samples)
|
||||||
|
|
||||||
|
|
||||||
|
def get_samples() -> List[TestSample]:
|
||||||
|
"""Get standard test samples (downloads on first call)."""
|
||||||
|
return download_test_samples()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample(name: str) -> TestSample:
|
||||||
|
"""Get a specific test sample by name.
|
||||||
|
|
||||||
|
Available names: 'librispeech_short', 'librispeech_1', 'librispeech_2',
|
||||||
|
'ami_meeting'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the sample name is not found.
|
||||||
|
"""
|
||||||
|
samples = get_samples()
|
||||||
|
for s in samples:
|
||||||
|
if s.name == name:
|
||||||
|
return s
|
||||||
|
available = [s.name for s in samples]
|
||||||
|
raise KeyError(f"Sample '{name}' not found. Available: {available}")
|
||||||
|
|
||||||
|
|
||||||
|
def list_sample_names() -> List[str]:
|
||||||
|
"""List names of available test samples (downloads if needed)."""
|
||||||
|
return [s.name for s in get_samples()]
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_to_samples(meta_list: List[Dict]) -> List[TestSample]:
|
||||||
|
"""Convert metadata dicts to TestSample objects."""
|
||||||
|
samples = []
|
||||||
|
for m in meta_list:
|
||||||
|
samples.append(TestSample(
|
||||||
|
name=m["name"],
|
||||||
|
path=str(CACHE_DIR / m["file"]),
|
||||||
|
reference=m["reference"],
|
||||||
|
duration=m["duration"],
|
||||||
|
sample_rate=m.get("sample_rate", 16000),
|
||||||
|
n_speakers=m.get("n_speakers", 1),
|
||||||
|
language=m.get("language", "en"),
|
||||||
|
source=m.get("source", ""),
|
||||||
|
utterances=m.get("utterances", []),
|
||||||
|
))
|
||||||
|
return samples
|
||||||
745
whisperlivekit/test_harness.py
Normal file
745
whisperlivekit/test_harness.py
Normal file
|
|
@ -0,0 +1,745 @@
|
||||||
|
"""In-process testing harness for the full WhisperLiveKit pipeline.
|
||||||
|
|
||||||
|
Wraps AudioProcessor to provide a controllable, observable interface
|
||||||
|
for testing transcription, diarization, silence detection, and timing
|
||||||
|
without needing a running server or WebSocket connection.
|
||||||
|
|
||||||
|
Designed for use by AI agents: feed audio with timeline control,
|
||||||
|
inspect state at any point, pause/resume to test silence detection,
|
||||||
|
cut to test abrupt termination.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from whisperlivekit.test_harness import TestHarness
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with TestHarness(model_size="base", lan="en") as h:
|
||||||
|
# Load audio with timeline control
|
||||||
|
player = h.load_audio("interview.wav")
|
||||||
|
|
||||||
|
# Play first 5 seconds at real-time speed
|
||||||
|
await player.play(5.0, speed=1.0)
|
||||||
|
print(h.state.text) # Check what's transcribed so far
|
||||||
|
|
||||||
|
# Pause for 7 seconds (triggers silence detection)
|
||||||
|
await h.pause(7.0, speed=1.0)
|
||||||
|
assert h.state.has_silence
|
||||||
|
|
||||||
|
# Resume playback
|
||||||
|
await player.play(5.0, speed=1.0)
|
||||||
|
|
||||||
|
# Finish and evaluate
|
||||||
|
result = await h.finish()
|
||||||
|
print(f"WER: {result.wer('expected transcription'):.2%}")
|
||||||
|
print(f"Speakers: {result.speakers}")
|
||||||
|
print(f"Silence segments: {len(result.silence_segments)}")
|
||||||
|
|
||||||
|
# Inspect historical state at specific audio position
|
||||||
|
snap = h.snapshot_at(3.0)
|
||||||
|
print(f"At 3s: '{snap.text}'")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import FrontData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Engine cache: avoids reloading models when switching backends in tests.
|
||||||
|
# Key is a frozen config tuple, value is the TranscriptionEngine instance.
|
||||||
|
_engine_cache: Dict[Tuple, "Any"] = {}
|
||||||
|
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
BYTES_PER_SAMPLE = 2 # s16le
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_time(time_str: str) -> float:
|
||||||
|
"""Parse 'H:MM:SS.cc' timestamp string to seconds."""
|
||||||
|
parts = time_str.split(":")
|
||||||
|
if len(parts) == 3:
|
||||||
|
return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
|
||||||
|
if len(parts) == 2:
|
||||||
|
return int(parts[0]) * 60 + float(parts[1])
|
||||||
|
return float(parts[0])
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_pcm(audio_path: str, sample_rate: int = SAMPLE_RATE) -> bytes:
|
||||||
|
"""Load any audio file and convert to PCM s16le mono via ffmpeg."""
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-i", str(audio_path),
|
||||||
|
"-f", "s16le", "-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sample_rate), "-ac", "1",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"pipe:1",
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd, capture_output=True)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg conversion failed: {proc.stderr.decode().strip()}")
|
||||||
|
if not proc.stdout:
|
||||||
|
raise RuntimeError(f"ffmpeg produced no output for {audio_path}")
|
||||||
|
return proc.stdout
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestState — observable transcription state
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestState:
|
||||||
|
"""Observable transcription state at a point in time.
|
||||||
|
|
||||||
|
Provides accessors for inspecting lines, buffers, speakers, timestamps,
|
||||||
|
silence segments, and computing evaluation metrics like WER.
|
||||||
|
|
||||||
|
All time-based queries accept seconds as floats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lines: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
buffer_transcription: str = ""
|
||||||
|
buffer_diarization: str = ""
|
||||||
|
buffer_translation: str = ""
|
||||||
|
remaining_time_transcription: float = 0.0
|
||||||
|
remaining_time_diarization: float = 0.0
|
||||||
|
audio_position: float = 0.0
|
||||||
|
status: str = ""
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_front_data(cls, front_data: FrontData, audio_position: float = 0.0) -> "TestState":
|
||||||
|
d = front_data.to_dict()
|
||||||
|
return cls(
|
||||||
|
lines=d.get("lines", []),
|
||||||
|
buffer_transcription=d.get("buffer_transcription", ""),
|
||||||
|
buffer_diarization=d.get("buffer_diarization", ""),
|
||||||
|
buffer_translation=d.get("buffer_translation", ""),
|
||||||
|
remaining_time_transcription=d.get("remaining_time_transcription", 0),
|
||||||
|
remaining_time_diarization=d.get("remaining_time_diarization", 0),
|
||||||
|
audio_position=audio_position,
|
||||||
|
status=d.get("status", ""),
|
||||||
|
error=d.get("error", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Text accessors ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
"""Full transcription: committed lines + buffer."""
|
||||||
|
parts = [l["text"] for l in self.lines if l.get("text")]
|
||||||
|
if self.buffer_transcription:
|
||||||
|
parts.append(self.buffer_transcription)
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def committed_text(self) -> str:
|
||||||
|
"""Only committed (finalized) lines, no buffer."""
|
||||||
|
return " ".join(l["text"] for l in self.lines if l.get("text"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def committed_word_count(self) -> int:
|
||||||
|
"""Number of words in committed lines."""
|
||||||
|
t = self.committed_text
|
||||||
|
return len(t.split()) if t.strip() else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer_word_count(self) -> int:
|
||||||
|
"""Number of words in the unconfirmed buffer."""
|
||||||
|
return len(self.buffer_transcription.split()) if self.buffer_transcription.strip() else 0
|
||||||
|
|
||||||
|
# ── Speaker accessors ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speakers(self) -> Set[int]:
|
||||||
|
"""Set of speaker IDs (excluding silence marker -2)."""
|
||||||
|
return {l["speaker"] for l in self.lines if l.get("speaker", 0) > 0}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_speakers(self) -> int:
|
||||||
|
return len(self.speakers)
|
||||||
|
|
||||||
|
def speaker_at(self, time_s: float) -> Optional[int]:
|
||||||
|
"""Speaker ID at the given timestamp, or None if no segment covers it."""
|
||||||
|
line = self.line_at(time_s)
|
||||||
|
return line["speaker"] if line else None
|
||||||
|
|
||||||
|
def speakers_in(self, start_s: float, end_s: float) -> Set[int]:
|
||||||
|
"""All speaker IDs active in the time range (excluding silence -2)."""
|
||||||
|
return {
|
||||||
|
l.get("speaker")
|
||||||
|
for l in self.lines_between(start_s, end_s)
|
||||||
|
if l.get("speaker", 0) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speaker_timeline(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Timeline: [{"start": float, "end": float, "speaker": int}] for all lines."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"start": _parse_time(l.get("start", "0:00:00")),
|
||||||
|
"end": _parse_time(l.get("end", "0:00:00")),
|
||||||
|
"speaker": l.get("speaker", -1),
|
||||||
|
}
|
||||||
|
for l in self.lines
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_speaker_changes(self) -> int:
|
||||||
|
"""Number of speaker transitions (excluding silence segments)."""
|
||||||
|
speech = [s for s in self.speaker_timeline if s["speaker"] != -2]
|
||||||
|
return sum(
|
||||||
|
1 for i in range(1, len(speech))
|
||||||
|
if speech[i]["speaker"] != speech[i - 1]["speaker"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Silence accessors ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_silence(self) -> bool:
|
||||||
|
"""Whether any silence segment (speaker=-2) exists."""
|
||||||
|
return any(l.get("speaker") == -2 for l in self.lines)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def silence_segments(self) -> List[Dict[str, Any]]:
|
||||||
|
"""All silence segments (raw line dicts)."""
|
||||||
|
return [l for l in self.lines if l.get("speaker") == -2]
|
||||||
|
|
||||||
|
def silence_at(self, time_s: float) -> bool:
|
||||||
|
"""True if time_s falls within a silence segment."""
|
||||||
|
line = self.line_at(time_s)
|
||||||
|
return line is not None and line.get("speaker") == -2
|
||||||
|
|
||||||
|
# ── Line / segment accessors ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speech_lines(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Lines excluding silence segments."""
|
||||||
|
return [l for l in self.lines if l.get("speaker", 0) != -2 and l.get("text")]
|
||||||
|
|
||||||
|
def line_at(self, time_s: float) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Find the line covering the given timestamp (seconds)."""
|
||||||
|
for line in self.lines:
|
||||||
|
start = _parse_time(line.get("start", "0:00:00"))
|
||||||
|
end = _parse_time(line.get("end", "0:00:00"))
|
||||||
|
if start <= time_s <= end:
|
||||||
|
return line
|
||||||
|
return None
|
||||||
|
|
||||||
|
def text_at(self, time_s: float) -> Optional[str]:
|
||||||
|
"""Text of the segment covering the given timestamp."""
|
||||||
|
line = self.line_at(time_s)
|
||||||
|
return line["text"] if line else None
|
||||||
|
|
||||||
|
def lines_between(self, start_s: float, end_s: float) -> List[Dict[str, Any]]:
|
||||||
|
"""All lines overlapping the time range [start_s, end_s]."""
|
||||||
|
result = []
|
||||||
|
for line in self.lines:
|
||||||
|
ls = _parse_time(line.get("start", "0:00:00"))
|
||||||
|
le = _parse_time(line.get("end", "0:00:00"))
|
||||||
|
if le >= start_s and ls <= end_s:
|
||||||
|
result.append(line)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def text_between(self, start_s: float, end_s: float) -> str:
|
||||||
|
"""Concatenated text of all lines overlapping the time range."""
|
||||||
|
return " ".join(
|
||||||
|
l["text"] for l in self.lines_between(start_s, end_s)
|
||||||
|
if l.get("text")
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Evaluation ──
|
||||||
|
|
||||||
|
def wer(self, reference: str) -> float:
|
||||||
|
"""Word Error Rate of committed text against reference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WER as a float (0.0 = perfect, 1.0 = 100% error rate).
|
||||||
|
"""
|
||||||
|
from whisperlivekit.metrics import compute_wer
|
||||||
|
result = compute_wer(reference, self.committed_text)
|
||||||
|
return result["wer"]
|
||||||
|
|
||||||
|
def wer_detailed(self, reference: str) -> Dict:
|
||||||
|
"""Full WER breakdown: substitutions, insertions, deletions, etc."""
|
||||||
|
from whisperlivekit.metrics import compute_wer
|
||||||
|
return compute_wer(reference, self.committed_text)
|
||||||
|
|
||||||
|
# ── Timing validation ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timestamps(self) -> List[Dict[str, Any]]:
|
||||||
|
"""All line timestamps as [{"start": float, "end": float, "speaker": int, "text": str}]."""
|
||||||
|
result = []
|
||||||
|
for line in self.lines:
|
||||||
|
result.append({
|
||||||
|
"start": _parse_time(line.get("start", "0:00:00")),
|
||||||
|
"end": _parse_time(line.get("end", "0:00:00")),
|
||||||
|
"speaker": line.get("speaker", -1),
|
||||||
|
"text": line.get("text", ""),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timing_valid(self) -> bool:
|
||||||
|
"""All timestamps have start <= end and no negative values."""
|
||||||
|
for ts in self.timestamps:
|
||||||
|
if ts["start"] < 0 or ts["end"] < 0:
|
||||||
|
return False
|
||||||
|
if ts["end"] < ts["start"]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timing_monotonic(self) -> bool:
|
||||||
|
"""Line start times are non-decreasing."""
|
||||||
|
stamps = self.timestamps
|
||||||
|
for i in range(1, len(stamps)):
|
||||||
|
if stamps[i]["start"] < stamps[i - 1]["start"]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def timing_errors(self) -> List[str]:
|
||||||
|
"""Human-readable list of timing issues found."""
|
||||||
|
errors = []
|
||||||
|
stamps = self.timestamps
|
||||||
|
for i, ts in enumerate(stamps):
|
||||||
|
if ts["start"] < 0:
|
||||||
|
errors.append(f"Line {i}: negative start {ts['start']:.2f}s")
|
||||||
|
if ts["end"] < 0:
|
||||||
|
errors.append(f"Line {i}: negative end {ts['end']:.2f}s")
|
||||||
|
if ts["end"] < ts["start"]:
|
||||||
|
errors.append(
|
||||||
|
f"Line {i}: end ({ts['end']:.2f}s) < start ({ts['start']:.2f}s)"
|
||||||
|
)
|
||||||
|
for i in range(1, len(stamps)):
|
||||||
|
if stamps[i]["start"] < stamps[i - 1]["start"]:
|
||||||
|
errors.append(
|
||||||
|
f"Line {i}: start ({stamps[i]['start']:.2f}s) < previous start "
|
||||||
|
f"({stamps[i-1]['start']:.2f}s) — non-monotonic"
|
||||||
|
)
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AudioPlayer — timeline control for a loaded audio file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AudioPlayer:
|
||||||
|
"""Controls playback of a loaded audio file through the pipeline.
|
||||||
|
|
||||||
|
Tracks position in the audio, enabling play/pause/resume patterns::
|
||||||
|
|
||||||
|
player = h.load_audio("speech.wav")
|
||||||
|
await player.play(3.0) # Play first 3 seconds
|
||||||
|
await h.pause(7.0) # 7s silence (triggers detection)
|
||||||
|
await player.play(5.0) # Play next 5 seconds
|
||||||
|
await player.play() # Play all remaining audio
|
||||||
|
|
||||||
|
Args:
|
||||||
|
harness: The TestHarness instance.
|
||||||
|
pcm_data: Raw PCM s16le 16kHz mono bytes.
|
||||||
|
sample_rate: Audio sample rate (default 16000).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, harness: "TestHarness", pcm_data: bytes, sample_rate: int = SAMPLE_RATE):
|
||||||
|
self._harness = harness
|
||||||
|
self._pcm = pcm_data
|
||||||
|
self._sr = sample_rate
|
||||||
|
self._bps = sample_rate * BYTES_PER_SAMPLE # bytes per second
|
||||||
|
self._pos = 0 # current position in bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position(self) -> float:
|
||||||
|
"""Current playback position in seconds."""
|
||||||
|
return self._pos / self._bps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> float:
|
||||||
|
"""Total audio duration in seconds."""
|
||||||
|
return len(self._pcm) / self._bps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining(self) -> float:
|
||||||
|
"""Remaining audio in seconds."""
|
||||||
|
return max(0.0, (len(self._pcm) - self._pos) / self._bps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""True if all audio has been played."""
|
||||||
|
return self._pos >= len(self._pcm)
|
||||||
|
|
||||||
|
async def play(
|
||||||
|
self,
|
||||||
|
duration_s: Optional[float] = None,
|
||||||
|
speed: float = 1.0,
|
||||||
|
chunk_duration: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
"""Play audio from the current position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
duration_s: Seconds of audio to play. None = all remaining.
|
||||||
|
speed: 1.0 = real-time, 0 = instant, >1 = faster.
|
||||||
|
chunk_duration: Size of each chunk fed to the pipeline (seconds).
|
||||||
|
"""
|
||||||
|
if duration_s is None:
|
||||||
|
end_pos = len(self._pcm)
|
||||||
|
else:
|
||||||
|
end_pos = min(self._pos + int(duration_s * self._bps), len(self._pcm))
|
||||||
|
|
||||||
|
# Align to sample boundary
|
||||||
|
end_pos = (end_pos // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE
|
||||||
|
|
||||||
|
if end_pos <= self._pos:
|
||||||
|
return
|
||||||
|
|
||||||
|
segment = self._pcm[self._pos:end_pos]
|
||||||
|
self._pos = end_pos
|
||||||
|
await self._harness.feed_pcm(segment, speed=speed, chunk_duration=chunk_duration)
|
||||||
|
|
||||||
|
async def play_until(
|
||||||
|
self,
|
||||||
|
time_s: float,
|
||||||
|
speed: float = 1.0,
|
||||||
|
chunk_duration: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
"""Play until reaching time_s in the audio timeline."""
|
||||||
|
target = min(int(time_s * self._bps), len(self._pcm))
|
||||||
|
target = (target // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE
|
||||||
|
|
||||||
|
if target <= self._pos:
|
||||||
|
return
|
||||||
|
|
||||||
|
segment = self._pcm[self._pos:target]
|
||||||
|
self._pos = target
|
||||||
|
await self._harness.feed_pcm(segment, speed=speed, chunk_duration=chunk_duration)
|
||||||
|
|
||||||
|
def seek(self, time_s: float) -> None:
|
||||||
|
"""Move the playback cursor without feeding audio."""
|
||||||
|
pos = int(time_s * self._bps)
|
||||||
|
pos = (pos // BYTES_PER_SAMPLE) * BYTES_PER_SAMPLE
|
||||||
|
self._pos = max(0, min(pos, len(self._pcm)))
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset to the beginning of the audio."""
|
||||||
|
self._pos = 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestHarness — pipeline controller
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHarness:
|
||||||
|
"""In-process testing harness for the full WhisperLiveKit pipeline.
|
||||||
|
|
||||||
|
Use as an async context manager. Provides methods to feed audio,
|
||||||
|
pause/resume, inspect state, and evaluate results.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
load_audio(path) → AudioPlayer with play/seek controls
|
||||||
|
feed(path, speed) → feed entire audio file (simple mode)
|
||||||
|
pause(duration) → inject silence (triggers detection if > 5s)
|
||||||
|
drain(seconds) → let pipeline catch up
|
||||||
|
finish() → flush and return final state
|
||||||
|
cut() → abrupt stop, return partial state
|
||||||
|
wait_for(pred) → wait for condition on state
|
||||||
|
|
||||||
|
State inspection:
|
||||||
|
.state → current TestState
|
||||||
|
.history → all historical states
|
||||||
|
.snapshot_at(t) → state at audio position t
|
||||||
|
.metrics → SessionMetrics (latency, RTF, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
All keyword arguments passed to AudioProcessor.
|
||||||
|
Common: model_size, lan, backend, diarization, vac.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
kwargs.setdefault("pcm_input", True)
|
||||||
|
self._engine_kwargs = kwargs
|
||||||
|
self._processor = None
|
||||||
|
self._results_gen = None
|
||||||
|
self._collect_task = None
|
||||||
|
self._state = TestState()
|
||||||
|
self._audio_position = 0.0
|
||||||
|
self._history: List[TestState] = []
|
||||||
|
self._on_update: Optional[Callable[[TestState], None]] = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "TestHarness":
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
from whisperlivekit.core import TranscriptionEngine
|
||||||
|
|
||||||
|
# Cache engines by config to avoid reloading models when switching
|
||||||
|
# backends between tests. The singleton is reset only when the
|
||||||
|
# requested config doesn't match any cached engine.
|
||||||
|
cache_key = tuple(sorted(self._engine_kwargs.items()))
|
||||||
|
|
||||||
|
if cache_key not in _engine_cache:
|
||||||
|
TranscriptionEngine.reset()
|
||||||
|
_engine_cache[cache_key] = TranscriptionEngine(**self._engine_kwargs)
|
||||||
|
|
||||||
|
engine = _engine_cache[cache_key]
|
||||||
|
|
||||||
|
self._processor = AudioProcessor(transcription_engine=engine)
|
||||||
|
self._results_gen = await self._processor.create_tasks()
|
||||||
|
self._collect_task = asyncio.create_task(self._collect_results())
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: Any) -> None:
|
||||||
|
if self._processor:
|
||||||
|
await self._processor.cleanup()
|
||||||
|
if self._collect_task and not self._collect_task.done():
|
||||||
|
self._collect_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._collect_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _collect_results(self) -> None:
|
||||||
|
"""Background task: consume results from the pipeline."""
|
||||||
|
try:
|
||||||
|
async for front_data in self._results_gen:
|
||||||
|
self._state = TestState.from_front_data(front_data, self._audio_position)
|
||||||
|
self._history.append(self._state)
|
||||||
|
if self._on_update:
|
||||||
|
self._on_update(self._state)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Result collector ended: %s", e)
|
||||||
|
|
||||||
|
# ── Properties ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> TestState:
|
||||||
|
"""Current transcription state (updated live as results arrive)."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history(self) -> List[TestState]:
|
||||||
|
"""All states received so far, in order."""
|
||||||
|
return self._history
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_position(self) -> float:
|
||||||
|
"""How many seconds of audio have been fed so far."""
|
||||||
|
return self._audio_position
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metrics(self):
|
||||||
|
"""Pipeline's SessionMetrics (latency, RTF, token counts, etc.)."""
|
||||||
|
if self._processor:
|
||||||
|
return self._processor.metrics
|
||||||
|
return None
|
||||||
|
|
||||||
|
def on_update(self, callback: Callable[[TestState], None]) -> None:
|
||||||
|
"""Register a callback invoked on each new state update."""
|
||||||
|
self._on_update = callback
|
||||||
|
|
||||||
|
# ── Audio loading and feeding ──
|
||||||
|
|
||||||
|
def load_audio(self, source) -> AudioPlayer:
|
||||||
|
"""Load audio and return a player with timeline control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Path to audio file (str), or a TestSample with .path attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AudioPlayer with play/play_until/seek/reset methods.
|
||||||
|
"""
|
||||||
|
path = source.path if hasattr(source, "path") else str(source)
|
||||||
|
pcm = load_audio_pcm(path)
|
||||||
|
return AudioPlayer(self, pcm)
|
||||||
|
|
||||||
|
async def feed(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
speed: float = 1.0,
|
||||||
|
chunk_duration: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
"""Feed an entire audio file to the pipeline (simple mode).
|
||||||
|
|
||||||
|
For timeline control (play/pause/resume), use load_audio() instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to any audio file ffmpeg can decode.
|
||||||
|
speed: Playback speed (1.0 = real-time, 0 = instant).
|
||||||
|
chunk_duration: Size of each PCM chunk in seconds.
|
||||||
|
"""
|
||||||
|
pcm = load_audio_pcm(audio_path)
|
||||||
|
await self.feed_pcm(pcm, speed=speed, chunk_duration=chunk_duration)
|
||||||
|
|
||||||
|
async def feed_pcm(
|
||||||
|
self,
|
||||||
|
pcm_data: bytes,
|
||||||
|
speed: float = 1.0,
|
||||||
|
chunk_duration: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
"""Feed raw PCM s16le 16kHz mono bytes to the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_data: Raw PCM bytes.
|
||||||
|
speed: Playback speed multiplier.
|
||||||
|
chunk_duration: Duration of each chunk sent (seconds).
|
||||||
|
"""
|
||||||
|
chunk_bytes = int(chunk_duration * SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||||
|
offset = 0
|
||||||
|
while offset < len(pcm_data):
|
||||||
|
end = min(offset + chunk_bytes, len(pcm_data))
|
||||||
|
await self._processor.process_audio(pcm_data[offset:end])
|
||||||
|
chunk_seconds = (end - offset) / (SAMPLE_RATE * BYTES_PER_SAMPLE)
|
||||||
|
self._audio_position += chunk_seconds
|
||||||
|
offset = end
|
||||||
|
if speed > 0:
|
||||||
|
await asyncio.sleep(chunk_duration / speed)
|
||||||
|
|
||||||
|
# ── Pause / silence ──
|
||||||
|
|
||||||
|
async def pause(self, duration_s: float, speed: float = 1.0) -> None:
|
||||||
|
"""Inject silence to simulate a pause in speech.
|
||||||
|
|
||||||
|
Pauses > 5s trigger silence segment detection (MIN_DURATION_REAL_SILENCE).
|
||||||
|
Pauses < 5s are treated as brief gaps and produce no silence segment
|
||||||
|
(provided speech resumes afterward).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
duration_s: Duration of silence in seconds.
|
||||||
|
speed: Playback speed (1.0 = real-time, 0 = instant).
|
||||||
|
"""
|
||||||
|
silent_pcm = bytes(int(duration_s * SAMPLE_RATE * BYTES_PER_SAMPLE))
|
||||||
|
await self.feed_pcm(silent_pcm, speed=speed)
|
||||||
|
|
||||||
|
async def silence(self, duration_s: float, speed: float = 1.0) -> None:
|
||||||
|
"""Alias for pause(). Inject silence for the given duration."""
|
||||||
|
await self.pause(duration_s, speed=speed)
|
||||||
|
|
||||||
|
# ── Waiting ──
|
||||||
|
|
||||||
|
async def wait_for(
|
||||||
|
self,
|
||||||
|
predicate: Callable[[TestState], bool],
|
||||||
|
timeout: float = 30.0,
|
||||||
|
poll_interval: float = 0.1,
|
||||||
|
) -> TestState:
|
||||||
|
"""Wait until predicate(state) returns True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If the condition is not met within timeout.
|
||||||
|
"""
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
if predicate(self._state):
|
||||||
|
return self._state
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Condition not met within {timeout}s. "
|
||||||
|
f"Current state: {len(self._state.lines)} lines, "
|
||||||
|
f"buffer='{self._state.buffer_transcription[:50]}', "
|
||||||
|
f"audio_pos={self._audio_position:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def wait_for_text(self, timeout: float = 30.0) -> TestState:
|
||||||
|
"""Wait until any transcription text appears."""
|
||||||
|
return await self.wait_for(lambda s: s.text.strip(), timeout=timeout)
|
||||||
|
|
||||||
|
async def wait_for_lines(self, n: int = 1, timeout: float = 30.0) -> TestState:
|
||||||
|
"""Wait until at least n committed speech lines exist."""
|
||||||
|
return await self.wait_for(lambda s: len(s.speech_lines) >= n, timeout=timeout)
|
||||||
|
|
||||||
|
async def wait_for_silence(self, timeout: float = 30.0) -> TestState:
|
||||||
|
"""Wait until a silence segment is detected."""
|
||||||
|
return await self.wait_for(lambda s: s.has_silence, timeout=timeout)
|
||||||
|
|
||||||
|
async def wait_for_speakers(self, n: int = 2, timeout: float = 30.0) -> TestState:
|
||||||
|
"""Wait until at least n distinct speakers are detected."""
|
||||||
|
return await self.wait_for(lambda s: s.n_speakers >= n, timeout=timeout)
|
||||||
|
|
||||||
|
async def drain(self, seconds: float = 2.0) -> None:
|
||||||
|
"""Let the pipeline process without feeding audio.
|
||||||
|
|
||||||
|
Useful after feeding audio to allow the ASR backend to catch up.
|
||||||
|
"""
|
||||||
|
await asyncio.sleep(seconds)
|
||||||
|
|
||||||
|
# ── Finishing ──
|
||||||
|
|
||||||
|
async def finish(self, timeout: float = 30.0) -> TestState:
|
||||||
|
"""Signal end of audio and wait for pipeline to flush all results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final TestState with all committed lines and empty buffer.
|
||||||
|
"""
|
||||||
|
await self._processor.process_audio(b"")
|
||||||
|
if self._collect_task:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._collect_task, timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Timed out waiting for pipeline to finish after %.0fs", timeout)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
async def cut(self, timeout: float = 5.0) -> TestState:
|
||||||
|
"""Abrupt audio stop — signal EOF and return current state quickly.
|
||||||
|
|
||||||
|
Simulates user closing the connection mid-speech. Sends EOF but
|
||||||
|
uses a short timeout, so partial results are returned even if
|
||||||
|
the pipeline hasn't fully flushed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestState with whatever has been processed so far.
|
||||||
|
"""
|
||||||
|
await self._processor.process_audio(b"")
|
||||||
|
if self._collect_task:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._collect_task, timeout=timeout)
|
||||||
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
|
pass
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
# ── History inspection ──
|
||||||
|
|
||||||
|
def snapshot_at(self, audio_time: float) -> Optional[TestState]:
|
||||||
|
"""Find the historical state closest to when audio_time was reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_time: Audio position in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The TestState captured at that point, or None if no history.
|
||||||
|
"""
|
||||||
|
if not self._history:
|
||||||
|
return None
|
||||||
|
best = None
|
||||||
|
best_diff = float("inf")
|
||||||
|
for s in self._history:
|
||||||
|
diff = abs(s.audio_position - audio_time)
|
||||||
|
if diff < best_diff:
|
||||||
|
best_diff = diff
|
||||||
|
best = s
|
||||||
|
return best
|
||||||
|
|
||||||
|
# ── Debug ──
|
||||||
|
|
||||||
|
def print_state(self) -> None:
|
||||||
|
"""Print current state to stdout for debugging."""
|
||||||
|
s = self._state
|
||||||
|
print(f"--- Audio: {self._audio_position:.1f}s | Status: {s.status} ---")
|
||||||
|
for line in s.lines:
|
||||||
|
speaker = line.get("speaker", "?")
|
||||||
|
text = line.get("text", "")
|
||||||
|
start = line.get("start", "")
|
||||||
|
end = line.get("end", "")
|
||||||
|
tag = "SILENCE" if speaker == -2 else f"Speaker {speaker}"
|
||||||
|
print(f" [{start} -> {end}] {tag}: {text}")
|
||||||
|
if s.buffer_transcription:
|
||||||
|
print(f" [buffer] {s.buffer_transcription}")
|
||||||
|
if s.buffer_diarization:
|
||||||
|
print(f" [diar buffer] {s.buffer_diarization}")
|
||||||
|
print(f" Speakers: {s.speakers or 'none'} | Silence: {s.has_silence}")
|
||||||
|
print()
|
||||||
Loading…
Reference in a new issue