# Copyright (c) OpenMMLab. All rights reserved. import argparse from mmengine.config import Config from xtuner.registry import BUILDER def parse_args(): parser = argparse.ArgumentParser(description='Log processed dataset.') parser.add_argument('config', help='config file name or path.') # chose which kind of dataset style to show parser.add_argument( '--show', default='text', choices=['text', 'masked_text', 'input_ids', 'labels', 'all'], help='which kind of dataset style to show') args = parser.parse_args() return args def main(): args = parse_args() cfg = Config.fromfile(args.config) tokenizer = BUILDER.build(cfg.tokenizer) if cfg.get('framework', 'mmengine').lower() == 'huggingface': train_dataset = BUILDER.build(cfg.train_dataset) else: train_dataset = BUILDER.build(cfg.train_dataloader.dataset) if args.show == 'text' or args.show == 'all': print('#' * 20 + ' text ' + '#' * 20) print(tokenizer.decode(train_dataset[0]['input_ids'])) if args.show == 'masked_text' or args.show == 'all': print('#' * 20 + ' text(masked) ' + '#' * 20) masked_text = ' '.join( ['[-100]' for i in train_dataset[0]['labels'] if i == -100]) unmasked_text = tokenizer.decode( [i for i in train_dataset[0]['labels'] if i != -100]) print(masked_text + ' ' + unmasked_text) if args.show == 'input_ids' or args.show == 'all': print('#' * 20 + ' input_ids ' + '#' * 20) print(train_dataset[0]['input_ids']) if args.show == 'labels' or args.show == 'all': print('#' * 20 + ' labels ' + '#' * 20) print(train_dataset[0]['labels']) if __name__ == '__main__': main()