Distilled model compatibility with HF config.json to ModelDimensions

This commit is contained in:
Quentin Fuxa 2025-11-08 20:20:05 +01:00
parent ffe5284764
commit 0491681be4
2 changed files with 172 additions and 19 deletions

View file

@ -23,7 +23,7 @@ try:
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install `mlx-whisper\n{"="*50}`""")
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: `pip install mlx-whisper`\n{"="*50}""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
@ -33,7 +33,7 @@ else:
HAS_FASTER_WHISPER = True
except ImportError:
if platform.system() != "Darwin":
print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: pip install `faster-whisper\n{"="*50}`""")
print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""")
HAS_FASTER_WHISPER = False
def model_path_and_type(model_path):
@ -42,7 +42,8 @@ def model_path_and_type(model_path):
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if pt_path is None:
pt_path = path if path.is_file() and path.suffix.lower() == '.bin' else None
if path.is_dir():
for file in path.iterdir():
if file.is_file():
@ -52,6 +53,9 @@ def model_path_and_type(model_path):
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
if pt_path is None:
if (model_path / Path("pytorch_model.bin")).exists():
pt_path = model_path / Path("pytorch_model.bin")
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
@ -171,11 +175,11 @@ class SimulStreamingASR():
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.model_name = self.pt_path.stem
is_multilingual = not self.model_path.endswith(".en")
elif self.model_size is not None:
model_mapping = {
'tiny': './tiny.pt',
@ -191,12 +195,12 @@ class SimulStreamingASR():
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = pt_path.name.replace(".pt", "")
self.pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = self.model_size
is_multilingual = not self.model_name.endswith(".en")
self.cfg = AlignAttConfig(
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold,
language=self.lan,
@ -249,7 +253,7 @@ class SimulStreamingASR():
def load_model(self):
whisper_model = load_model(
name=self.model_path if self.model_path else self.model_name,
name=self.pt_path if self.pt_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads

View file

@ -1,9 +1,10 @@
import hashlib
import io
import json
import os
import urllib
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict
import torch
from tqdm import tqdm
@ -100,6 +101,137 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models
"""
candidates = []
if os.path.isdir(path):
candidates.append(os.path.join(path, "config.json"))
else:
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
for candidate in candidates:
if not os.path.isfile(candidate):
continue
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers")
or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
)
except KeyError as err:
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
return None
return None
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
converts a HF checkpoint state_dict into the naming convention used by
default whisper
"""
if not any(k.startswith("model.") for k in state_dict):
return state_dict
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
if remainder.startswith("self_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "attn.query",
"k_proj": "attn.key",
"v_proj": "attn.value",
"out_proj": "attn.out",
}
stem = mapping.get(suffix.split(".")[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "self_attn_layer_norm.weight":
return f"{target_prefix}.attn_ln.weight"
elif remainder == "self_attn_layer_norm.bias":
return f"{target_prefix}.attn_ln.bias"
elif remainder.startswith("encoder_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "cross_attn.query",
"k_proj": "cross_attn.key",
"v_proj": "cross_attn.value",
"out_proj": "cross_attn.out",
}
stem = mapping.get(suffix.split(".", 1)[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "encoder_attn_layer_norm.weight":
return f"{target_prefix}.cross_attn_ln.weight"
elif remainder == "encoder_attn_layer_norm.bias":
return f"{target_prefix}.cross_attn_ln.bias"
elif remainder.startswith("fc1."):
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
elif remainder.startswith("fc2."):
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
elif remainder == "final_layer_norm.weight":
return f"{target_prefix}.mlp_ln.weight"
elif remainder == "final_layer_norm.bias":
return f"{target_prefix}.mlp_ln.bias"
return None
converted = {}
for key, value in state_dict.items():
if not key.startswith("model."):
continue
subkey = key[len("model.") :]
if subkey.startswith("encoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("decoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
mapped = subkey
elif subkey == "encoder.embed_positions.weight":
mapped = "encoder.positional_embedding"
elif subkey == "decoder.embed_positions.weight":
mapped = "decoder.positional_embedding"
elif subkey == "encoder.layer_norm.weight":
mapped = "encoder.ln_post.weight"
elif subkey == "encoder.layer_norm.bias":
mapped = "encoder.ln_post.bias"
elif subkey.startswith("decoder.embed_tokens."):
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
elif subkey == "decoder.layer_norm.weight":
mapped = "decoder.ln.weight"
elif subkey == "decoder.layer_norm.bias":
mapped = "decoder.ln.bias"
else:
mapped = None
if mapped:
converted[mapped] = value
return converted if converted else state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
@ -134,7 +266,6 @@ def load_model(
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
@ -154,16 +285,34 @@ def load_model(
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
state_dict = _convert_hf_state_dict(state_dict)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(name)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
)
if not isinstance(state_dict, dict):
state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
state_dict = {
k: v for k, v in state_dict.items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
model.load_state_dict(state_dict)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)