OMG-LLaVA / xtuner /tools /process_untokenized_datasets.py
zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
raw
history blame
2.56 kB
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
from mmengine import Config, ConfigDict
from mmengine.config.lazy import LazyObject
from xtuner.registry import BUILDER
# ignore FutureWarning in hf datasets
warnings.simplefilter(action='ignore', category=FutureWarning)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='config file name or path.')
parser.add_argument('--save-folder', help='The folder to save data order.')
args = parser.parse_args()
return args
def modify_config(config, dataset_save_folder):
dataset = ConfigDict(
type=LazyObject('datasets', 'load_from_disk'),
dataset_path=dataset_save_folder)
train_dataset = ConfigDict(
type=LazyObject('xtuner.dataset', 'process_hf_dataset'),
dataset=dataset,
do_dataset_tokenization=False,
tokenizer=None,
max_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_dataset_length=None,
split=None,
remove_unused_columns=False,
rename_maps=[],
pack_to_max_length=False,
input_ids_with_output=False)
config.train_dataloader.dataset = train_dataset
return config
def process_untokenized_dataset(config):
dataset = BUILDER.build(config.train_dataloader.dataset)
return dataset
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
print('Start to process untokenized dataset...')
processed_dataset = process_untokenized_dataset(cfg)
print('Processing untokenized dataset finished.')
processed_dataset_save_folder = args.save_folder
if not os.path.isabs(processed_dataset_save_folder):
processed_dataset_save_folder = os.path.join(
os.getcwd(), processed_dataset_save_folder)
modified_cfg = modify_config(cfg, processed_dataset_save_folder)
print('Start to save processed dataset...')
processed_dataset.save_to_disk(processed_dataset_save_folder)
print(
f'Processed dataset has been saved to {processed_dataset_save_folder}')
cfg_folder, cfg_file_name = os.path.split(args.config)
cfg_file_name = cfg_file_name.split('.')[0]
cfg_file_name = f'{cfg_file_name}_modified.py'
modified_cfg_save_path = os.path.join(cfg_folder, cfg_file_name)
modified_cfg.dump(modified_cfg_save_path)
print(f'Modified config has been saved to {modified_cfg_save_path}. '
'Please use this new config for the next training phase.')