# Copyright (c) OpenMMLab. All rights reserved. import json import logging import os import torch from datasets import Dataset as HFDataset from datasets import DatasetDict, load_from_disk from mmengine import print_log from mmengine.config import Config, ConfigDict from PIL import Image from torch.utils.data import Dataset from xtuner.registry import BUILDER from xtuner.dataset.huggingface import process_hf_dataset from .utils import expand2square import copy class LLaVADataset(Dataset): def __init__(self, image_folder, image_processor, data_path=None, tokenizer=None, offline_processed_text_folder=None, max_dataset_length=None, dataset_map_fn=None, template_map_fn=None, max_length=2048, pad_image_to_square=False, debug=False): super().__init__() assert offline_processed_text_folder or (data_path and tokenizer) self.tokenizer = tokenizer if isinstance(tokenizer, dict) or isinstance( tokenizer, Config) or isinstance(tokenizer, ConfigDict): tokenizer_type = self.tokenizer['type'] del self.tokenizer['type'] self.tokenizer = tokenizer_type(**self.tokenizer) self._add_special_tokens() if offline_processed_text_folder and data_path: print_log( 'Both `offline_processed_text_folder` and ' '`data_path` are set, and we load dataset from' '`offline_processed_text_folder` ' f'({offline_processed_text_folder})', logger='current', level=logging.WARNING) if offline_processed_text_folder is not None: self.text_data = load_from_disk(offline_processed_text_folder) else: json_data = json.load(open(data_path)) if debug: json_data = json_data[:10000] for idx in range(len(json_data)): if isinstance(json_data[idx]['id'], int): json_data[idx]['id'] = str(json_data[idx]['id']) json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) self.text_data = process_hf_dataset( dataset=json_data, tokenizer=self.tokenizer, max_length=max_length, dataset_map_fn=dataset_map_fn, template_map_fn=template_map_fn, split='train', max_dataset_length=max_dataset_length, remove_unused_columns=False, pack_to_max_length=False, with_image_token=True, map_num_proc=32, # because limited mem ) self.image_folder = image_folder if isinstance(image_processor, dict) or isinstance( image_processor, Config) or isinstance(image_processor, ConfigDict): self.image_processor = BUILDER.build(image_processor) else: self.image_processor = image_processor self.pad_image_to_square = pad_image_to_square @property def modality_length(self): length_list = [] for data_dict in self.text_data: cur_len = len(data_dict['input_ids']) if data_dict.get('image', None) is None: cur_len = -cur_len length_list.append(cur_len) return length_list def __len__(self): return len(self.text_data) def __getitem__(self, index): data_dict = copy.deepcopy(self.text_data[index]) if data_dict.get('image', None) is not None: image_file = data_dict['image'] image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') if self.pad_image_to_square: image = expand2square( image, tuple( int(x * 255) for x in self.image_processor.image_mean)) image = self.image_processor.preprocess( image, return_tensors='pt')['pixel_values'][0] data_dict['pixel_values'] = image else: if hasattr(self.image_processor, 'crop_size'): crop_size = self.image_processor.crop_size else: crop_size = self.image_processor.size data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict def _add_special_tokens(self): assert hasattr(self, "tokenizer") # Adding special tokens for pixel grounding segmentation_tokens = ['[SEG]'] # Adding tokens for GCG phrase_tokens = ['

', '

'] # add for visual prompt region_tokens = [''] point_tokens = [''] special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens self.tokenizer.add_tokens(special_tokens, special_tokens=True) return