Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import os | |
import warnings | |
from mmengine import Config, ConfigDict | |
from mmengine.config.lazy import LazyObject | |
from xtuner.registry import BUILDER | |
# ignore FutureWarning in hf datasets | |
warnings.simplefilter(action='ignore', category=FutureWarning) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('config', help='config file name or path.') | |
parser.add_argument('--save-folder', help='The folder to save data order.') | |
args = parser.parse_args() | |
return args | |
def modify_config(config, dataset_save_folder): | |
dataset = ConfigDict( | |
type=LazyObject('datasets', 'load_from_disk'), | |
dataset_path=dataset_save_folder) | |
train_dataset = ConfigDict( | |
type=LazyObject('xtuner.dataset', 'process_hf_dataset'), | |
dataset=dataset, | |
do_dataset_tokenization=False, | |
tokenizer=None, | |
max_length=None, | |
dataset_map_fn=None, | |
template_map_fn=None, | |
max_dataset_length=None, | |
split=None, | |
remove_unused_columns=False, | |
rename_maps=[], | |
pack_to_max_length=False, | |
input_ids_with_output=False) | |
config.train_dataloader.dataset = train_dataset | |
return config | |
def process_untokenized_dataset(config): | |
dataset = BUILDER.build(config.train_dataloader.dataset) | |
return dataset | |
if __name__ == '__main__': | |
args = parse_args() | |
cfg = Config.fromfile(args.config) | |
print('Start to process untokenized dataset...') | |
processed_dataset = process_untokenized_dataset(cfg) | |
print('Processing untokenized dataset finished.') | |
processed_dataset_save_folder = args.save_folder | |
if not os.path.isabs(processed_dataset_save_folder): | |
processed_dataset_save_folder = os.path.join( | |
os.getcwd(), processed_dataset_save_folder) | |
modified_cfg = modify_config(cfg, processed_dataset_save_folder) | |
print('Start to save processed dataset...') | |
processed_dataset.save_to_disk(processed_dataset_save_folder) | |
print( | |
f'Processed dataset has been saved to {processed_dataset_save_folder}') | |
cfg_folder, cfg_file_name = os.path.split(args.config) | |
cfg_file_name = cfg_file_name.split('.')[0] | |
cfg_file_name = f'{cfg_file_name}_modified.py' | |
modified_cfg_save_path = os.path.join(cfg_folder, cfg_file_name) | |
modified_cfg.dump(modified_cfg_save_path) | |
print(f'Modified config has been saved to {modified_cfg_save_path}. ' | |
'Please use this new config for the next training phase.') | |