561 lines
17 KiB
Python
561 lines
17 KiB
Python
"""Benchmark audio datasets from public HuggingFace repositories.
|
|
|
|
Downloads curated samples across languages, noise conditions, and speaker
|
|
configurations. All datasets are public and freely accessible — no auth
|
|
tokens required.
|
|
|
|
Samples are cached in ~/.cache/whisperlivekit/benchmark_data/ and reused
|
|
across benchmark runs.
|
|
|
|
Datasets used:
|
|
- LibriSpeech test-clean (English, clean, single speaker)
|
|
- LibriSpeech test-other (English, noisy/hard, single speaker)
|
|
- Multilingual LibriSpeech (French, Spanish, German, Portuguese, Italian, Polish, Dutch)
|
|
- AMI (English, multi-speaker meeting)
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import wave
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CACHE_DIR = Path.home() / ".cache" / "whisperlivekit" / "benchmark_data"
|
|
METADATA_FILE = "benchmark_metadata.json"
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkSample:
|
|
"""A benchmark audio sample with metadata and ground truth."""
|
|
|
|
name: str
|
|
path: str
|
|
reference: str
|
|
duration: float
|
|
language: str
|
|
category: str # "clean", "noisy", "multilingual", "meeting"
|
|
sample_rate: int = 16000
|
|
n_speakers: int = 1
|
|
source: str = ""
|
|
tags: Set[str] = field(default_factory=set)
|
|
|
|
def to_dict(self) -> Dict:
|
|
return {
|
|
"name": self.name,
|
|
"file": Path(self.path).name,
|
|
"reference": self.reference,
|
|
"duration": self.duration,
|
|
"language": self.language,
|
|
"category": self.category,
|
|
"sample_rate": self.sample_rate,
|
|
"n_speakers": self.n_speakers,
|
|
"source": self.source,
|
|
"tags": list(self.tags),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dataset catalog — defines what to download
|
|
# ---------------------------------------------------------------------------
|
|
|
|
BENCHMARK_CATALOG = {
|
|
# English clean (LibriSpeech test-clean)
|
|
"en_clean_short": {
|
|
"dataset": "openslr/librispeech_asr",
|
|
"config": "clean",
|
|
"split": "test",
|
|
"language": "en",
|
|
"category": "clean",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": {"short"},
|
|
},
|
|
"en_clean_medium": {
|
|
"dataset": "openslr/librispeech_asr",
|
|
"config": "clean",
|
|
"split": "test",
|
|
"language": "en",
|
|
"category": "clean",
|
|
"n_samples": 1,
|
|
"skip": 1,
|
|
"tags": {"medium"},
|
|
},
|
|
# English noisy (LibriSpeech test-other)
|
|
"en_noisy_1": {
|
|
"dataset": "openslr/librispeech_asr",
|
|
"config": "other",
|
|
"split": "test",
|
|
"language": "en",
|
|
"category": "noisy",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": {"accented"},
|
|
},
|
|
"en_noisy_2": {
|
|
"dataset": "openslr/librispeech_asr",
|
|
"config": "other",
|
|
"split": "test",
|
|
"language": "en",
|
|
"category": "noisy",
|
|
"n_samples": 1,
|
|
"skip": 1,
|
|
"tags": {"accented"},
|
|
},
|
|
# French (Multilingual LibriSpeech)
|
|
"fr_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "french",
|
|
"split": "test",
|
|
"language": "fr",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
"fr_clean_2": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "french",
|
|
"split": "test",
|
|
"language": "fr",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 1,
|
|
"tags": set(),
|
|
},
|
|
# Spanish (Multilingual LibriSpeech)
|
|
"es_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "spanish",
|
|
"split": "test",
|
|
"language": "es",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
# German (Multilingual LibriSpeech)
|
|
"de_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "german",
|
|
"split": "test",
|
|
"language": "de",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
# Portuguese (Multilingual LibriSpeech)
|
|
"pt_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "portuguese",
|
|
"split": "test",
|
|
"language": "pt",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
# Italian (Multilingual LibriSpeech)
|
|
"it_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "italian",
|
|
"split": "test",
|
|
"language": "it",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
# Polish (Multilingual LibriSpeech)
|
|
"pl_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "polish",
|
|
"split": "test",
|
|
"language": "pl",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
# Dutch (Multilingual LibriSpeech)
|
|
"nl_clean_1": {
|
|
"dataset": "facebook/multilingual_librispeech",
|
|
"config": "dutch",
|
|
"split": "test",
|
|
"language": "nl",
|
|
"category": "multilingual",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": set(),
|
|
},
|
|
# English multi-speaker meeting (AMI)
|
|
"en_meeting": {
|
|
"dataset": "edinburghcstr/ami",
|
|
"config": "ihm",
|
|
"split": "test",
|
|
"language": "en",
|
|
"category": "meeting",
|
|
"n_samples": 1,
|
|
"skip": 0,
|
|
"tags": {"multi_speaker", "long"},
|
|
"max_duration": 60.0,
|
|
},
|
|
}
|
|
|
|
# Quick mode: subset of samples for fast smoke tests
|
|
QUICK_SAMPLES = {"en_clean_short", "en_clean_medium", "en_noisy_1", "fr_clean_1"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Audio utilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
|
|
if audio.ndim > 1:
|
|
audio = audio.mean(axis=-1)
|
|
if audio.dtype in (np.float32, np.float64):
|
|
audio = np.clip(audio, -1.0, 1.0)
|
|
audio = (audio * 32767).astype(np.int16)
|
|
elif audio.dtype != np.int16:
|
|
audio = audio.astype(np.int16)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with wave.open(str(path), "w") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(sample_rate)
|
|
wf.writeframes(audio.tobytes())
|
|
|
|
|
|
def _decode_audio(audio_bytes: bytes) -> tuple:
|
|
import io
|
|
import soundfile as sf
|
|
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
|
|
return np.array(audio_array, dtype=np.float32), sr
|
|
|
|
|
|
def _ensure_datasets():
|
|
try:
|
|
import datasets # noqa: F401
|
|
except ImportError:
|
|
raise ImportError(
|
|
"The 'datasets' package is required for benchmark data. "
|
|
"Install with: pip install whisperlivekit[test]"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Download functions per dataset type
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _download_librispeech(config: str, n_samples: int, skip: int,
|
|
category: str, language: str,
|
|
prefix: str) -> List[Dict]:
|
|
"""Download from openslr/librispeech_asr (clean or other)."""
|
|
_ensure_datasets()
|
|
import datasets.config
|
|
datasets.config.TORCHCODEC_AVAILABLE = False
|
|
from datasets import Audio, load_dataset
|
|
|
|
logger.info("Downloading LibriSpeech %s samples...", config)
|
|
ds = load_dataset(
|
|
"openslr/librispeech_asr", config, split="test", streaming=True,
|
|
)
|
|
ds = ds.cast_column("audio", Audio(decode=False))
|
|
|
|
samples = []
|
|
for i, item in enumerate(ds):
|
|
if i < skip:
|
|
continue
|
|
if len(samples) >= n_samples:
|
|
break
|
|
|
|
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
|
duration = len(audio_array) / sr
|
|
text = item["text"]
|
|
|
|
wav_name = f"{prefix}_{i}.wav"
|
|
_save_wav(CACHE_DIR / wav_name, audio_array, sr)
|
|
|
|
samples.append({
|
|
"file": wav_name,
|
|
"reference": text,
|
|
"duration": round(duration, 2),
|
|
"sample_rate": sr,
|
|
"language": language,
|
|
"category": category,
|
|
"n_speakers": 1,
|
|
"source": f"openslr/librispeech_asr ({config})",
|
|
})
|
|
logger.info(" %.1fs - %s", duration, text[:60])
|
|
|
|
return samples
|
|
|
|
|
|
def _download_mls(config: str, n_samples: int, skip: int,
|
|
language: str, prefix: str) -> List[Dict]:
|
|
"""Download from facebook/multilingual_librispeech."""
|
|
_ensure_datasets()
|
|
import datasets.config
|
|
datasets.config.TORCHCODEC_AVAILABLE = False
|
|
from datasets import Audio, load_dataset
|
|
|
|
logger.info("Downloading MLS %s samples...", config)
|
|
ds = load_dataset(
|
|
"facebook/multilingual_librispeech", config, split="test", streaming=True,
|
|
)
|
|
ds = ds.cast_column("audio", Audio(decode=False))
|
|
|
|
samples = []
|
|
for i, item in enumerate(ds):
|
|
if i < skip:
|
|
continue
|
|
if len(samples) >= n_samples:
|
|
break
|
|
|
|
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
|
duration = len(audio_array) / sr
|
|
text = item.get("text", item.get("transcript", ""))
|
|
|
|
wav_name = f"{prefix}_{i}.wav"
|
|
_save_wav(CACHE_DIR / wav_name, audio_array, sr)
|
|
|
|
samples.append({
|
|
"file": wav_name,
|
|
"reference": text,
|
|
"duration": round(duration, 2),
|
|
"sample_rate": sr,
|
|
"language": language,
|
|
"category": "multilingual",
|
|
"n_speakers": 1,
|
|
"source": f"facebook/multilingual_librispeech ({config})",
|
|
})
|
|
logger.info(" [%s] %.1fs - %s", language, duration, text[:60])
|
|
|
|
return samples
|
|
|
|
|
|
def _download_fleurs(config: str, n_samples: int, skip: int,
|
|
language: str, prefix: str) -> List[Dict]:
|
|
"""Download from google/fleurs."""
|
|
_ensure_datasets()
|
|
import datasets.config
|
|
datasets.config.TORCHCODEC_AVAILABLE = False
|
|
from datasets import Audio, load_dataset
|
|
|
|
logger.info("Downloading FLEURS %s samples...", config)
|
|
ds = load_dataset(
|
|
"google/fleurs", config, split="test", streaming=True,
|
|
)
|
|
ds = ds.cast_column("audio", Audio(decode=False))
|
|
|
|
samples = []
|
|
for i, item in enumerate(ds):
|
|
if i < skip:
|
|
continue
|
|
if len(samples) >= n_samples:
|
|
break
|
|
|
|
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
|
duration = len(audio_array) / sr
|
|
text = item.get("transcription", item.get("raw_transcription", ""))
|
|
|
|
wav_name = f"{prefix}_{i}.wav"
|
|
_save_wav(CACHE_DIR / wav_name, audio_array, sr)
|
|
|
|
samples.append({
|
|
"file": wav_name,
|
|
"reference": text,
|
|
"duration": round(duration, 2),
|
|
"sample_rate": sr,
|
|
"language": language,
|
|
"category": "multilingual",
|
|
"n_speakers": 1,
|
|
"source": f"google/fleurs ({config})",
|
|
})
|
|
logger.info(" [%s] %.1fs - %s", language, duration, text[:60])
|
|
|
|
return samples
|
|
|
|
|
|
def _download_ami(max_duration: float = 60.0) -> List[Dict]:
|
|
"""Download one AMI meeting segment with multiple speakers."""
|
|
_ensure_datasets()
|
|
import datasets.config
|
|
datasets.config.TORCHCODEC_AVAILABLE = False
|
|
from datasets import Audio, load_dataset
|
|
|
|
logger.info("Downloading AMI meeting sample...")
|
|
ds = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True)
|
|
ds = ds.cast_column("audio", Audio(decode=False))
|
|
|
|
meeting_id = None
|
|
audio_arrays = []
|
|
texts = []
|
|
sample_rate = None
|
|
|
|
for item in ds:
|
|
mid = item.get("meeting_id", "unknown")
|
|
if meeting_id is None:
|
|
meeting_id = mid
|
|
elif mid != meeting_id:
|
|
break
|
|
|
|
audio_array, sr = _decode_audio(item["audio"]["bytes"])
|
|
sample_rate = sr
|
|
texts.append(item.get("text", ""))
|
|
audio_arrays.append(audio_array)
|
|
|
|
total_dur = sum(len(a) / sr for a in audio_arrays)
|
|
if total_dur > max_duration:
|
|
break
|
|
|
|
if not audio_arrays:
|
|
return []
|
|
|
|
full_audio = np.concatenate(audio_arrays)
|
|
duration = len(full_audio) / sample_rate
|
|
reference = " ".join(t for t in texts if t)
|
|
|
|
wav_name = "ami_meeting.wav"
|
|
_save_wav(CACHE_DIR / wav_name, full_audio, sample_rate)
|
|
|
|
logger.info(" AMI meeting: %.1fs, %d utterances", duration, len(texts))
|
|
return [{
|
|
"file": wav_name,
|
|
"reference": reference,
|
|
"duration": round(duration, 2),
|
|
"sample_rate": sample_rate,
|
|
"language": "en",
|
|
"category": "meeting",
|
|
"n_speakers": 4,
|
|
"source": f"edinburghcstr/ami (ihm, meeting {meeting_id})",
|
|
}]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dispatcher — routes catalog entries to download functions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _download_catalog_entry(name: str, spec: Dict) -> List[Dict]:
|
|
"""Download a single catalog entry and return metadata dicts."""
|
|
dataset = spec["dataset"]
|
|
config = spec.get("config", "")
|
|
n_samples = spec.get("n_samples", 1)
|
|
skip = spec.get("skip", 0)
|
|
language = spec["language"]
|
|
category = spec["category"]
|
|
|
|
if dataset == "openslr/librispeech_asr":
|
|
return _download_librispeech(
|
|
config=config, n_samples=n_samples, skip=skip,
|
|
category=category, language=language, prefix=name,
|
|
)
|
|
elif dataset == "facebook/multilingual_librispeech":
|
|
return _download_mls(
|
|
config=config, n_samples=n_samples, skip=skip,
|
|
language=language, prefix=name,
|
|
)
|
|
elif dataset == "google/fleurs":
|
|
return _download_fleurs(
|
|
config=config, n_samples=n_samples, skip=skip,
|
|
language=language, prefix=name,
|
|
)
|
|
elif dataset == "edinburghcstr/ami":
|
|
return _download_ami(max_duration=spec.get("max_duration", 60.0))
|
|
else:
|
|
logger.warning("Unknown dataset: %s", dataset)
|
|
return []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_benchmark_samples(
|
|
languages: Optional[List[str]] = None,
|
|
categories: Optional[List[str]] = None,
|
|
quick: bool = False,
|
|
force: bool = False,
|
|
) -> List[BenchmarkSample]:
|
|
"""Download and return benchmark samples, filtered by language/category.
|
|
|
|
Args:
|
|
languages: List of language codes to include (None = all).
|
|
categories: List of categories to include (None = all).
|
|
quick: If True, only download a small subset for smoke tests.
|
|
force: Re-download even if cached.
|
|
|
|
Returns:
|
|
List of BenchmarkSample objects ready for benchmarking.
|
|
"""
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
meta_path = CACHE_DIR / METADATA_FILE
|
|
|
|
# Load cached metadata
|
|
cached = {}
|
|
if meta_path.exists() and not force:
|
|
cached = json.loads(meta_path.read_text())
|
|
|
|
# Determine which entries to download
|
|
entries = BENCHMARK_CATALOG
|
|
if quick:
|
|
entries = {k: v for k, v in entries.items() if k in QUICK_SAMPLES}
|
|
|
|
if languages:
|
|
lang_set = set(languages)
|
|
entries = {k: v for k, v in entries.items() if v["language"] in lang_set}
|
|
|
|
if categories:
|
|
cat_set = set(categories)
|
|
entries = {k: v for k, v in entries.items() if v["category"] in cat_set}
|
|
|
|
# Download missing entries
|
|
all_meta = cached.get("samples", {})
|
|
for name, spec in entries.items():
|
|
if name in all_meta and not force:
|
|
# Check file exists
|
|
file_path = CACHE_DIR / all_meta[name][0]["file"]
|
|
if file_path.exists():
|
|
continue
|
|
|
|
logger.info("Downloading benchmark sample: %s", name)
|
|
try:
|
|
downloaded = _download_catalog_entry(name, spec)
|
|
if downloaded:
|
|
all_meta[name] = downloaded
|
|
except Exception as e:
|
|
logger.warning("Failed to download %s: %s", name, e)
|
|
|
|
# Save metadata
|
|
meta_path.write_text(json.dumps({"samples": all_meta}, indent=2))
|
|
|
|
# Build BenchmarkSample objects
|
|
samples = []
|
|
for name, spec in entries.items():
|
|
if name not in all_meta:
|
|
continue
|
|
for meta in all_meta[name]:
|
|
file_path = CACHE_DIR / meta["file"]
|
|
if not file_path.exists():
|
|
continue
|
|
catalog_entry = BENCHMARK_CATALOG.get(name, {})
|
|
samples.append(BenchmarkSample(
|
|
name=name,
|
|
path=str(file_path),
|
|
reference=meta["reference"],
|
|
duration=meta["duration"],
|
|
language=meta["language"],
|
|
category=meta["category"],
|
|
sample_rate=meta.get("sample_rate", 16000),
|
|
n_speakers=meta.get("n_speakers", 1),
|
|
source=meta.get("source", ""),
|
|
tags=set(catalog_entry.get("tags", set())),
|
|
))
|
|
|
|
logger.info("Loaded %d benchmark samples", len(samples))
|
|
return samples
|