"""Model adapter registration.""" import math import sys from typing import List, Optional import warnings if sys.version_info >= (3, 9): from functools import cache else: from functools import lru_cache as cache import psutil import torch from transformers import ( AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, ) from peft import PeftModel from toolbench.tool_conversation import Conversation, get_conv_template from toolbench.model.compression import load_compress_model from toolbench.utils import get_gpu_memory class BaseAdapter: """The base and the default model adapter.""" def match(self, model_path: str): return True def load_model(self, model_path: str, from_pretrained_kwargs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs ) return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("one_shot") # A global registry for all model adapters model_adapters: List[BaseAdapter] = [] def register_model_adapter(cls): """Register a model adapter.""" model_adapters.append(cls()) @cache def get_model_adapter(model_path: str) -> BaseAdapter: """Get a model adapter for a model_path.""" for adapter in model_adapters: if adapter.match(model_path): return adapter raise ValueError(f"No valid model adapter for {model_path}") def raise_warning_for_incompatible_cpu_offloading_configuration( device: str, load_8bit: bool, cpu_offloading: bool ): if cpu_offloading: if not load_8bit: warnings.warn( "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" "Use '--load-8bit' to enable 8-bit-quantization\n" "Continuing without cpu-offloading enabled\n" ) return False if not "linux" in sys.platform: warnings.warn( "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" "Continuing without cpu-offloading enabled\n" ) return False if device != "cuda": warnings.warn( "CPU-offloading is only enabled when using CUDA-devices\n" "Continuing without cpu-offloading enabled\n" ) return False return cpu_offloading def load_model( model_path: str, device: str, num_gpus: int, max_gpu_memory: Optional[str] = None, load_8bit: bool = False, cpu_offloading: bool = False, debug: bool = False, lora: bool = False, lora_base_model : str = "huggyllama/llama-7b" ): """Load a model from Hugging Face.""" # Handle device mapping cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( device, load_8bit, cpu_offloading ) if device == "cpu": kwargs = {"torch_dtype": torch.float32} elif device == "cuda": kwargs = {"torch_dtype": torch.float16} if lora: model = LlamaForCausalLM.from_pretrained( lora_base_model, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map="auto", ) model = PeftModel.from_pretrained( model, model_path, torch_dtype=torch.float16, ) elif num_gpus != 1: kwargs["device_map"] = "auto" if max_gpu_memory is None: kwargs[ "device_map" ] = "sequential" # This is important for not the same VRAM sizes available_gpu_memory = get_gpu_memory(num_gpus) kwargs["max_memory"] = { i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus) } else: kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} else: raise ValueError(f"Invalid device: {device}") if cpu_offloading: # raises an error on incompatible platforms from transformers import BitsAndBytesConfig if "max_memory" in kwargs: kwargs["max_memory"]["cpu"] = ( str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" ) kwargs["quantization_config"] = BitsAndBytesConfig( load_in_8bit_fp32_cpu_offload=cpu_offloading ) kwargs["load_in_8bit"] = load_8bit elif load_8bit: if num_gpus != 1: warnings.warn( "8-bit quantization is not supported for multi-gpu inference." ) else: return load_compress_model( model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"] ) # Load model if not lora: adapter = get_model_adapter(model_path) model, tokenizer = adapter.load_model(model_path, kwargs) else: tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=False, model_max_length=8192) if device == "cuda" and num_gpus == 1 and not cpu_offloading: model.to(device) if debug: print(model) return model, tokenizer def get_conversation_template(model_path: str) -> Conversation: adapter = get_model_adapter(model_path) return adapter.get_default_conv_template(model_path) def add_model_args(parser): parser.add_argument( "--model-path", type=str, default="lmsys/fastchat-t5-3b-v1.0", help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", ) parser.add_argument( "--device", type=str, choices=["cpu", "cuda"], default="cuda", help="The device type", ) parser.add_argument( "--gpus", type=str, default=None, help="A single GPU like 1 or multiple GPUs like 0,2", ) parser.add_argument("--num-gpus", type=int, default=1) parser.add_argument( "--max-gpu-memory", type=str, help="The maximum memory per gpu. Use a string like '13Gib'", ) parser.add_argument( "--load-8bit", action="store_true", help="Use 8-bit quantization" ) parser.add_argument( "--cpu-offloading", action="store_true", help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", ) class VicunaAdapter(BaseAdapter): "Model adapater for vicuna-v1.1" def match(self, model_path: str): return "vicuna" in model_path def load_model(self, model_path: str, from_pretrained_kwargs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs, ) self.raise_warning_for_old_weights(model) return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("vicuna-v1.1") def raise_warning_for_old_weights(self, model): if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000: warnings.warn( "\nYou are probably using the old Vicuna-v0 model, " "which will generate unexpected results with the " "current toolbench.\nYou can try one of the following methods:\n" "1. Upgrade your weights to the new Vicuna-v1.1: https://github.com/lm-sys/FastChat#vicuna-weights.\n" "2. Use the old conversation template by `python3 -m toolbench.serve.cli --model-path /path/to/vicuna-v0 --conv-template conv_one_shot`\n" "3. Downgrade fschat to fschat==0.1.10 (Not recommonded).\n" ) class ToolLlamaAdapter(BaseAdapter): "Model adapater for tool-llama" def match(self, model_path: str): return "tool-llama" == model_path def load_model(self, model_path: str, from_pretrained_kwargs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs, ) return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("tool-llama") class ToolLlamaAdapterSingleRound(BaseAdapter): "Model adapater for tool-llama-single-round" def match(self, model_path: str): return "tool-llama-single-round" == model_path def load_model(self, model_path: str, from_pretrained_kwargs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, model_max_length=8192) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs, ) return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("tool-llama-single-round") # Note: the registration order matters. # The one registered earlier has a higher matching priority. register_model_adapter(VicunaAdapter) register_model_adapter(ToolLlamaAdapter) register_model_adapter(ToolLlamaAdapterSingleRound) # After all adapters, try the default base adapter. register_model_adapter(BaseAdapter)