Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import json | |
import logging | |
import os | |
import os.path as osp | |
from functools import partial | |
from types import FunctionType | |
from mmengine.config import Config, DictAction | |
from mmengine.config.lazy import LazyObject | |
from mmengine.logging import print_log | |
from mmengine.registry import RUNNERS | |
from mmengine.runner import Runner | |
from mmengine.utils import digit_version | |
from peft import get_peft_model, prepare_model_for_kbit_training | |
from transformers import TrainingArguments | |
from xtuner.configs import cfgs_name_path | |
from xtuner.dataset.collate_fns import default_collate_fn | |
from xtuner.model.modules import dispatch_modules | |
from xtuner.model.modules.dispatch import SUPPORT_FLASH2 | |
from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict | |
from xtuner.registry import BUILDER, MAP_FUNC | |
from xtuner.tools.utils import (auto_dtype_of_deepspeed_config, | |
get_seed_from_checkpoint) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train LLM') | |
parser.add_argument('config', help='config file name or path.') | |
parser.add_argument('--work-dir', help='the dir to save logs and models') | |
parser.add_argument( | |
'--deepspeed', | |
type=str, | |
default=None, | |
help='the path to the .json file for deepspeed') | |
parser.add_argument( | |
'--resume', | |
type=str, | |
default=None, | |
help='specify checkpoint path to be resumed from.') | |
parser.add_argument( | |
'--seed', type=int, default=None, help='Random seed for the training') | |
parser.add_argument( | |
'--cfg-options', | |
nargs='+', | |
action=DictAction, | |
help='override some settings in the used config, the key-value pair ' | |
'in xxx=yyy format will be merged into config file. If the value to ' | |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
'Note that the quotation marks are necessary and that no white space ' | |
'is allowed.') | |
parser.add_argument( | |
'--launcher', | |
choices=['none', 'pytorch', 'slurm', 'mpi'], | |
default='none', | |
help='job launcher') | |
parser.add_argument('--local_rank', '--local-rank', type=int, default=0) | |
args = parser.parse_args() | |
return args | |
def register_function(cfg_dict): | |
if isinstance(cfg_dict, dict): | |
for key, value in dict.items(cfg_dict): | |
if isinstance(value, FunctionType): | |
value_str = str(value) | |
if value_str not in MAP_FUNC: | |
MAP_FUNC.register_module(module=value, name=value_str) | |
cfg_dict[key] = value_str | |
else: | |
register_function(value) | |
elif isinstance(cfg_dict, (list, tuple)): | |
for value in cfg_dict: | |
register_function(value) | |
def check_cfg(cfg, args): | |
if getattr(cfg, 'use_varlen_attn', | |
False) and cfg.train_dataloader.batch_size > 1: | |
raise NotImplementedError( | |
f'If utilizing varlen attention, the batch size should be' | |
f' set to 1, but got {cfg.train_dataloader.batch_size}') | |
if getattr(cfg, 'use_varlen_attn', False): | |
sequence_parallel = getattr(cfg, 'sequence_parallel', 1) | |
max_length = getattr(cfg.train_dataloader.dataset, 'max_length', None) | |
if max_length is not None: | |
assert max_length % sequence_parallel == 0, \ | |
('When using varlen attention, `max_length` should be evenly ' | |
'divided by sequence parallel world size, but got ' | |
f'max_length = {max_length} and sequence_parallel = ' | |
f'{sequence_parallel}') | |
if getattr(cfg, 'sequence_parallel_size', 1) > 1: | |
assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use ' | |
'sequence parallel.') | |
attn_implementation = getattr(cfg.model.llm, 'attn_implementation', | |
None) | |
assert (attn_implementation is None or | |
attn_implementation == 'flash_attention_2'), \ | |
('If you want to use sequence parallel, please set ' | |
'attn_implementation to `flash_attention_2` or do not ' | |
f'set this attribute. Got `{attn_implementation}` .') | |
if getattr(cfg, 'use_varlen_attn', False): | |
assert SUPPORT_FLASH2, ('`flash_attn` is required if you set ' | |
'`use_varlen_attn` to True.') | |
attn_implementation = getattr(cfg.model.llm, 'attn_implementation', | |
None) | |
assert (attn_implementation is None or | |
attn_implementation == 'flash_attention_2'), \ | |
('If you want to set `use_varlen_attn` to True, please set' | |
' attn_implementation to `flash_attention_2` or do not ' | |
f'set this attribute. Got `{attn_implementation}` .') | |
if args.deepspeed is None: | |
assert getattr(cfg, 'sequence_parallel_size', 1) == 1, \ | |
('Sequence parallel training without DeepSpeed lacks validation.' | |
'Please use DeepSpeed to optimize the training phase by ' | |
'`--deepspeed deepspeed_zero1 (deepspeed_zero2 or ' | |
'deepspeed_zero3)`.') | |
def main(): | |
args = parse_args() | |
# parse config | |
if not osp.isfile(args.config): | |
try: | |
args.config = cfgs_name_path[args.config] | |
except KeyError: | |
raise FileNotFoundError(f'Cannot find {args.config}') | |
# load config | |
cfg = Config.fromfile(args.config) | |
if args.cfg_options is not None: | |
cfg.merge_from_dict(args.cfg_options) | |
# register FunctionType object in cfg to `MAP_FUNC` Registry and | |
# change these FunctionType object to str | |
register_function(cfg._cfg_dict) | |
check_cfg(cfg, args) | |
if cfg.get('framework', 'mmengine').lower() == 'huggingface': | |
# set default training_args | |
if cfg.get('training_args', None) is None: | |
cfg.training_args = dict(type=TrainingArguments) | |
if args.seed is not None: | |
cfg.training_args.seed = args.seed | |
# set work_dir | |
if args.work_dir is not None: | |
# update configs according to CLI args if args.work_dir is not None | |
cfg.training_args.output_dir = args.work_dir | |
elif cfg.training_args.get('output_dir', None) is None: | |
# use config filename as default work_dir if cfg.work_dir is None | |
cfg.training_args.output_dir = osp.join( | |
'./work_dirs', | |
osp.splitext(osp.basename(args.config))[0]) | |
# enable deepspeed | |
if args.deepspeed: | |
if not osp.isfile(args.deepspeed): | |
try: | |
args.deepspeed = cfgs_name_path[args.deepspeed] | |
except KeyError: | |
raise FileNotFoundError(f'Cannot find {args.deepspeed}') | |
cfg.training_args.deepspeed = args.deepspeed | |
if cfg.training_args.get('deepspeed'): | |
device_map = None | |
else: | |
# Data Parallel | |
device_map = { | |
'': int(os.environ.get('LOCAL_RANK', args.local_rank)) | |
} | |
# build training_args | |
training_args = BUILDER.build(cfg.training_args) | |
# build model | |
with LoadWoInit(): | |
cfg.model.device_map = device_map | |
traverse_dict(cfg.model) | |
model = BUILDER.build(cfg.model) | |
model.config.use_cache = False | |
dispatch_modules(model) | |
if cfg.get('lora', None): | |
lora = BUILDER.build(cfg.lora) | |
model = prepare_model_for_kbit_training(model) | |
if lora.target_modules is None: | |
modules = find_all_linear_names(model) | |
lora.target_modules = modules | |
model = get_peft_model(model, lora) | |
# build dataset | |
train_dataset = BUILDER.build(cfg.train_dataset) | |
data_collator = partial(default_collate_fn, return_hf_format=True) | |
# build trainer | |
trainer = cfg.trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
data_collator=data_collator) | |
# training | |
trainer.train(resume_from_checkpoint=args.resume) | |
trainer.save_state() | |
trainer.save_model(output_dir=training_args.output_dir) | |
else: | |
if args.seed is not None and args.resume is None: | |
# Use args.seed | |
cfg.merge_from_dict(dict(randomness=dict(seed=args.seed))) | |
print_log( | |
f'Set the random seed to {args.seed}.', | |
logger='current', | |
level=logging.INFO) | |
elif args.resume is not None: | |
# Use resumed seed | |
from mmengine.fileio import PetrelBackend, get_file_backend | |
from xtuner.utils.fileio import patch_fileio | |
backend = get_file_backend(args.resume) | |
if isinstance(backend, PetrelBackend): | |
with patch_fileio(): | |
resumed_seed = get_seed_from_checkpoint(args.resume) | |
else: | |
resumed_seed = get_seed_from_checkpoint(args.resume) | |
cfg.merge_from_dict(dict(randomness=dict(seed=resumed_seed))) | |
if args.seed is not None and args.seed != resumed_seed: | |
print_log( | |
(f'The value of random seed in resume checkpoint ' | |
f'"{args.resume}" is different from the value in ' | |
f'arguments. The resumed seed is {resumed_seed}, while ' | |
f'the input argument seed is {args.seed}. Using the ' | |
f'resumed seed {resumed_seed}.'), | |
logger='current', | |
level=logging.WARNING) | |
else: | |
print_log( | |
f'Set the random seed to {resumed_seed}.', | |
logger='current', | |
level=logging.INFO) | |
if 'LOCAL_RANK' not in os.environ: | |
os.environ['LOCAL_RANK'] = str(args.local_rank) | |
cfg.launcher = args.launcher | |
# work_dir is determined in this priority: | |
# CLI > segment in file > filename | |
if args.work_dir is not None: | |
# update configs according to CLI args if args.work_dir is not None | |
cfg.work_dir = args.work_dir | |
elif cfg.get('work_dir', None) is None: | |
# use config filename as default work_dir if cfg.work_dir is None | |
cfg.work_dir = osp.join('./work_dirs', | |
osp.splitext(osp.basename(args.config))[0]) | |
if args.deepspeed: | |
try: | |
import deepspeed | |
except ImportError: | |
raise ImportError( | |
'deepspeed is not installed properly, please check.') | |
if digit_version(deepspeed.__version__) < digit_version('0.12.3'): | |
raise RuntimeError('Please upgrade your DeepSpeed version ' | |
'by using the command pip install ' | |
'`deepspeed>=0.12.3`') | |
optim_wrapper = cfg.optim_wrapper.type | |
if optim_wrapper == 'DeepSpeedOptimWrapper': | |
print_log( | |
'Deepspeed training is already enabled in your config.', | |
logger='current', | |
level=logging.WARNING) | |
else: | |
if not osp.isfile(args.deepspeed): | |
try: | |
args.deepspeed = cfgs_name_path[args.deepspeed] | |
except KeyError: | |
raise FileNotFoundError( | |
f'Cannot find {args.deepspeed}') | |
with open(args.deepspeed) as f: | |
ds_cfg = json.load(f) | |
ds_grad_accum = ds_cfg.get('gradient_accumulation_steps', | |
'auto') | |
mm_grad_accum = cfg.optim_wrapper.get('accumulative_counts', 1) | |
if ds_grad_accum != 'auto' and ds_grad_accum != mm_grad_accum: | |
print_log(('Mismatch on gradient_accumulation_steps: ' | |
f'MMEngine {mm_grad_accum}, ' | |
f'Deepspeed {ds_grad_accum}. ' | |
f'Set to {mm_grad_accum}'), | |
logger='current', | |
level=logging.WARNING) | |
grad_accum = mm_grad_accum | |
ds_train_bs = ds_cfg.get('train_micro_batch_size_per_gpu', | |
'auto') | |
mm_train_bs = cfg.train_dataloader.batch_size | |
if ds_train_bs != 'auto' and ds_train_bs != mm_train_bs: | |
print_log( | |
('Mismatch on train_micro_batch_size_per_gpu: ' | |
f'MMEngine {mm_train_bs}, Deepspeed {ds_train_bs}. ' | |
f'Set to {mm_train_bs}'), | |
logger='current', | |
level=logging.WARNING) | |
train_bs = cfg.train_dataloader.batch_size | |
ds_grad_clip = ds_cfg.get('gradient_clipping', 'auto') | |
clip_grad = cfg.optim_wrapper.get('clip_grad', None) | |
if clip_grad and clip_grad.get('max_norm', None) is not None: | |
mm_max_norm = cfg.optim_wrapper.clip_grad.max_norm | |
else: | |
mm_max_norm = 1.0 | |
if ds_grad_clip != 'auto' and ds_grad_clip != mm_max_norm: | |
print_log( | |
('Mismatch on gradient_clipping: ' | |
f'MMEngine {mm_max_norm}, Deepspeed {ds_grad_clip}. ' | |
f'Set to {mm_max_norm}'), | |
logger='current', | |
level=logging.WARNING) | |
grad_clip = mm_max_norm | |
ds_cfg = auto_dtype_of_deepspeed_config(ds_cfg) | |
exclude_frozen_parameters = True if digit_version( | |
deepspeed.__version__) >= digit_version('0.10.1') else None | |
strategy = dict( | |
type=LazyObject('xtuner.engine', 'DeepSpeedStrategy'), | |
config=ds_cfg, | |
gradient_accumulation_steps=grad_accum, | |
train_micro_batch_size_per_gpu=train_bs, | |
gradient_clipping=grad_clip, | |
exclude_frozen_parameters=exclude_frozen_parameters, | |
sequence_parallel_size=getattr(cfg, | |
'sequence_parallel_size', | |
1)) | |
cfg.__setitem__('strategy', strategy) | |
optim_wrapper = dict( | |
type='DeepSpeedOptimWrapper', | |
optimizer=cfg.optim_wrapper.optimizer) | |
cfg.__setitem__('optim_wrapper', optim_wrapper) | |
cfg.runner_type = 'FlexibleRunner' | |
# resume is determined in this priority: resume from > auto_resume | |
if args.resume is not None: | |
cfg.resume = True | |
cfg.load_from = args.resume | |
# build the runner from config | |
if 'runner_type' not in cfg: | |
# build the default runner | |
runner = Runner.from_cfg(cfg) | |
else: | |
# build customized runner from the registry | |
# if 'runner_type' is set in the cfg | |
runner = RUNNERS.build(cfg) | |
# start training | |
runner.train() | |
if __name__ == '__main__': | |
main() | |