WhisperLiveKit/whisperlivekit/benchmark/datasets.py
2026-03-15 11:16:15 +01:00

561 lines
17 KiB
Python

"""Benchmark audio datasets from public HuggingFace repositories.
Downloads curated samples across languages, noise conditions, and speaker
configurations. All datasets are public and freely accessible — no auth
tokens required.
Samples are cached in ~/.cache/whisperlivekit/benchmark_data/ and reused
across benchmark runs.
Datasets used:
- LibriSpeech test-clean (English, clean, single speaker)
- LibriSpeech test-other (English, noisy/hard, single speaker)
- Multilingual LibriSpeech (French, Spanish, German, Portuguese, Italian, Polish, Dutch)
- AMI (English, multi-speaker meeting)
"""
import json
import logging
import wave
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set
import numpy as np
logger = logging.getLogger(__name__)
CACHE_DIR = Path.home() / ".cache" / "whisperlivekit" / "benchmark_data"
METADATA_FILE = "benchmark_metadata.json"
@dataclass
class BenchmarkSample:
"""A benchmark audio sample with metadata and ground truth."""
name: str
path: str
reference: str
duration: float
language: str
category: str # "clean", "noisy", "multilingual", "meeting"
sample_rate: int = 16000
n_speakers: int = 1
source: str = ""
tags: Set[str] = field(default_factory=set)
def to_dict(self) -> Dict:
return {
"name": self.name,
"file": Path(self.path).name,
"reference": self.reference,
"duration": self.duration,
"language": self.language,
"category": self.category,
"sample_rate": self.sample_rate,
"n_speakers": self.n_speakers,
"source": self.source,
"tags": list(self.tags),
}
# ---------------------------------------------------------------------------
# Dataset catalog — defines what to download
# ---------------------------------------------------------------------------
BENCHMARK_CATALOG = {
# English clean (LibriSpeech test-clean)
"en_clean_short": {
"dataset": "openslr/librispeech_asr",
"config": "clean",
"split": "test",
"language": "en",
"category": "clean",
"n_samples": 1,
"skip": 0,
"tags": {"short"},
},
"en_clean_medium": {
"dataset": "openslr/librispeech_asr",
"config": "clean",
"split": "test",
"language": "en",
"category": "clean",
"n_samples": 1,
"skip": 1,
"tags": {"medium"},
},
# English noisy (LibriSpeech test-other)
"en_noisy_1": {
"dataset": "openslr/librispeech_asr",
"config": "other",
"split": "test",
"language": "en",
"category": "noisy",
"n_samples": 1,
"skip": 0,
"tags": {"accented"},
},
"en_noisy_2": {
"dataset": "openslr/librispeech_asr",
"config": "other",
"split": "test",
"language": "en",
"category": "noisy",
"n_samples": 1,
"skip": 1,
"tags": {"accented"},
},
# French (Multilingual LibriSpeech)
"fr_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "french",
"split": "test",
"language": "fr",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
"fr_clean_2": {
"dataset": "facebook/multilingual_librispeech",
"config": "french",
"split": "test",
"language": "fr",
"category": "multilingual",
"n_samples": 1,
"skip": 1,
"tags": set(),
},
# Spanish (Multilingual LibriSpeech)
"es_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "spanish",
"split": "test",
"language": "es",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
# German (Multilingual LibriSpeech)
"de_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "german",
"split": "test",
"language": "de",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
# Portuguese (Multilingual LibriSpeech)
"pt_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "portuguese",
"split": "test",
"language": "pt",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
# Italian (Multilingual LibriSpeech)
"it_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "italian",
"split": "test",
"language": "it",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
# Polish (Multilingual LibriSpeech)
"pl_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "polish",
"split": "test",
"language": "pl",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
# Dutch (Multilingual LibriSpeech)
"nl_clean_1": {
"dataset": "facebook/multilingual_librispeech",
"config": "dutch",
"split": "test",
"language": "nl",
"category": "multilingual",
"n_samples": 1,
"skip": 0,
"tags": set(),
},
# English multi-speaker meeting (AMI)
"en_meeting": {
"dataset": "edinburghcstr/ami",
"config": "ihm",
"split": "test",
"language": "en",
"category": "meeting",
"n_samples": 1,
"skip": 0,
"tags": {"multi_speaker", "long"},
"max_duration": 60.0,
},
}
# Quick mode: subset of samples for fast smoke tests
QUICK_SAMPLES = {"en_clean_short", "en_clean_medium", "en_noisy_1", "fr_clean_1"}
# ---------------------------------------------------------------------------
# Audio utilities
# ---------------------------------------------------------------------------
def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
if audio.ndim > 1:
audio = audio.mean(axis=-1)
if audio.dtype in (np.float32, np.float64):
audio = np.clip(audio, -1.0, 1.0)
audio = (audio * 32767).astype(np.int16)
elif audio.dtype != np.int16:
audio = audio.astype(np.int16)
path.parent.mkdir(parents=True, exist_ok=True)
with wave.open(str(path), "w") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio.tobytes())
def _decode_audio(audio_bytes: bytes) -> tuple:
import io
import soundfile as sf
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
return np.array(audio_array, dtype=np.float32), sr
def _ensure_datasets():
try:
import datasets # noqa: F401
except ImportError:
raise ImportError(
"The 'datasets' package is required for benchmark data. "
"Install with: pip install whisperlivekit[test]"
)
# ---------------------------------------------------------------------------
# Download functions per dataset type
# ---------------------------------------------------------------------------
def _download_librispeech(config: str, n_samples: int, skip: int,
category: str, language: str,
prefix: str) -> List[Dict]:
"""Download from openslr/librispeech_asr (clean or other)."""
_ensure_datasets()
import datasets.config
datasets.config.TORCHCODEC_AVAILABLE = False
from datasets import Audio, load_dataset
logger.info("Downloading LibriSpeech %s samples...", config)
ds = load_dataset(
"openslr/librispeech_asr", config, split="test", streaming=True,
)
ds = ds.cast_column("audio", Audio(decode=False))
samples = []
for i, item in enumerate(ds):
if i < skip:
continue
if len(samples) >= n_samples:
break
audio_array, sr = _decode_audio(item["audio"]["bytes"])
duration = len(audio_array) / sr
text = item["text"]
wav_name = f"{prefix}_{i}.wav"
_save_wav(CACHE_DIR / wav_name, audio_array, sr)
samples.append({
"file": wav_name,
"reference": text,
"duration": round(duration, 2),
"sample_rate": sr,
"language": language,
"category": category,
"n_speakers": 1,
"source": f"openslr/librispeech_asr ({config})",
})
logger.info(" %.1fs - %s", duration, text[:60])
return samples
def _download_mls(config: str, n_samples: int, skip: int,
language: str, prefix: str) -> List[Dict]:
"""Download from facebook/multilingual_librispeech."""
_ensure_datasets()
import datasets.config
datasets.config.TORCHCODEC_AVAILABLE = False
from datasets import Audio, load_dataset
logger.info("Downloading MLS %s samples...", config)
ds = load_dataset(
"facebook/multilingual_librispeech", config, split="test", streaming=True,
)
ds = ds.cast_column("audio", Audio(decode=False))
samples = []
for i, item in enumerate(ds):
if i < skip:
continue
if len(samples) >= n_samples:
break
audio_array, sr = _decode_audio(item["audio"]["bytes"])
duration = len(audio_array) / sr
text = item.get("text", item.get("transcript", ""))
wav_name = f"{prefix}_{i}.wav"
_save_wav(CACHE_DIR / wav_name, audio_array, sr)
samples.append({
"file": wav_name,
"reference": text,
"duration": round(duration, 2),
"sample_rate": sr,
"language": language,
"category": "multilingual",
"n_speakers": 1,
"source": f"facebook/multilingual_librispeech ({config})",
})
logger.info(" [%s] %.1fs - %s", language, duration, text[:60])
return samples
def _download_fleurs(config: str, n_samples: int, skip: int,
language: str, prefix: str) -> List[Dict]:
"""Download from google/fleurs."""
_ensure_datasets()
import datasets.config
datasets.config.TORCHCODEC_AVAILABLE = False
from datasets import Audio, load_dataset
logger.info("Downloading FLEURS %s samples...", config)
ds = load_dataset(
"google/fleurs", config, split="test", streaming=True,
)
ds = ds.cast_column("audio", Audio(decode=False))
samples = []
for i, item in enumerate(ds):
if i < skip:
continue
if len(samples) >= n_samples:
break
audio_array, sr = _decode_audio(item["audio"]["bytes"])
duration = len(audio_array) / sr
text = item.get("transcription", item.get("raw_transcription", ""))
wav_name = f"{prefix}_{i}.wav"
_save_wav(CACHE_DIR / wav_name, audio_array, sr)
samples.append({
"file": wav_name,
"reference": text,
"duration": round(duration, 2),
"sample_rate": sr,
"language": language,
"category": "multilingual",
"n_speakers": 1,
"source": f"google/fleurs ({config})",
})
logger.info(" [%s] %.1fs - %s", language, duration, text[:60])
return samples
def _download_ami(max_duration: float = 60.0) -> List[Dict]:
"""Download one AMI meeting segment with multiple speakers."""
_ensure_datasets()
import datasets.config
datasets.config.TORCHCODEC_AVAILABLE = False
from datasets import Audio, load_dataset
logger.info("Downloading AMI meeting sample...")
ds = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True)
ds = ds.cast_column("audio", Audio(decode=False))
meeting_id = None
audio_arrays = []
texts = []
sample_rate = None
for item in ds:
mid = item.get("meeting_id", "unknown")
if meeting_id is None:
meeting_id = mid
elif mid != meeting_id:
break
audio_array, sr = _decode_audio(item["audio"]["bytes"])
sample_rate = sr
texts.append(item.get("text", ""))
audio_arrays.append(audio_array)
total_dur = sum(len(a) / sr for a in audio_arrays)
if total_dur > max_duration:
break
if not audio_arrays:
return []
full_audio = np.concatenate(audio_arrays)
duration = len(full_audio) / sample_rate
reference = " ".join(t for t in texts if t)
wav_name = "ami_meeting.wav"
_save_wav(CACHE_DIR / wav_name, full_audio, sample_rate)
logger.info(" AMI meeting: %.1fs, %d utterances", duration, len(texts))
return [{
"file": wav_name,
"reference": reference,
"duration": round(duration, 2),
"sample_rate": sample_rate,
"language": "en",
"category": "meeting",
"n_speakers": 4,
"source": f"edinburghcstr/ami (ihm, meeting {meeting_id})",
}]
# ---------------------------------------------------------------------------
# Dispatcher — routes catalog entries to download functions
# ---------------------------------------------------------------------------
def _download_catalog_entry(name: str, spec: Dict) -> List[Dict]:
"""Download a single catalog entry and return metadata dicts."""
dataset = spec["dataset"]
config = spec.get("config", "")
n_samples = spec.get("n_samples", 1)
skip = spec.get("skip", 0)
language = spec["language"]
category = spec["category"]
if dataset == "openslr/librispeech_asr":
return _download_librispeech(
config=config, n_samples=n_samples, skip=skip,
category=category, language=language, prefix=name,
)
elif dataset == "facebook/multilingual_librispeech":
return _download_mls(
config=config, n_samples=n_samples, skip=skip,
language=language, prefix=name,
)
elif dataset == "google/fleurs":
return _download_fleurs(
config=config, n_samples=n_samples, skip=skip,
language=language, prefix=name,
)
elif dataset == "edinburghcstr/ami":
return _download_ami(max_duration=spec.get("max_duration", 60.0))
else:
logger.warning("Unknown dataset: %s", dataset)
return []
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def get_benchmark_samples(
languages: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
quick: bool = False,
force: bool = False,
) -> List[BenchmarkSample]:
"""Download and return benchmark samples, filtered by language/category.
Args:
languages: List of language codes to include (None = all).
categories: List of categories to include (None = all).
quick: If True, only download a small subset for smoke tests.
force: Re-download even if cached.
Returns:
List of BenchmarkSample objects ready for benchmarking.
"""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
meta_path = CACHE_DIR / METADATA_FILE
# Load cached metadata
cached = {}
if meta_path.exists() and not force:
cached = json.loads(meta_path.read_text())
# Determine which entries to download
entries = BENCHMARK_CATALOG
if quick:
entries = {k: v for k, v in entries.items() if k in QUICK_SAMPLES}
if languages:
lang_set = set(languages)
entries = {k: v for k, v in entries.items() if v["language"] in lang_set}
if categories:
cat_set = set(categories)
entries = {k: v for k, v in entries.items() if v["category"] in cat_set}
# Download missing entries
all_meta = cached.get("samples", {})
for name, spec in entries.items():
if name in all_meta and not force:
# Check file exists
file_path = CACHE_DIR / all_meta[name][0]["file"]
if file_path.exists():
continue
logger.info("Downloading benchmark sample: %s", name)
try:
downloaded = _download_catalog_entry(name, spec)
if downloaded:
all_meta[name] = downloaded
except Exception as e:
logger.warning("Failed to download %s: %s", name, e)
# Save metadata
meta_path.write_text(json.dumps({"samples": all_meta}, indent=2))
# Build BenchmarkSample objects
samples = []
for name, spec in entries.items():
if name not in all_meta:
continue
for meta in all_meta[name]:
file_path = CACHE_DIR / meta["file"]
if not file_path.exists():
continue
catalog_entry = BENCHMARK_CATALOG.get(name, {})
samples.append(BenchmarkSample(
name=name,
path=str(file_path),
reference=meta["reference"],
duration=meta["duration"],
language=meta["language"],
category=meta["category"],
sample_rate=meta.get("sample_rate", 16000),
n_speakers=meta.get("n_speakers", 1),
source=meta.get("source", ""),
tags=set(catalog_entry.get("tags", set())),
))
logger.info("Loaded %d benchmark samples", len(samples))
return samples