diff --git a/BENCHMARK.md b/BENCHMARK.md deleted file mode 100644 index 81239f0..0000000 --- a/BENCHMARK.md +++ /dev/null @@ -1,205 +0,0 @@ -# WhisperLiveKit Benchmark Report - -Benchmark comparing all supported ASR backends, streaming policies, and model sizes on Apple Silicon. -All tests run through the full AudioProcessor pipeline (same code path as production WebSocket). - -## Test Environment - -| Property | Value | -|----------|-------| -| Hardware | Apple M4, 32 GB RAM | -| OS | macOS 25.3.0 (arm64) | -| Python | 3.13 | -| faster-whisper | 1.2.1 | -| mlx-whisper | installed (via mlx) | -| Voxtral MLX | native MLX backend | -| Voxtral (HF) | transformers-based | -| VAC (Silero VAD) | enabled unless noted | -| Chunk size | 100 ms | -| Pacing | no-realtime (as fast as possible) | - -## Audio Test Files - -| File | Duration | Language | Speakers | Description | -|------|----------|----------|----------|-------------| -| `00_00_07_english_1_speaker.wav` | 7.2 s | English | 1 | Short dictation with pauses | -| `00_00_16_french_1_speaker.wav` | 16.3 s | French | 1 | French speech with intentional silence gaps | -| `00_00_30_english_3_speakers.wav` | 30.0 s | English | 3 | Multi-speaker conversation | - -Ground truth transcripts (`.transcript.json`) with per-word timestamps are hand-verified. - ---- - -## Results - -### English -- Short (7.2 s, 1 speaker) - -| Backend | Policy | Model | RTF | WER | Timestamp MAE | -|---------|--------|-------|-----|-----|---------------| -| faster-whisper | LocalAgreement | base | 0.20x | 21.1% | 0.080 s | -| faster-whisper | SimulStreaming | base | 0.14x | 0.0% | 0.239 s | -| faster-whisper | LocalAgreement | small | 0.59x | 21.1% | 0.089 s | -| faster-whisper | SimulStreaming | small | 0.39x | 0.0% | 0.221 s | -| mlx-whisper | LocalAgreement | base | 0.05x | 21.1% | 0.080 s | -| mlx-whisper | SimulStreaming | base | 0.14x | 10.5% | 0.245 s | -| mlx-whisper | LocalAgreement | small | 0.16x | 21.1% | 0.089 s | -| mlx-whisper | SimulStreaming | small | 0.20x | 10.5% | 0.226 s | -| voxtral-mlx | voxtral | 4B | 0.32x | 0.0% | 0.254 s | -| voxtral (HF) | voxtral | 4B | 1.29x | 0.0% | 1.876 s | - -### English -- Multi-speaker (30.0 s, 3 speakers) - -| Backend | Policy | Model | RTF | WER | Timestamp MAE | -|---------|--------|-------|-----|-----|---------------| -| faster-whisper | LocalAgreement | base | 0.24x | 44.7% | 0.235 s | -| faster-whisper | SimulStreaming | base | 0.10x | 5.3% | 0.398 s | -| faster-whisper | LocalAgreement | small | 0.59x | 25.0% | 0.226 s | -| faster-whisper | SimulStreaming | small | 0.26x | 5.3% | 0.387 s | -| mlx-whisper | LocalAgreement | base | 0.06x | 23.7% | 0.237 s | -| mlx-whisper | SimulStreaming | base | 0.11x | 5.3% | 0.395 s | -| mlx-whisper | LocalAgreement | small | 0.13x | 25.0% | 0.226 s | -| mlx-whisper | SimulStreaming | small | 0.20x | 5.3% | 0.394 s | -| voxtral-mlx | voxtral | 4B | 0.31x | 9.2% | 0.176 s | -| voxtral (HF) | voxtral | 4B | 1.00x | 32.9% | 1.034 s | - -

-Benchmark comparison on 30s English -

- -

-Speed vs Accuracy tradeoff -

- -### French (16.3 s, 1 speaker, `--language fr`) - -| Backend | Policy | Model | RTF | WER | Timestamp MAE | -|---------|--------|-------|-----|-----|---------------| -| faster-whisper | LocalAgreement | base | 0.22x | 25.7% | 3.460 s | -| faster-whisper | SimulStreaming | base | 0.10x | 31.4% | 3.660 s | -| faster-whisper | LocalAgreement | small | 0.76x | 42.9% | 0.051 s | -| faster-whisper | SimulStreaming | small | 0.29x | 25.7% | 0.219 s | -| mlx-whisper | LocalAgreement | base | 0.09x | ~45%* | ~5.0 s* | -| mlx-whisper | SimulStreaming | base | 0.09x | 40.0% | 3.540 s | -| mlx-whisper | LocalAgreement | small | 0.14x | 25.7% | 0.083 s | -| mlx-whisper | SimulStreaming | small | 0.17x | 31.4% | 0.203 s | -| voxtral-mlx | voxtral | 4B | 0.18x | 37.1% | 3.422 s | -| voxtral (HF) | voxtral | 4B | 0.63x | 28.6% | 4.040 s | - -\* mlx-whisper + LocalAgreement + base is unstable on this French file (WER fluctuates 34-1037% across runs due to hallucination loops). The `small` model does not have this problem. - -**Timestamp note:** The base model produces very high timestamp MAE (3.4-3.7s) on this French file because it misaligns words around the silence gaps. The small model handles this much better (0.05-0.22s MAE). Voxtral also drifts on the silence gaps. - ---- - -## Model Size Comparison (base vs small) - -| | base | small | Observation | -|--|------|-------|-------------| -| **RTF** | 0.05-0.24x | 0.13-0.76x | small is 2-3x slower | -| **English WER (SS)** | 0-5.3% | 0-5.3% | No improvement: SimulStreaming already saturates on base | -| **English WER (LA)** | 21-44.7% | 21-25% | small reduces LA errors on longer audio | -| **French WER** | 25-40% | 25-43% | Mixed: depends on backend/policy combo | -| **French timestamps** | 3.4-5.0s MAE | 0.05-0.22s MAE | small is dramatically better for French timestamps | - -In short: **base + SimulStreaming** gives the best speed/accuracy tradeoff for English. The small model only helps if you need LocalAgreement (for subtitle-grade timestamps) or non-English languages. - ---- - -## Key Findings - -### Speed (RTF = processing time / audio duration, lower is better) - -1. **mlx-whisper + LocalAgreement + base** is the fastest combo on Apple Silicon: 0.05-0.06x RTF on English. 30 seconds of audio in under 2 seconds. -2. For **faster-whisper**, SimulStreaming is faster than LocalAgreement. For **mlx-whisper**, it is the opposite: LocalAgreement (0.05-0.06x) outperforms SimulStreaming (0.11-0.14x) on speed. -3. **voxtral-mlx** runs at 0.18-0.32x RTF -- 3-5x slower than mlx-whisper base, but well within real-time. -4. **voxtral (HF transformers)** hits 1.0-1.3x RTF. At the real-time boundary on Apple Silicon. Use the MLX variant instead. -5. The **small** model is 2-3x slower than base across all backends. - -### Accuracy (WER = Word Error Rate, lower is better) - -1. **SimulStreaming** gives dramatically lower WER than LocalAgreement on the whisper backends. On the 30s English file: 5.3% vs 23-44%. -2. **voxtral-mlx** hits 0% on short English and 9.2% on multi-speaker. It auto-detects language natively. Whisper also supports `--language auto`, but tends to bias towards English on short segments. -3. **LocalAgreement** tends to repeat the last sentence at end-of-stream (a known LCP artifact), inflating WER. This is visible in the 21% WER on the 7s file -- the same 4 extra words appear in every LA run. -4. On **French** with the correct `--language fr`, whisper base achieves 25-40% WER -- comparable to Voxtral's 28-37%. The small model does not consistently improve French WER. - -### Timestamps (MAE = Mean Absolute Error on word start times) - -1. **LocalAgreement** gives the best timestamps on English (0.08-0.09s MAE). -2. **SimulStreaming** is less precise (0.22-0.40s MAE) but good enough for most applications. -3. On French with silence gaps, **base model timestamps are unreliable** (3.4-5s MAE). The **small model fixes this** (0.05-0.22s MAE). This is the strongest argument for using `small` over `base`. -4. **voxtral-mlx** has good timestamps on English (0.18-0.25s MAE) but drifts on audio with long silence gaps (3.4s MAE on the French file). - -### VAC (Voice Activity Classification) Impact - -| Backend | Policy | VAC | 7s English WER | 30s English WER | -|---------|--------|-----|----------------|-----------------| -| faster-whisper | LocalAgreement | on | 21.1% | 44.7% | -| faster-whisper | LocalAgreement | off | 100.0% | 100.0% | -| voxtral-mlx | voxtral | on | 0.0% | 9.2% | -| voxtral-mlx | voxtral | off | 0.0% | 9.2% | - -- **Whisper backends need VAC** to work in streaming mode. Without it the buffer logic breaks down and you get empty or garbage output. -- **Voxtral is unaffected by VAC** since it handles its own internal chunking. Identical results with or without. VAC still saves compute on silent segments. - ---- - -## Recommendations - -| Use Case | Backend | Policy | Model | Notes | -|----------|---------|--------|-------|-------| -| Fastest English (Apple Silicon) | mlx-whisper | SimulStreaming | base | 0.11x RTF, 5.3% WER | -| Fastest English (Linux/GPU) | faster-whisper | SimulStreaming | base | 0.10x RTF, 5.3% WER | -| Best accuracy, English | faster-whisper | SimulStreaming | small | 0.26x RTF, 5.3% WER, still fast | -| Multilingual / auto-detect | voxtral-mlx | voxtral | 4B | 100+ languages, 0.18-0.32x RTF | -| Best timestamps | any | LocalAgreement | small | 0.05-0.09s MAE, good for subtitles | -| Low memory / embedded | mlx-whisper | SimulStreaming | base | Smallest footprint, fastest response | - ---- - -## Caveats - -- **3 test files, ~53 seconds total.** Results give relative rankings between backends but should not be taken as definitive WER numbers. Run on your own data for production decisions. -- **RTF varies between runs** (up to +/-30%) depending on thermal state, background processes, and model caching. The numbers above are single sequential runs on a warm machine. -- **Only base and small tested.** Medium and large-v3 would likely improve WER at the cost of higher RTF. We did not test them here because they are slow on Apple Silicon without GPU. - ---- - -## Reproducing These Benchmarks - -```bash -# Install test dependencies -pip install -e ".[test]" - -# Single backend test -python test_backend_offline.py --backend faster-whisper --policy simulstreaming --model base --no-realtime - -# With a specific language -python test_backend_offline.py --backend mlx-whisper --policy simulstreaming --model small --lan fr --no-realtime - -# Multi-backend auto-detect benchmark -python test_backend_offline.py --benchmark --no-realtime - -# Export to JSON -python test_backend_offline.py --benchmark --no-realtime --json results.json - -# Test with your own audio -python test_backend_offline.py --backend voxtral-mlx --audio your_file.wav --no-realtime -``` - -The benchmark harness computes WER and timestamp accuracy automatically when ground truth -`.transcript.json` files exist alongside the audio files. See `audio_tests/` for the format. - ---- - -## Help Us Benchmark on More Hardware - -These results are from a single Apple M4 machine. We'd love to see numbers from other setups: Linux with CUDA GPUs, older Macs, different CPU architectures, cloud instances, etc. - -If you run the benchmark on your hardware, please open an issue or PR with your results and we will add them here. The more data points we have, the better the recommendations get. - -What we are especially interested in: -- **NVIDIA GPUs** (RTX 3090, 4090, A100, T4, etc.) with faster-whisper -- **Older Apple Silicon** (M1, M2, M3) with mlx-whisper and voxtral-mlx -- **Medium and large-v3 models** (we only tested base and small so far) -- **Longer audio files** or domain-specific audio (medical, legal, call center) -- **Other languages** beyond English and French diff --git a/README.md b/README.md index 1d04f68..4f9c17c 100644 --- a/README.md +++ b/README.md @@ -134,10 +134,13 @@ uv sync --extra cu129 --extra voxtral-hf --extra translation See **Parameters & Configuration** below on how to use them.

-Speed vs Accuracy tradeoff +Speed vs Accuracy — English, compute-unaware +

+

+Speed vs Accuracy — English, compute-aware

-See **[BENCHMARK.md](BENCHMARK.md)** for the full benchmark with tables, model size comparison, and more. +Benchmarks use public audio from [LibriSpeech](https://huggingface.co/datasets/openslr/librispeech_asr) and [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) — fully reproducible with `python scripts/run_scatter_benchmark.py`. We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR! @@ -371,7 +374,7 @@ docker compose up --build wlk-cpu # Quick benchmark with the CLI wlk bench wlk bench --backend faster-whisper --model large-v3 -wlk bench --json results.json +wlk bench --languages all --json results.json # Install test dependencies for full suite pip install -e ".[test]" @@ -379,13 +382,11 @@ pip install -e ".[test]" # Run unit tests (no model download required) pytest tests/ -v -# Detailed multi-backend benchmark -python test_backend_offline.py --benchmark --no-realtime -python test_backend_offline.py --benchmark --no-realtime --json results.json +# Speed vs Accuracy scatter plot (all backends, compute-aware + unaware) +python scripts/create_long_samples.py # generate ~90s test samples (cached) +python scripts/run_scatter_benchmark.py # English (both modes) +python scripts/run_scatter_benchmark.py --lang fr # French ``` -See [BENCHMARK.md](BENCHMARK.md) for a full comparison of backends, policies, WER, speed, and -timestamp accuracy on Apple Silicon. - ## Use Cases Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service... diff --git a/audio_tests/00_00_07_english_1_speaker.transcript.json b/audio_tests/00_00_07_english_1_speaker.transcript.json deleted file mode 100644 index 43ca785..0000000 --- a/audio_tests/00_00_07_english_1_speaker.transcript.json +++ /dev/null @@ -1,97 +0,0 @@ -[ - { - "word": "This", - "start": 0.0, - "end": 0.24 - }, - { - "word": "is", - "start": 0.24, - "end": 0.56 - }, - { - "word": "a", - "start": 0.56, - "end": 0.76 - }, - { - "word": "transcription", - "start": 0.76, - "end": 1.32 - }, - { - "word": "test.", - "start": 1.32, - "end": 2.0 - }, - { - "word": "We", - "start": 2.4, - "end": 2.5 - }, - { - "word": "want", - "start": 2.5, - "end": 2.66 - }, - { - "word": "to", - "start": 2.66, - "end": 2.84 - }, - { - "word": "see", - "start": 2.84, - "end": 3.1 - }, - { - "word": "if", - "start": 3.1, - "end": 3.34 - }, - { - "word": "we", - "start": 3.34, - "end": 3.5 - }, - { - "word": "can", - "start": 3.5, - "end": 3.68 - }, - { - "word": "use", - "start": 3.68, - "end": 4.04 - }, - { - "word": "smaller", - "start": 4.04, - "end": 4.76 - }, - { - "word": "chunks.", - "start": 4.76, - "end": 5.16 - }, - { - "word": "What", - "start": 6.06, - "end": 6.32 - }, - { - "word": "do", - "start": 6.32, - "end": 6.44 - }, - { - "word": "you", - "start": 6.44, - "end": 6.58 - }, - { - "word": "think?", - "start": 6.58, - "end": 6.84 - } -] \ No newline at end of file diff --git a/audio_tests/00_00_16_french_1_speaker.transcript.json b/audio_tests/00_00_16_french_1_speaker.transcript.json deleted file mode 100644 index 07c0b31..0000000 --- a/audio_tests/00_00_16_french_1_speaker.transcript.json +++ /dev/null @@ -1,177 +0,0 @@ -[ - { - "word": "Ok,", - "start": 2.02, - "end": 2.38 - }, - { - "word": "là", - "start": 2.52, - "end": 2.58 - }, - { - "word": "c", - "start": 2.58, - "end": 2.74 - }, - { - "word": "'est", - "start": 2.74, - "end": 2.76 - }, - { - "word": "un", - "start": 2.76, - "end": 2.86 - }, - { - "word": "test,", - "start": 2.86, - "end": 3.2 - }, - { - "word": "on", - "start": 3.34, - "end": 3.34 - }, - { - "word": "veut", - "start": 3.34, - "end": 3.48 - }, - { - "word": "voir", - "start": 3.48, - "end": 3.86 - }, - { - "word": "si", - "start": 3.86, - "end": 4.14 - }, - { - "word": "ça", - "start": 4.14, - "end": 4.26 - }, - { - "word": "arrive", - "start": 4.26, - "end": 4.36 - }, - { - "word": "à", - "start": 4.36, - "end": 4.5 - }, - { - "word": "capté", - "start": 4.5, - "end": 4.78 - }, - { - "word": "le", - "start": 4.78, - "end": 4.9 - }, - { - "word": "silence.", - "start": 4.9, - "end": 5.44 - }, - { - "word": "Là", - "start": 9.24, - "end": 9.6 - }, - { - "word": "il", - "start": 9.6, - "end": 9.78 - }, - { - "word": "est", - "start": 9.78, - "end": 9.84 - }, - { - "word": "une", - "start": 9.84, - "end": 9.96 - }, - { - "word": "telle", - "start": 9.96, - "end": 10.12 - }, - { - "word": "seconde", - "start": 10.12, - "end": 10.38 - }, - { - "word": "de", - "start": 10.38, - "end": 10.48 - }, - { - "word": "silence", - "start": 10.48, - "end": 10.78 - }, - { - "word": "et", - "start": 10.78, - "end": 11.06 - }, - { - "word": "je", - "start": 11.06, - "end": 11.16 - }, - { - "word": "vous", - "start": 11.16, - "end": 11.32 - }, - { - "word": "parle.", - "start": 11.32, - "end": 11.68 - }, - { - "word": "Et", - "start": 13.28, - "end": 13.64 - }, - { - "word": "voilà,", - "start": 13.64, - "end": 13.96 - }, - { - "word": "allez", - "start": 14.36, - "end": 14.62 - }, - { - "word": "on", - "start": 14.62, - "end": 14.78 - }, - { - "word": "va", - "start": 14.78, - "end": 14.88 - }, - { - "word": "tester", - "start": 14.88, - "end": 15.06 - }, - { - "word": "ça.", - "start": 15.06, - "end": 15.36 - } -] \ No newline at end of file diff --git a/audio_tests/00_00_30_english_3_speakers.transcript.json b/audio_tests/00_00_30_english_3_speakers.transcript.json deleted file mode 100644 index bb9d097..0000000 --- a/audio_tests/00_00_30_english_3_speakers.transcript.json +++ /dev/null @@ -1,382 +0,0 @@ -[ - { - "word": "Transcription", - "start": 0.0, - "end": 0.6 - }, - { - "word": "technology", - "start": 0.6, - "end": 1.24 - }, - { - "word": "has", - "start": 1.24, - "end": 1.5 - }, - { - "word": "improved", - "start": 1.5, - "end": 1.96 - }, - { - "word": "so", - "start": 1.96, - "end": 2.32 - }, - { - "word": "much", - "start": 2.32, - "end": 2.68 - }, - { - "word": "in", - "start": 2.68, - "end": 2.94 - }, - { - "word": "the", - "start": 2.94, - "end": 3.02 - }, - { - "word": "past", - "start": 3.02, - "end": 3.24 - }, - { - "word": "few", - "start": 3.24, - "end": 3.5 - }, - { - "word": "years.", - "start": 3.5, - "end": 3.96 - }, - { - "word": "Have", - "start": 4.56, - "end": 4.74 - }, - { - "word": "you", - "start": 4.74, - "end": 4.9 - }, - { - "word": "noticed", - "start": 4.9, - "end": 5.26 - }, - { - "word": "how", - "start": 5.26, - "end": 5.52 - }, - { - "word": "accurate", - "start": 5.52, - "end": 6.08 - }, - { - "word": "real", - "start": 6.08, - "end": 6.42 - }, - { - "word": "-time", - "start": 6.42, - "end": 6.74 - }, - { - "word": "speech", - "start": 6.74, - "end": 7.24 - }, - { - "word": "to", - "start": 7.24, - "end": 7.46 - }, - { - "word": "text", - "start": 7.46, - "end": 7.78 - }, - { - "word": "is", - "start": 7.78, - "end": 8.0 - }, - { - "word": "now?", - "start": 8.0, - "end": 8.3 - }, - { - "word": "Absolutely.", - "start": 8.7, - "end": 9.16 - }, - { - "word": "I", - "start": 10.04, - "end": 10.38 - }, - { - "word": "use", - "start": 10.38, - "end": 10.56 - }, - { - "word": "it", - "start": 10.56, - "end": 10.76 - }, - { - "word": "all", - "start": 10.76, - "end": 10.9 - }, - { - "word": "the", - "start": 10.9, - "end": 11.04 - }, - { - "word": "time", - "start": 11.04, - "end": 11.32 - }, - { - "word": "for", - "start": 11.32, - "end": 11.54 - }, - { - "word": "taking", - "start": 11.54, - "end": 11.86 - }, - { - "word": "notes", - "start": 11.86, - "end": 12.16 - }, - { - "word": "during", - "start": 12.16, - "end": 12.54 - }, - { - "word": "meetings.", - "start": 12.54, - "end": 12.94 - }, - { - "word": "It's", - "start": 13.6, - "end": 13.8 - }, - { - "word": "amazing", - "start": 13.8, - "end": 14.1 - }, - { - "word": "how", - "start": 14.1, - "end": 14.48 - }, - { - "word": "it", - "start": 14.48, - "end": 14.62 - }, - { - "word": "can", - "start": 14.62, - "end": 14.74 - }, - { - "word": "recognise", - "start": 14.74, - "end": 15.24 - }, - { - "word": "different", - "start": 15.24, - "end": 15.68 - }, - { - "word": "speakers", - "start": 15.68, - "end": 16.16 - }, - { - "word": "and", - "start": 16.16, - "end": 16.8 - }, - { - "word": "even", - "start": 16.8, - "end": 17.1 - }, - { - "word": "add", - "start": 17.1, - "end": 17.44 - }, - { - "word": "punctuation.", - "start": 17.44, - "end": 18.36 - }, - { - "word": "Yeah,", - "start": 18.88, - "end": 19.16 - }, - { - "word": "but", - "start": 19.36, - "end": 19.52 - }, - { - "word": "sometimes", - "start": 19.52, - "end": 20.16 - }, - { - "word": "noise", - "start": 20.16, - "end": 20.54 - }, - { - "word": "can", - "start": 20.54, - "end": 20.8 - }, - { - "word": "still", - "start": 20.8, - "end": 21.1 - }, - { - "word": "cause", - "start": 21.1, - "end": 21.44 - }, - { - "word": "mistakes.", - "start": 21.44, - "end": 21.94 - }, - { - "word": "Does", - "start": 22.68, - "end": 22.9 - }, - { - "word": "this", - "start": 22.9, - "end": 23.12 - }, - { - "word": "system", - "start": 23.12, - "end": 23.46 - }, - { - "word": "handle", - "start": 23.46, - "end": 23.88 - }, - { - "word": "that", - "start": 23.88, - "end": 24.12 - }, - { - "word": "well?", - "start": 24.12, - "end": 24.42 - }, - { - "word": "It", - "start": 24.42, - "end": 25.32 - }, - { - "word": "does", - "start": 25.32, - "end": 25.48 - }, - { - "word": "a", - "start": 25.48, - "end": 25.62 - }, - { - "word": "pretty", - "start": 25.62, - "end": 25.88 - }, - { - "word": "good", - "start": 25.88, - "end": 26.08 - }, - { - "word": "job", - "start": 26.08, - "end": 26.32 - }, - { - "word": "filtering", - "start": 26.32, - "end": 26.8 - }, - { - "word": "noise,", - "start": 26.8, - "end": 27.18 - }, - { - "word": "especially", - "start": 27.36, - "end": 28.0 - }, - { - "word": "with", - "start": 28.0, - "end": 28.28 - }, - { - "word": "models", - "start": 28.28, - "end": 28.62 - }, - { - "word": "that", - "start": 28.62, - "end": 28.94 - }, - { - "word": "use", - "start": 28.94, - "end": 29.22 - }, - { - "word": "voice", - "start": 29.22, - "end": 29.54 - }, - { - "word": "active.", - "start": 29.54, - "end": 29.9 - } -] \ No newline at end of file diff --git a/audio_tests/generate_transcripts.py b/audio_tests/generate_transcripts.py deleted file mode 100644 index 6749a5c..0000000 --- a/audio_tests/generate_transcripts.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -"""Generate word-level timestamped transcripts using faster-whisper (offline). - -Produces one JSON file per audio with: [{word, start, end}, ...] -""" - -import json -import os - -from faster_whisper import WhisperModel - -AUDIO_DIR = os.path.dirname(os.path.abspath(__file__)) - -FILES = [ - ("00_00_07_english_1_speaker.wav", "en"), - ("00_00_16_french_1_speaker.wav", "fr"), - ("00_00_30_english_3_speakers.wav", "en"), -] - -def main(): - print("Loading faster-whisper model (base, cpu, float32)...") - model = WhisperModel("base", device="cpu", compute_type="float32") - - for filename, lang in FILES: - audio_path = os.path.join(AUDIO_DIR, filename) - out_path = os.path.join( - AUDIO_DIR, filename.rsplit(".", 1)[0] + ".transcript.json" - ) - - print(f"\n{'='*60}") - print(f"Transcribing: {filename} (language={lang})") - print(f"{'='*60}") - - segments, info = model.transcribe( - audio_path, word_timestamps=True, language=lang - ) - - words = [] - for segment in segments: - if segment.words: - for w in segment.words: - words.append({ - "word": w.word.strip(), - "start": round(w.start, 3), - "end": round(w.end, 3), - }) - print(f" {w.start:6.2f} - {w.end:6.2f} {w.word.strip()}") - - with open(out_path, "w", encoding="utf-8") as f: - json.dump(words, f, indent=2, ensure_ascii=False) - - print(f"\n -> {len(words)} words written to {os.path.basename(out_path)}") - - print("\nDone.") - - -if __name__ == "__main__": - main() diff --git a/benchmark_chart.png b/benchmark_chart.png deleted file mode 100644 index 20123bd..0000000 Binary files a/benchmark_chart.png and /dev/null differ diff --git a/benchmark_scatter.png b/benchmark_scatter.png deleted file mode 100644 index 9f62bf3..0000000 Binary files a/benchmark_scatter.png and /dev/null differ diff --git a/benchmark_scatter_en_aware.png b/benchmark_scatter_en_aware.png new file mode 100644 index 0000000..43a500f Binary files /dev/null and b/benchmark_scatter_en_aware.png differ diff --git a/benchmark_scatter_en_unaware.png b/benchmark_scatter_en_unaware.png new file mode 100644 index 0000000..97ac7fd Binary files /dev/null and b/benchmark_scatter_en_unaware.png differ diff --git a/benchmark_scatter_fr_aware.png b/benchmark_scatter_fr_aware.png new file mode 100644 index 0000000..c5e78ac Binary files /dev/null and b/benchmark_scatter_fr_aware.png differ diff --git a/benchmark_scatter_fr_unaware.png b/benchmark_scatter_fr_unaware.png new file mode 100644 index 0000000..ea3f816 Binary files /dev/null and b/benchmark_scatter_fr_unaware.png differ diff --git a/pyproject.toml b/pyproject.toml index 3a36468..dfba364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "whisperlivekit" version = "0.2.20" -description = "Real-time speech-to-text with speaker diarization using Whisper" +description = "Real-time speech-to-text models" readme = "README.md" authors = [{ name = "Quentin Fuxa" }] license = { file = "LICENSE" } @@ -144,6 +144,7 @@ packages = [ "whisperlivekit.local_agreement", "whisperlivekit.voxtral_mlx", "whisperlivekit.silero_vad_models", + "whisperlivekit.benchmark", ] [tool.setuptools.package-data] diff --git a/run_benchmark.py b/run_benchmark.py deleted file mode 100644 index 8c737fb..0000000 --- a/run_benchmark.py +++ /dev/null @@ -1,290 +0,0 @@ -#!/usr/bin/env python3 -""" -Comprehensive benchmark runner for WhisperLiveKit. - -Tests all available backend+policy combinations across multiple audio files, -model sizes, and VAC on/off configurations. Outputs structured JSON that -is consumed by the report generator. - -Usage: - python run_benchmark.py # full benchmark - python run_benchmark.py --quick # subset (tiny models, fewer combos) - python run_benchmark.py --json results.json # custom output path -""" - -import argparse -import asyncio -import gc -import json -import logging -import platform -import subprocess -import sys -import time -from dataclasses import asdict -from pathlib import Path - -logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s: %(message)s") -logger = logging.getLogger("benchmark") -logger.setLevel(logging.INFO) - -# Re-use harness functions -sys.path.insert(0, str(Path(__file__).parent)) -from test_backend_offline import ( - AUDIO_TESTS_DIR, - SAMPLE_RATE, - create_engine, - discover_audio_files, - download_sample_audio, - load_audio, - run_test, -) - -CACHE_DIR = Path(__file__).parent / ".test_cache" - - -def get_system_info() -> dict: - """Collect system metadata for the report.""" - info = { - "platform": platform.platform(), - "machine": platform.machine(), - "processor": platform.processor(), - "python_version": platform.python_version(), - } - - # macOS: get chip info - try: - chip = subprocess.check_output( - ["sysctl", "-n", "machdep.cpu.brand_string"], text=True - ).strip() - info["cpu"] = chip - except Exception: - info["cpu"] = platform.processor() - - # RAM - try: - mem_bytes = int( - subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip() - ) - info["ram_gb"] = round(mem_bytes / (1024**3)) - except Exception: - info["ram_gb"] = None - - # Backend versions - versions = {} - try: - import faster_whisper - versions["faster-whisper"] = faster_whisper.__version__ - except ImportError: - pass - try: - import mlx_whisper # noqa: F401 - versions["mlx-whisper"] = "installed" - except ImportError: - pass - try: - import mlx.core as mx - versions["mlx"] = mx.__version__ - except ImportError: - pass - try: - import transformers - versions["transformers"] = transformers.__version__ - except ImportError: - pass - try: - import torch - versions["torch"] = torch.__version__ - except ImportError: - pass - - info["backend_versions"] = versions - return info - - -def detect_combos(quick: bool = False) -> list: - """Build list of (backend, policy, model_size) combos to test.""" - combos = [] - - # Model sizes to test - model_sizes = ["tiny", "base", "small"] if not quick else ["tiny", "base"] - - # faster-whisper - try: - import faster_whisper # noqa: F401 - for model in model_sizes: - combos.append({"backend": "faster-whisper", "policy": "localagreement", "model": model}) - combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "model": model}) - except ImportError: - pass - - # mlx-whisper - try: - import mlx_whisper # noqa: F401 - for model in model_sizes: - combos.append({"backend": "mlx-whisper", "policy": "localagreement", "model": model}) - combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "model": model}) - except ImportError: - pass - - # voxtral-mlx (single model, single policy) - try: - from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401 - combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "model": ""}) - except ImportError: - pass - - # voxtral HF (single model, single policy) - try: - from transformers import AutoModelForSpeechSeq2Seq # noqa: F401 - combos.append({"backend": "voxtral", "policy": "voxtral", "model": ""}) - except ImportError: - pass - - return combos - - -def collect_audio_files() -> list: - """Collect all benchmark audio files.""" - files = [] - - # audio_tests/ directory - if AUDIO_TESTS_DIR.is_dir(): - files.extend(discover_audio_files(str(AUDIO_TESTS_DIR))) - - # JFK sample - jfk = CACHE_DIR / "jfk.wav" - if not jfk.exists(): - jfk = download_sample_audio() - if jfk.exists(): - files.append(jfk) - - return files - - -async def run_single_combo( - combo: dict, audio_files: list, vac: bool, lan: str, max_duration: float, -) -> list: - """Run one backend+policy+model combo across all audio files.""" - backend = combo["backend"] - policy = combo["policy"] - model = combo["model"] - - results = [] - try: - engine = create_engine( - backend=backend, - model_size=model, - lan=lan, - vac=vac, - policy=policy, - ) - - # Quiet noisy loggers - for mod in ( - "whisperlivekit.audio_processor", - "whisperlivekit.simul_whisper", - "whisperlivekit.tokens_alignment", - "whisperlivekit.simul_whisper.align_att_base", - "whisperlivekit.simul_whisper.simul_whisper", - ): - logging.getLogger(mod).setLevel(logging.WARNING) - - for audio_path in audio_files: - duration = len(load_audio(str(audio_path))) / SAMPLE_RATE - if duration > max_duration: - logger.info(f" Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s)") - continue - - file_lan = lan - if "french" in audio_path.name.lower() and lan == "en": - file_lan = "fr" - - audio = load_audio(str(audio_path)) - result = await run_test( - engine, audio, chunk_ms=100, realtime=False, - audio_file=audio_path.name, backend=backend, - policy=policy, lan=file_lan, - ) - # Tag with extra metadata - result_dict = asdict(result) - result_dict["model_size"] = model - result_dict["vac"] = vac - results.append(result_dict) - - except Exception as e: - logger.error(f" FAILED: {e}") - import traceback - traceback.print_exc() - - return results - - -async def run_full_benchmark(combos, audio_files, max_duration=60.0): - """Run all combos with VAC on and off.""" - all_results = [] - total = len(combos) * 2 # x2 for VAC on/off - idx = 0 - - for combo in combos: - for vac in [True, False]: - idx += 1 - vac_str = "VAC=on" if vac else "VAC=off" - desc = f"{combo['backend']} / {combo['policy']}" - if combo["model"]: - desc += f" / {combo['model']}" - desc += f" / {vac_str}" - - print(f"\n{'='*70}") - print(f"[{idx}/{total}] {desc}") - print(f"{'='*70}") - - results = await run_single_combo( - combo, audio_files, vac=vac, lan="en", max_duration=max_duration, - ) - all_results.extend(results) - - # Free memory between combos - gc.collect() - - return all_results - - -def main(): - parser = argparse.ArgumentParser(description="Run comprehensive WhisperLiveKit benchmark") - parser.add_argument("--quick", action="store_true", help="Quick mode: fewer models and combos") - parser.add_argument("--json", default="benchmark_results.json", dest="json_output", help="Output JSON path") - parser.add_argument("--max-duration", type=float, default=60.0, help="Max audio duration in seconds") - args = parser.parse_args() - - system_info = get_system_info() - combos = detect_combos(quick=args.quick) - audio_files = collect_audio_files() - - print(f"System: {system_info.get('cpu', 'unknown')}, {system_info.get('ram_gb', '?')}GB RAM") - print(f"Backends: {list(system_info['backend_versions'].keys())}") - print(f"Combos to test: {len(combos)} x 2 (VAC on/off) = {len(combos)*2}") - print(f"Audio files: {[f.name for f in audio_files]}") - print() - - t0 = time.time() - all_results = asyncio.run( - run_full_benchmark(combos, audio_files, max_duration=args.max_duration) - ) - total_time = time.time() - t0 - - output = { - "system_info": system_info, - "benchmark_date": time.strftime("%Y-%m-%d %H:%M"), - "total_benchmark_time_s": round(total_time, 1), - "n_combos": len(combos) * 2, - "n_audio_files": len(audio_files), - "results": all_results, - } - - Path(args.json_output).write_text(json.dumps(output, indent=2, ensure_ascii=False)) - print(f"\nBenchmark complete in {total_time:.0f}s. Results: {args.json_output}") - - -if __name__ == "__main__": - main() diff --git a/scripts/create_long_samples.py b/scripts/create_long_samples.py new file mode 100644 index 0000000..c925b72 --- /dev/null +++ b/scripts/create_long_samples.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +"""Create long benchmark samples (5min+) by concatenating utterances from public datasets.""" + +import io +import json +import logging +import wave +from pathlib import Path + +import numpy as np + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + +CACHE = Path.home() / ".cache/whisperlivekit/benchmark_data" +CACHE.mkdir(parents=True, exist_ok=True) +SR = 16000 + + +def save_wav(path, audio, sr=SR): + audio = np.clip(audio, -1, 1) + audio_int = (audio * 32767).astype(np.int16) + with wave.open(str(path), "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sr) + wf.writeframes(audio_int.tobytes()) + + +def decode_audio(audio_bytes): + import soundfile as sf + arr, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32") + return np.array(arr, dtype=np.float32), sr + + +def download_long_librispeech(config, lang_code, target_dur=300): + """Concatenate LibriSpeech utterances into a ~5min sample.""" + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info(f"Downloading LibriSpeech {config} for {lang_code} (~{target_dur}s)...") + ds = load_dataset("openslr/librispeech_asr", config, split="test", streaming=True) + ds = ds.cast_column("audio", Audio(decode=False)) + + chunks, texts = [], [] + total = 0 + for item in ds: + arr, sr = decode_audio(item["audio"]["bytes"]) + chunks.append(arr) + texts.append(item["text"]) + total += len(arr) / sr + if total >= target_dur: + break + if len(chunks) % 20 == 0: + logger.info(f" {total:.0f}s / {target_dur}s ({len(chunks)} utterances)") + + # Insert small silences between utterances for natural transitions + silence = np.zeros(int(0.5 * sr), dtype=np.float32) + interleaved = [] + for i, chunk in enumerate(chunks): + if i > 0: + interleaved.append(silence) + interleaved.append(chunk) + full = np.concatenate(interleaved) + total = len(full) / sr + ref = " ".join(texts) + name = f"{lang_code}_long_{config}" + path = CACHE / f"{name}.wav" + save_wav(path, full) + logger.info(f" -> {name}: {total:.1f}s ({len(texts)} utterances)") + return {"name": name, "path": str(path), "reference": ref, + "duration": round(total, 2), "language": lang_code.split("_")[0]} + + +def download_long_mls(config, lang_code, target_dur=300): + """Concatenate MLS utterances into a ~5min sample.""" + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info(f"Downloading MLS {config} for {lang_code} (~{target_dur}s)...") + ds = load_dataset("facebook/multilingual_librispeech", config, split="test", streaming=True) + ds = ds.cast_column("audio", Audio(decode=False)) + + chunks, texts = [], [] + total = 0 + for item in ds: + arr, sr = decode_audio(item["audio"]["bytes"]) + chunks.append(arr) + texts.append(item.get("text", item.get("transcript", ""))) + total += len(arr) / sr + if total >= target_dur: + break + if len(chunks) % 20 == 0: + logger.info(f" {total:.0f}s / {target_dur}s ({len(chunks)} utterances)") + + silence = np.zeros(int(0.5 * sr), dtype=np.float32) + interleaved = [] + for i, chunk in enumerate(chunks): + if i > 0: + interleaved.append(silence) + interleaved.append(chunk) + full = np.concatenate(interleaved) + total = len(full) / sr + ref = " ".join(texts) + name = f"{lang_code}_long" + path = CACHE / f"{name}.wav" + save_wav(path, full) + logger.info(f" -> {name}: {total:.1f}s ({len(texts)} utterances)") + return {"name": name, "path": str(path), "reference": ref, + "duration": round(total, 2), "language": lang_code} + + +def main(): + samples = [] + + # English clean ~90s + samples.append(download_long_librispeech("clean", "en", target_dur=90)) + + # English noisy ~90s + samples.append(download_long_librispeech("other", "en_noisy", target_dur=90)) + + # French ~90s + samples.append(download_long_mls("french", "fr", target_dur=90)) + + # Save metadata + meta_path = CACHE / "long_samples.json" + meta_path.write_text(json.dumps(samples, indent=2)) + logger.info(f"\nSaved metadata to {meta_path}") + + total = sum(s["duration"] for s in samples) + logger.info(f"Total: {len(samples)} long samples, {total:.0f}s ({total/60:.1f}min)") + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_architecture.py b/scripts/generate_architecture.py index 7f42d45..9ef45c3 100644 --- a/scripts/generate_architecture.py +++ b/scripts/generate_architecture.py @@ -144,13 +144,14 @@ ax.text(10.4, 4.2, "4B params • 15 languages • 6-bit quant (MLX)", fontsize= # ── Qwen3 backend ── section_box(15.0, 3.8, 4.6, 3.2, "Qwen3 ASR (batch + aligner)", border=C_GREEN, bg=C_BOX_BG3) -box(15.2, 5.9, 2.0, 0.6, "Qwen3 ASR\n1.7B / 0.6B", color=C_GREEN, bg="#1a3a2a", fontsize=7, bold=True) -box(17.4, 5.9, 2.0, 0.6, "Forced\nAligner", color=C_GREEN, bg="#1a3a2a", fontsize=7) +box(15.2, 5.9, 1.5, 0.6, "Qwen3 ASR\n1.7B / 0.6B", color=C_GREEN, bg="#1a3a2a", fontsize=7, bold=True) +box(16.9, 5.9, 1.5, 0.6, "Qwen3\nSimul", color=C_GREEN, bg="#1a3a2a", fontsize=7, bold=True) +box(18.6, 5.9, 1.0, 0.6, "Forced\nAligner", color=C_GREEN, bg="#1a3a2a", fontsize=6.5) -ax.text(15.2, 5.4, "Full-audio batch inference", fontsize=6.5, color=C_TEXTDIM, family="monospace") +ax.text(15.2, 5.4, "Batch + SimulStreaming (AlignAtt)", fontsize=6.5, color=C_TEXTDIM, family="monospace") ax.text(15.2, 5.0, "ForcedAligner provides word timestamps", fontsize=6, color=C_GREEN, family="monospace") -ax.text(15.2, 4.6, "Uses LocalAgreement for streaming output", fontsize=6, color=C_TEXTDIM, family="monospace") -ax.text(15.2, 4.2, "12 languages • CUDA/MPS/CPU", fontsize=6, color=C_TEXTDIM, family="monospace") +ax.text(15.2, 4.6, "LocalAgreement or border-distance policy", fontsize=6, color=C_TEXTDIM, family="monospace") +ax.text(15.2, 4.2, "29 languages • CUDA/MPS/CPU", fontsize=6, color=C_TEXTDIM, family="monospace") # ── OpenAI API ── box(15.2, 7.7, 4.2, 0.6, "OpenAI API (cloud)", color="#5a6a7a", fontsize=7) @@ -168,8 +169,10 @@ box(16.0, 2.2, 3.4, 0.8, "Translation\nNLLB • CTranslate2", box(10.4, 0.8, 4.0, 0.8, "WhisperLiveKitConfig\n(single source of truth)", color=C_ACCENT, fontsize=7, bold=True) -box(14.8, 0.8, 4.6, 0.8, "TestHarness\nfull pipeline testing without server", +box(14.8, 0.8, 2.3, 0.8, "TestHarness\npipeline testing", color="#5a6a7a", fontsize=7) +box(17.3, 0.8, 2.3, 0.8, "Benchmark\n8 langs • 13 samples", + color=C_ORANGE, fontsize=7, bold=True) # ═══════════════════════════════════════════════════════════════════ # Arrows: main data flow diff --git a/scripts/run_scatter_benchmark.py b/scripts/run_scatter_benchmark.py new file mode 100644 index 0000000..0176c25 --- /dev/null +++ b/scripts/run_scatter_benchmark.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +"""Run benchmark across all backend x model x policy combos for scatter plot. + +Tests each configuration on long audio samples in two modes: + - Compute-unaware (speed=0): all audio dumped instantly, measures pure model accuracy + - Compute-aware (speed=1.0): real-time simulation, slow models lose audio + +Usage: + python scripts/run_scatter_benchmark.py + python scripts/run_scatter_benchmark.py --aware # only compute-aware + python scripts/run_scatter_benchmark.py --unaware # only compute-unaware + python scripts/run_scatter_benchmark.py --plot-only results.json +""" + +import argparse +import asyncio +import gc +import json +import logging +import platform +import subprocess +import sys +import time +import warnings + +warnings.filterwarnings("ignore") +logging.basicConfig(level=logging.WARNING) +for name in [ + "whisperlivekit", "transformers", "torch", "httpx", "datasets", + "numexpr", "faster_whisper", +]: + logging.getLogger(name).setLevel(logging.ERROR) + + +LONG_SAMPLES_PATH = "~/.cache/whisperlivekit/benchmark_data/long_samples.json" + +# ── All configurations to benchmark ── + +COMBOS = [ + # faster-whisper x LocalAgreement + {"backend": "faster-whisper", "model_size": "base", "policy": "localagreement", + "label": "fw LA base", "color": "#4a9eff", "marker": "o", "size": 100}, + {"backend": "faster-whisper", "model_size": "small", "policy": "localagreement", + "label": "fw LA small", "color": "#4a9eff", "marker": "o", "size": 220}, + # faster-whisper x SimulStreaming + {"backend": "faster-whisper", "model_size": "base", "policy": "simulstreaming", + "label": "fw SS base", "color": "#4a9eff", "marker": "s", "size": 100}, + {"backend": "faster-whisper", "model_size": "small", "policy": "simulstreaming", + "label": "fw SS small", "color": "#4a9eff", "marker": "s", "size": 220}, + # mlx-whisper x LocalAgreement + {"backend": "mlx-whisper", "model_size": "base", "policy": "localagreement", + "label": "mlx LA base", "color": "#4ecca3", "marker": "o", "size": 100}, + {"backend": "mlx-whisper", "model_size": "small", "policy": "localagreement", + "label": "mlx LA small", "color": "#4ecca3", "marker": "o", "size": 220}, + # mlx-whisper x SimulStreaming + {"backend": "mlx-whisper", "model_size": "base", "policy": "simulstreaming", + "label": "mlx SS base", "color": "#4ecca3", "marker": "s", "size": 100}, + {"backend": "mlx-whisper", "model_size": "small", "policy": "simulstreaming", + "label": "mlx SS small", "color": "#4ecca3", "marker": "s", "size": 220}, + # voxtral-mlx (4B, native streaming) + {"backend": "voxtral-mlx", "model_size": "", "policy": "", + "label": "voxtral mlx", "color": "#f5a623", "marker": "D", "size": 250}, +] + + +def is_backend_available(backend): + try: + if backend == "faster-whisper": + import faster_whisper; return True # noqa + elif backend == "mlx-whisper": + import mlx_whisper; return True # noqa + elif backend == "whisper": + import whisper; return True # noqa + elif backend == "voxtral-mlx": + import mlx.core # noqa + from whisperlivekit.voxtral_mlx.loader import load_voxtral_model; return True # noqa + elif backend == "voxtral": + from transformers import VoxtralRealtimeForConditionalGeneration; return True # noqa + elif backend in ("qwen3", "qwen3-simul"): + from whisperlivekit.qwen3_asr import _patch_transformers_compat + _patch_transformers_compat() + from qwen_asr import Qwen3ASRModel; return True # noqa + except (ImportError, Exception): + pass + return False + + +def get_system_info(): + info = {"platform": platform.platform(), "machine": platform.machine()} + try: + info["cpu"] = subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], text=True).strip() + except Exception: + info["cpu"] = platform.processor() + try: + mem = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip()) + info["ram_gb"] = round(mem / (1024**3)) + except Exception: + info["ram_gb"] = None + return info + + +async def run_combo_on_samples(combo, samples, lang="en", speed=0): + """Run one config on all samples, return averaged result. + + Args: + speed: 0 = compute-unaware (instant dump), 1.0 = compute-aware (real-time) + """ + from whisperlivekit.core import TranscriptionEngine + from whisperlivekit.metrics import compute_wer + from whisperlivekit.test_harness import TestHarness, _engine_cache + + kwargs = {"lan": lang, "pcm_input": True} + if combo["backend"]: + kwargs["backend"] = combo["backend"] + if combo["model_size"]: + kwargs["model_size"] = combo["model_size"] + if combo.get("policy"): + kwargs["backend_policy"] = combo["policy"] + + TranscriptionEngine.reset() + _engine_cache.clear() + gc.collect() + + total_ref_words, total_errors = 0, 0 + total_infer_time, total_audio_time = 0.0, 0.0 + n_ok = 0 + + for sample in samples: + try: + async with TestHarness(**kwargs) as h: + await h.feed(sample["path"], speed=speed) + await h.drain(max(5.0, sample["duration"] * 0.5)) + state = await h.finish(timeout=120) + metrics = h.metrics + + hypothesis = state.committed_text or state.text + wer_result = compute_wer(sample["reference"], hypothesis) + + total_ref_words += wer_result["ref_words"] + total_errors += (wer_result["substitutions"] + + wer_result["insertions"] + + wer_result["deletions"]) + + # Use actual inference time from metrics, not wall clock + if metrics and metrics.transcription_durations: + total_infer_time += sum(metrics.transcription_durations) + total_audio_time += sample["duration"] + n_ok += 1 + except Exception as e: + print(f" [WARN: {sample['name']} failed: {e}]", end="") + + if n_ok == 0: + return None + + weighted_wer = total_errors / max(total_ref_words, 1) + # Real RTF = actual inference time / audio duration + real_rtf = total_infer_time / total_audio_time if total_audio_time > 0 else 0 + + return { + "label": combo["label"], + "backend": combo["backend"], + "model_size": combo.get("model_size", ""), + "policy": combo.get("policy", ""), + "color": combo["color"], + "marker": combo["marker"], + "size": combo["size"], + "rtf": round(real_rtf, 4), + "wer_pct": round(weighted_wer * 100, 1), + "n_samples": n_ok, + } + + +async def run_all(combos, samples, lang="en", speed=0): + mode_label = "compute-aware" if speed > 0 else "compute-unaware" + results = [] + for i, combo in enumerate(combos): + if not is_backend_available(combo["backend"]): + print(f" [{i+1}/{len(combos)}] SKIP {combo['label']} (not installed)") + continue + print(f" [{i+1}/{len(combos)}] {combo['label']} ({mode_label})...", end="", flush=True) + result = await run_combo_on_samples(combo, samples, lang, speed=speed) + if result: + results.append(result) + print(f" RTF={result['rtf']:.2f}x WER={result['wer_pct']:.1f}% ({result['n_samples']} samples)") + else: + print(" FAILED (no results)") + return results + + +def get_long_samples_for_lang(lang="en"): + """Load long benchmark samples from long_samples.json, filtered by language.""" + import os + path = os.path.expanduser(LONG_SAMPLES_PATH) + if not os.path.exists(path): + print(f"ERROR: Long samples file not found: {path}") + print("Please generate it first (see benchmark_data/README).") + sys.exit(1) + with open(path) as f: + all_samples = json.load(f) + samples = [s for s in all_samples if s["language"] == lang] + return [{"name": s["name"], "path": s["path"], "reference": s["reference"], + "duration": s["duration"]} for s in samples] + + +LANG_NAMES = { + "en": "English", "fr": "French", "es": "Spanish", "de": "German", + "pt": "Portuguese", "it": "Italian", "nl": "Dutch", "pl": "Polish", + "zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ru": "Russian", +} + + +def generate_scatter(results, system_info, output_path, n_samples, lang="en", + mode="unaware", sample_duration=0.0): + """Generate scatter plot. + + Args: + mode: "unaware" or "aware" -- shown in title + sample_duration: total audio duration in seconds -- shown in title + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + from matplotlib.lines import Line2D + + fig, ax = plt.subplots(figsize=(12, 7), facecolor="white") + ax.set_facecolor("#fafafa") + + # Separate main cluster from outliers (RTF > 1.0) + main = [r for r in results if r["rtf"] <= 1.0] + slow = [r for r in results if r["rtf"] > 1.0] + + # Axis limits: tight around main data + if main: + xmax = max(r["rtf"] for r in main) * 1.6 + ymax = max(r["wer_pct"] for r in main) * 1.5 + 1 + else: + xmax, ymax = 0.5, 10 + xmax = max(xmax, 0.45) + ymax = max(ymax, 8) + + # Sweet spot zone + sweet_x = xmax * 0.85 + sweet_y = ymax * 0.55 + rect = plt.Rectangle((0, 0), sweet_x, sweet_y, alpha=0.07, color="#4ecca3", + zorder=0, linewidth=0) + ax.add_patch(rect) + ax.text(sweet_x - 0.005, sweet_y - 0.15, "sweet spot", ha="right", va="top", + fontsize=10, color="#2ecc71", fontstyle="italic", fontweight="bold", alpha=0.5) + + # Manual label offsets keyed by label name — hand-tuned + OFFSETS = { + "fw LA base": (8, 8), + "fw LA small": (8, 8), + "fw SS base": (-55, -14), + "fw SS small": (8, 8), + "mlx LA base": (8, 10), + "mlx LA small": (8, 8), + "mlx SS base": (-55, 8), + "mlx SS small": (-55, -5), + "voxtral mlx": (10, -14), + "qwen3 0.6B": (10, 8), + "fw LA large-v3": (8, -5), + "fw SS large-v3": (8, 5), + } + + # Plot main points + for r in main: + ax.scatter(r["rtf"], r["wer_pct"], c=r["color"], marker=r["marker"], + s=r["size"], edgecolors="white", linewidths=1.0, zorder=5, alpha=0.85) + ox, oy = OFFSETS.get(r["label"], (8, -4)) + ax.annotate(r["label"], (r["rtf"], r["wer_pct"]), + textcoords="offset points", xytext=(ox, oy), + fontsize=8.5, color="#333333", fontweight="medium") + + # Note slow backends outside main view + if slow: + lines = [] + for r in slow: + lines.append(f"{r['label']}: RTF={r['rtf']:.1f}x, WER={r['wer_pct']:.1f}%") + note = "Beyond real-time:\n" + "\n".join(lines) + ax.text(xmax * 0.97, ymax * 0.97, note, ha="right", va="top", + fontsize=7.5, color="#777777", fontstyle="italic", + bbox=dict(boxstyle="round,pad=0.4", facecolor="#f8f8f8", + edgecolor="#dddddd", alpha=0.9)) + + # Axes + ax.set_xlim(left=-0.01, right=xmax) + ax.set_ylim(bottom=0, top=ymax) + ax.set_xlabel("RTF (lower = faster)", fontsize=13, fontweight="bold", labelpad=8) + ax.set_ylabel("WER % (lower = more accurate)", fontsize=13, fontweight="bold", labelpad=8) + ax.grid(True, alpha=0.15, linestyle="-", color="#cccccc") + ax.tick_params(labelsize=10) + + # Title + cpu = system_info.get("cpu", "unknown").replace("Apple ", "") + lang_name = LANG_NAMES.get(lang, lang.upper()) + mode_label = "compute-unaware" if mode == "unaware" else "compute-aware" + dur_str = f"{sample_duration / 60:.0f}min" if sample_duration >= 60 else f"{sample_duration:.0f}s" + ax.set_title( + f"Speed vs Accuracy ({mode_label}) — {n_samples} {lang_name} samples, {dur_str} ({cpu})", + fontsize=14, fontweight="bold", pad=12) + + # Legend — backends + backend_handles = [] + seen = set() + for r in results: + if r["backend"] not in seen: + seen.add(r["backend"]) + backend_handles.append(mpatches.Patch(color=r["color"], label=r["backend"])) + + # Legend — shapes + marker_map = {"o": "LocalAgreement", "s": "SimulStreaming", "D": "Native streaming", + "h": "Batch + aligner"} + active = set(r["marker"] for r in results) + shape_handles = [ + Line2D([0], [0], marker=m, color="#888", label=lbl, + markerfacecolor="#888", markersize=8, linestyle="None") + for m, lbl in marker_map.items() if m in active + ] + # sizes + shape_handles += [ + Line2D([0], [0], marker="o", color="#888", label="base", + markerfacecolor="#888", markersize=5, linestyle="None"), + Line2D([0], [0], marker="o", color="#888", label="small / 4B", + markerfacecolor="#888", markersize=9, linestyle="None"), + ] + + leg1 = ax.legend(handles=backend_handles, loc="upper left", fontsize=9, + framealpha=0.95, edgecolor="#ddd", title="Backend", title_fontsize=9) + ax.add_artist(leg1) + ax.legend(handles=shape_handles, loc="lower right", fontsize=8, + framealpha=0.95, edgecolor="#ddd", ncol=2) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight", pad_inches=0.15) + print(f"Saved {output_path}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--plot-only", default=None) + parser.add_argument("--lang", default="en", help="Language code (en, fr, es, de, ...)") + parser.add_argument("--output", "-o", default=None, + help="Output path prefix (mode suffix added automatically)") + parser.add_argument("--json-output", default=None, + help="JSON output path prefix (mode suffix added automatically)") + parser.add_argument("--aware", action="store_true", + help="Run only compute-aware mode (speed=1.0)") + parser.add_argument("--unaware", action="store_true", + help="Run only compute-unaware mode (speed=0)") + args = parser.parse_args() + + lang = args.lang + + # Determine which modes to run + if args.aware and args.unaware: + modes = ["unaware", "aware"] + elif args.aware: + modes = ["aware"] + elif args.unaware: + modes = ["unaware"] + else: + # Default: run both + modes = ["unaware", "aware"] + + if args.plot_only: + data = json.load(open(args.plot_only)) + mode = data.get("mode", "unaware") + output_path = args.output or f"benchmark_scatter_{lang}_{mode}.png" + generate_scatter(data["results"], data["system_info"], output_path, + data["n_samples"], data.get("lang", "en"), + mode=mode, + sample_duration=data.get("total_audio_s", 0)) + return + + print(f"Loading long {lang} samples from {LONG_SAMPLES_PATH}...") + samples = get_long_samples_for_lang(lang) + if not samples: + print(f"ERROR: No long samples for language '{lang}'") + sys.exit(1) + print(f"Using {len(samples)} samples: {[s['name'] for s in samples]}") + total_dur = sum(s["duration"] for s in samples) + print(f"Total audio: {total_dur:.0f}s ({total_dur / 60:.1f}min)\n") + + # Filter combos to backends that support this language + from whisperlivekit.benchmark.compat import backend_supports_language + combos = [c for c in COMBOS if backend_supports_language(c["backend"], lang)] + + system_info = get_system_info() + + for mode in modes: + speed = 1.0 if mode == "aware" else 0 + mode_label = "compute-aware" if mode == "aware" else "compute-unaware" + print(f"\n{'='*60}") + print(f" Running {mode_label} (speed={speed})") + print(f"{'='*60}\n") + + t0 = time.time() + results = asyncio.run(run_all(combos, samples, lang, speed=speed)) + total = time.time() - t0 + + # Save JSON + json_path = args.json_output or f"/tmp/bench_scatter_{lang}" + json_file = f"{json_path}_{mode}.json" + output_data = { + "system_info": system_info, + "lang": lang, + "mode": mode, + "speed": speed, + "n_samples": len(samples), + "sample_names": [s["name"] for s in samples], + "total_audio_s": round(total_dur, 1), + "total_benchmark_time_s": round(total, 1), + "results": results, + } + with open(json_file, "w") as f: + json.dump(output_data, f, indent=2) + print(f"\nJSON: {json_file} ({total:.0f}s total)") + + # Generate scatter plot + output_base = args.output or f"benchmark_scatter_{lang}" + output_path = f"{output_base}_{mode}.png" + generate_scatter(results, system_info, output_path, len(samples), lang, + mode=mode, sample_duration=total_dur) + + +if __name__ == "__main__": + main() diff --git a/test_backend_offline.py b/test_backend_offline.py deleted file mode 100644 index 75af927..0000000 --- a/test_backend_offline.py +++ /dev/null @@ -1,804 +0,0 @@ -#!/usr/bin/env python3 -""" -Offline test harness and benchmark suite for WhisperLiveKit backends. - -Simulates a client-server session by feeding audio files as PCM bytes through -the full AudioProcessor pipeline (the same path used by the WebSocket server), -without needing a browser or microphone. - -Computes WER (Word Error Rate) and timestamp accuracy when ground truth -transcript files (.transcript.json) are available alongside audio files. - -Usage: - # Test with a single audio file: - python test_backend_offline.py --backend faster-whisper --audio audio_tests/00_00_07_english_1_speaker.wav - - # Test all files in audio_tests/: - python test_backend_offline.py --backend faster-whisper --no-realtime - - # Override streaming policy: - python test_backend_offline.py --backend faster-whisper --policy simulstreaming --no-realtime - - # Multi-backend benchmark (auto-detects all installed backends): - python test_backend_offline.py --benchmark --no-realtime - - # Export results as JSON: - python test_backend_offline.py --benchmark --no-realtime --json results.json - - # Insert silence for testing silence handling: - python test_backend_offline.py --backend faster-whisper --insert-silence 3.0 2.0 -""" - -import argparse -import asyncio -import json -import logging -import sys -import time -import urllib.request -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import List, Optional - -import numpy as np - -logging.basicConfig( - level=logging.WARNING, - format="%(asctime)s %(levelname)s %(name)s: %(message)s", -) -logger = logging.getLogger("test_offline") -logger.setLevel(logging.INFO) - -SAMPLE_RATE = 16000 -JFK_WAV_URL = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav" -CACHE_DIR = Path(__file__).parent / ".test_cache" -AUDIO_TESTS_DIR = Path(__file__).parent / "audio_tests" -AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"} - - -@dataclass -class WordTimestamp: - """Word with its start/end time.""" - word: str - start: float - end: float - - -@dataclass -class TestResult: - """Structured result from a single test run.""" - audio_file: str - audio_duration_s: float - backend: str - policy: str - language: str - chunk_ms: int - realtime_pacing: bool - # Timing - processing_time_s: float - rtf: float # real-time factor - # Transcription output - transcription: str - n_lines: int - n_responses: int - # WER metrics (None if no ground truth) - wer: Optional[float] = None - wer_details: Optional[dict] = None - # Timestamp accuracy (None if no ground truth) - timestamp_mae: Optional[float] = None - timestamp_max_delta: Optional[float] = None - timestamp_median_delta: Optional[float] = None - # Word-level timestamps - word_timestamps: List[WordTimestamp] = field(default_factory=list) - # Raw last response - last_response: Optional[dict] = None - - -def download_sample_audio() -> Path: - """Download the jfk.wav sample if not cached.""" - CACHE_DIR.mkdir(exist_ok=True) - path = CACHE_DIR / "jfk.wav" - if not path.exists(): - logger.info(f"Downloading sample audio to {path} ...") - urllib.request.urlretrieve(JFK_WAV_URL, path) - logger.info("Done.") - return path - - -def load_audio(path: str) -> np.ndarray: - """Load audio file as float32 mono 16kHz numpy array. - - Supports WAV, FLAC (via soundfile) and MP3, OGG, M4A (via librosa). - """ - ext = Path(path).suffix.lower() - if ext in (".mp3", ".ogg", ".m4a"): - import librosa - audio, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True) - return audio.astype(np.float32) - - import soundfile as sf - audio, sr = sf.read(path, dtype="float32") - if audio.ndim > 1: - audio = audio.mean(axis=1) - if sr != SAMPLE_RATE: - import librosa - audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) - return audio - - -def insert_silence(audio: np.ndarray, silence_sec: float, position_sec: float) -> np.ndarray: - """Insert silence into audio at a given position. - - Args: - audio: Float32 mono audio array at SAMPLE_RATE. - silence_sec: Duration of silence to insert in seconds. - position_sec: Position in seconds where silence starts. - Returns: - New audio array with silence inserted. - """ - pos_samples = int(position_sec * SAMPLE_RATE) - silence_samples = int(silence_sec * SAMPLE_RATE) - pos_samples = min(pos_samples, len(audio)) - silence = np.zeros(silence_samples, dtype=np.float32) - return np.concatenate([audio[:pos_samples], silence, audio[pos_samples:]]) - - -def float32_to_s16le_bytes(audio: np.ndarray) -> bytes: - """Convert float32 audio to s16le PCM bytes (what the browser sends).""" - return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes() - - -def create_engine( - backend: str, model_size: str, lan: str, - diarization: bool = False, - diarization_backend: str = "", - vac: bool = True, - policy: str = "", -): - """Create a TranscriptionEngine with the given backend config.""" - import gc - - from whisperlivekit.core import TranscriptionEngine - - # Reset singleton so we get a fresh instance - TranscriptionEngine._instance = None - TranscriptionEngine._initialized = False - gc.collect() - - kwargs = dict( - backend=backend, - lan=lan, - pcm_input=True, - vac=vac, - transcription=True, - diarization=diarization, - ) - if diarization_backend: - kwargs["diarization_backend"] = diarization_backend - if model_size: - kwargs["model_size"] = model_size - if policy: - kwargs["backend_policy"] = policy - - return TranscriptionEngine(**kwargs) - - -def _extract_text_from_response(response_dict: dict) -> str: - """Extract full transcription text from a FrontData dict.""" - def _strip_or_empty(value: object) -> str: - return value.strip() if isinstance(value, str) else "" - - segments = response_dict.get("lines", []) - full_text = " ".join( - text - for seg in segments - if isinstance(seg, dict) - for text in [_strip_or_empty(seg.get("text"))] - if text - ) - buf = _strip_or_empty(response_dict.get("buffer_transcription")) - if buf: - full_text = f"{full_text} {buf}".strip() if full_text else buf - return full_text - - -async def run_test( - engine, audio: np.ndarray, chunk_ms: int, realtime: bool, - audio_file: str = "", backend: str = "", policy: str = "", lan: str = "", -) -> TestResult: - """ - Simulate a client session through the full AudioProcessor pipeline. - - 1. Create AudioProcessor (one per "client session") - 2. Start async pipeline (transcription_processor, results_formatter, etc.) - 3. Feed audio as PCM bytes in timed chunks - 4. Collect and display FrontData responses - 5. Signal EOF and cleanup - """ - from whisperlivekit.audio_processor import AudioProcessor - - chunk_samples = int(SAMPLE_RATE * chunk_ms / 1000) - total_samples = len(audio) - audio_duration = total_samples / SAMPLE_RATE - - logger.info( - f"Audio: {audio_duration:.2f}s | " - f"Chunk: {chunk_ms}ms ({chunk_samples} samples) | " - f"Steps: {total_samples // chunk_samples + 1} | " - f"Realtime: {realtime}" - ) - - # --- Server side: create processor and start pipeline --- - processor = AudioProcessor(transcription_engine=engine) - results_generator = await processor.create_tasks() - - # Collect results in background (like handle_websocket_results) - all_responses = [] - response_count = 0 - last_printed_text = "" - - async def collect_results(): - nonlocal response_count, last_printed_text - async for response in results_generator: - all_responses.append(response) - response_count += 1 - d = response.to_dict() - - # Only print when transcription text actually changes - current_text = _extract_text_from_response(d) - if current_text and current_text != last_printed_text: - buf = d.get("buffer_transcription") - buf = buf.strip() if isinstance(buf, str) else "" - committed = current_text - if buf and committed.endswith(buf): - committed = committed[:-len(buf)].strip() - - # Show committed text + buffer separately - display = committed - if buf: - display = f"{committed} \033[90m{buf}\033[0m" if committed else f"\033[90m{buf}\033[0m" - print(f" > {display}", flush=True) - last_printed_text = current_text - - result_task = asyncio.create_task(collect_results()) - - # --- Client side: feed audio as PCM bytes --- - t_start = time.time() - - for offset in range(0, total_samples, chunk_samples): - chunk = audio[offset : offset + chunk_samples] - pcm_bytes = float32_to_s16le_bytes(chunk) - await processor.process_audio(pcm_bytes) - if realtime: - await asyncio.sleep(chunk_ms / 1000) - - feed_elapsed = time.time() - t_start - - logger.info(f"Audio fed in {feed_elapsed:.2f}s. Signaling EOF...") - - # Signal end of audio (like client disconnect / empty message) - await processor.process_audio(None) - - # Wait for pipeline to drain completely - try: - await asyncio.wait_for(result_task, timeout=120.0) - except asyncio.TimeoutError: - logger.warning("Timed out waiting for results. Proceeding with cleanup.") - result_task.cancel() - try: - await result_task - except asyncio.CancelledError: - pass - - # --- Capture word-level timestamps before cleanup --- - word_timestamps = [] - try: - state = await processor.get_current_state() - for token in state.tokens: - if hasattr(token, 'start') and hasattr(token, 'text') and token.text: - word_timestamps.append(WordTimestamp( - word=token.text.strip(), - start=round(token.start, 3), - end=round(token.end, 3), - )) - except Exception as e: - logger.warning(f"Could not capture word timestamps: {e}") - - # Cleanup - await processor.cleanup() - - total_elapsed = time.time() - t_start - - # --- Build result --- - transcription = "" - n_lines = 0 - last_response_dict = None - - if all_responses: - last = all_responses[-1].to_dict() - last_response_dict = last - n_lines = len(last.get("lines", [])) - transcription = _extract_text_from_response(last) - - # --- Compute WER and timestamp accuracy against ground truth --- - from whisperlivekit.metrics import compute_timestamp_accuracy, compute_wer - - wer_val = None - wer_details = None - ts_mae = None - ts_max_delta = None - ts_median_delta = None - - gt_path = Path(audio_file).with_suffix(".transcript.json") - if not gt_path.exists(): - gt_path = AUDIO_TESTS_DIR / gt_path - gt = None - if gt_path.exists(): - with open(gt_path) as f: - gt = json.load(f) - - # WER - gt_text = " ".join(w["word"] for w in gt) - wer_result = compute_wer(gt_text, transcription) - wer_val = round(wer_result["wer"], 4) - wer_details = wer_result - - # Timestamp accuracy - if word_timestamps: - pred_dicts = [{"word": wt.word, "start": wt.start, "end": wt.end} for wt in word_timestamps] - ts_result = compute_timestamp_accuracy(pred_dicts, gt) - ts_mae = ts_result["mae_start"] - ts_max_delta = ts_result["max_delta_start"] - ts_median_delta = ts_result["median_delta_start"] - - result = TestResult( - audio_file=audio_file, - audio_duration_s=round(audio_duration, 2), - backend=backend, - policy=policy, - language=lan, - chunk_ms=chunk_ms, - realtime_pacing=realtime, - processing_time_s=round(total_elapsed, 2), - rtf=round(total_elapsed / audio_duration, 2), - transcription=transcription, - n_lines=n_lines, - n_responses=response_count, - wer=wer_val, - wer_details=wer_details, - timestamp_mae=round(ts_mae, 3) if ts_mae is not None else None, - timestamp_max_delta=round(ts_max_delta, 3) if ts_max_delta is not None else None, - timestamp_median_delta=round(ts_median_delta, 3) if ts_median_delta is not None else None, - word_timestamps=word_timestamps, - last_response=last_response_dict, - ) - - # --- Print summary --- - print(f"\n{'=' * 60}") - print(f"RESULT: {audio_file}") - print(f"{'=' * 60}") - print(f"Transcription: {transcription}") - print(f"Lines: {n_lines} | Responses: {response_count}") - print(f"Audio: {audio_duration:.2f}s | Time: {total_elapsed:.2f}s | RTF: {result.rtf:.2f}x") - - if wer_val is not None: - print(f"WER: {wer_val:.2%} (S={wer_details['substitutions']} I={wer_details['insertions']} D={wer_details['deletions']})") - - # Print word timestamps if available - if word_timestamps: - print(f"\nWord timestamps ({len(word_timestamps)} words):") - for wt in word_timestamps: - print(f" [{wt.start:6.2f} - {wt.end:6.2f}] {wt.word}") - - # Detailed comparison with ground truth - if gt: - print(f"\n vs Ground truth ({len(gt)} words):") - max_words = max(len(word_timestamps), len(gt)) - for i in range(max_words): - pred = word_timestamps[i] if i < len(word_timestamps) else None - ref = gt[i] if i < len(gt) else None - p_str = f"[{pred.start:5.2f}-{pred.end:5.2f}] {pred.word:<15}" if pred else " " * 30 - r_str = f"[{ref['start']:5.2f}-{ref['end']:5.2f}] {ref['word']:<15}" if ref else "" - delta = "" - if pred and ref: - d = pred.start - ref['start'] - delta = f" Δstart={d:+.2f}" - print(f" {p_str} | {r_str}{delta}") - - if ts_mae is not None: - print(f"\n Timestamp stats: MAE={ts_mae:.3f}s max|Δ|={ts_max_delta:.3f}s median|Δ|={ts_median_delta:.3f}s") - - print(f"{'=' * 60}") - - return result - - -def discover_audio_files(directory: str) -> List[Path]: - """Find all supported audio files in directory.""" - d = Path(directory) - files = sorted( - p for p in d.iterdir() - if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS - ) - return files - - -async def run_all_tests( - engine, audio_files: List[Path], chunk_ms: int, realtime: bool, - backend: str, policy: str, lan: str, max_duration: float = 60.0, - silence_insertions: Optional[List[List[float]]] = None, -) -> List[TestResult]: - """Run tests on multiple audio files sequentially.""" - results = [] - for audio_path in audio_files: - # Detect language from filename if "french" in name - file_lan = lan - if "french" in audio_path.name.lower() and lan == "en": - file_lan = "fr" - logger.info("Auto-detected language 'fr' from filename") - - audio = load_audio(str(audio_path)) - - # Insert silence segments (applied in reverse position order to keep offsets valid) - if silence_insertions: - for secs, at_sec in sorted(silence_insertions, key=lambda x: x[1], reverse=True): - logger.info(f"Inserting {secs:.1f}s silence at {at_sec:.1f}s") - audio = insert_silence(audio, secs, at_sec) - - duration = len(audio) / SAMPLE_RATE - - if duration > max_duration: - logger.info(f"Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s max)") - continue - - print(f"\n{'#' * 60}") - print(f"# Testing: {audio_path.name} ({duration:.1f}s)") - print(f"{'#' * 60}") - - result = await run_test( - engine, audio, chunk_ms, realtime, - audio_file=audio_path.name, backend=backend, policy=policy, lan=file_lan, - ) - results.append(result) - - return results - - -def print_benchmark_summary(results: List[TestResult]): - """Print a tabular summary of all test results.""" - print(f"\n{'=' * 110}") - print("BENCHMARK SUMMARY") - print(f"{'=' * 110}") - print( - f"{'File':<40} {'Duration':>8} {'Time':>8} {'RTF':>6} " - f"{'WER':>7} {'MAE(s)':>7} {'Lines':>5}" - ) - print(f"{'-' * 110}") - for r in results: - wer_str = f"{r.wer:.2%}" if r.wer is not None else " -" - mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -" - print( - f"{r.audio_file:<40} {r.audio_duration_s:>7.1f}s {r.processing_time_s:>7.1f}s " - f"{r.rtf:>5.2f}x {wer_str:>7} {mae_str:>7} {r.n_lines:>5}" - ) - print(f"{'-' * 110}") - total_audio = sum(r.audio_duration_s for r in results) - total_time = sum(r.processing_time_s for r in results) - avg_rtf = total_time / total_audio if total_audio > 0 else 0 - wer_vals = [r.wer for r in results if r.wer is not None] - avg_wer_str = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -" - mae_vals = [r.timestamp_mae for r in results if r.timestamp_mae is not None] - avg_mae_str = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -" - print( - f"{'TOTAL/AVG':<40} {total_audio:>7.1f}s {total_time:>7.1f}s " - f"{avg_rtf:>5.2f}x {avg_wer_str:>7} {avg_mae_str:>7}" - ) - print(f"{'=' * 110}") - - # Print transcription excerpts - print("\nTRANSCRIPTIONS:") - print(f"{'-' * 110}") - for r in results: - excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription - print(f" {r.audio_file}:") - print(f" {excerpt}") - print(f"{'=' * 110}") - - -def detect_available_backends() -> List[dict]: - """Probe which backends can be imported and return (backend, policy) combos. - - Returns list of dicts with keys: backend, policy, description. - """ - combos = [] - - # faster-whisper - try: - import faster_whisper # noqa: F401 - combos.append({"backend": "faster-whisper", "policy": "localagreement", "description": "faster-whisper + LocalAgreement"}) - combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "description": "faster-whisper + SimulStreaming"}) - except ImportError: - pass - - # mlx-whisper (macOS only) - try: - import mlx_whisper # noqa: F401 - combos.append({"backend": "mlx-whisper", "policy": "localagreement", "description": "mlx-whisper + LocalAgreement"}) - combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "description": "mlx-whisper + SimulStreaming"}) - except ImportError: - pass - - # openai-whisper - try: - import whisper # noqa: F401 - combos.append({"backend": "whisper", "policy": "localagreement", "description": "openai-whisper + LocalAgreement"}) - combos.append({"backend": "whisper", "policy": "simulstreaming", "description": "openai-whisper + SimulStreaming"}) - except ImportError: - pass - - # voxtral-mlx - try: - from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401 - combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "description": "voxtral-mlx (MLX)"}) - except ImportError: - pass - - # voxtral (HuggingFace) - try: - from transformers import AutoModelForSpeechSeq2Seq # noqa: F401 - combos.append({"backend": "voxtral", "policy": "voxtral", "description": "voxtral (HuggingFace)"}) - except ImportError: - pass - - return combos - - -def print_cross_backend_comparison(all_results: List[TestResult]): - """Print a comparison table across backends and policies.""" - print(f"\n{'=' * 110}") - print("CROSS-BACKEND BENCHMARK COMPARISON") - print(f"{'=' * 110}") - print( - f"{'Backend':<18} {'Policy':<16} {'File':<30} " - f"{'WER':>7} {'RTF':>6} {'MAE(s)':>7} {'MaxΔ(s)':>8}" - ) - print(f"{'-' * 110}") - - for r in all_results: - wer_str = f"{r.wer:.2%}" if r.wer is not None else " -" - rtf_str = f"{r.rtf:.2f}x" - mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -" - max_str = f"{r.timestamp_max_delta:.3f}" if r.timestamp_max_delta is not None else " -" - # Truncate filename for readability - fname = r.audio_file[:28] + ".." if len(r.audio_file) > 30 else r.audio_file - print( - f"{r.backend:<18} {r.policy:<16} {fname:<30} " - f"{wer_str:>7} {rtf_str:>6} {mae_str:>7} {max_str:>8}" - ) - - print(f"{'-' * 110}") - - # Per-backend averages - from collections import defaultdict - by_combo = defaultdict(list) - for r in all_results: - by_combo[(r.backend, r.policy)].append(r) - - print(f"\n{'Backend':<18} {'Policy':<16} {'Avg WER':>8} {'Avg RTF':>8} {'Avg MAE':>8} {'Files':>6}") - print(f"{'-' * 80}") - for (backend, policy), group in sorted(by_combo.items()): - wer_vals = [r.wer for r in group if r.wer is not None] - rtf_vals = [r.rtf for r in group] - mae_vals = [r.timestamp_mae for r in group if r.timestamp_mae is not None] - avg_wer = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -" - avg_rtf = f"{sum(rtf_vals)/len(rtf_vals):.2f}x" - avg_mae = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -" - print( - f"{backend:<18} {policy:<16} {avg_wer:>8} {avg_rtf:>8} {avg_mae:>8} {len(group):>6}" - ) - print(f"{'=' * 110}") - - -def _quiet_loggers(verbose: bool): - """Set internal module log levels to reduce noise.""" - if verbose: - logging.getLogger().setLevel(logging.DEBUG) - else: - for mod in ( - "whisperlivekit.audio_processor", "whisperlivekit.simul_whisper", - "whisperlivekit.tokens_alignment", "whisperlivekit.simul_whisper.align_att_base", - "whisperlivekit.simul_whisper.simul_whisper", - ): - logging.getLogger(mod).setLevel(logging.WARNING) - - -async def run_benchmark( - audio_files: List[Path], chunk_ms: int, realtime: bool, - model_size: str, lan: str, max_duration: float, vac: bool, - verbose: bool, -) -> List[TestResult]: - """Run benchmark across all available backend+policy combinations.""" - combos = detect_available_backends() - if not combos: - logger.error("No backends available. Install at least one ASR backend.") - return [] - - logger.info(f"Detected {len(combos)} backend+policy combinations:") - for c in combos: - logger.info(f" - {c['description']}") - - all_results = [] - for i, combo in enumerate(combos, 1): - backend = combo["backend"] - policy = combo["policy"] - desc = combo["description"] - - print(f"\n{'*' * 70}") - print(f"* BENCHMARK {i}/{len(combos)}: {desc}") - print(f"{'*' * 70}") - - try: - engine = create_engine( - backend, model_size, lan, vac=vac, policy=policy, - ) - _quiet_loggers(verbose) - - results = await run_all_tests( - engine, audio_files, chunk_ms, realtime, - backend=backend, policy=policy, lan=lan, - max_duration=max_duration, - ) - all_results.extend(results) - except Exception as e: - logger.error(f"Failed to run {desc}: {e}") - import traceback - traceback.print_exc() - - return all_results - - -def main(): - parser = argparse.ArgumentParser( - description="Offline backend test harness (AudioProcessor-level)" - ) - parser.add_argument( - "--backend", default="faster-whisper", - help="Backend: voxtral, voxtral-mlx, auto, faster-whisper, mlx-whisper, whisper.", - ) - parser.add_argument( - "--policy", default="", - help="Override backend policy: localagreement, simulstreaming, voxtral.", - ) - parser.add_argument( - "--audio", default=None, - help="Path to a single audio file (WAV, MP3, FLAC, etc.).", - ) - parser.add_argument( - "--audio-dir", default=None, - help="Directory of audio files to test. Defaults to audio_tests/ if neither --audio nor --audio-dir given.", - ) - parser.add_argument( - "--chunk-ms", type=int, default=100, - help="Chunk size in milliseconds (simulates real-time interval).", - ) - parser.add_argument( - "--model", default="", dest="model_size", - help="Model size or HF repo ID.", - ) - parser.add_argument("--lan", default="en", help="Language code.") - parser.add_argument( - "--no-realtime", action="store_true", - help="Skip real-time pacing between chunks (faster but less realistic).", - ) - parser.add_argument( - "--no-vac", action="store_true", - help="Disable Voice Activity Classification (send all audio without silence filtering).", - ) - parser.add_argument( - "--diarization", action="store_true", - help="Enable speaker diarization.", - ) - parser.add_argument( - "--diarization-backend", - default="", - choices=["diart", "sortformer"], - help="Diarization backend when --diarization is enabled.", - ) - parser.add_argument( - "--benchmark", action="store_true", - help="Run benchmark across all detected backend+policy combinations.", - ) - parser.add_argument( - "--json", default=None, dest="json_output", - help="Write structured JSON results to this file.", - ) - parser.add_argument( - "--max-duration", type=float, default=60.0, - help="Skip audio files longer than this many seconds (default: 60).", - ) - parser.add_argument( - "--insert-silence", nargs=2, type=float, metavar=("SECS", "AT_SEC"), - action="append", default=[], - help="Insert SECS of silence at AT_SEC position. Can be repeated. " - "E.g.: --insert-silence 3.0 2.0 --insert-silence 5.0 7.0", - ) - parser.add_argument( - "-v", "--verbose", action="store_true", - help="Show debug-level logs from all components.", - ) - args = parser.parse_args() - - realtime = not args.no_realtime - vac = not args.no_vac - - # Resolve audio file(s) - if args.audio: - audio_files = [Path(args.audio)] - elif args.audio_dir: - audio_files = discover_audio_files(args.audio_dir) - elif AUDIO_TESTS_DIR.is_dir(): - audio_files = discover_audio_files(str(AUDIO_TESTS_DIR)) - else: - # Fall back to jfk.wav download - audio_files = [download_sample_audio()] - - if not audio_files: - logger.error("No audio files found.") - sys.exit(1) - - logger.info(f"Audio files: {[f.name for f in audio_files]}") - - if args.benchmark: - # --- Multi-backend benchmark mode --- - all_results = asyncio.run( - run_benchmark( - audio_files, args.chunk_ms, realtime, - args.model_size, args.lan, args.max_duration, vac, - args.verbose, - ) - ) - if all_results: - print_cross_backend_comparison(all_results) - results = all_results - else: - # --- Single-backend mode --- - policy = args.policy - logger.info(f"Creating {args.backend} engine...") - engine = create_engine( - args.backend, args.model_size, args.lan, - diarization=args.diarization, - diarization_backend=args.diarization_backend, - vac=vac, - policy=policy, - ) - logger.info("Engine ready.") - - _quiet_loggers(args.verbose) - - results = asyncio.run( - run_all_tests( - engine, audio_files, args.chunk_ms, realtime, - args.backend, policy, args.lan, - max_duration=args.max_duration, - silence_insertions=args.insert_silence or None, - ) - ) - - if len(results) > 1: - print_benchmark_summary(results) - - # JSON output - if args.json_output and results: - json_results = [] - for r in results: - d = asdict(r) - d.pop("last_response", None) # too verbose for summary - json_results.append(d) - Path(args.json_output).write_text( - json.dumps(json_results, indent=2, ensure_ascii=False) - ) - logger.info(f"Results written to {args.json_output}") - - -if __name__ == "__main__": - main() diff --git a/whisperlivekit/benchmark/__init__.py b/whisperlivekit/benchmark/__init__.py new file mode 100644 index 0000000..07e4571 --- /dev/null +++ b/whisperlivekit/benchmark/__init__.py @@ -0,0 +1,34 @@ +"""WhisperLiveKit benchmark suite. + +Comprehensive benchmarking of ASR backends using public datasets, +run through the same pipeline as real-time streaming. + +Usage: + wlk bench # benchmark current backend + wlk bench --backend whisper --json results.json + wlk bench --languages en,fr,es # multilingual + wlk bench --quick # fast subset + +Programmatic: + from whisperlivekit.benchmark import BenchmarkRunner + import asyncio + + runner = BenchmarkRunner(backend="whisper", model_size="base") + report = asyncio.run(runner.run()) + print(report.summary_table()) +""" + +from whisperlivekit.benchmark.datasets import ( + BENCHMARK_CATALOG, + get_benchmark_samples, +) +from whisperlivekit.benchmark.metrics import BenchmarkReport, SampleResult +from whisperlivekit.benchmark.runner import BenchmarkRunner + +__all__ = [ + "BENCHMARK_CATALOG", + "BenchmarkReport", + "BenchmarkRunner", + "SampleResult", + "get_benchmark_samples", +] diff --git a/whisperlivekit/benchmark/compat.py b/whisperlivekit/benchmark/compat.py new file mode 100644 index 0000000..024e770 --- /dev/null +++ b/whisperlivekit/benchmark/compat.py @@ -0,0 +1,105 @@ +"""Backend detection and language compatibility matrix.""" + +import logging +from typing import Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + +# Language support per backend. +# None means all Whisper-supported languages. +# A set means only those languages are supported. +BACKEND_LANGUAGES: Dict[str, Optional[Set[str]]] = { + "whisper": None, + "faster-whisper": None, + "mlx-whisper": None, + "voxtral-mlx": None, + "voxtral": None, + "qwen3": { + "zh", "en", "yue", "ar", "de", "fr", "es", "pt", "id", "it", + "ko", "ru", "th", "vi", "ja", "tr", "hi", "ms", "nl", "sv", + "da", "fi", "pl", "cs", "fa", "el", "hu", "mk", "ro", + }, + "qwen3-simul": { + "zh", "en", "yue", "ar", "de", "fr", "es", "pt", "id", "it", + "ko", "ru", "th", "vi", "ja", "tr", "hi", "ms", "nl", "sv", + "da", "fi", "pl", "cs", "fa", "el", "hu", "mk", "ro", + }, +} + + +def backend_supports_language(backend: str, language: str) -> bool: + """Check if a backend supports a given language code.""" + langs = BACKEND_LANGUAGES.get(backend) + if langs is None: + return True + return language in langs + + +def detect_available_backends() -> List[str]: + """Probe which ASR backends are importable.""" + backends = [] + + try: + import whisper # noqa: F401 + backends.append("whisper") + except ImportError: + pass + + try: + import faster_whisper # noqa: F401 + backends.append("faster-whisper") + except ImportError: + pass + + try: + import mlx_whisper # noqa: F401 + backends.append("mlx-whisper") + except ImportError: + pass + + try: + import mlx.core # noqa: F401 + from whisperlivekit.voxtral_mlx.loader import load_voxtral_model # noqa: F401 + backends.append("voxtral-mlx") + except ImportError: + pass + + try: + from transformers import VoxtralRealtimeForConditionalGeneration # noqa: F401 + backends.append("voxtral") + except ImportError: + pass + + try: + from whisperlivekit.qwen3_asr import _patch_transformers_compat + _patch_transformers_compat() + from qwen_asr import Qwen3ASRModel # noqa: F401 + backends.append("qwen3") + backends.append("qwen3-simul") + except (ImportError, Exception): + pass + + return backends + + +def resolve_backend(backend: str) -> str: + """Resolve 'auto' to the best available backend.""" + if backend != "auto": + return backend + + available = detect_available_backends() + if not available: + raise RuntimeError( + "No ASR backend available. Install at least one: " + "pip install openai-whisper, faster-whisper, or mlx-whisper" + ) + + # Priority order + priority = [ + "faster-whisper", "mlx-whisper", "voxtral-mlx", "voxtral", + "qwen3", "qwen3-simul", "whisper", + ] + for p in priority: + if p in available: + return p + return available[0] diff --git a/whisperlivekit/benchmark/datasets.py b/whisperlivekit/benchmark/datasets.py new file mode 100644 index 0000000..e761a15 --- /dev/null +++ b/whisperlivekit/benchmark/datasets.py @@ -0,0 +1,561 @@ +"""Benchmark audio datasets from public HuggingFace repositories. + +Downloads curated samples across languages, noise conditions, and speaker +configurations. All datasets are public and freely accessible — no auth +tokens required. + +Samples are cached in ~/.cache/whisperlivekit/benchmark_data/ and reused +across benchmark runs. + +Datasets used: + - LibriSpeech test-clean (English, clean, single speaker) + - LibriSpeech test-other (English, noisy/hard, single speaker) + - Multilingual LibriSpeech (French, Spanish, German, Portuguese, Italian, Polish, Dutch) + - AMI (English, multi-speaker meeting) +""" + +import json +import logging +import wave +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Set + +import numpy as np + +logger = logging.getLogger(__name__) + +CACHE_DIR = Path.home() / ".cache" / "whisperlivekit" / "benchmark_data" +METADATA_FILE = "benchmark_metadata.json" + + +@dataclass +class BenchmarkSample: + """A benchmark audio sample with metadata and ground truth.""" + + name: str + path: str + reference: str + duration: float + language: str + category: str # "clean", "noisy", "multilingual", "meeting" + sample_rate: int = 16000 + n_speakers: int = 1 + source: str = "" + tags: Set[str] = field(default_factory=set) + + def to_dict(self) -> Dict: + return { + "name": self.name, + "file": Path(self.path).name, + "reference": self.reference, + "duration": self.duration, + "language": self.language, + "category": self.category, + "sample_rate": self.sample_rate, + "n_speakers": self.n_speakers, + "source": self.source, + "tags": list(self.tags), + } + + +# --------------------------------------------------------------------------- +# Dataset catalog — defines what to download +# --------------------------------------------------------------------------- + +BENCHMARK_CATALOG = { + # English clean (LibriSpeech test-clean) + "en_clean_short": { + "dataset": "openslr/librispeech_asr", + "config": "clean", + "split": "test", + "language": "en", + "category": "clean", + "n_samples": 1, + "skip": 0, + "tags": {"short"}, + }, + "en_clean_medium": { + "dataset": "openslr/librispeech_asr", + "config": "clean", + "split": "test", + "language": "en", + "category": "clean", + "n_samples": 1, + "skip": 1, + "tags": {"medium"}, + }, + # English noisy (LibriSpeech test-other) + "en_noisy_1": { + "dataset": "openslr/librispeech_asr", + "config": "other", + "split": "test", + "language": "en", + "category": "noisy", + "n_samples": 1, + "skip": 0, + "tags": {"accented"}, + }, + "en_noisy_2": { + "dataset": "openslr/librispeech_asr", + "config": "other", + "split": "test", + "language": "en", + "category": "noisy", + "n_samples": 1, + "skip": 1, + "tags": {"accented"}, + }, + # French (Multilingual LibriSpeech) + "fr_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "french", + "split": "test", + "language": "fr", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + "fr_clean_2": { + "dataset": "facebook/multilingual_librispeech", + "config": "french", + "split": "test", + "language": "fr", + "category": "multilingual", + "n_samples": 1, + "skip": 1, + "tags": set(), + }, + # Spanish (Multilingual LibriSpeech) + "es_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "spanish", + "split": "test", + "language": "es", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + # German (Multilingual LibriSpeech) + "de_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "german", + "split": "test", + "language": "de", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + # Portuguese (Multilingual LibriSpeech) + "pt_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "portuguese", + "split": "test", + "language": "pt", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + # Italian (Multilingual LibriSpeech) + "it_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "italian", + "split": "test", + "language": "it", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + # Polish (Multilingual LibriSpeech) + "pl_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "polish", + "split": "test", + "language": "pl", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + # Dutch (Multilingual LibriSpeech) + "nl_clean_1": { + "dataset": "facebook/multilingual_librispeech", + "config": "dutch", + "split": "test", + "language": "nl", + "category": "multilingual", + "n_samples": 1, + "skip": 0, + "tags": set(), + }, + # English multi-speaker meeting (AMI) + "en_meeting": { + "dataset": "edinburghcstr/ami", + "config": "ihm", + "split": "test", + "language": "en", + "category": "meeting", + "n_samples": 1, + "skip": 0, + "tags": {"multi_speaker", "long"}, + "max_duration": 60.0, + }, +} + +# Quick mode: subset of samples for fast smoke tests +QUICK_SAMPLES = {"en_clean_short", "en_clean_medium", "en_noisy_1", "fr_clean_1"} + + +# --------------------------------------------------------------------------- +# Audio utilities +# --------------------------------------------------------------------------- + +def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None: + if audio.ndim > 1: + audio = audio.mean(axis=-1) + if audio.dtype in (np.float32, np.float64): + audio = np.clip(audio, -1.0, 1.0) + audio = (audio * 32767).astype(np.int16) + elif audio.dtype != np.int16: + audio = audio.astype(np.int16) + path.parent.mkdir(parents=True, exist_ok=True) + with wave.open(str(path), "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio.tobytes()) + + +def _decode_audio(audio_bytes: bytes) -> tuple: + import io + import soundfile as sf + audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32") + return np.array(audio_array, dtype=np.float32), sr + + +def _ensure_datasets(): + try: + import datasets # noqa: F401 + except ImportError: + raise ImportError( + "The 'datasets' package is required for benchmark data. " + "Install with: pip install whisperlivekit[test]" + ) + + +# --------------------------------------------------------------------------- +# Download functions per dataset type +# --------------------------------------------------------------------------- + +def _download_librispeech(config: str, n_samples: int, skip: int, + category: str, language: str, + prefix: str) -> List[Dict]: + """Download from openslr/librispeech_asr (clean or other).""" + _ensure_datasets() + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info("Downloading LibriSpeech %s samples...", config) + ds = load_dataset( + "openslr/librispeech_asr", config, split="test", streaming=True, + ) + ds = ds.cast_column("audio", Audio(decode=False)) + + samples = [] + for i, item in enumerate(ds): + if i < skip: + continue + if len(samples) >= n_samples: + break + + audio_array, sr = _decode_audio(item["audio"]["bytes"]) + duration = len(audio_array) / sr + text = item["text"] + + wav_name = f"{prefix}_{i}.wav" + _save_wav(CACHE_DIR / wav_name, audio_array, sr) + + samples.append({ + "file": wav_name, + "reference": text, + "duration": round(duration, 2), + "sample_rate": sr, + "language": language, + "category": category, + "n_speakers": 1, + "source": f"openslr/librispeech_asr ({config})", + }) + logger.info(" %.1fs - %s", duration, text[:60]) + + return samples + + +def _download_mls(config: str, n_samples: int, skip: int, + language: str, prefix: str) -> List[Dict]: + """Download from facebook/multilingual_librispeech.""" + _ensure_datasets() + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info("Downloading MLS %s samples...", config) + ds = load_dataset( + "facebook/multilingual_librispeech", config, split="test", streaming=True, + ) + ds = ds.cast_column("audio", Audio(decode=False)) + + samples = [] + for i, item in enumerate(ds): + if i < skip: + continue + if len(samples) >= n_samples: + break + + audio_array, sr = _decode_audio(item["audio"]["bytes"]) + duration = len(audio_array) / sr + text = item.get("text", item.get("transcript", "")) + + wav_name = f"{prefix}_{i}.wav" + _save_wav(CACHE_DIR / wav_name, audio_array, sr) + + samples.append({ + "file": wav_name, + "reference": text, + "duration": round(duration, 2), + "sample_rate": sr, + "language": language, + "category": "multilingual", + "n_speakers": 1, + "source": f"facebook/multilingual_librispeech ({config})", + }) + logger.info(" [%s] %.1fs - %s", language, duration, text[:60]) + + return samples + + +def _download_fleurs(config: str, n_samples: int, skip: int, + language: str, prefix: str) -> List[Dict]: + """Download from google/fleurs.""" + _ensure_datasets() + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info("Downloading FLEURS %s samples...", config) + ds = load_dataset( + "google/fleurs", config, split="test", streaming=True, + ) + ds = ds.cast_column("audio", Audio(decode=False)) + + samples = [] + for i, item in enumerate(ds): + if i < skip: + continue + if len(samples) >= n_samples: + break + + audio_array, sr = _decode_audio(item["audio"]["bytes"]) + duration = len(audio_array) / sr + text = item.get("transcription", item.get("raw_transcription", "")) + + wav_name = f"{prefix}_{i}.wav" + _save_wav(CACHE_DIR / wav_name, audio_array, sr) + + samples.append({ + "file": wav_name, + "reference": text, + "duration": round(duration, 2), + "sample_rate": sr, + "language": language, + "category": "multilingual", + "n_speakers": 1, + "source": f"google/fleurs ({config})", + }) + logger.info(" [%s] %.1fs - %s", language, duration, text[:60]) + + return samples + + +def _download_ami(max_duration: float = 60.0) -> List[Dict]: + """Download one AMI meeting segment with multiple speakers.""" + _ensure_datasets() + import datasets.config + datasets.config.TORCHCODEC_AVAILABLE = False + from datasets import Audio, load_dataset + + logger.info("Downloading AMI meeting sample...") + ds = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True) + ds = ds.cast_column("audio", Audio(decode=False)) + + meeting_id = None + audio_arrays = [] + texts = [] + sample_rate = None + + for item in ds: + mid = item.get("meeting_id", "unknown") + if meeting_id is None: + meeting_id = mid + elif mid != meeting_id: + break + + audio_array, sr = _decode_audio(item["audio"]["bytes"]) + sample_rate = sr + texts.append(item.get("text", "")) + audio_arrays.append(audio_array) + + total_dur = sum(len(a) / sr for a in audio_arrays) + if total_dur > max_duration: + break + + if not audio_arrays: + return [] + + full_audio = np.concatenate(audio_arrays) + duration = len(full_audio) / sample_rate + reference = " ".join(t for t in texts if t) + + wav_name = "ami_meeting.wav" + _save_wav(CACHE_DIR / wav_name, full_audio, sample_rate) + + logger.info(" AMI meeting: %.1fs, %d utterances", duration, len(texts)) + return [{ + "file": wav_name, + "reference": reference, + "duration": round(duration, 2), + "sample_rate": sample_rate, + "language": "en", + "category": "meeting", + "n_speakers": 4, + "source": f"edinburghcstr/ami (ihm, meeting {meeting_id})", + }] + + +# --------------------------------------------------------------------------- +# Dispatcher — routes catalog entries to download functions +# --------------------------------------------------------------------------- + +def _download_catalog_entry(name: str, spec: Dict) -> List[Dict]: + """Download a single catalog entry and return metadata dicts.""" + dataset = spec["dataset"] + config = spec.get("config", "") + n_samples = spec.get("n_samples", 1) + skip = spec.get("skip", 0) + language = spec["language"] + category = spec["category"] + + if dataset == "openslr/librispeech_asr": + return _download_librispeech( + config=config, n_samples=n_samples, skip=skip, + category=category, language=language, prefix=name, + ) + elif dataset == "facebook/multilingual_librispeech": + return _download_mls( + config=config, n_samples=n_samples, skip=skip, + language=language, prefix=name, + ) + elif dataset == "google/fleurs": + return _download_fleurs( + config=config, n_samples=n_samples, skip=skip, + language=language, prefix=name, + ) + elif dataset == "edinburghcstr/ami": + return _download_ami(max_duration=spec.get("max_duration", 60.0)) + else: + logger.warning("Unknown dataset: %s", dataset) + return [] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def get_benchmark_samples( + languages: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + quick: bool = False, + force: bool = False, +) -> List[BenchmarkSample]: + """Download and return benchmark samples, filtered by language/category. + + Args: + languages: List of language codes to include (None = all). + categories: List of categories to include (None = all). + quick: If True, only download a small subset for smoke tests. + force: Re-download even if cached. + + Returns: + List of BenchmarkSample objects ready for benchmarking. + """ + CACHE_DIR.mkdir(parents=True, exist_ok=True) + meta_path = CACHE_DIR / METADATA_FILE + + # Load cached metadata + cached = {} + if meta_path.exists() and not force: + cached = json.loads(meta_path.read_text()) + + # Determine which entries to download + entries = BENCHMARK_CATALOG + if quick: + entries = {k: v for k, v in entries.items() if k in QUICK_SAMPLES} + + if languages: + lang_set = set(languages) + entries = {k: v for k, v in entries.items() if v["language"] in lang_set} + + if categories: + cat_set = set(categories) + entries = {k: v for k, v in entries.items() if v["category"] in cat_set} + + # Download missing entries + all_meta = cached.get("samples", {}) + for name, spec in entries.items(): + if name in all_meta and not force: + # Check file exists + file_path = CACHE_DIR / all_meta[name][0]["file"] + if file_path.exists(): + continue + + logger.info("Downloading benchmark sample: %s", name) + try: + downloaded = _download_catalog_entry(name, spec) + if downloaded: + all_meta[name] = downloaded + except Exception as e: + logger.warning("Failed to download %s: %s", name, e) + + # Save metadata + meta_path.write_text(json.dumps({"samples": all_meta}, indent=2)) + + # Build BenchmarkSample objects + samples = [] + for name, spec in entries.items(): + if name not in all_meta: + continue + for meta in all_meta[name]: + file_path = CACHE_DIR / meta["file"] + if not file_path.exists(): + continue + catalog_entry = BENCHMARK_CATALOG.get(name, {}) + samples.append(BenchmarkSample( + name=name, + path=str(file_path), + reference=meta["reference"], + duration=meta["duration"], + language=meta["language"], + category=meta["category"], + sample_rate=meta.get("sample_rate", 16000), + n_speakers=meta.get("n_speakers", 1), + source=meta.get("source", ""), + tags=set(catalog_entry.get("tags", set())), + )) + + logger.info("Loaded %d benchmark samples", len(samples)) + return samples diff --git a/whisperlivekit/benchmark/metrics.py b/whisperlivekit/benchmark/metrics.py new file mode 100644 index 0000000..100701a --- /dev/null +++ b/whisperlivekit/benchmark/metrics.py @@ -0,0 +1,273 @@ +"""Benchmark result data structures and aggregation.""" + +import platform +import subprocess +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class SampleResult: + """Result from benchmarking one audio sample.""" + + sample_name: str + language: str + category: str + duration_s: float + + # Quality + wer: float + wer_details: Dict[str, int] + + # Speed + processing_time_s: float + rtf: float + + # Latency (from SessionMetrics) + avg_latency_ms: float = 0.0 + p95_latency_ms: float = 0.0 + n_transcription_calls: int = 0 + + # Pipeline stats + n_lines: int = 0 + n_tokens: int = 0 + + # Timing quality + timing_valid: bool = True + timing_monotonic: bool = True + + # Memory + peak_memory_mb: Optional[float] = None + + # Texts + hypothesis: str = "" + reference: str = "" + + # Source + source: str = "" + tags: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "sample": self.sample_name, + "language": self.language, + "category": self.category, + "duration_s": round(self.duration_s, 2), + "wer": round(self.wer, 4), + "wer_details": self.wer_details, + "processing_time_s": round(self.processing_time_s, 2), + "rtf": round(self.rtf, 3), + "avg_latency_ms": round(self.avg_latency_ms, 1), + "p95_latency_ms": round(self.p95_latency_ms, 1), + "n_transcription_calls": self.n_transcription_calls, + "n_lines": self.n_lines, + "n_tokens": self.n_tokens, + "timing_valid": self.timing_valid, + "timing_monotonic": self.timing_monotonic, + "peak_memory_mb": round(self.peak_memory_mb, 1) if self.peak_memory_mb else None, + "hypothesis": self.hypothesis, + "reference": self.reference, + "source": self.source, + "tags": self.tags, + } + + +@dataclass +class BenchmarkReport: + """Aggregated benchmark report with system info and per-sample results.""" + + backend: str + model_size: str + timestamp: str = field(default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%S")) + system_info: Dict[str, Any] = field(default_factory=dict) + results: List[SampleResult] = field(default_factory=list) + + # --- Aggregate properties --- + + @property + def n_samples(self) -> int: + return len(self.results) + + @property + def total_audio_s(self) -> float: + return sum(r.duration_s for r in self.results) + + @property + def total_processing_s(self) -> float: + return sum(r.processing_time_s for r in self.results) + + @property + def avg_wer(self) -> float: + if not self.results: + return 0.0 + return sum(r.wer for r in self.results) / len(self.results) + + @property + def weighted_wer(self) -> float: + """Micro-averaged WER: total errors / total reference words.""" + total_errors = sum( + r.wer_details.get("substitutions", 0) + + r.wer_details.get("insertions", 0) + + r.wer_details.get("deletions", 0) + for r in self.results + ) + total_ref = sum(r.wer_details.get("ref_words", 0) for r in self.results) + return total_errors / max(total_ref, 1) + + @property + def avg_rtf(self) -> float: + if not self.results: + return 0.0 + return sum(r.rtf for r in self.results) / len(self.results) + + @property + def overall_rtf(self) -> float: + if self.total_audio_s <= 0: + return 0.0 + return self.total_processing_s / self.total_audio_s + + @property + def avg_latency_ms(self) -> float: + vals = [r.avg_latency_ms for r in self.results if r.avg_latency_ms > 0] + return sum(vals) / len(vals) if vals else 0.0 + + @property + def p95_latency_ms(self) -> float: + vals = [r.p95_latency_ms for r in self.results if r.p95_latency_ms > 0] + return sum(vals) / len(vals) if vals else 0.0 + + # --- Per-dimension breakdowns --- + + def _group_by(self, key: str) -> Dict[str, List[SampleResult]]: + groups: Dict[str, List[SampleResult]] = {} + for r in self.results: + k = getattr(r, key, "unknown") + groups.setdefault(k, []).append(r) + return groups + + def wer_by_language(self) -> Dict[str, float]: + return { + lang: sum(r.wer for r in group) / len(group) + for lang, group in sorted(self._group_by("language").items()) + } + + def rtf_by_language(self) -> Dict[str, float]: + return { + lang: sum(r.rtf for r in group) / len(group) + for lang, group in sorted(self._group_by("language").items()) + } + + def wer_by_category(self) -> Dict[str, float]: + return { + cat: sum(r.wer for r in group) / len(group) + for cat, group in sorted(self._group_by("category").items()) + } + + @property + def languages(self) -> List[str]: + return sorted(set(r.language for r in self.results)) + + @property + def categories(self) -> List[str]: + return sorted(set(r.category for r in self.results)) + + def to_dict(self) -> Dict[str, Any]: + return { + "benchmark_version": "1.0", + "timestamp": self.timestamp, + "system_info": self.system_info, + "config": { + "backend": self.backend, + "model_size": self.model_size, + }, + "summary": { + "n_samples": self.n_samples, + "total_audio_s": round(self.total_audio_s, 1), + "total_processing_s": round(self.total_processing_s, 1), + "avg_wer": round(self.avg_wer, 4), + "weighted_wer": round(self.weighted_wer, 4), + "avg_rtf": round(self.avg_rtf, 3), + "overall_rtf": round(self.overall_rtf, 3), + "avg_latency_ms": round(self.avg_latency_ms, 1), + "p95_latency_ms": round(self.p95_latency_ms, 1), + "wer_by_language": { + k: round(v, 4) for k, v in self.wer_by_language().items() + }, + "rtf_by_language": { + k: round(v, 3) for k, v in self.rtf_by_language().items() + }, + "wer_by_category": { + k: round(v, 4) for k, v in self.wer_by_category().items() + }, + }, + "results": [r.to_dict() for r in self.results], + } + + +def get_system_info() -> Dict[str, Any]: + """Collect system metadata for the benchmark report.""" + info: Dict[str, Any] = { + "platform": platform.platform(), + "machine": platform.machine(), + "python_version": platform.python_version(), + } + + # CPU info + try: + chip = subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], text=True, + ).strip() + info["cpu"] = chip + except Exception: + info["cpu"] = platform.processor() + + # RAM + try: + mem_bytes = int( + subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip() + ) + info["ram_gb"] = round(mem_bytes / (1024**3)) + except Exception: + try: + import os + pages = os.sysconf("SC_PHYS_PAGES") + page_size = os.sysconf("SC_PAGE_SIZE") + info["ram_gb"] = round(pages * page_size / (1024**3)) + except Exception: + info["ram_gb"] = None + + # Accelerator + try: + import torch + if torch.cuda.is_available(): + info["accelerator"] = torch.cuda.get_device_name(0) + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + info["accelerator"] = "Apple Silicon (MPS)" + else: + info["accelerator"] = "CPU" + except ImportError: + info["accelerator"] = "CPU" + + # Backend versions + versions = {} + for pkg, name in [ + ("faster_whisper", "faster-whisper"), + ("whisper", "openai-whisper"), + ("mlx_whisper", "mlx-whisper"), + ("transformers", "transformers"), + ("torch", "torch"), + ]: + try: + mod = __import__(pkg) + versions[name] = getattr(mod, "__version__", "installed") + except ImportError: + pass + try: + import mlx.core as mx + versions["mlx"] = mx.__version__ + except ImportError: + pass + + info["backend_versions"] = versions + return info diff --git a/whisperlivekit/benchmark/report.py b/whisperlivekit/benchmark/report.py new file mode 100644 index 0000000..56c1790 --- /dev/null +++ b/whisperlivekit/benchmark/report.py @@ -0,0 +1,161 @@ +"""Benchmark report formatting — terminal tables and JSON export.""" + +import json +import sys +from pathlib import Path +from typing import TextIO + +from whisperlivekit.benchmark.metrics import BenchmarkReport + +# ANSI color codes +GREEN = "\033[32m" +YELLOW = "\033[33m" +RED = "\033[31m" +CYAN = "\033[36m" +BOLD = "\033[1m" +DIM = "\033[2m" +RESET = "\033[0m" + + +def _wer_color(wer: float) -> str: + if wer < 0.15: + return GREEN + elif wer < 0.30: + return YELLOW + return RED + + +def _rtf_color(rtf: float) -> str: + if rtf < 0.5: + return GREEN + elif rtf < 1.0: + return YELLOW + return RED + + +def _lat_color(ms: float) -> str: + if ms < 500: + return GREEN + elif ms < 1000: + return YELLOW + return RED + + +def print_report(report: BenchmarkReport, out: TextIO = sys.stderr) -> None: + """Print a comprehensive benchmark report to the terminal.""" + w = out.write + + # Header + w(f"\n{BOLD} WhisperLiveKit Benchmark Report{RESET}\n") + w(f" {'─' * 72}\n") + + si = report.system_info + w(f" Backend: {CYAN}{report.backend}{RESET}\n") + w(f" Model: {report.model_size}\n") + w(f" Accelerator: {si.get('accelerator', 'unknown')}\n") + w(f" CPU: {si.get('cpu', 'unknown')}\n") + w(f" RAM: {si.get('ram_gb', '?')} GB\n") + w(f" Timestamp: {report.timestamp}\n") + w(f" {'─' * 72}\n\n") + + # Per-sample table + w(f" {BOLD}{'Sample':<20} {'Lang':>4} {'Dur':>5} {'WER':>7} " + f"{'RTF':>6} {'Lat(avg)':>8} {'Lat(p95)':>8} {'Calls':>5} {'Lines':>5}{RESET}\n") + w(f" {'─' * 72}\n") + + for r in report.results: + wc = _wer_color(r.wer) + rc = _rtf_color(r.rtf) + lc = _lat_color(r.avg_latency_ms) + + name = r.sample_name[:20] + w(f" {name:<20} {r.language:>4} {r.duration_s:>4.1f}s " + f"{wc}{r.wer * 100:>6.1f}%{RESET} " + f"{rc}{r.rtf:>5.2f}x{RESET} " + f"{lc}{r.avg_latency_ms:>7.0f}ms{RESET} " + f"{lc}{r.p95_latency_ms:>7.0f}ms{RESET} " + f"{r.n_transcription_calls:>5} {r.n_lines:>5}\n") + + # Timing warnings + if not r.timing_valid: + w(f" {' ' * 20} {RED}⚠ invalid timestamps{RESET}\n") + if not r.timing_monotonic: + w(f" {' ' * 20} {YELLOW}⚠ non-monotonic timestamps{RESET}\n") + + w(f" {'─' * 72}\n\n") + + # Summary + w(f" {BOLD}Summary{RESET} ({report.n_samples} samples, " + f"{report.total_audio_s:.1f}s total audio)\n\n") + + wc = _wer_color(report.avg_wer) + rc = _rtf_color(report.overall_rtf) + lc = _lat_color(report.avg_latency_ms) + + w(f" Avg WER (macro): {wc}{report.avg_wer * 100:>6.1f}%{RESET}\n") + w(f" Weighted WER: {_wer_color(report.weighted_wer)}" + f"{report.weighted_wer * 100:>6.1f}%{RESET}\n") + w(f" Overall RTF: {rc}{report.overall_rtf:>6.3f}x{RESET} " + f"({report.total_processing_s:.1f}s for {report.total_audio_s:.1f}s audio)\n") + w(f" Avg latency: {lc}{report.avg_latency_ms:>6.0f}ms{RESET}\n") + w(f" P95 latency: {_lat_color(report.p95_latency_ms)}" + f"{report.p95_latency_ms:>6.0f}ms{RESET}\n") + + # Per-language breakdown + wer_by_lang = report.wer_by_language() + rtf_by_lang = report.rtf_by_language() + if len(wer_by_lang) > 1: + w(f"\n {BOLD}By Language{RESET}\n") + w(f" {'─' * 40}\n") + w(f" {'Lang':>4} {'WER':>7} {'RTF':>6} {'Samples':>7}\n") + w(f" {'─' * 34}\n") + lang_groups = {} + for r in report.results: + lang_groups.setdefault(r.language, []).append(r) + for lang in sorted(lang_groups): + group = lang_groups[lang] + avg_wer = sum(r.wer for r in group) / len(group) + avg_rtf = sum(r.rtf for r in group) / len(group) + wc = _wer_color(avg_wer) + rc = _rtf_color(avg_rtf) + w(f" {lang:>4} {wc}{avg_wer * 100:>6.1f}%{RESET} " + f"{rc}{avg_rtf:>5.2f}x{RESET} {len(group):>7}\n") + + # Per-category breakdown + wer_by_cat = report.wer_by_category() + if len(wer_by_cat) > 1: + w(f"\n {BOLD}By Category{RESET}\n") + w(f" {'─' * 40}\n") + w(f" {'Category':>12} {'WER':>7} {'Samples':>7}\n") + w(f" {'─' * 30}\n") + cat_groups = {} + for r in report.results: + cat_groups.setdefault(r.category, []).append(r) + for cat in sorted(cat_groups): + group = cat_groups[cat] + avg_wer = sum(r.wer for r in group) / len(group) + wc = _wer_color(avg_wer) + w(f" {cat:>12} {wc}{avg_wer * 100:>6.1f}%{RESET} {len(group):>7}\n") + + w(f"\n {'─' * 72}\n\n") + + +def print_transcriptions(report: BenchmarkReport, out: TextIO = sys.stderr) -> None: + """Print hypothesis vs reference for each sample.""" + w = out.write + w(f"\n {BOLD}Transcriptions{RESET}\n") + w(f" {'─' * 72}\n") + for r in report.results: + wc = _wer_color(r.wer) + w(f"\n {BOLD}{r.sample_name}{RESET} ({r.language}, {r.category}) " + f"WER={wc}{r.wer * 100:.1f}%{RESET}\n") + ref = r.reference[:120] + "..." if len(r.reference) > 120 else r.reference + hyp = r.hypothesis[:120] + "..." if len(r.hypothesis) > 120 else r.hypothesis + w(f" {DIM}ref: {ref}{RESET}\n") + w(f" hyp: {hyp}\n") + w(f"\n {'─' * 72}\n\n") + + +def write_json(report: BenchmarkReport, path: str) -> None: + """Export the full report as JSON.""" + Path(path).write_text(json.dumps(report.to_dict(), indent=2, ensure_ascii=False)) diff --git a/whisperlivekit/benchmark/runner.py b/whisperlivekit/benchmark/runner.py new file mode 100644 index 0000000..99bdaf4 --- /dev/null +++ b/whisperlivekit/benchmark/runner.py @@ -0,0 +1,181 @@ +"""Benchmark runner — orchestrates runs through TestHarness.""" + +import logging +import resource +import time +from typing import Callable, List, Optional + +from whisperlivekit.benchmark.compat import backend_supports_language, resolve_backend +from whisperlivekit.benchmark.datasets import BenchmarkSample, get_benchmark_samples +from whisperlivekit.benchmark.metrics import BenchmarkReport, SampleResult, get_system_info + +logger = logging.getLogger(__name__) + + +class BenchmarkRunner: + """Orchestrates benchmark runs through TestHarness. + + Args: + backend: ASR backend name or "auto". + model_size: Model size (e.g. "base", "large-v3"). + languages: Language codes to benchmark (None = all available). + categories: Categories to benchmark (None = all). + quick: Use a small subset for fast smoke tests. + speed: Feed speed (0 = instant, 1.0 = real-time). + on_progress: Callback(sample_name, i, total) for progress updates. + """ + + def __init__( + self, + backend: str = "auto", + model_size: str = "base", + languages: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + quick: bool = False, + speed: float = 0, + on_progress: Optional[Callable] = None, + ): + self.backend = resolve_backend(backend) + self.model_size = model_size + self.languages = languages + self.categories = categories + self.quick = quick + self.speed = speed + self.on_progress = on_progress + + async def run(self) -> BenchmarkReport: + """Run the full benchmark suite and return a report.""" + from whisperlivekit.metrics import compute_wer + from whisperlivekit.test_harness import TestHarness + + # Get samples + samples = get_benchmark_samples( + languages=self.languages, + categories=self.categories, + quick=self.quick, + ) + + # Filter by backend language support + compatible = [] + for s in samples: + if backend_supports_language(self.backend, s.language): + compatible.append(s) + else: + logger.info( + "Skipping %s (%s) — backend %s does not support %s", + s.name, s.language, self.backend, s.language, + ) + samples = compatible + + if not samples: + raise RuntimeError( + f"No benchmark samples available for backend={self.backend}, " + f"languages={self.languages}, categories={self.categories}" + ) + + # Build harness kwargs + harness_kwargs = { + "model_size": self.model_size, + "lan": "auto", # let the model auto-detect for multilingual + "pcm_input": True, + } + if self.backend not in ("auto",): + harness_kwargs["backend"] = self.backend + + report = BenchmarkReport( + backend=self.backend, + model_size=self.model_size, + system_info=get_system_info(), + ) + + for i, sample in enumerate(samples): + if self.on_progress: + self.on_progress(sample.name, i, len(samples)) + + result = await self._run_sample( + sample, harness_kwargs, compute_wer, + ) + report.results.append(result) + + if self.on_progress: + self.on_progress("done", len(samples), len(samples)) + + return report + + async def _run_sample( + self, + sample: BenchmarkSample, + harness_kwargs: dict, + compute_wer, + ) -> SampleResult: + """Benchmark a single sample through TestHarness.""" + from whisperlivekit.test_harness import TestHarness + + # Override language for the specific sample + kwargs = {**harness_kwargs, "lan": sample.language} + + # Memory before + mem_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + t_start = time.perf_counter() + + async with TestHarness(**kwargs) as h: + await h.feed(sample.path, speed=self.speed) + # Drain time scales with audio duration for slow backends + drain = max(5.0, sample.duration * 0.5) + await h.drain(drain) + state = await h.finish(timeout=120) + + # Extract metrics from the pipeline + metrics = h.metrics + + t_elapsed = time.perf_counter() - t_start + + # Memory after + mem_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + # On macOS ru_maxrss is bytes, on Linux it's KB + import sys + divisor = 1024 * 1024 if sys.platform == "darwin" else 1024 + mem_delta = (mem_after - mem_before) / divisor + + # RTF + rtf = t_elapsed / sample.duration if sample.duration > 0 else 0 + + # WER + hypothesis = state.committed_text or state.text + wer_result = compute_wer(sample.reference, hypothesis) + + # Latency from SessionMetrics + avg_lat = metrics.avg_latency_ms if metrics else 0 + p95_lat = metrics.p95_latency_ms if metrics else 0 + n_calls = metrics.n_transcription_calls if metrics else 0 + n_tokens = metrics.n_tokens_produced if metrics else 0 + + return SampleResult( + sample_name=sample.name, + language=sample.language, + category=sample.category, + duration_s=sample.duration, + wer=wer_result["wer"], + wer_details={ + "substitutions": wer_result["substitutions"], + "insertions": wer_result["insertions"], + "deletions": wer_result["deletions"], + "ref_words": wer_result["ref_words"], + "hyp_words": wer_result["hyp_words"], + }, + processing_time_s=round(t_elapsed, 2), + rtf=round(rtf, 3), + avg_latency_ms=round(avg_lat, 1), + p95_latency_ms=round(p95_lat, 1), + n_transcription_calls=n_calls, + n_lines=len(state.speech_lines), + n_tokens=n_tokens, + timing_valid=state.timing_valid, + timing_monotonic=state.timing_monotonic, + peak_memory_mb=round(mem_delta, 1) if mem_delta > 0 else None, + hypothesis=hypothesis, + reference=sample.reference, + source=sample.source, + tags=list(sample.tags), + ) diff --git a/whisperlivekit/cli.py b/whisperlivekit/cli.py index 7c8dc98..9feb2a4 100644 --- a/whisperlivekit/cli.py +++ b/whisperlivekit/cli.py @@ -690,7 +690,11 @@ def _subtitle_timestamp(seconds: float, fmt: str) -> str: # --------------------------------------------------------------------------- def cmd_bench(args: list): - """Benchmark the transcription pipeline on standard test audio. + """Benchmark the transcription pipeline on public test audio. + + Downloads samples from LibriSpeech, Multilingual LibriSpeech, FLEURS, + and AMI on first run. Supports multilingual benchmarking across all + available backends. Usage: wlk bench [options] """ @@ -698,27 +702,48 @@ def cmd_bench(args: list): parser = argparse.ArgumentParser( prog="wlk bench", - description="Benchmark WhisperLiveKit on standard test audio.", + description="Benchmark WhisperLiveKit on public test audio.", ) - parser.add_argument("--backend", default="auto", help="ASR backend (default: auto)") - parser.add_argument("--model", default="base", dest="model_size", help="Model size (default: base)") - parser.add_argument("--language", "--lan", default="en", dest="lan", help="Language code (default: en)") - parser.add_argument("--samples", default="all", help="Sample name or 'all' (default: all)") - parser.add_argument("--json", default=None, dest="json_out", help="Export results to JSON file") - parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed logs") + parser.add_argument("--backend", default="auto", + help="ASR backend (default: auto-detect)") + parser.add_argument("--model", default="base", dest="model_size", + help="Model size (default: base)") + parser.add_argument("--languages", "--lan", default=None, + help="Comma-separated language codes, or 'all' (default: en)") + parser.add_argument("--categories", default=None, + help="Comma-separated categories: clean,noisy,multilingual,meeting") + parser.add_argument("--quick", action="store_true", + help="Quick mode: small subset for smoke tests") + parser.add_argument("--json", default=None, dest="json_out", + help="Export full report to JSON file") + parser.add_argument("--transcriptions", action="store_true", + help="Show hypothesis vs reference for each sample") + parser.add_argument("--verbose", "-v", action="store_true", + help="Show detailed logs") parsed = parser.parse_args(args) + # Parse languages + languages = None + if parsed.languages and parsed.languages != "all": + languages = [l.strip() for l in parsed.languages.split(",")] + elif parsed.languages is None: + languages = ["en"] # default to English only + + categories = None + if parsed.categories: + categories = [c.strip() for c in parsed.categories.split(",")] + import asyncio if not parsed.verbose: - asyncio.run(_run_bench_quiet(parsed)) - else: - asyncio.run(_run_bench(parsed)) + _suppress_logging() + + asyncio.run(_run_bench_new(parsed, languages, categories)) -async def _run_bench_quiet(parsed): - """Run benchmark with suppressed logging.""" +def _suppress_logging(): + """Suppress noisy logs during benchmark.""" import warnings warnings.filterwarnings("ignore") logging.root.setLevel(logging.ERROR) @@ -726,130 +751,42 @@ async def _run_bench_quiet(parsed): handler.setLevel(logging.ERROR) for name in list(logging.Logger.manager.loggerDict.keys()): logging.getLogger(name).setLevel(logging.ERROR) - await _run_bench(parsed) -async def _run_bench(parsed): - """Run the benchmark.""" - import json as json_module - import time +async def _run_bench_new(parsed, languages, categories): + """Run the benchmark using the new benchmark module.""" + from whisperlivekit.benchmark.report import print_report, print_transcriptions, write_json + from whisperlivekit.benchmark.runner import BenchmarkRunner - from whisperlivekit.metrics import compute_wer - from whisperlivekit.test_data import get_sample, get_samples - from whisperlivekit.test_harness import TestHarness + def on_progress(name, i, total): + if name == "done": + print(f"\r [{total}/{total}] Done.{' ' * 30}", file=sys.stderr) + else: + print(f"\r [{i + 1}/{total}] {name}...{' ' * 20}", + end="", file=sys.stderr, flush=True) - # Determine samples to run - if parsed.samples == "all": - print(" Downloading test samples (first run only)...", file=sys.stderr) - samples = get_samples() - # Filter to matching language - samples = [s for s in samples if s.language == parsed.lan] - if not samples: - # Fall back to all samples if none match the language - samples = get_samples() - else: - samples = [get_sample(parsed.samples)] + runner = BenchmarkRunner( + backend=parsed.backend, + model_size=parsed.model_size, + languages=languages, + categories=categories, + quick=parsed.quick, + on_progress=on_progress, + ) - backend_label = parsed.backend - if backend_label == "auto": - backend_label = "auto-detect" + print(f"\n Downloading benchmark samples (cached after first run)...", + file=sys.stderr) - print(file=sys.stderr) - print(" WhisperLiveKit Benchmark", file=sys.stderr) - print(f" Backend: {backend_label} | Model: {parsed.model_size} | Language: {parsed.lan}", file=sys.stderr) - print(f" Samples: {len(samples)}", file=sys.stderr) - print(f" {'─' * 70}", file=sys.stderr) + report = await runner.run() - results = [] + print_report(report) - kwargs = { - "model_size": parsed.model_size, - "lan": parsed.lan, - "pcm_input": True, - } - if parsed.backend != "auto": - kwargs["backend"] = parsed.backend + if parsed.transcriptions: + print_transcriptions(report) - for sample in samples: - print(f"\n {sample.name} ({sample.duration:.1f}s, {sample.language})", file=sys.stderr) - - t_start = time.perf_counter() - - async with TestHarness(**kwargs) as h: - await h.feed(sample.path, speed=0) - await h.drain(5.0) - state = await h.finish(timeout=120) - - t_elapsed = time.perf_counter() - t_start - rtf = t_elapsed / sample.duration if sample.duration > 0 else 0 - - # Compute WER - hypothesis = state.committed_text or state.text - wer_result = compute_wer(sample.reference, hypothesis) - - n_lines = len(state.speech_lines) - - result_entry = { - "sample": sample.name, - "duration_s": round(sample.duration, 2), - "processing_time_s": round(t_elapsed, 2), - "rtf": round(rtf, 3), - "wer": round(wer_result["wer"], 4), - "wer_details": { - "substitutions": wer_result["substitutions"], - "insertions": wer_result["insertions"], - "deletions": wer_result["deletions"], - "ref_words": wer_result["ref_words"], - "hyp_words": wer_result["hyp_words"], - }, - "n_lines": n_lines, - "transcription": hypothesis, - } - results.append(result_entry) - - # Print per-sample result - wer_pct = wer_result["wer"] * 100 - wer_color = "\033[32m" if wer_pct < 15 else "\033[33m" if wer_pct < 30 else "\033[31m" - rtf_color = "\033[32m" if rtf < 0.5 else "\033[33m" if rtf < 1.0 else "\033[31m" - - print(f" WER: {wer_color}{wer_pct:5.1f}%\033[0m " - f"(S:{wer_result['substitutions']} I:{wer_result['insertions']} D:{wer_result['deletions']})", - file=sys.stderr) - print(f" RTF: {rtf_color}{rtf:.3f}x\033[0m " - f"({t_elapsed:.1f}s for {sample.duration:.1f}s audio)", - file=sys.stderr) - print(f" Lines: {n_lines}", - file=sys.stderr) - - # Summary - if len(results) > 1: - avg_wer = sum(r["wer"] for r in results) / len(results) - avg_rtf = sum(r["rtf"] for r in results) / len(results) - total_audio = sum(r["duration_s"] for r in results) - total_proc = sum(r["processing_time_s"] for r in results) - - print(f"\n {'─' * 70}", file=sys.stderr) - print(f" Summary ({len(results)} samples, {total_audio:.1f}s total audio)", file=sys.stderr) - wer_color = "\033[32m" if avg_wer * 100 < 15 else "\033[33m" if avg_wer * 100 < 30 else "\033[31m" - rtf_color = "\033[32m" if avg_rtf < 0.5 else "\033[33m" if avg_rtf < 1.0 else "\033[31m" - print(f" Avg WER: {wer_color}{avg_wer * 100:5.1f}%\033[0m", file=sys.stderr) - print(f" Avg RTF: {rtf_color}{avg_rtf:.3f}x\033[0m " - f"({total_proc:.1f}s for {total_audio:.1f}s audio)", file=sys.stderr) - - print(file=sys.stderr) - - # JSON export if parsed.json_out: - export = { - "backend": parsed.backend, - "model_size": parsed.model_size, - "language": parsed.lan, - "accelerator": _gpu_info(), - "results": results, - } - with open(parsed.json_out, "w") as f: - json_module.dump(export, f, indent=2) - print(f" Results exported to: {parsed.json_out}", file=sys.stderr) + write_json(report, parsed.json_out) + print(f" Results exported to: {parsed.json_out}\n", file=sys.stderr) # ---------------------------------------------------------------------------