# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field import logging import pathlib import typing import os from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from peft import LoraConfig, get_peft_model import transformers from transformers import Trainer from toolbench.train.train import ( DataArguments, ModelArguments, TrainingArguments, make_supervised_data_module, ) from toolbench.train.llama_flash_attn_monkey_patch import ( replace_llama_attn_with_flash_attn, ) from toolbench.train.llama_condense_monkey_patch import replace_llama_with_condense replace_llama_attn_with_flash_attn() @dataclass class LoraArguments: lora_r: int = 8 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_target_modules: typing.List[str] = field( default_factory=lambda: ["q_proj", "v_proj"] ) lora_weight_path: str = "" lora_bias: str = "none" def maybe_zero_3(param): if hasattr(param, "ds_id"): assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} return to_return def train(): parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments, LoraArguments) ) ( model_args, data_args, training_args, lora_args, ) = parser.parse_args_into_dataclasses() if training_args.source_model_max_length < training_args.model_max_length: condense_ratio = int(training_args.model_max_length/training_args.source_model_max_length) # ratio = N means the sequence length is expanded by N, remember to change the model_max_length to 8192 (2048 * ratio) for ratio = 4 replace_llama_with_condense(ratio=condense_ratio) world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, device_map=device_map ) lora_config = LoraConfig( r=lora_args.lora_r, lora_alpha=lora_args.lora_alpha, target_modules=lora_args.lora_target_modules, lora_dropout=lora_args.lora_dropout, bias=lora_args.lora_bias, task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) if training_args.deepspeed is not None and training_args.local_rank == 0: model.print_trainable_parameters() if training_args.gradient_checkpointing: logging.warning( "gradient checkpointing with lora makes requires_grad " "incorrect and needs a monkey patch in Trainer or the " "wrapped model's forward. ref: " "https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198" ) model.enable_input_require_grads() tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, **data_module ) model.config.use_cache = False if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() # Save states. Weights might be a placeholder in zero3 and need a gather state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), lora_args.lora_bias ) if training_args.local_rank == 0: model.save_pretrained(training_args.output_dir, state_dict=state_dict) if __name__ == "__main__": train()