Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
from mmengine.config import Config | |
from xtuner.registry import BUILDER | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Log processed dataset.') | |
parser.add_argument('config', help='config file name or path.') | |
# chose which kind of dataset style to show | |
parser.add_argument( | |
'--show', | |
default='text', | |
choices=['text', 'masked_text', 'input_ids', 'labels', 'all'], | |
help='which kind of dataset style to show') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
cfg = Config.fromfile(args.config) | |
tokenizer = BUILDER.build(cfg.tokenizer) | |
if cfg.get('framework', 'mmengine').lower() == 'huggingface': | |
train_dataset = BUILDER.build(cfg.train_dataset) | |
else: | |
train_dataset = BUILDER.build(cfg.train_dataloader.dataset) | |
if args.show == 'text' or args.show == 'all': | |
print('#' * 20 + ' text ' + '#' * 20) | |
print(tokenizer.decode(train_dataset[0]['input_ids'])) | |
if args.show == 'masked_text' or args.show == 'all': | |
print('#' * 20 + ' text(masked) ' + '#' * 20) | |
masked_text = ' '.join( | |
['[-100]' for i in train_dataset[0]['labels'] if i == -100]) | |
unmasked_text = tokenizer.decode( | |
[i for i in train_dataset[0]['labels'] if i != -100]) | |
print(masked_text + ' ' + unmasked_text) | |
if args.show == 'input_ids' or args.show == 'all': | |
print('#' * 20 + ' input_ids ' + '#' * 20) | |
print(train_dataset[0]['input_ids']) | |
if args.show == 'labels' or args.show == 'all': | |
print('#' * 20 + ' labels ' + '#' * 20) | |
print(train_dataset[0]['labels']) | |
if __name__ == '__main__': | |
main() | |