Upload 8 files
Browse files- data_script/flintstones_hdf5.py +51 -0
- data_script/pororo_hdf5.py +83 -0
- data_script/vist_hdf5.py +111 -0
- data_script/vist_img_download.py +61 -0
- datasets/flintstones.py +93 -0
- datasets/pororo.py +144 -0
- datasets/vistdii.py +94 -0
- datasets/vistsis.py +94 -0
data_script/flintstones_hdf5.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import h5py
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
splits = json.load(open(os.path.join(args.data_dir, 'train-val-test_split.json'), 'r'))
|
14 |
+
train_ids, val_ids, test_ids = splits["train"], splits["val"], splits["test"]
|
15 |
+
followings = pickle.load(open(os.path.join(args.data_dir, 'following_cache4.pkl'), 'rb'))
|
16 |
+
annotations = json.load(open(os.path.join(args.data_dir, 'flintstones_annotations_v1-0.json')))
|
17 |
+
descriptions = dict()
|
18 |
+
for sample in annotations:
|
19 |
+
descriptions[sample["globalID"]] = sample["description"]
|
20 |
+
|
21 |
+
f = h5py.File(args.save_path, "w")
|
22 |
+
for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
|
23 |
+
ids = [i for i in ids if i in followings and len(followings[i]) == 4]
|
24 |
+
length = len(ids)
|
25 |
+
|
26 |
+
group = f.create_group(subset)
|
27 |
+
images = list()
|
28 |
+
for i in range(5):
|
29 |
+
images.append(
|
30 |
+
group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
|
31 |
+
text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
32 |
+
for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
|
33 |
+
globalIDs = [item] + followings[item]
|
34 |
+
txt = list()
|
35 |
+
for j, globalID in enumerate(globalIDs):
|
36 |
+
img = np.load(os.path.join(args.data_dir, 'video_frames_sampled', '{}.npy'.format(globalID)))
|
37 |
+
img = np.concatenate(img, axis=0).astype(np.uint8)
|
38 |
+
img = cv2.imencode('.png', img)[1].tobytes()
|
39 |
+
img = np.frombuffer(img, np.uint8)
|
40 |
+
images[j][i] = img
|
41 |
+
txt.append(descriptions[globalID])
|
42 |
+
text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
|
43 |
+
f.close()
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser(description='arguments for flintstones hdf5 file saving')
|
48 |
+
parser.add_argument('--data_dir', type=str, required=True, help='flintstones data directory')
|
49 |
+
parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
|
50 |
+
args = parser.parse_args()
|
51 |
+
main(args)
|
data_script/pororo_hdf5.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
# 使用numpy库的load函数来加载名为descriptions.npy的文件。该文件是一个Python字典对象,因此我们使用item()方法将其转换为字典对象。
|
13 |
+
# ——os.path.join函数用于连接文件路径
|
14 |
+
# ——args.data_dir作为基础目录,将'descriptions.npy'添加到该目录中
|
15 |
+
# ——指定allow_pickle=True,表示允许加载包含Python对象的文件
|
16 |
+
# ——指定encoding='latin1',表示使用拉丁字符编码加载该文件
|
17 |
+
descriptions = np.load(os.path.join(args.data_dir, 'descriptions.npy'), allow_pickle=True, encoding='latin1').item()
|
18 |
+
# imgs_list包含一组图像文件的路径,
|
19 |
+
# followings_list包含每个图像的一些附加信息
|
20 |
+
imgs_list = np.load(os.path.join(args.data_dir, 'img_cache4.npy'), encoding='latin1')
|
21 |
+
followings_list = np.load(os.path.join(args.data_dir, 'following_cache4.npy'))
|
22 |
+
# 使用numpy库的load函数来加载名为train_seen_unseen_ids.npy的文件
|
23 |
+
# 该文件包含三个numpy数组:train_ids、val_ids和test_ids,分别代表训练集、验证集和测试集的ID列表。
|
24 |
+
# 使用元组来一次性加载这三个数组,并将它们赋值给相应的变量。
|
25 |
+
train_ids, val_ids, test_ids = np.load(os.path.join(args.data_dir, 'train_seen_unseen_ids.npy'), allow_pickle=True)
|
26 |
+
# 按照ID的顺序逐一排序
|
27 |
+
train_ids = np.sort(train_ids)
|
28 |
+
val_ids = np.sort(val_ids)
|
29 |
+
test_ids = np.sort(test_ids)
|
30 |
+
|
31 |
+
# 创建一个新的HDF5文件,并指定文件名为args.save_path。
|
32 |
+
# 使用h5py库的File函数来创建文件对象,指定打开方式为写模式("w")。
|
33 |
+
# 在这个文件中存储处理后的图像和文本数据。
|
34 |
+
f = h5py.File(args.save_path, "w")
|
35 |
+
for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
|
36 |
+
length = len(ids)
|
37 |
+
|
38 |
+
# 为每个数据集(train、val和test)创建一个组
|
39 |
+
# 针对每个数据集都创建了5个数据集,名为'image0'、'image1'、'image2'、'image3'、'image4',分别对应于当前图像及其相关联的4个图像。
|
40 |
+
# 目的:将每个图像及其相关联的图像数据保存到同一个HDF5文件中,并按照一定的组织方式存储,方便后续的数据读取和处理。
|
41 |
+
group = f.create_group(subset)
|
42 |
+
# 创建一个长度为ids列表长度的空列表images,按照image0-4顺序添加了5个HDF5数据集对象
|
43 |
+
images = list()
|
44 |
+
# 为当前数据集中的每个图像创建了五个数据集。
|
45 |
+
# 每个数据集都使用vlen_dtype(np.dtype('uint8'))作为数据类型,并将其添加到当前组group中。
|
46 |
+
# ——vlen_dtype(np.dtype('uint8'))表示可变长度的无符号8位整数数组。
|
47 |
+
for i in range(5):
|
48 |
+
images.append(
|
49 |
+
group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
|
50 |
+
# 创建一个数据集text,用于存储与当前数据集中图像相关的文本描述。该数据集的数据类型为字符串,编码方式为utf-8,并将其添加到当前组group中。
|
51 |
+
text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
52 |
+
# 遍历当前数据集中的每个图像,并将相关数据保存到HDF5文件中
|
53 |
+
for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
|
54 |
+
# 获取与当前图像相关的所有图像的路径,存储到列表img_paths中。
|
55 |
+
# ——imgs_list是一个字典,存储了所有图像的路径
|
56 |
+
# ——followings_list是一个字典,存储了与每个图像相关的四张图像的路径
|
57 |
+
img_paths = [str(imgs_list[item])[2:-1]] + [str(followings_list[item][i])[2:-1] for i in range(4)]
|
58 |
+
# 打开img_paths列表中的每个图像,并将其转换为RGB格式的PIL图像对象。
|
59 |
+
imgs = [Image.open(os.path.join(args.data_dir, img_path)).convert('RGB') for img_path in img_paths]
|
60 |
+
# 将每个PIL图像对象转换为numpy数组
|
61 |
+
for j, img in enumerate(imgs):
|
62 |
+
img = np.array(img).astype(np.uint8)
|
63 |
+
# 使用OpenCV将其编码为png格式的二进制数据
|
64 |
+
img = cv2.imencode('.png', img)[1].tobytes()
|
65 |
+
# 将该二进制数据转换为numpy数组
|
66 |
+
img = np.frombuffer(img, np.uint8)
|
67 |
+
# 将其存储到images列表中与当前图像相关的数据集中
|
68 |
+
images[j][i] = img
|
69 |
+
# 获取与当前图像相关的所有图像的文件名,并将其存储到列表tgt_img_ids中
|
70 |
+
tgt_img_ids = [str(img_path).replace('.png', '') for img_path in img_paths]
|
71 |
+
# 根据目标图像的文件名,获取其对应的文本描述,并将其存储到列表txt中。
|
72 |
+
txt = [descriptions[tgt_img_id][0] for tgt_img_id in tgt_img_ids]
|
73 |
+
# 将txt列表中的所有文本描述合并为一个字符串,并将其中的"\n"、"\t"等无关字符替换为空格。然后,将该字符串存储到数据集text中
|
74 |
+
text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
|
75 |
+
f.close()
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
parser = argparse.ArgumentParser(description='arguments for flintstones pororo file saving')
|
80 |
+
parser.add_argument('--data_dir', type=str, required=True, help='pororo data directory')
|
81 |
+
parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
|
82 |
+
args = parser.parse_args()
|
83 |
+
main(args)
|
data_script/vist_hdf5.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.story-in-sequence.json')))
|
14 |
+
val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.story-in-sequence.json')))
|
15 |
+
test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.story-in-sequence.json')))
|
16 |
+
|
17 |
+
prefix = ["train", "val", "test"]
|
18 |
+
whole_album = {}
|
19 |
+
for i, data in enumerate([train_data, val_data, test_data]):
|
20 |
+
album_mapping = {}
|
21 |
+
for annot_new in data["annotations"]:
|
22 |
+
annot = annot_new[0]
|
23 |
+
assert len(annot_new) == 1
|
24 |
+
if annot['story_id'] not in album_mapping:
|
25 |
+
album_mapping[annot['story_id']] = {"flickr_id": [annot['photo_flickr_id']],
|
26 |
+
"sis": [annot['original_text']],
|
27 |
+
"length": 1}
|
28 |
+
else:
|
29 |
+
album_mapping[annot['story_id']]["flickr_id"].append(annot['photo_flickr_id'])
|
30 |
+
album_mapping[annot['story_id']]["sis"].append(
|
31 |
+
annot['original_text'])
|
32 |
+
album_mapping[annot['story_id']]["length"] += 1
|
33 |
+
whole_album[prefix[i]] = album_mapping
|
34 |
+
|
35 |
+
for p in prefix:
|
36 |
+
deletables = []
|
37 |
+
for story_id, story in whole_album[p].items():
|
38 |
+
if story['length'] != 5:
|
39 |
+
print("deleting {}".format(story_id))
|
40 |
+
deletables.append(story_id)
|
41 |
+
continue
|
42 |
+
d = [os.path.exists(os.path.join(args.img_dir, "{}.jpg".format(_))) for _ in story["flickr_id"]]
|
43 |
+
if sum(d) < 5:
|
44 |
+
print("deleting {}".format(story_id))
|
45 |
+
deletables.append(story_id)
|
46 |
+
else:
|
47 |
+
pass
|
48 |
+
for i in deletables:
|
49 |
+
del whole_album[p][i]
|
50 |
+
|
51 |
+
train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.description-in-isolation.json')))
|
52 |
+
val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.description-in-isolation.json')))
|
53 |
+
test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.description-in-isolation.json')))
|
54 |
+
|
55 |
+
flickr_id2text = {}
|
56 |
+
for i, data in enumerate([train_data, val_data, test_data]):
|
57 |
+
for l in data['annotations']:
|
58 |
+
assert len(l) == 1
|
59 |
+
if l[0]['photo_flickr_id'] in flickr_id2text:
|
60 |
+
flickr_id2text[l[0]['photo_flickr_id']] = \
|
61 |
+
max([flickr_id2text[l[0]['photo_flickr_id']], l[0]['original_text']], key=len)
|
62 |
+
else:
|
63 |
+
flickr_id2text[l[0]['photo_flickr_id']] = l[0]['original_text']
|
64 |
+
|
65 |
+
for p in prefix:
|
66 |
+
deletables = []
|
67 |
+
for story_id, story in whole_album[p].items():
|
68 |
+
story['dii'] = []
|
69 |
+
for i, flickr_id in enumerate(story['flickr_id']):
|
70 |
+
if flickr_id not in flickr_id2text:
|
71 |
+
print("{} not found in story {}".format(flickr_id, story_id))
|
72 |
+
deletables.append(story_id)
|
73 |
+
break
|
74 |
+
story['dii'].append(flickr_id2text[flickr_id])
|
75 |
+
for i in deletables:
|
76 |
+
del whole_album[p][i]
|
77 |
+
|
78 |
+
f = h5py.File(args.save_path, "w")
|
79 |
+
for p in prefix:
|
80 |
+
group = f.create_group(p)
|
81 |
+
story_dict = whole_album[p]
|
82 |
+
length = len(story_dict)
|
83 |
+
images = list()
|
84 |
+
for i in range(5):
|
85 |
+
images.append(
|
86 |
+
group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
|
87 |
+
sis = group.create_dataset('sis', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
88 |
+
dii = group.create_dataset('dii', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
|
89 |
+
for i, (story_id, story) in enumerate(tqdm(story_dict.items(), leave=True, desc="saveh5")):
|
90 |
+
imgs = [Image.open('{}/{}.jpg'.format(args.img_dir, flickr_id)).convert('RGB') for flickr_id in
|
91 |
+
story['flickr_id']]
|
92 |
+
for j, img in enumerate(imgs):
|
93 |
+
img = np.array(img).astype(np.uint8)
|
94 |
+
img = cv2.imencode('.png', img)[1].tobytes()
|
95 |
+
img = np.frombuffer(img, np.uint8)
|
96 |
+
images[j][i] = img
|
97 |
+
sis[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in story['sis']])
|
98 |
+
txt_dii = [t.replace('\n', '').replace('\t', '').strip() for t in story['dii']]
|
99 |
+
txt_dii = sorted(set(txt_dii), key=txt_dii.index)
|
100 |
+
dii[i] = '|'.join(txt_dii)
|
101 |
+
f.close()
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == '__main__':
|
105 |
+
parser = argparse.ArgumentParser(description='arguments for vist hdf5 file saving')
|
106 |
+
parser.add_argument('--sis_json_dir', type=str, required=True, help='sis json file directory')
|
107 |
+
parser.add_argument('--dii_json_dir', type=str, required=True, help='dii json file directory')
|
108 |
+
parser.add_argument('--img_dir', type=str, required=True, help='json file directory')
|
109 |
+
parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
|
110 |
+
args = parser.parse_args()
|
111 |
+
main(args)
|
data_script/vist_img_download.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
from io import BytesIO
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
from multiprocessing import Process
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
|
11 |
+
def download_subprocess(dii, save_dir):
|
12 |
+
for image in tqdm(dii):
|
13 |
+
key, value = image.popitem()
|
14 |
+
try:
|
15 |
+
img_data = requests.get(value).content
|
16 |
+
img = Image.open(BytesIO(img_data)).convert('RGB')
|
17 |
+
h = img.size[0]
|
18 |
+
w = img.size[1]
|
19 |
+
if min(h, w) > 512:
|
20 |
+
img = img.resize((int(h / (w / 512)), 512) if h > w else (512, int(w / (h / 512))))
|
21 |
+
img.save('{}/{}.jpg'.format(save_dir, key))
|
22 |
+
except:
|
23 |
+
print(key, value)
|
24 |
+
|
25 |
+
|
26 |
+
def main(args):
|
27 |
+
train_data = json.load(open(os.path.join(args.json_dir, 'train.description-in-isolation.json')))
|
28 |
+
val_data = json.load(open(os.path.join(args.json_dir, 'val.description-in-isolation.json')))
|
29 |
+
test_data = json.load(open(os.path.join(args.json_dir, 'test.description-in-isolation.json')))
|
30 |
+
dii = []
|
31 |
+
for subset in [train_data, val_data, test_data]:
|
32 |
+
for image in subset["images"]:
|
33 |
+
try:
|
34 |
+
dii.append({image['id']: image['url_o']})
|
35 |
+
except:
|
36 |
+
dii.append({image['id']: image['url_m']})
|
37 |
+
|
38 |
+
dii = [image for image in dii if not os.path.exists('{}/{}.jpg'.format(args.save_dir, list(image)[0]))]
|
39 |
+
print('total images: {}'.format(len(dii)))
|
40 |
+
|
41 |
+
def splitlist(inlist, chunksize):
|
42 |
+
return [inlist[x:x + chunksize] for x in range(0, len(inlist), chunksize)]
|
43 |
+
|
44 |
+
dii_splitted = splitlist(dii, int((len(dii) / args.num_process)))
|
45 |
+
process_list = []
|
46 |
+
for dii_sub_list in dii_splitted:
|
47 |
+
p = Process(target=download_subprocess, args=(dii_sub_list,))
|
48 |
+
process_list.append(p)
|
49 |
+
p.Daemon = True
|
50 |
+
p.start()
|
51 |
+
for p in process_list:
|
52 |
+
p.join()
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
parser = argparse.ArgumentParser(description='arguments for vist images downloading')
|
57 |
+
parser.add_argument('--json_dir', type=str, required=True, help='dii json file directory')
|
58 |
+
parser.add_argument('--img_dir', type=str, required=True, help='images saving directory')
|
59 |
+
parser.add_argument('--num_process', type=int, default=32)
|
60 |
+
args = parser.parse_args()
|
61 |
+
main(args)
|
datasets/flintstones.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import h5py
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
from transformers import CLIPTokenizer
|
10 |
+
|
11 |
+
from models.blip_override.blip import init_tokenizer
|
12 |
+
|
13 |
+
|
14 |
+
class StoryDataset(Dataset):
|
15 |
+
"""
|
16 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, subset, args):
|
20 |
+
super(StoryDataset, self).__init__()
|
21 |
+
self.args = args
|
22 |
+
|
23 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
24 |
+
self.subset = subset
|
25 |
+
|
26 |
+
self.augment = transforms.Compose([
|
27 |
+
transforms.ToPILImage(),
|
28 |
+
transforms.Resize([512, 512]),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize([0.5], [0.5])
|
31 |
+
])
|
32 |
+
self.dataset = args.dataset
|
33 |
+
self.max_length = args.get(args.dataset).max_length
|
34 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
35 |
+
self.blip_tokenizer = init_tokenizer()
|
36 |
+
msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
37 |
+
print("clip {} new tokens added".format(msg))
|
38 |
+
msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
39 |
+
print("blip {} new tokens added".format(msg))
|
40 |
+
|
41 |
+
self.blip_image_processor = transforms.Compose([
|
42 |
+
transforms.ToPILImage(),
|
43 |
+
transforms.Resize([224, 224]),
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
46 |
+
])
|
47 |
+
|
48 |
+
def open_h5(self):
|
49 |
+
h5 = h5py.File(self.h5_file, "r")
|
50 |
+
self.h5 = h5[self.subset]
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
if not hasattr(self, 'h5'):
|
54 |
+
self.open_h5()
|
55 |
+
|
56 |
+
images = list()
|
57 |
+
for i in range(5):
|
58 |
+
im = self.h5['image{}'.format(i)][index]
|
59 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
60 |
+
idx = random.randint(0, 4)
|
61 |
+
images.append(im[idx * 128: (idx + 1) * 128])
|
62 |
+
|
63 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
64 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
65 |
+
images = torch.stack([self.augment(im) for im in images]) \
|
66 |
+
if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
|
67 |
+
|
68 |
+
texts = self.h5['text'][index].decode('utf-8').split('|')
|
69 |
+
|
70 |
+
# tokenize caption using default tokenizer
|
71 |
+
tokenized = self.clip_tokenizer(
|
72 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
73 |
+
padding="max_length",
|
74 |
+
max_length=self.max_length,
|
75 |
+
truncation=False,
|
76 |
+
return_tensors="pt",
|
77 |
+
)
|
78 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
79 |
+
|
80 |
+
tokenized = self.blip_tokenizer(
|
81 |
+
texts,
|
82 |
+
padding="max_length",
|
83 |
+
max_length=self.max_length,
|
84 |
+
truncation=False,
|
85 |
+
return_tensors="pt",
|
86 |
+
)
|
87 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
88 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
if not hasattr(self, 'h5'):
|
92 |
+
self.open_h5()
|
93 |
+
return len(self.h5['text'])
|
datasets/pororo.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from torchvision import transforms
|
11 |
+
from transformers import CLIPTokenizer
|
12 |
+
|
13 |
+
from models.blip_override.blip import init_tokenizer
|
14 |
+
|
15 |
+
|
16 |
+
class StoryDataset(Dataset):
|
17 |
+
"""
|
18 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
19 |
+
"""
|
20 |
+
# StoryDataset 类的构造函数
|
21 |
+
def __init__(self, subset, args):
|
22 |
+
# 用来调用父类 Dataset 的初始化函数,确保该类能够继承 Dataset 类的所有方法和属性。
|
23 |
+
super(StoryDataset, self).__init__()
|
24 |
+
# args 则是该类的其他参数,是一个命名空间(namespace)对象
|
25 |
+
self.args = args
|
26 |
+
# 一个 HDF5 文件的路径,存储了训练、验证和测试集的图像和文本数据。
|
27 |
+
# ——args.get(args.dataset)表示从命名空间对象args中获取指定数据集(训练集、验证集或测试集)的参数。
|
28 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
29 |
+
# 初始化函数中 subset 表示要读取的子集的类型(如训练集、验证集、测试集)
|
30 |
+
self.subset = subset
|
31 |
+
|
32 |
+
# 一个图像变换函数序列(transform),用来对图像进行预处理,包括将图像转化为 PIL 格式,调整图像大小,将图像转换为 Tensor,并进行归一化。
|
33 |
+
self.augment = transforms.Compose([
|
34 |
+
transforms.ToPILImage(),
|
35 |
+
# transforms.Resize([256, 256]),
|
36 |
+
transforms.Resize([512, 512]),
|
37 |
+
transforms.ToTensor(),
|
38 |
+
transforms.Normalize([0.5], [0.5])
|
39 |
+
])
|
40 |
+
# 表示当前数据集的类型(训练集、验证集或测试集)
|
41 |
+
self.dataset = args.dataset
|
42 |
+
# 最大的 caption 长度,在进行tokenize操作时,caption中的单词数量将被填充到该长度。
|
43 |
+
self.max_length = args.get(args.dataset).max_length
|
44 |
+
# 一个使用CLIP模型进行tokenize的tokenizer
|
45 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
46 |
+
# 一个自定义的tokenizer,用于处理文本输入
|
47 |
+
self.blip_tokenizer = init_tokenizer()
|
48 |
+
msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
49 |
+
print("clip {} new tokens added".format(msg))
|
50 |
+
msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
|
51 |
+
print("blip {} new tokens added".format(msg))
|
52 |
+
|
53 |
+
# 一个用于对输入的图像进行处理的函数序列,包括转换为PIL图像、重置图像大小、转换为tensor、归一化等。
|
54 |
+
self.blip_image_processor = transforms.Compose([
|
55 |
+
transforms.ToPILImage(),
|
56 |
+
transforms.Resize([224, 224]),
|
57 |
+
transforms.ToTensor(),
|
58 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
59 |
+
])
|
60 |
+
|
61 |
+
# 打开与数据集对应的h5文件
|
62 |
+
def open_h5(self):
|
63 |
+
h5 = h5py.File(self.h5_file, "r")
|
64 |
+
self.h5 = h5[self.subset]
|
65 |
+
|
66 |
+
# 用于按索引获取数据。
|
67 |
+
|
68 |
+
# 对于每个图像,都进行数据增强操作,以进行数据增强。
|
69 |
+
# 然后,将文本输入的caption进行tokenize操作,
|
70 |
+
# 使用CLIP tokenizer和自定义tokenizer分别进行tokenize。
|
71 |
+
# 最后,将处理好的图像、caption和attention mask返回
|
72 |
+
def __getitem__(self, index):
|
73 |
+
# 首先调用open_h5()打开数据集的h5文件
|
74 |
+
if not hasattr(self, 'h5'):
|
75 |
+
self.open_h5()
|
76 |
+
#index = 1
|
77 |
+
images = list()
|
78 |
+
for i in range(5):
|
79 |
+
# 从h5文件中读取一组图像和对应的文本。
|
80 |
+
im = self.h5['image{}'.format(i)][index]
|
81 |
+
# print(im)
|
82 |
+
# pil_img = Image.fromarray(im)
|
83 |
+
# # 保存图像
|
84 |
+
# pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
|
85 |
+
# 对每个图像解码
|
86 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
87 |
+
# 随机选择一个128像素的图像切片
|
88 |
+
idx = random.randint(0, im.shape[0] / 128 - 1)
|
89 |
+
# 将切片后的图像加到images列表中
|
90 |
+
images.append(im[idx * 128: (idx + 1) * 128])
|
91 |
+
# 深拷贝,后续不随images变化
|
92 |
+
ori_images = copy.deepcopy(images)
|
93 |
+
# 保存test原始图像
|
94 |
+
|
95 |
+
# for i, im in enumerate(images):
|
96 |
+
# file_path = '/root/lihui/StoryVisualization/ori_test_images/group{:02d}_image{:02d}.png'.format(index + 1,
|
97 |
+
# i + 1)
|
98 |
+
# cv2.imwrite(file_path, im)
|
99 |
+
# 将图像转换为张量
|
100 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
101 |
+
# 如果为continuation任务,将列表中的第一个图像从images中移除
|
102 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
103 |
+
# 如果subset的值为train/val,则使用augment方法对images列表中的所有图像进行数据增强,并将其转换为张量
|
104 |
+
# 否则使用numpy.array方法将images列表转换为张量,并进行转置操作
|
105 |
+
images = torch.stack([self.augment(im) for im in images]) \
|
106 |
+
if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
|
107 |
+
######################
|
108 |
+
# 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
|
109 |
+
texts = self.h5['text'][index].decode('utf-8').split('|')
|
110 |
+
# print(f"index: {index}")
|
111 |
+
# for text in texts:
|
112 |
+
# print(f"texts: {text}")
|
113 |
+
|
114 |
+
# tokenize caption using default tokenizer
|
115 |
+
tokenized = self.clip_tokenizer(
|
116 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
117 |
+
padding="max_length",
|
118 |
+
max_length=self.max_length,
|
119 |
+
truncation=False,
|
120 |
+
return_tensors="pt",
|
121 |
+
)
|
122 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
123 |
+
|
124 |
+
tokenized = self.blip_tokenizer(
|
125 |
+
texts,
|
126 |
+
padding="max_length",
|
127 |
+
max_length=self.max_length,
|
128 |
+
truncation=False,
|
129 |
+
return_tensors="pt",
|
130 |
+
)
|
131 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
132 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images
|
133 |
+
|
134 |
+
# 返回数据集中样本的数量
|
135 |
+
# 如果是测试集,则返回100,否则返回对应的数据集中的样本数量
|
136 |
+
def __len__(self):
|
137 |
+
if not hasattr(self, 'h5'):
|
138 |
+
self.open_h5()
|
139 |
+
if self.subset == 'test':
|
140 |
+
#print('')
|
141 |
+
return 1
|
142 |
+
# if self.subset == 'test':
|
143 |
+
# return 100
|
144 |
+
return len(self.h5['text'])
|
datasets/vistdii.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
from transformers import CLIPTokenizer
|
8 |
+
|
9 |
+
from models.blip_override.blip import init_tokenizer
|
10 |
+
|
11 |
+
|
12 |
+
class StoryDataset(Dataset):
|
13 |
+
"""
|
14 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, subset, args):
|
18 |
+
super(StoryDataset, self).__init__()
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
22 |
+
self.subset = subset
|
23 |
+
|
24 |
+
self.augment = transforms.Compose([
|
25 |
+
transforms.ToPILImage(),
|
26 |
+
transforms.Resize(512),
|
27 |
+
transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
transforms.Normalize([0.5], [0.5])
|
30 |
+
]) if self.subset in ['train', 'val'] else transforms.Compose([
|
31 |
+
transforms.ToPILImage(),
|
32 |
+
transforms.Resize(64),
|
33 |
+
transforms.CenterCrop(64)
|
34 |
+
])
|
35 |
+
|
36 |
+
self.dataset = args.dataset
|
37 |
+
self.max_length = args.get(args.dataset).max_length
|
38 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
39 |
+
self.blip_tokenizer = init_tokenizer()
|
40 |
+
|
41 |
+
self.blip_image_processor = transforms.Compose([
|
42 |
+
transforms.ToPILImage(),
|
43 |
+
transforms.Resize(224),
|
44 |
+
transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
47 |
+
])
|
48 |
+
|
49 |
+
def open_h5(self):
|
50 |
+
h5 = h5py.File(self.h5_file, "r")
|
51 |
+
self.h5 = h5[self.subset]
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
if not hasattr(self, 'h5'):
|
55 |
+
self.open_h5()
|
56 |
+
|
57 |
+
images = list()
|
58 |
+
for i in range(5):
|
59 |
+
im = self.h5['image{}'.format(i)][index]
|
60 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
61 |
+
images.append(im)
|
62 |
+
|
63 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
64 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
65 |
+
images = [self.augment(im) for im in images]
|
66 |
+
images = torch.stack(images) if self.subset in ['train', 'val'] \
|
67 |
+
else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
|
68 |
+
|
69 |
+
texts = self.h5['dii'][index].decode('utf-8').split('|')
|
70 |
+
|
71 |
+
# tokenize caption using default tokenizer
|
72 |
+
tokenized = self.clip_tokenizer(
|
73 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
74 |
+
padding="max_length",
|
75 |
+
max_length=self.max_length,
|
76 |
+
truncation=False,
|
77 |
+
return_tensors="pt",
|
78 |
+
)
|
79 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
80 |
+
|
81 |
+
tokenized = self.blip_tokenizer(
|
82 |
+
texts,
|
83 |
+
padding="max_length",
|
84 |
+
max_length=self.max_length,
|
85 |
+
truncation=False,
|
86 |
+
return_tensors="pt",
|
87 |
+
)
|
88 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
89 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
if not hasattr(self, 'h5'):
|
93 |
+
self.open_h5()
|
94 |
+
return len(self.h5['dii'])
|
datasets/vistsis.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
from transformers import CLIPTokenizer
|
8 |
+
|
9 |
+
from models.blip_override.blip import init_tokenizer
|
10 |
+
|
11 |
+
|
12 |
+
class StoryDataset(Dataset):
|
13 |
+
"""
|
14 |
+
A custom subset class for the LRW (includes train, val, test) subset
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, subset, args):
|
18 |
+
super(StoryDataset, self).__init__()
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
self.h5_file = args.get(args.dataset).hdf5_file
|
22 |
+
self.subset = subset
|
23 |
+
|
24 |
+
self.augment = transforms.Compose([
|
25 |
+
transforms.ToPILImage(),
|
26 |
+
transforms.Resize(512),
|
27 |
+
transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
transforms.Normalize([0.5], [0.5])
|
30 |
+
]) if self.subset in ['train', 'val'] else transforms.Compose([
|
31 |
+
transforms.ToPILImage(),
|
32 |
+
transforms.Resize(64),
|
33 |
+
transforms.CenterCrop(64)
|
34 |
+
])
|
35 |
+
|
36 |
+
self.dataset = args.dataset
|
37 |
+
self.max_length = args.get(args.dataset).max_length
|
38 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
39 |
+
self.blip_tokenizer = init_tokenizer()
|
40 |
+
|
41 |
+
self.blip_image_processor = transforms.Compose([
|
42 |
+
transforms.ToPILImage(),
|
43 |
+
transforms.Resize(224),
|
44 |
+
transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
47 |
+
])
|
48 |
+
|
49 |
+
def open_h5(self):
|
50 |
+
h5 = h5py.File(self.h5_file, "r")
|
51 |
+
self.h5 = h5[self.subset]
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
if not hasattr(self, 'h5'):
|
55 |
+
self.open_h5()
|
56 |
+
|
57 |
+
images = list()
|
58 |
+
for i in range(5):
|
59 |
+
im = self.h5['image{}'.format(i)][index]
|
60 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
61 |
+
images.append(im)
|
62 |
+
|
63 |
+
source_images = torch.stack([self.blip_image_processor(im) for im in images])
|
64 |
+
images = images[1:] if self.args.task == 'continuation' else images
|
65 |
+
images = [self.augment(im) for im in images]
|
66 |
+
images = torch.stack(images) if self.subset in ['train', 'val'] \
|
67 |
+
else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
|
68 |
+
|
69 |
+
texts = self.h5['sis'][index].decode('utf-8').split('|')
|
70 |
+
|
71 |
+
# tokenize caption using default tokenizer
|
72 |
+
tokenized = self.clip_tokenizer(
|
73 |
+
texts[1:] if self.args.task == 'continuation' else texts,
|
74 |
+
padding="max_length",
|
75 |
+
max_length=self.max_length,
|
76 |
+
truncation=False,
|
77 |
+
return_tensors="pt",
|
78 |
+
)
|
79 |
+
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
80 |
+
|
81 |
+
tokenized = self.blip_tokenizer(
|
82 |
+
texts,
|
83 |
+
padding="max_length",
|
84 |
+
max_length=self.max_length,
|
85 |
+
truncation=False,
|
86 |
+
return_tensors="pt",
|
87 |
+
)
|
88 |
+
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
|
89 |
+
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
if not hasattr(self, 'h5'):
|
93 |
+
self.open_h5()
|
94 |
+
return len(self.h5['sis'])
|