Add a script to detect alignement heads, usefull for distilled whisper
This commit is contained in:
parent
0491681be4
commit
a732e0903e
5 changed files with 296 additions and 3 deletions
|
|
@ -171,7 +171,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
| SimulStreaming backend options | Description | Default |
|
| SimulStreaming backend options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
|
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
|
||||||
|
| `None` |
|
||||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||||
|
|
|
||||||
BIN
scripts/alignment_heads.png
Normal file
BIN
scripts/alignment_heads.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
292
scripts/determine_alignment_heads.py
Normal file
292
scripts/determine_alignment_heads.py
Normal file
|
|
@ -0,0 +1,292 @@
|
||||||
|
"""Determine alignment heads for a variants, such as distilled model"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import gzip
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import Audio as DatasetAudio, load_dataset
|
||||||
|
import soundfile as sf
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||||
|
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
sys.path.insert(0, str(WHISPER_ROOT))
|
||||||
|
|
||||||
|
from whisper import load_model
|
||||||
|
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_clips(name, config, split, limit):
|
||||||
|
ds = load_dataset(name, config, split=split)
|
||||||
|
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||||
|
clips = []
|
||||||
|
for idx, row in enumerate(ds):
|
||||||
|
if limit is not None and idx >= limit:
|
||||||
|
break
|
||||||
|
audio_field = row["audio"]
|
||||||
|
transcript = row["text"]
|
||||||
|
|
||||||
|
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||||
|
if waveform_np.ndim > 1:
|
||||||
|
waveform_np = waveform_np.mean(axis=1)
|
||||||
|
waveform = waveform_np
|
||||||
|
transcript = str(transcript)
|
||||||
|
|
||||||
|
clips.append((waveform, transcript))
|
||||||
|
return clips
|
||||||
|
|
||||||
|
|
||||||
|
def load_clips(args):
|
||||||
|
return load_dataset_clips(
|
||||||
|
args.dataset,
|
||||||
|
args.dataset_config,
|
||||||
|
args.dataset_split,
|
||||||
|
args.dataset_num_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||||
|
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="pytorch_model.bin",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Torch device to run on",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="librispeech_asr"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-config",
|
||||||
|
type=str,
|
||||||
|
default="clean"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-split",
|
||||||
|
type=str,
|
||||||
|
default="validation[:1%]",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-num-samples",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--threshold",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="Z score threshold for a head to be selected",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--votes",
|
||||||
|
type=float,
|
||||||
|
default=0.75,
|
||||||
|
help="percentage of clips that must vote for a head",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="alignment_heads.b85",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--visualize-top-k",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def collect_heads(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
clips: Sequence[Tuple[AudioInput, str]],
|
||||||
|
threshold: float,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
device = model.device
|
||||||
|
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||||
|
strengths = torch.zeros_like(votes)
|
||||||
|
|
||||||
|
for audio_source, transcript in clips:
|
||||||
|
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||||
|
mel = log_mel_spectrogram(waveform, device=device)
|
||||||
|
|
||||||
|
tokens = torch.tensor(
|
||||||
|
[
|
||||||
|
*tokenizer.sot_sequence,
|
||||||
|
tokenizer.no_timestamps,
|
||||||
|
*tokenizer.encode(transcript),
|
||||||
|
tokenizer.eot,
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
qks = [None] * model.dims.n_text_layer
|
||||||
|
hooks = [
|
||||||
|
block.cross_attn.register_forward_hook(
|
||||||
|
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||||
|
)
|
||||||
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
for layer_idx, tensor in enumerate(qks):
|
||||||
|
if tensor is None:
|
||||||
|
continue
|
||||||
|
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||||
|
tensor = tensor.softmax(dim=-1)
|
||||||
|
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||||
|
strengths[layer_idx] += peak.mean(dim=-1)
|
||||||
|
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||||
|
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||||
|
)
|
||||||
|
mask = (zscore > 3).any(dim=-1)
|
||||||
|
votes[layer_idx] += mask.float()
|
||||||
|
|
||||||
|
votes /= len(clips)
|
||||||
|
strengths /= len(clips)
|
||||||
|
return votes, strengths
|
||||||
|
|
||||||
|
|
||||||
|
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||||
|
selected = torch.nonzero(selection, as_tuple=False)
|
||||||
|
if selected.numel() == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||||
|
for layer, head in selected
|
||||||
|
]
|
||||||
|
entries.sort(key=lambda item: item[2], reverse=True)
|
||||||
|
return entries[:top_k]
|
||||||
|
|
||||||
|
def _extract_heatmaps(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
clip: Tuple[AudioInput, str],
|
||||||
|
heads: Sequence[Tuple[int, int, float]],
|
||||||
|
) -> dict:
|
||||||
|
if not heads:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
target_map = {}
|
||||||
|
for layer, head, _ in heads:
|
||||||
|
target_map.setdefault(layer, set()).add(head)
|
||||||
|
|
||||||
|
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||||
|
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||||
|
transcript = clip[1]
|
||||||
|
tokens = torch.tensor(
|
||||||
|
[
|
||||||
|
*tokenizer.sot_sequence,
|
||||||
|
tokenizer.no_timestamps,
|
||||||
|
*tokenizer.encode(transcript),
|
||||||
|
tokenizer.eot,
|
||||||
|
],
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
QKs = [None] * model.dims.n_text_layer
|
||||||
|
hooks = [
|
||||||
|
block.cross_attn.register_forward_hook(
|
||||||
|
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||||
|
)
|
||||||
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
heatmaps = {}
|
||||||
|
for layer_idx, tensor in enumerate(QKs):
|
||||||
|
if tensor is None or layer_idx not in target_map:
|
||||||
|
continue
|
||||||
|
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||||
|
tensor = tensor.softmax(dim=-1).cpu()
|
||||||
|
for head_idx in target_map[layer_idx]:
|
||||||
|
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||||
|
|
||||||
|
return heatmaps
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_heatmaps(
|
||||||
|
heads, heatmaps, output_path):
|
||||||
|
cols = min(3, len(heads))
|
||||||
|
rows = math.ceil(len(heads) / cols)
|
||||||
|
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||||
|
|
||||||
|
for idx, (layer, head, score) in enumerate(heads):
|
||||||
|
ax = axes[idx // cols][idx % cols]
|
||||||
|
mat = heatmaps.get((layer, head))
|
||||||
|
if mat is None:
|
||||||
|
ax.axis("off")
|
||||||
|
continue
|
||||||
|
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||||
|
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||||
|
ax.set_xlabel("time")
|
||||||
|
ax.set_ylabel("tokens")
|
||||||
|
|
||||||
|
for j in range(len(heads), rows * cols):
|
||||||
|
axes[j // cols][j % cols].axis("off")
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(output_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||||
|
payload = mask.numpy().astype(np.bool_)
|
||||||
|
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(blob)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = _parse_args()
|
||||||
|
model = load_model(args.model, device=args.device)
|
||||||
|
model.eval()
|
||||||
|
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||||
|
clips = load_clips(args)
|
||||||
|
|
||||||
|
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||||
|
# selection = votes > 0.5
|
||||||
|
selection = strengths > 0.05
|
||||||
|
_dump_mask(selection.cpu(), args.output)
|
||||||
|
|
||||||
|
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||||
|
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||||
|
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
|
"""Copy core files from web directory to Chrome extension directory."""
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def sync_extension_files():
|
def sync_extension_files():
|
||||||
"""Copy core files from web directory to Chrome extension directory."""
|
|
||||||
|
|
||||||
web_dir = Path("whisperlivekit/web")
|
web_dir = Path("whisperlivekit/web")
|
||||||
extension_dir = Path("chrome-extension")
|
extension_dir = Path("chrome-extension")
|
||||||
|
|
@ -195,7 +195,6 @@ class SimulStreamingASR():
|
||||||
'large-v3': './large-v3.pt',
|
'large-v3': './large-v3.pt',
|
||||||
'large': './large-v3.pt'
|
'large': './large-v3.pt'
|
||||||
}
|
}
|
||||||
self.pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
|
|
||||||
self.model_name = self.model_size
|
self.model_name = self.model_size
|
||||||
is_multilingual = not self.model_name.endswith(".en")
|
is_multilingual = not self.model_name.endswith(".en")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue