# Copyright (c) OpenMMLab. All rights reserved. from mmengine.hooks import Hook from xtuner.registry import BUILDER from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX def split_list(lst, value): res = [] tmp_res = [] for i in lst: if i == value: res.append(tmp_res) tmp_res = [] else: tmp_res.append(i) res.append(tmp_res) return res class DatasetInfoHook(Hook): def __init__(self, tokenizer, is_intern_repo_dataset=False): self.tokenizer = BUILDER.build(tokenizer) self.is_intern_repo_dataset = is_intern_repo_dataset def log(self, runner, dataset, mode='train'): runner.logger.info(f'Num {mode} samples {len(dataset)}') runner.logger.info(f'{mode} example:') input_ids = dataset[0]['input_ids'] if self.is_intern_repo_dataset: input_ids = [abs(x) for x in input_ids] # Try to split list to be compatible with IMAGE token input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX) text = '' for idx, ids in enumerate(input_ids): text += self.tokenizer.decode(ids) if idx != len(input_ids) - 1: text += DEFAULT_IMAGE_TOKEN runner.logger.info(text) def before_train(self, runner) -> None: do_train = runner.train_loop is not None do_eval = runner.val_loop is not None if do_train: train_dataset = runner.train_dataloader.dataset self.log(runner, train_dataset, mode='train') if do_eval: eval_dataset = runner.val_dataloader.dataset self.log(runner, eval_dataset, mode='eval') def before_val(self, runner) -> None: eval_dataset = runner.val_dataloader.dataset self.log(runner, eval_dataset, mode='eval') def before_test(self, runner) -> None: test_dataset = runner.test_dataloader.dataset self.log(runner, test_dataset, mode='test')