167 lines
No EOL
5.7 KiB
Python
167 lines
No EOL
5.7 KiB
Python
# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>
|
|
|
|
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
|
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from dataclasses import dataclass, field
|
|
import logging
|
|
import pathlib
|
|
import typing
|
|
import os
|
|
from deepspeed import zero
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
|
from peft import LoraConfig, get_peft_model
|
|
import transformers
|
|
from transformers import Trainer
|
|
|
|
from toolbench.train.train import (
|
|
DataArguments,
|
|
ModelArguments,
|
|
TrainingArguments,
|
|
make_supervised_data_module,
|
|
)
|
|
|
|
from toolbench.train.llama_flash_attn_monkey_patch import (
|
|
replace_llama_attn_with_flash_attn,
|
|
)
|
|
from toolbench.train.llama_condense_monkey_patch import replace_llama_with_condense
|
|
replace_llama_attn_with_flash_attn()
|
|
|
|
|
|
@dataclass
|
|
class LoraArguments:
|
|
lora_r: int = 8
|
|
lora_alpha: int = 16
|
|
lora_dropout: float = 0.05
|
|
lora_target_modules: typing.List[str] = field(
|
|
default_factory=lambda: ["q_proj", "v_proj"]
|
|
)
|
|
lora_weight_path: str = ""
|
|
lora_bias: str = "none"
|
|
|
|
|
|
def maybe_zero_3(param):
|
|
if hasattr(param, "ds_id"):
|
|
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
|
|
with zero.GatheredParameters([param]):
|
|
param = param.data.detach().cpu().clone()
|
|
else:
|
|
param = param.detach().cpu().clone()
|
|
return param
|
|
|
|
|
|
# Borrowed from peft.utils.get_peft_model_state_dict
|
|
def get_peft_state_maybe_zero_3(named_params, bias):
|
|
if bias == "none":
|
|
to_return = {k: t for k, t in named_params if "lora_" in k}
|
|
elif bias == "all":
|
|
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
|
elif bias == "lora_only":
|
|
to_return = {}
|
|
maybe_lora_bias = {}
|
|
lora_bias_names = set()
|
|
for k, t in named_params:
|
|
if "lora_" in k:
|
|
to_return[k] = t
|
|
bias_name = k.split("lora_")[0] + "bias"
|
|
lora_bias_names.add(bias_name)
|
|
elif "bias" in k:
|
|
maybe_lora_bias[k] = t
|
|
for k, t in maybe_lora_bias:
|
|
if bias_name in lora_bias_names:
|
|
to_return[bias_name] = t
|
|
else:
|
|
raise NotImplementedError
|
|
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
|
|
return to_return
|
|
|
|
|
|
def train():
|
|
parser = transformers.HfArgumentParser(
|
|
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
|
|
)
|
|
(
|
|
model_args,
|
|
data_args,
|
|
training_args,
|
|
lora_args,
|
|
) = parser.parse_args_into_dataclasses()
|
|
|
|
if training_args.source_model_max_length < training_args.model_max_length:
|
|
condense_ratio = int(training_args.model_max_length/training_args.source_model_max_length)
|
|
# ratio = N means the sequence length is expanded by N, remember to change the model_max_length to 8192 (2048 * ratio) for ratio = 4
|
|
replace_llama_with_condense(ratio=condense_ratio)
|
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
ddp = world_size != 1
|
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
cache_dir=training_args.cache_dir,
|
|
device_map=device_map
|
|
)
|
|
lora_config = LoraConfig(
|
|
r=lora_args.lora_r,
|
|
lora_alpha=lora_args.lora_alpha,
|
|
target_modules=lora_args.lora_target_modules,
|
|
lora_dropout=lora_args.lora_dropout,
|
|
bias=lora_args.lora_bias,
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
model = get_peft_model(model, lora_config)
|
|
if training_args.deepspeed is not None and training_args.local_rank == 0:
|
|
model.print_trainable_parameters()
|
|
|
|
if training_args.gradient_checkpointing:
|
|
logging.warning(
|
|
"gradient checkpointing with lora makes requires_grad "
|
|
"incorrect and needs a monkey patch in Trainer or the "
|
|
"wrapped model's forward. ref: "
|
|
"https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198"
|
|
)
|
|
model.enable_input_require_grads()
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
cache_dir=training_args.cache_dir,
|
|
model_max_length=training_args.model_max_length,
|
|
padding_side="right",
|
|
use_fast=False,
|
|
)
|
|
tokenizer.pad_token = tokenizer.unk_token
|
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
|
trainer = Trainer(
|
|
model=model, tokenizer=tokenizer, args=training_args, **data_module
|
|
)
|
|
|
|
model.config.use_cache = False
|
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
|
trainer.train(resume_from_checkpoint=True)
|
|
else:
|
|
trainer.train()
|
|
trainer.save_state()
|
|
|
|
# Save states. Weights might be a placeholder in zero3 and need a gather
|
|
state_dict = get_peft_state_maybe_zero_3(
|
|
model.named_parameters(), lora_args.lora_bias
|
|
)
|
|
if training_args.local_rank == 0:
|
|
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train() |