AnyTool/toolbench/model/model_adapter.py
2024-02-23 15:13:06 +08:00

293 lines
9.6 KiB
Python

"""Model adapter registration."""
import math
import sys
from typing import List, Optional
import warnings
if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache as cache
import psutil
import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
LlamaForCausalLM,
)
from peft import PeftModel
from toolbench.tool_conversation import Conversation, get_conv_template
from toolbench.model.compression import load_compress_model
from toolbench.utils import get_gpu_memory
class BaseAdapter:
"""The base and the default model adapter."""
def match(self, model_path: str):
return True
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")
# A global registry for all model adapters
model_adapters: List[BaseAdapter] = []
def register_model_adapter(cls):
"""Register a model adapter."""
model_adapters.append(cls())
@cache
def get_model_adapter(model_path: str) -> BaseAdapter:
"""Get a model adapter for a model_path."""
for adapter in model_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"No valid model adapter for {model_path}")
def raise_warning_for_incompatible_cpu_offloading_configuration(
device: str, load_8bit: bool, cpu_offloading: bool
):
if cpu_offloading:
if not load_8bit:
warnings.warn(
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
"Use '--load-8bit' to enable 8-bit-quantization\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if not "linux" in sys.platform:
warnings.warn(
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if device != "cuda":
warnings.warn(
"CPU-offloading is only enabled when using CUDA-devices\n"
"Continuing without cpu-offloading enabled\n"
)
return False
return cpu_offloading
def load_model(
model_path: str,
device: str,
num_gpus: int,
max_gpu_memory: Optional[str] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
debug: bool = False,
lora: bool = False,
lora_base_model : str = "huggyllama/llama-7b"
):
"""Load a model from Hugging Face."""
# Handle device mapping
cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
device, load_8bit, cpu_offloading
)
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if lora:
model = LlamaForCausalLM.from_pretrained(
lora_base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
model_path,
torch_dtype=torch.float16,
)
elif num_gpus != 1:
kwargs["device_map"] = "auto"
if max_gpu_memory is None:
kwargs[
"device_map"
] = "sequential" # This is important for not the same VRAM sizes
available_gpu_memory = get_gpu_memory(num_gpus)
kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
else:
raise ValueError(f"Invalid device: {device}")
if cpu_offloading:
# raises an error on incompatible platforms
from transformers import BitsAndBytesConfig
if "max_memory" in kwargs:
kwargs["max_memory"]["cpu"] = (
str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
)
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit_fp32_cpu_offload=cpu_offloading
)
kwargs["load_in_8bit"] = load_8bit
elif load_8bit:
if num_gpus != 1:
warnings.warn(
"8-bit quantization is not supported for multi-gpu inference."
)
else:
return load_compress_model(
model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"]
)
# Load model
if not lora:
adapter = get_model_adapter(model_path)
model, tokenizer = adapter.load_model(model_path, kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=False, model_max_length=8192)
if device == "cuda" and num_gpus == 1 and not cpu_offloading:
model.to(device)
if debug:
print(model)
return model, tokenizer
def get_conversation_template(model_path: str) -> Conversation:
adapter = get_model_adapter(model_path)
return adapter.get_default_conv_template(model_path)
def add_model_args(parser):
parser.add_argument(
"--model-path",
type=str,
default="lmsys/fastchat-t5-3b-v1.0",
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="The device type",
)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
parser.add_argument(
"--cpu-offloading",
action="store_true",
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
)
class VicunaAdapter(BaseAdapter):
"Model adapater for vicuna-v1.1"
def match(self, model_path: str):
return "vicuna" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
self.raise_warning_for_old_weights(model)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna-v1.1")
def raise_warning_for_old_weights(self, model):
if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000:
warnings.warn(
"\nYou are probably using the old Vicuna-v0 model, "
"which will generate unexpected results with the "
"current toolbench.\nYou can try one of the following methods:\n"
"1. Upgrade your weights to the new Vicuna-v1.1: https://github.com/lm-sys/FastChat#vicuna-weights.\n"
"2. Use the old conversation template by `python3 -m toolbench.serve.cli --model-path /path/to/vicuna-v0 --conv-template conv_one_shot`\n"
"3. Downgrade fschat to fschat==0.1.10 (Not recommonded).\n"
)
class ToolLlamaAdapter(BaseAdapter):
"Model adapater for tool-llama"
def match(self, model_path: str):
return "tool-llama" == model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("tool-llama")
class ToolLlamaAdapterSingleRound(BaseAdapter):
"Model adapater for tool-llama-single-round"
def match(self, model_path: str):
return "tool-llama-single-round" == model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, model_max_length=8192)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("tool-llama-single-round")
# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(VicunaAdapter)
register_model_adapter(ToolLlamaAdapter)
register_model_adapter(ToolLlamaAdapterSingleRound)
# After all adapters, try the default base adapter.
register_model_adapter(BaseAdapter)