293 lines
9.6 KiB
Python
293 lines
9.6 KiB
Python
"""Model adapter registration."""
|
|
|
|
import math
|
|
import sys
|
|
from typing import List, Optional
|
|
import warnings
|
|
|
|
if sys.version_info >= (3, 9):
|
|
from functools import cache
|
|
else:
|
|
from functools import lru_cache as cache
|
|
|
|
import psutil
|
|
import torch
|
|
from transformers import (
|
|
AutoModel,
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
LlamaForCausalLM,
|
|
)
|
|
from peft import PeftModel
|
|
|
|
from toolbench.tool_conversation import Conversation, get_conv_template
|
|
from toolbench.model.compression import load_compress_model
|
|
from toolbench.utils import get_gpu_memory
|
|
|
|
|
|
class BaseAdapter:
|
|
"""The base and the default model adapter."""
|
|
|
|
def match(self, model_path: str):
|
|
return True
|
|
|
|
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
|
)
|
|
return model, tokenizer
|
|
|
|
def get_default_conv_template(self, model_path: str) -> Conversation:
|
|
return get_conv_template("one_shot")
|
|
|
|
|
|
# A global registry for all model adapters
|
|
model_adapters: List[BaseAdapter] = []
|
|
|
|
|
|
def register_model_adapter(cls):
|
|
"""Register a model adapter."""
|
|
model_adapters.append(cls())
|
|
|
|
|
|
@cache
|
|
def get_model_adapter(model_path: str) -> BaseAdapter:
|
|
"""Get a model adapter for a model_path."""
|
|
for adapter in model_adapters:
|
|
if adapter.match(model_path):
|
|
return adapter
|
|
raise ValueError(f"No valid model adapter for {model_path}")
|
|
|
|
|
|
def raise_warning_for_incompatible_cpu_offloading_configuration(
|
|
device: str, load_8bit: bool, cpu_offloading: bool
|
|
):
|
|
if cpu_offloading:
|
|
if not load_8bit:
|
|
warnings.warn(
|
|
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
|
|
"Use '--load-8bit' to enable 8-bit-quantization\n"
|
|
"Continuing without cpu-offloading enabled\n"
|
|
)
|
|
return False
|
|
if not "linux" in sys.platform:
|
|
warnings.warn(
|
|
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
|
|
"Continuing without cpu-offloading enabled\n"
|
|
)
|
|
return False
|
|
if device != "cuda":
|
|
warnings.warn(
|
|
"CPU-offloading is only enabled when using CUDA-devices\n"
|
|
"Continuing without cpu-offloading enabled\n"
|
|
)
|
|
return False
|
|
return cpu_offloading
|
|
|
|
|
|
def load_model(
|
|
model_path: str,
|
|
device: str,
|
|
num_gpus: int,
|
|
max_gpu_memory: Optional[str] = None,
|
|
load_8bit: bool = False,
|
|
cpu_offloading: bool = False,
|
|
debug: bool = False,
|
|
lora: bool = False,
|
|
lora_base_model : str = "huggyllama/llama-7b"
|
|
):
|
|
"""Load a model from Hugging Face."""
|
|
|
|
# Handle device mapping
|
|
cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
|
|
device, load_8bit, cpu_offloading
|
|
)
|
|
if device == "cpu":
|
|
kwargs = {"torch_dtype": torch.float32}
|
|
elif device == "cuda":
|
|
kwargs = {"torch_dtype": torch.float16}
|
|
if lora:
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
lora_base_model,
|
|
load_in_8bit=load_8bit,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
)
|
|
model = PeftModel.from_pretrained(
|
|
model,
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
|
|
elif num_gpus != 1:
|
|
|
|
kwargs["device_map"] = "auto"
|
|
if max_gpu_memory is None:
|
|
kwargs[
|
|
"device_map"
|
|
] = "sequential" # This is important for not the same VRAM sizes
|
|
available_gpu_memory = get_gpu_memory(num_gpus)
|
|
kwargs["max_memory"] = {
|
|
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
|
for i in range(num_gpus)
|
|
}
|
|
else:
|
|
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
|
|
else:
|
|
raise ValueError(f"Invalid device: {device}")
|
|
|
|
if cpu_offloading:
|
|
# raises an error on incompatible platforms
|
|
from transformers import BitsAndBytesConfig
|
|
|
|
if "max_memory" in kwargs:
|
|
kwargs["max_memory"]["cpu"] = (
|
|
str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
|
|
)
|
|
kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
load_in_8bit_fp32_cpu_offload=cpu_offloading
|
|
)
|
|
kwargs["load_in_8bit"] = load_8bit
|
|
elif load_8bit:
|
|
if num_gpus != 1:
|
|
warnings.warn(
|
|
"8-bit quantization is not supported for multi-gpu inference."
|
|
)
|
|
else:
|
|
return load_compress_model(
|
|
model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"]
|
|
)
|
|
|
|
# Load model
|
|
if not lora:
|
|
adapter = get_model_adapter(model_path)
|
|
model, tokenizer = adapter.load_model(model_path, kwargs)
|
|
else:
|
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=False, model_max_length=8192)
|
|
if device == "cuda" and num_gpus == 1 and not cpu_offloading:
|
|
model.to(device)
|
|
|
|
if debug:
|
|
print(model)
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def get_conversation_template(model_path: str) -> Conversation:
|
|
adapter = get_model_adapter(model_path)
|
|
return adapter.get_default_conv_template(model_path)
|
|
|
|
|
|
def add_model_args(parser):
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
default="lmsys/fastchat-t5-3b-v1.0",
|
|
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
choices=["cpu", "cuda"],
|
|
default="cuda",
|
|
help="The device type",
|
|
)
|
|
parser.add_argument(
|
|
"--gpus",
|
|
type=str,
|
|
default=None,
|
|
help="A single GPU like 1 or multiple GPUs like 0,2",
|
|
)
|
|
parser.add_argument("--num-gpus", type=int, default=1)
|
|
parser.add_argument(
|
|
"--max-gpu-memory",
|
|
type=str,
|
|
help="The maximum memory per gpu. Use a string like '13Gib'",
|
|
)
|
|
parser.add_argument(
|
|
"--load-8bit", action="store_true", help="Use 8-bit quantization"
|
|
)
|
|
parser.add_argument(
|
|
"--cpu-offloading",
|
|
action="store_true",
|
|
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
|
)
|
|
|
|
|
|
class VicunaAdapter(BaseAdapter):
|
|
"Model adapater for vicuna-v1.1"
|
|
|
|
def match(self, model_path: str):
|
|
return "vicuna" in model_path
|
|
|
|
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
low_cpu_mem_usage=True,
|
|
**from_pretrained_kwargs,
|
|
)
|
|
self.raise_warning_for_old_weights(model)
|
|
return model, tokenizer
|
|
|
|
def get_default_conv_template(self, model_path: str) -> Conversation:
|
|
return get_conv_template("vicuna-v1.1")
|
|
|
|
def raise_warning_for_old_weights(self, model):
|
|
if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000:
|
|
warnings.warn(
|
|
"\nYou are probably using the old Vicuna-v0 model, "
|
|
"which will generate unexpected results with the "
|
|
"current toolbench.\nYou can try one of the following methods:\n"
|
|
"1. Upgrade your weights to the new Vicuna-v1.1: https://github.com/lm-sys/FastChat#vicuna-weights.\n"
|
|
"2. Use the old conversation template by `python3 -m toolbench.serve.cli --model-path /path/to/vicuna-v0 --conv-template conv_one_shot`\n"
|
|
"3. Downgrade fschat to fschat==0.1.10 (Not recommonded).\n"
|
|
)
|
|
|
|
|
|
class ToolLlamaAdapter(BaseAdapter):
|
|
"Model adapater for tool-llama"
|
|
|
|
def match(self, model_path: str):
|
|
return "tool-llama" == model_path
|
|
|
|
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
low_cpu_mem_usage=True,
|
|
**from_pretrained_kwargs,
|
|
)
|
|
return model, tokenizer
|
|
|
|
def get_default_conv_template(self, model_path: str) -> Conversation:
|
|
return get_conv_template("tool-llama")
|
|
|
|
class ToolLlamaAdapterSingleRound(BaseAdapter):
|
|
"Model adapater for tool-llama-single-round"
|
|
|
|
def match(self, model_path: str):
|
|
return "tool-llama-single-round" == model_path
|
|
|
|
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, model_max_length=8192)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
low_cpu_mem_usage=True,
|
|
**from_pretrained_kwargs,
|
|
)
|
|
return model, tokenizer
|
|
|
|
def get_default_conv_template(self, model_path: str) -> Conversation:
|
|
return get_conv_template("tool-llama-single-round")
|
|
|
|
|
|
# Note: the registration order matters.
|
|
# The one registered earlier has a higher matching priority.
|
|
register_model_adapter(VicunaAdapter)
|
|
register_model_adapter(ToolLlamaAdapter)
|
|
register_model_adapter(ToolLlamaAdapterSingleRound)
|
|
|
|
# After all adapters, try the default base adapter.
|
|
register_model_adapter(BaseAdapter)
|