Spaces:
Runtime error
Runtime error
File size: 5,233 Bytes
f6d075a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import process_hf_dataset
from .utils import expand2square
import copy
class LLaVADataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
debug=False):
super().__init__()
assert offline_processed_text_folder or (data_path and tokenizer)
self.tokenizer = tokenizer
if isinstance(tokenizer, dict) or isinstance(
tokenizer, Config) or isinstance(tokenizer, ConfigDict):
tokenizer_type = self.tokenizer['type']
del self.tokenizer['type']
self.tokenizer = tokenizer_type(**self.tokenizer)
self._add_special_tokens()
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
self.text_data = load_from_disk(offline_processed_text_folder)
else:
json_data = json.load(open(data_path))
if debug:
json_data = json_data[:10000]
for idx in range(len(json_data)):
if isinstance(json_data[idx]['id'], int):
json_data[idx]['id'] = str(json_data[idx]['id'])
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=self.tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True,
map_num_proc=32, # because limited mem
)
self.image_folder = image_folder
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list
def __len__(self):
return len(self.text_data)
def __getitem__(self, index):
data_dict = copy.deepcopy(self.text_data[index])
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
return data_dict
def _add_special_tokens(self):
assert hasattr(self, "tokenizer")
# Adding special tokens for pixel grounding
segmentation_tokens = ['[SEG]']
# Adding tokens for GCG
phrase_tokens = ['<p>', '</p>']
# add for visual prompt
region_tokens = ['<region>']
point_tokens = ['<mark>']
special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
return |