lora loader in shared whisper core
This commit is contained in:
parent
bcffdbc6b3
commit
1bbbb7903c
1 changed files with 91 additions and 3 deletions
|
|
@ -4,11 +4,12 @@ import json
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union, Dict
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
|
|
@ -233,13 +234,97 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
|
||||||
return converted if converted else state_dict
|
return converted if converted else state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lora_state(lora_path: str):
|
||||||
|
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||||
|
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||||
|
if os.path.isfile(safe_path):
|
||||||
|
try:
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
|
||||||
|
) from exc
|
||||||
|
return load_file(safe_path)
|
||||||
|
if os.path.isfile(bin_path):
|
||||||
|
return torch.load(bin_path, map_location="cpu")
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collapse_hf_module_name(module: str):
|
||||||
|
if module.startswith("base_model."):
|
||||||
|
module = module[len("base_model.") :]
|
||||||
|
if module.startswith("model.model."):
|
||||||
|
module = module[len("model.") :]
|
||||||
|
if not module.startswith("model."):
|
||||||
|
module = f"model.{module}"
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||||
|
if not lora_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||||
|
if not os.path.isfile(config_path):
|
||||||
|
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
|
||||||
|
with open(config_path, "r", encoding="utf-8") as handle:
|
||||||
|
config = json.load(handle)
|
||||||
|
if config.get("peft_type") != "LORA":
|
||||||
|
raise ValueError("Only LoRA adapters are supported.")
|
||||||
|
|
||||||
|
r = config.get("r")
|
||||||
|
alpha = config.get("lora_alpha") or config.get("alpha")
|
||||||
|
if not r or not alpha:
|
||||||
|
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
|
||||||
|
scaling = alpha / r
|
||||||
|
|
||||||
|
adapter_state = _load_lora_state(lora_path)
|
||||||
|
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||||
|
for key, tensor in adapter_state.items():
|
||||||
|
if key.endswith("lora_A.weight"):
|
||||||
|
module = key[: -len(".lora_A.weight")]
|
||||||
|
lora_layers.setdefault(module, {})["A"] = tensor
|
||||||
|
elif key.endswith("lora_B.weight"):
|
||||||
|
module = key[: -len(".lora_B.weight")]
|
||||||
|
lora_layers.setdefault(module, {})["B"] = tensor
|
||||||
|
|
||||||
|
if not lora_layers:
|
||||||
|
raise ValueError(f"No LoRA tensors found in {lora_path}")
|
||||||
|
|
||||||
|
for module, parts in lora_layers.items():
|
||||||
|
if "A" not in parts or "B" not in parts:
|
||||||
|
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
|
||||||
|
|
||||||
|
hf_module = _collapse_hf_module_name(module)
|
||||||
|
hf_weight_key = f"{hf_module}.weight"
|
||||||
|
|
||||||
|
delta = parts["B"] @ parts["A"]
|
||||||
|
delta = delta * scaling
|
||||||
|
|
||||||
|
converted = _convert_hf_state_dict({hf_weight_key: delta})
|
||||||
|
if not converted:
|
||||||
|
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
|
||||||
|
target_name, delta_tensor = next(iter(converted.items()))
|
||||||
|
if target_name not in state_dict:
|
||||||
|
raise KeyError(
|
||||||
|
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
|
||||||
|
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name: str,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
decoder_only=False,
|
decoder_only: bool = False,
|
||||||
custom_alignment_heads=None
|
custom_alignment_heads: Optional[str] = None,
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
) -> Whisper:
|
) -> Whisper:
|
||||||
"""
|
"""
|
||||||
Load a Whisper ASR model
|
Load a Whisper ASR model
|
||||||
|
|
@ -255,6 +340,8 @@ def load_model(
|
||||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
in_memory: bool
|
in_memory: bool
|
||||||
whether to preload the model weights into host memory
|
whether to preload the model weights into host memory
|
||||||
|
lora_path: str
|
||||||
|
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -302,6 +389,7 @@ def load_model(
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
state_dict = _convert_hf_state_dict(state_dict)
|
state_dict = _convert_hf_state_dict(state_dict)
|
||||||
|
_apply_lora_adapter(state_dict, lora_path)
|
||||||
|
|
||||||
if dims_cfg is not None:
|
if dims_cfg is not None:
|
||||||
dims = ModelDimensions(**dims_cfg)
|
dims = ModelDimensions(**dims_cfg)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue