# Copyright (c) OpenMMLab. All rights reserved. import argparse from functools import partial import numpy as np from datasets import DatasetDict from mmengine.config import Config from xtuner.dataset.utils import Packer, encode_fn from xtuner.registry import BUILDER def parse_args(): parser = argparse.ArgumentParser( description='Verify the correctness of the config file for the ' 'custom dataset.') parser.add_argument('config', help='config file name or path.') args = parser.parse_args() return args def is_standard_format(dataset): example = next(iter(dataset)) if 'conversation' not in example: return False conversation = example['conversation'] if not isinstance(conversation, list): return False for item in conversation: if (not isinstance(item, dict)) or ('input' not in item) or ('output' not in item): return False input, output = item['input'], item['output'] if (not isinstance(input, str)) or (not isinstance(output, str)): return False return True def main(): args = parse_args() cfg = Config.fromfile(args.config) tokenizer = BUILDER.build(cfg.tokenizer) if cfg.get('framework', 'mmengine').lower() == 'huggingface': train_dataset = cfg.train_dataset else: train_dataset = cfg.train_dataloader.dataset dataset = train_dataset.dataset max_length = train_dataset.max_length dataset_map_fn = train_dataset.get('dataset_map_fn', None) template_map_fn = train_dataset.get('template_map_fn', None) max_dataset_length = train_dataset.get('max_dataset_length', 10) split = train_dataset.get('split', 'train') remove_unused_columns = train_dataset.get('remove_unused_columns', False) rename_maps = train_dataset.get('rename_maps', []) shuffle_before_pack = train_dataset.get('shuffle_before_pack', True) pack_to_max_length = train_dataset.get('pack_to_max_length', True) input_ids_with_output = train_dataset.get('input_ids_with_output', True) if dataset.get('path', '') != 'json': raise ValueError( 'You are using custom datasets for SFT. ' 'The custom datasets should be in json format. To load your JSON ' 'file, you can use the following code snippet: \n' '"""\nfrom datasets import load_dataset \n' 'dataset = dict(type=load_dataset, path=\'json\', ' 'data_files=\'your_json_file.json\')\n"""\n' 'For more details, please refer to Step 5 in the ' '`Using Custom Datasets` section of the documentation found at' ' docs/zh_cn/user_guides/single_turn_conversation.md.') try: dataset = BUILDER.build(dataset) except RuntimeError: raise RuntimeError( 'Unable to load the custom JSON file using ' '`datasets.load_dataset`. Your data-related config is ' f'{train_dataset}. Please refer to the official documentation on' ' `load_dataset` (https://huggingface.co/docs/datasets/loading) ' 'for more details.') if isinstance(dataset, DatasetDict): dataset = dataset[split] if not is_standard_format(dataset) and dataset_map_fn is None: raise ValueError( 'If the custom dataset is not in the XTuner-defined ' 'format, please utilize `dataset_map_fn` to map the original data' ' to the standard format. For more details, please refer to ' 'Step 1 and Step 5 in the `Using Custom Datasets` section of the ' 'documentation found at ' '`docs/zh_cn/user_guides/single_turn_conversation.md`.') if is_standard_format(dataset) and dataset_map_fn is not None: raise ValueError( 'If the custom dataset is already in the XTuner-defined format, ' 'please set `dataset_map_fn` to None.' 'For more details, please refer to Step 1 and Step 5 in the ' '`Using Custom Datasets` section of the documentation found at' ' docs/zh_cn/user_guides/single_turn_conversation.md.') max_dataset_length = min(max_dataset_length, len(dataset)) indices = np.random.choice(len(dataset), max_dataset_length, replace=False) dataset = dataset.select(indices) if dataset_map_fn is not None: dataset = dataset.map(dataset_map_fn) print('#' * 20 + ' dataset after `dataset_map_fn` ' + '#' * 20) print(dataset[0]['conversation']) if template_map_fn is not None: template_map_fn = BUILDER.build(template_map_fn) dataset = dataset.map(template_map_fn) print('#' * 20 + ' dataset after adding templates ' + '#' * 20) print(dataset[0]['conversation']) for old, new in rename_maps: dataset = dataset.rename_column(old, new) if pack_to_max_length and (not remove_unused_columns): raise ValueError('We have to remove unused columns if ' '`pack_to_max_length` is set to True.') dataset = dataset.map( partial( encode_fn, tokenizer=tokenizer, max_length=max_length, input_ids_with_output=input_ids_with_output), remove_columns=list(dataset.column_names) if remove_unused_columns else None) print('#' * 20 + ' encoded input_ids ' + '#' * 20) print(dataset[0]['input_ids']) print('#' * 20 + ' encoded labels ' + '#' * 20) print(dataset[0]['labels']) if pack_to_max_length and split == 'train': if shuffle_before_pack: dataset = dataset.shuffle() dataset = dataset.flatten_indices() dataset = dataset.map(Packer(max_length), batched=True) print('#' * 20 + ' input_ids after packed to max_length ' + '#' * 20) print(dataset[0]['input_ids']) print('#' * 20 + ' labels after packed to max_length ' + '#' * 20) print(dataset[0]['labels']) if __name__ == '__main__': main()