OMG-LLaVA / xtuner /tools /log_dataset.py
zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
raw
history blame
1.79 kB
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine.config import Config
from xtuner.registry import BUILDER
def parse_args():
parser = argparse.ArgumentParser(description='Log processed dataset.')
parser.add_argument('config', help='config file name or path.')
# chose which kind of dataset style to show
parser.add_argument(
'--show',
default='text',
choices=['text', 'masked_text', 'input_ids', 'labels', 'all'],
help='which kind of dataset style to show')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
tokenizer = BUILDER.build(cfg.tokenizer)
if cfg.get('framework', 'mmengine').lower() == 'huggingface':
train_dataset = BUILDER.build(cfg.train_dataset)
else:
train_dataset = BUILDER.build(cfg.train_dataloader.dataset)
if args.show == 'text' or args.show == 'all':
print('#' * 20 + ' text ' + '#' * 20)
print(tokenizer.decode(train_dataset[0]['input_ids']))
if args.show == 'masked_text' or args.show == 'all':
print('#' * 20 + ' text(masked) ' + '#' * 20)
masked_text = ' '.join(
['[-100]' for i in train_dataset[0]['labels'] if i == -100])
unmasked_text = tokenizer.decode(
[i for i in train_dataset[0]['labels'] if i != -100])
print(masked_text + ' ' + unmasked_text)
if args.show == 'input_ids' or args.show == 'all':
print('#' * 20 + ' input_ids ' + '#' * 20)
print(train_dataset[0]['input_ids'])
if args.show == 'labels' or args.show == 'all':
print('#' * 20 + ' labels ' + '#' * 20)
print(train_dataset[0]['labels'])
if __name__ == '__main__':
main()