Quα»³nh PhΓΉng commited on
Commit
ce7c64a
Β·
1 Parent(s): 589b7f1
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. __pycache__/app.cpython-38.pyc +0 -0
  2. __pycache__/example_component.cpython-38.pyc +0 -0
  3. dataset/__init__.py +0 -0
  4. dataset/__pycache__/__init__.cpython-38.pyc +0 -0
  5. dataset/__pycache__/catalog.cpython-38.pyc +0 -0
  6. dataset/__pycache__/concat_dataset.cpython-38.pyc +0 -0
  7. dataset/base_dataset.py +220 -0
  8. dataset/catalog.py +72 -0
  9. dataset/cd_dataset.py +250 -0
  10. dataset/concat_dataset.py +65 -0
  11. dataset/grounding_dataset.py +205 -0
  12. dataset/layout_dataset.py +237 -0
  13. dataset/tsv.py +212 -0
  14. dataset/tsv_dataset.py +326 -0
  15. dataset/utils.py +116 -0
  16. gligen/__pycache__/__init__.cpython-38.pyc +0 -0
  17. gligen/__pycache__/distributed.cpython-38.pyc +0 -0
  18. gligen/__pycache__/evaluator.cpython-38.pyc +0 -0
  19. gligen/__pycache__/task_grounded_generation.cpython-38.pyc +0 -0
  20. gligen/__pycache__/trainer.cpython-38.pyc +0 -0
  21. gligen/ldm/__pycache__/util.cpython-38.pyc +0 -0
  22. gligen/ldm/models/.DS_Store +0 -0
  23. gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  24. gligen/ldm/models/autoencoder.py +52 -0
  25. gligen/ldm/models/diffusion/__init__.py +0 -0
  26. gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  27. gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  28. gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  29. gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc +0 -0
  30. gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc +0 -0
  31. gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc +0 -0
  32. gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc +0 -0
  33. gligen/ldm/models/diffusion/classifier.py +267 -0
  34. gligen/ldm/models/diffusion/ddim.py +134 -0
  35. gligen/ldm/models/diffusion/ddpm.py +72 -0
  36. gligen/ldm/models/diffusion/gaussian_smoothing.py +119 -0
  37. gligen/ldm/models/diffusion/ldm.py +88 -0
  38. gligen/ldm/models/diffusion/loss.py +170 -0
  39. gligen/ldm/models/diffusion/plms.py +295 -0
  40. gligen/ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
  41. gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc +0 -0
  42. gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc +0 -0
  43. gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
  44. gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc +0 -0
  45. gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
  46. gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc +0 -0
  47. gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc +0 -0
  48. gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc +0 -0
  49. gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc +0 -0
  50. gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc +0 -0
__pycache__/app.cpython-38.pyc ADDED
Binary file (25.8 kB). View file
 
__pycache__/example_component.cpython-38.pyc ADDED
Binary file (26.6 kB). View file
 
dataset/__init__.py ADDED
File without changes
dataset/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (139 Bytes). View file
 
dataset/__pycache__/catalog.cpython-38.pyc ADDED
Binary file (1.11 kB). View file
 
dataset/__pycache__/concat_dataset.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
dataset/base_dataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw
3
+ import torchvision.transforms as transforms
4
+ import torchvision
5
+ from zipfile import ZipFile
6
+ import os
7
+ import multiprocessing
8
+ import math
9
+ import numpy as np
10
+ import random
11
+ from io import BytesIO
12
+
13
+ VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
14
+
15
+
16
+ def check_filenames_in_zipdata(filenames, ziproot):
17
+ samples = []
18
+ for fst in ZipFile(ziproot).infolist():
19
+ fname = fst.filename
20
+ if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
21
+ continue
22
+ if os.path.splitext(fname)[1].lower() in VALID_IMAGE_TYPES:
23
+ samples.append((fname))
24
+ filenames = set(filenames)
25
+ samples = set(samples)
26
+ assert filenames.issubset(samples), 'Something wrong with your zip data'
27
+
28
+
29
+
30
+ def draw_box(img, boxes):
31
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
32
+ draw = ImageDraw.Draw(img)
33
+ for bid, box in enumerate(boxes):
34
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline =colors[bid % len(colors)], width=4)
35
+ # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
36
+ return img
37
+
38
+
39
+
40
+ def to_valid(x0, y0, x1, y1, image_size, min_box_size):
41
+ valid = True
42
+
43
+ if x0>image_size or y0>image_size or x1<0 or y1<0:
44
+ valid = False # no way to make this box vide, it is completely cropped out
45
+ return valid, (None, None, None, None)
46
+
47
+ x0 = max(x0, 0)
48
+ y0 = max(y0, 0)
49
+ x1 = min(x1, image_size)
50
+ y1 = min(y1, image_size)
51
+
52
+ if (x1-x0)*(y1-y0) / (image_size*image_size) < min_box_size:
53
+ valid = False
54
+ return valid, (None, None, None, None)
55
+
56
+ return valid, (x0, y0, x1, y1)
57
+
58
+
59
+
60
+
61
+
62
+ def recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, image_size, min_box_size):
63
+ """
64
+ x,y,w,h: the original annotation corresponding to the raw image size.
65
+ trans_info: what resizing and cropping have been applied to the raw image
66
+ image_size: what is the final image size
67
+ """
68
+
69
+ x0 = x * trans_info["performed_scale"] - trans_info['crop_x']
70
+ y0 = y * trans_info["performed_scale"] - trans_info['crop_y']
71
+ x1 = (x + w) * trans_info["performed_scale"] - trans_info['crop_x']
72
+ y1 = (y + h) * trans_info["performed_scale"] - trans_info['crop_y']
73
+
74
+
75
+ # at this point, box annotation has been recalculated based on scaling and cropping
76
+ # but some point may fall off the image_size region (e.g., negative value), thus we
77
+ # need to clamp them into 0-image_size. But if all points falling outsize of image
78
+ # region, then we will consider this is an invalid box.
79
+ valid, (x0, y0, x1, y1) = to_valid(x0, y0, x1, y1, image_size, min_box_size)
80
+
81
+ if valid:
82
+ # we also perform random flip.
83
+ # Here boxes are valid, and are based on image_size
84
+ if trans_info["performed_flip"]:
85
+ x0, x1 = image_size-x1, image_size-x0
86
+
87
+ return valid, (x0, y0, x1, y1)
88
+
89
+
90
+
91
+ class BaseDataset(torch.utils.data.Dataset):
92
+ def __init__(self, image_root, random_crop, random_flip, image_size):
93
+ super().__init__()
94
+ self.image_root = image_root
95
+ self.random_crop = random_crop
96
+ self.random_flip = random_flip
97
+ self.image_size = image_size
98
+ self.use_zip = False
99
+
100
+ if image_root[-4::] == 'zip':
101
+ self.use_zip = True
102
+ self.zip_dict = {}
103
+
104
+ if self.random_crop:
105
+ assert False, 'NOT IMPLEMENTED'
106
+
107
+
108
+ def fetch_zipfile(self, ziproot):
109
+ pid = multiprocessing.current_process().pid # get pid of this process.
110
+ if pid not in self.zip_dict:
111
+ self.zip_dict[pid] = ZipFile(ziproot)
112
+ zip_file = self.zip_dict[pid]
113
+ return zip_file
114
+
115
+ def fetch_image(self, filename):
116
+ if self.use_zip:
117
+ zip_file = self.fetch_zipfile(self.image_root)
118
+ image = Image.open( BytesIO(zip_file.read(filename)) ).convert('RGB')
119
+ return image
120
+ else:
121
+ image = Image.open( os.path.join(self.image_root,filename) ).convert('RGB')
122
+ return image
123
+
124
+
125
+ def vis_getitem_data(self, index=None, out=None, return_tensor=False, name="res.jpg", print_caption=True):
126
+
127
+ if out is None:
128
+ out = self[index]
129
+
130
+ img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 )
131
+ canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) )
132
+ W, H = img.size
133
+
134
+ if print_caption:
135
+ caption = out["caption"]
136
+ print(caption)
137
+ print(" ")
138
+
139
+ boxes = []
140
+ for box in out["boxes"]:
141
+ x0,y0,x1,y1 = box
142
+ boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
143
+ img = draw_box(img, boxes)
144
+
145
+ if return_tensor:
146
+ return torchvision.transforms.functional.to_tensor(img)
147
+ else:
148
+ img.save(name)
149
+
150
+
151
+ def transform_image(self, pil_image):
152
+ if self.random_crop:
153
+ assert False
154
+ arr = random_crop_arr(pil_image, self.image_size)
155
+ else:
156
+ arr, info = center_crop_arr(pil_image, self.image_size)
157
+
158
+ info["performed_flip"] = False
159
+ if self.random_flip and random.random()<0.5:
160
+ arr = arr[:, ::-1]
161
+ info["performed_flip"] = True
162
+
163
+ arr = arr.astype(np.float32) / 127.5 - 1
164
+ arr = np.transpose(arr, [2,0,1])
165
+
166
+ return torch.tensor(arr), info
167
+
168
+
169
+
170
+ def center_crop_arr(pil_image, image_size):
171
+ # We are not on a new enough PIL to support the `reducing_gap`
172
+ # argument, which uses BOX downsampling at powers of two first.
173
+ # Thus, we do it by hand to improve downsample quality.
174
+ WW, HH = pil_image.size
175
+
176
+ while min(*pil_image.size) >= 2 * image_size:
177
+ pil_image = pil_image.resize(
178
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
179
+ )
180
+
181
+ scale = image_size / min(*pil_image.size)
182
+
183
+ pil_image = pil_image.resize(
184
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
185
+ )
186
+
187
+ # at this point, the min of pil_image side is desired image_size
188
+ performed_scale = image_size / min(WW, HH)
189
+
190
+ arr = np.array(pil_image)
191
+ crop_y = (arr.shape[0] - image_size) // 2
192
+ crop_x = (arr.shape[1] - image_size) // 2
193
+
194
+ info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH}
195
+
196
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info
197
+
198
+
199
+ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
200
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
201
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
202
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
203
+
204
+ # We are not on a new enough PIL to support the `reducing_gap`
205
+ # argument, which uses BOX downsampling at powers of two first.
206
+ # Thus, we do it by hand to improve downsample quality.
207
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
208
+ pil_image = pil_image.resize(
209
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
210
+ )
211
+
212
+ scale = smaller_dim_size / min(*pil_image.size)
213
+ pil_image = pil_image.resize(
214
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
215
+ )
216
+
217
+ arr = np.array(pil_image)
218
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
219
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
220
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
dataset/catalog.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class DatasetCatalog:
4
+ def __init__(self, ROOT, which_embedder):
5
+ assert which_embedder in ['clip', 'bert']
6
+
7
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
8
+
9
+
10
+ self.VGGrounding = {
11
+ "target": "dataset.tsv_dataset.TSVDataset",
12
+ "train_params": dict(
13
+ tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'),
14
+ )
15
+ }
16
+
17
+
18
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
19
+
20
+
21
+ self.FlickrGrounding = {
22
+ "target": "dataset.tsv_dataset.TSVDataset",
23
+ "train_params":dict(
24
+ tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'),
25
+ )
26
+ }
27
+
28
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
29
+
30
+ self.SBUGrounding = {
31
+ "target": "dataset.tsv_dataset.TSVDataset",
32
+ "train_params":dict(
33
+ tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'),
34
+ )
35
+ }
36
+
37
+
38
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
39
+
40
+
41
+ self.CC3MGrounding = {
42
+ "target": "dataset.tsv_dataset.TSVDataset",
43
+ "train_params":dict(
44
+ tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'),
45
+ )
46
+ }
47
+
48
+
49
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
50
+
51
+
52
+ self.CC12MGrounding = {
53
+ "target": "dataset.tsv_dataset.TSVDataset",
54
+ "train_params":dict(
55
+ tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'),
56
+ )
57
+ }
58
+
59
+
60
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
61
+
62
+ # temp = 'category_embedding_clip.pth' if which_embedder == 'clip' else 'category_embedding_bert.pth'
63
+ # obj365_category_embedding_path = os.path.join(ROOT, 'OBJECTS365', temp)
64
+
65
+ self.Obj365Detection = {
66
+ "target": "dataset.tsv_dataset.TSVDataset",
67
+ "train_params":dict(
68
+ tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'),
69
+ ),
70
+ }
71
+
72
+
dataset/cd_dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, random, math
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torchvision.transforms as transforms
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
12
+ from io import BytesIO
13
+
14
+
15
+
16
+ def not_in_at_all(list1, list2):
17
+ for a in list1:
18
+ if a in list2:
19
+ return False
20
+ return True
21
+
22
+
23
+ def clean_annotations(annotations):
24
+ for anno in annotations:
25
+ anno.pop("segmentation", None)
26
+ anno.pop("area", None)
27
+ anno.pop("iscrowd", None)
28
+ # anno.pop("id", None)
29
+
30
+
31
+ def make_a_sentence(obj_names, clean=False):
32
+
33
+ if clean:
34
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
35
+
36
+ caption = ""
37
+ tokens_positive = []
38
+ for obj_name in obj_names:
39
+ start_len = len(caption)
40
+ caption += obj_name
41
+ end_len = len(caption)
42
+ caption += ", "
43
+ tokens_positive.append(
44
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
45
+ )
46
+ caption = caption[:-2] # remove last ", "
47
+
48
+ return caption #, tokens_positive
49
+
50
+
51
+ def check_all_have_same_images(instances_data, stuff_data, caption_data):
52
+ if stuff_data is not None:
53
+ assert instances_data["images"] == stuff_data["images"]
54
+ if caption_data is not None:
55
+ assert instances_data["images"] == caption_data["images"]
56
+
57
+
58
+ class CDDataset(BaseDataset):
59
+ "CD: Caption Detection"
60
+ def __init__(self,
61
+ image_root,
62
+ category_embedding_path,
63
+ instances_json_path = None,
64
+ stuff_json_path = None,
65
+ caption_json_path = None,
66
+ prob_real_caption = 0,
67
+ fake_caption_type = 'empty',
68
+ image_size=256,
69
+ max_images=None,
70
+ min_box_size=0.01,
71
+ max_boxes_per_image=8,
72
+ include_other=False,
73
+ random_crop = False,
74
+ random_flip = True,
75
+ ):
76
+ super().__init__(random_crop, random_flip, image_size)
77
+
78
+ self.image_root = image_root
79
+ self.category_embedding_path = category_embedding_path
80
+ self.instances_json_path = instances_json_path
81
+ self.stuff_json_path = stuff_json_path
82
+ self.caption_json_path = caption_json_path
83
+ self.prob_real_caption = prob_real_caption
84
+ self.fake_caption_type = fake_caption_type
85
+ self.max_images = max_images
86
+ self.min_box_size = min_box_size
87
+ self.max_boxes_per_image = max_boxes_per_image
88
+ self.include_other = include_other
89
+
90
+
91
+ assert fake_caption_type in ["empty", "made"]
92
+ if prob_real_caption > 0:
93
+ assert caption_json_path is not None, "caption json must be given"
94
+
95
+
96
+ # Load all jsons
97
+ with open(instances_json_path, 'r') as f:
98
+ instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
99
+ clean_annotations(instances_data["annotations"])
100
+ self.instances_data = instances_data
101
+
102
+ self.stuff_data = None
103
+ if stuff_json_path is not None:
104
+ with open(stuff_json_path, 'r') as f:
105
+ stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
106
+ clean_annotations(stuff_data["annotations"])
107
+ self.stuff_data = stuff_data
108
+
109
+ self.captions_data = None
110
+ if caption_json_path is not None:
111
+ with open(caption_json_path, 'r') as f:
112
+ captions_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
113
+ clean_annotations(captions_data["annotations"])
114
+ self.captions_data = captions_data
115
+
116
+
117
+ # Load preprocessed name embedding
118
+ self.category_embeddings = torch.load(category_embedding_path)
119
+ self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
120
+
121
+
122
+ # Misc
123
+ self.image_ids = [] # main list for selecting images
124
+ self.image_id_to_filename = {} # file names used to read image
125
+ check_all_have_same_images(self.instances_data, self.stuff_data, self.captions_data)
126
+ for image_data in self.instances_data['images']:
127
+ image_id = image_data['id']
128
+ filename = image_data['file_name']
129
+ self.image_ids.append(image_id)
130
+ self.image_id_to_filename[image_id] = filename
131
+
132
+
133
+ # All category names (including things and stuff)
134
+ self.object_idx_to_name = {}
135
+ for category_data in self.instances_data['categories']:
136
+ self.object_idx_to_name[category_data['id']] = category_data['name']
137
+ if self.stuff_data is not None:
138
+ for category_data in self.stuff_data['categories']:
139
+ self.object_idx_to_name[category_data['id']] = category_data['name']
140
+
141
+
142
+ # Add object data from instances and stuff
143
+ self.image_id_to_objects = defaultdict(list)
144
+ self.select_objects( self.instances_data['annotations'] )
145
+ if self.stuff_data is not None:
146
+ self.select_objects( self.stuff_data['annotations'] )
147
+
148
+ # Add caption data
149
+ if self.captions_data is not None:
150
+ self.image_id_to_captions = defaultdict(list)
151
+ self.select_captions( self.captions_data['annotations'] )
152
+
153
+ # Check if all filenames can be found in the zip file
154
+ # all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids]
155
+ # check_filenames_in_zipdata(all_filenames, image_root)
156
+
157
+
158
+ def select_objects(self, annotations):
159
+ for object_anno in annotations:
160
+ image_id = object_anno['image_id']
161
+ object_name = self.object_idx_to_name[object_anno['category_id']]
162
+ other_ok = object_name != 'other' or self.include_other
163
+ if other_ok:
164
+ self.image_id_to_objects[image_id].append(object_anno)
165
+
166
+
167
+ def select_captions(self, annotations):
168
+ for caption_data in annotations:
169
+ image_id = caption_data['image_id']
170
+ self.image_id_to_captions[image_id].append(caption_data)
171
+
172
+
173
+ def total_images(self):
174
+ return len(self)
175
+
176
+
177
+ def __getitem__(self, index):
178
+ if self.max_boxes_per_image > 99:
179
+ assert False, "Are you sure setting such large number of boxes?"
180
+
181
+ out = {}
182
+
183
+ image_id = self.image_ids[index]
184
+ out['id'] = image_id
185
+
186
+ # Image
187
+ filename = self.image_id_to_filename[image_id]
188
+ image = self.fetch_image(filename)
189
+ #WW, HH = image.size
190
+ image_tensor, trans_info = self.transform_image(image)
191
+ out["image"] = image_tensor
192
+
193
+
194
+ # Select valid boxes after cropping (center or random)
195
+ this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id])
196
+ areas = []
197
+ all_obj_names = []
198
+ all_boxes = []
199
+ all_masks = []
200
+ all_positive_embeddings = []
201
+ for object_anno in this_image_obj_annos:
202
+
203
+ x, y, w, h = object_anno['bbox']
204
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
205
+
206
+ if valid:
207
+ areas.append( (x1-x0)*(y1-y0) )
208
+ obj_name = self.object_idx_to_name[ object_anno['category_id'] ]
209
+ all_obj_names.append(obj_name)
210
+ all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
211
+ all_masks.append(1)
212
+ all_positive_embeddings.append( self.category_embeddings[obj_name] )
213
+
214
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
215
+ wanted_idxs = wanted_idxs[0:self.max_boxes_per_image]
216
+ obj_names = [] # used for making a sentence
217
+ boxes = torch.zeros(self.max_boxes_per_image, 4)
218
+ masks = torch.zeros(self.max_boxes_per_image)
219
+ positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len)
220
+ for i, idx in enumerate(wanted_idxs):
221
+ obj_names.append( all_obj_names[idx] )
222
+ boxes[i] = all_boxes[idx]
223
+ masks[i] = all_masks[idx]
224
+ positive_embeddings[i] = all_positive_embeddings[idx]
225
+
226
+ # Caption
227
+ if random.uniform(0, 1) < self.prob_real_caption:
228
+ caption_data = self.image_id_to_captions[image_id]
229
+ idx = random.randint(0, len(caption_data)-1 )
230
+ caption = caption_data[idx]["caption"]
231
+ else:
232
+ if self.fake_caption_type == "empty":
233
+ caption = ""
234
+ else:
235
+ caption = make_a_sentence(obj_names, clean=True)
236
+
237
+
238
+ out["caption"] = caption
239
+ out["boxes"] = boxes
240
+ out["masks"] = masks
241
+ out["positive_embeddings"] = positive_embeddings
242
+
243
+ return out
244
+
245
+
246
+ def __len__(self):
247
+ if self.max_images is None:
248
+ return len(self.image_ids)
249
+ return min(len(self.image_ids), self.max_images)
250
+
dataset/concat_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .catalog import DatasetCatalog
2
+ from ldm.util import instantiate_from_config
3
+ import torch
4
+
5
+
6
+
7
+
8
+ class ConCatDataset():
9
+ def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None):
10
+ self.datasets = []
11
+ cul_previous_dataset_length = 0
12
+ offset_map = []
13
+ which_dataset = []
14
+
15
+ if repeats is None:
16
+ repeats = [1] * len(dataset_name_list)
17
+ else:
18
+ assert len(repeats) == len(dataset_name_list)
19
+
20
+
21
+ Catalog = DatasetCatalog(ROOT, which_embedder)
22
+ for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()):
23
+ repeat = repeats[dataset_idx]
24
+
25
+ dataset_dict = getattr(Catalog, dataset_name)
26
+
27
+ target = dataset_dict['target']
28
+ params = dataset_dict['train_params'] if train else dataset_dict['val_params']
29
+ if yaml_params is not None:
30
+ params.update(yaml_params)
31
+ dataset = instantiate_from_config( dict(target=target, params=params) )
32
+
33
+ self.datasets.append(dataset)
34
+ for _ in range(repeat):
35
+ offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length )
36
+ which_dataset.append( torch.ones(len(dataset))*dataset_idx )
37
+ cul_previous_dataset_length += len(dataset)
38
+ offset_map = torch.cat(offset_map, dim=0).long()
39
+ self.total_length = cul_previous_dataset_length
40
+
41
+ self.mapping = torch.arange(self.total_length) - offset_map
42
+ self.which_dataset = torch.cat(which_dataset, dim=0).long()
43
+
44
+
45
+ def total_images(self):
46
+ count = 0
47
+ for dataset in self.datasets:
48
+ print(dataset.total_images())
49
+ count += dataset.total_images()
50
+ return count
51
+
52
+
53
+
54
+ def __getitem__(self, idx):
55
+ dataset = self.datasets[ self.which_dataset[idx] ]
56
+ return dataset[ self.mapping[idx] ]
57
+
58
+
59
+ def __len__(self):
60
+ return self.total_length
61
+
62
+
63
+
64
+
65
+
dataset/grounding_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tkinter.messagebox import NO
2
+ import torch
3
+ import json
4
+ from collections import defaultdict
5
+ from PIL import Image, ImageDraw
6
+ from copy import deepcopy
7
+ import os
8
+ import torchvision.transforms as transforms
9
+ import torchvision
10
+ from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
11
+ from io import BytesIO
12
+ import random
13
+
14
+ def check_unique(images, fields):
15
+ for field in fields:
16
+ temp_list = []
17
+ for img_info in images:
18
+ temp_list.append(img_info[field])
19
+ assert len(set(temp_list)) == len(temp_list), field
20
+
21
+ def clean_data(data):
22
+ for data_info in data:
23
+ data_info.pop("original_img_id", None)
24
+ data_info.pop("original_id", None)
25
+ data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image)
26
+ data_info.pop("dataset_name", None)
27
+ data_info.pop("data_source", None)
28
+ data_info["data_id"] = data_info.pop("id")
29
+
30
+
31
+ def clean_annotations(annotations):
32
+ for anno_info in annotations:
33
+ anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco
34
+ anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation
35
+ anno_info.pop("area", None)
36
+ # anno_info.pop("id", None)
37
+ anno_info["data_id"] = anno_info.pop("image_id")
38
+
39
+
40
+ def draw_box(img, boxes):
41
+ draw = ImageDraw.Draw(img)
42
+ for box in boxes:
43
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
44
+ return img
45
+
46
+
47
+ def xyhw2xyxy(box):
48
+ x0, y0, w, h = box
49
+ return [ x0, y0, x0+w, y0+h ]
50
+
51
+
52
+
53
+ class GroundingDataset(BaseDataset):
54
+ def __init__(self,
55
+ image_root,
56
+ json_path,
57
+ annotation_embedding_path,
58
+ prob_real_caption=1,
59
+ image_size=256,
60
+ min_box_size=0.01,
61
+ max_boxes_per_data=8,
62
+ max_images=None, # set as 30K used to eval
63
+ random_crop = False,
64
+ random_flip = True,
65
+ ):
66
+ super().__init__(image_root, random_crop, random_flip, image_size)
67
+ self.image_root = image_root
68
+ self.json_path = json_path
69
+ self.annotation_embedding_path = annotation_embedding_path
70
+ self.prob_real_caption = prob_real_caption
71
+ self.min_box_size = min_box_size
72
+ self.max_boxes_per_data = max_boxes_per_data
73
+ self.max_images = max_images
74
+
75
+
76
+ # Load raw data
77
+ with open(json_path, 'r') as f:
78
+ json_raw = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
79
+ self.data = json_raw["images"] # donot name it images, which is misleading
80
+ self.annotations = json_raw["annotations"]
81
+
82
+
83
+ # Load preprocessed name embedding
84
+ if 'bert' in annotation_embedding_path:
85
+ self.embedding_len = 1280
86
+ elif 'clip' in annotation_embedding_path:
87
+ self.embedding_len = 768
88
+ else:
89
+ assert False
90
+
91
+
92
+ # clean data and annotation
93
+ check_unique( self.data, ['id'] )
94
+ check_unique( self.annotations, ['id'] )
95
+ clean_data(self.data)
96
+ clean_annotations(self.annotations)
97
+ self.data_id_list = [ datum['data_id'] for datum in self.data ]
98
+ self.data = { datum['data_id']:datum for datum in self.data } # map self.data from a list into a dict
99
+
100
+
101
+ # data point to its annotation mapping
102
+ self.data_id_to_annos = defaultdict(list)
103
+ for anno in self.annotations:
104
+ self.data_id_to_annos[ anno["data_id"] ].append(anno)
105
+
106
+
107
+
108
+ # These are not used that offen, but are useful in some cases
109
+ self.file_names = [] # all training images
110
+ self.file_name_to_data_ids = defaultdict(list) # for each image, there are multiple data points (captions)
111
+ for data_id in self.data_id_list:
112
+ fine_name = self.data[data_id]["file_name"]
113
+ self.file_names.append(fine_name)
114
+ self.file_name_to_data_ids[fine_name].append(data_id)
115
+ self.file_names = list(set(self.file_names))
116
+
117
+
118
+ if self.max_images is not None:
119
+ "This is only used as COCO2017P evulation, when we set max_images as 30k"
120
+ assert False, 'I have commented out the following code to save cpu memory'
121
+ # new_data_id_list = []
122
+ # new_file_name_to_data_ids = defaultdict(list)
123
+ # self.file_names = self.file_names[0:self.max_images]
124
+ # for file_name in self.file_names:
125
+ # data_id = self.file_name_to_data_ids[file_name][0]
126
+ # new_data_id_list.append(data_id)
127
+ # new_file_name_to_data_ids[file_name].append(data_id)
128
+ # self.data_id_list = new_data_id_list
129
+ # self.file_name_to_data_ids = new_file_name_to_data_ids
130
+
131
+
132
+ # Check if all filenames can be found in the zip file
133
+ # all_filenames = [self.data[idx]['file_name'] for idx in self.data_id_list ]
134
+ # check_filenames_in_zipdata(all_filenames, image_root)
135
+
136
+
137
+ def total_images(self):
138
+ return len(self.file_names)
139
+
140
+
141
+ def __getitem__(self, index):
142
+ if self.max_boxes_per_data > 99:
143
+ assert False, "Are you sure setting such large number of boxes?"
144
+
145
+ out = {}
146
+
147
+ data_id = self.data_id_list[index]
148
+ out['id'] = data_id
149
+
150
+
151
+ # Image and caption
152
+ file_name = self.data[data_id]['file_name']
153
+ image = self.fetch_image(file_name)
154
+ image_tensor, trans_info = self.transform_image(image)
155
+ out["image"] = image_tensor
156
+
157
+ if random.uniform(0, 1) < self.prob_real_caption:
158
+ out["caption"] = self.data[data_id]["caption"]
159
+ else:
160
+ out["caption"] = ""
161
+
162
+
163
+
164
+ annos = deepcopy(self.data_id_to_annos[data_id])
165
+ areas = []
166
+ all_boxes = []
167
+ all_masks = []
168
+ all_positive_embeddings = []
169
+
170
+
171
+ for anno in annos:
172
+
173
+ x, y, w, h = anno['bbox']
174
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
175
+
176
+ if valid:
177
+ areas.append( (x1-x0)*(y1-y0) )
178
+ all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
179
+ all_masks.append(1)
180
+ all_positive_embeddings.append( torch.load(os.path.join(self.annotation_embedding_path,str(anno["id"])), map_location='cpu' ) )
181
+
182
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
183
+ wanted_idxs = wanted_idxs[0:self.max_boxes_per_data]
184
+
185
+ boxes = torch.zeros(self.max_boxes_per_data, 4)
186
+ masks = torch.zeros(self.max_boxes_per_data)
187
+ positive_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
188
+ for i, idx in enumerate(wanted_idxs):
189
+ boxes[i] = all_boxes[idx]
190
+ masks[i] = all_masks[idx]
191
+ positive_embeddings[i] = all_positive_embeddings[idx]
192
+
193
+
194
+ out["boxes"] = boxes
195
+ out["masks"] = masks
196
+ out["positive_embeddings"] = positive_embeddings
197
+
198
+ return out
199
+
200
+
201
+
202
+ def __len__(self):
203
+ return len(self.data_id_list)
204
+
205
+
dataset/layout_dataset.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, random, math
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torchvision.transforms as transforms
8
+
9
+ import numpy as np
10
+ from PIL import Image, ImageOps
11
+ from .base_dataset import BaseDataset, check_filenames_in_zipdata
12
+ from io import BytesIO
13
+
14
+
15
+
16
+
17
+ def clean_annotations(annotations):
18
+ for anno in annotations:
19
+ anno.pop("segmentation", None)
20
+ anno.pop("area", None)
21
+ anno.pop("iscrowd", None)
22
+ anno.pop("id", None)
23
+
24
+
25
+ def make_a_sentence(obj_names, clean=False):
26
+
27
+ if clean:
28
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
29
+
30
+ caption = ""
31
+ tokens_positive = []
32
+ for obj_name in obj_names:
33
+ start_len = len(caption)
34
+ caption += obj_name
35
+ end_len = len(caption)
36
+ caption += ", "
37
+ tokens_positive.append(
38
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
39
+ )
40
+ caption = caption[:-2] # remove last ", "
41
+
42
+ return caption #, tokens_positive
43
+
44
+
45
+ class LayoutDataset(BaseDataset):
46
+ """
47
+ Note: this dataset can somehow be achieved in cd_dataset.CDDataset
48
+ Since if you donot set prob_real_caption=0 in CDDataset, then that
49
+ dataset will only use detection annotations. However, in that dataset,
50
+ we do not remove images but remove boxes.
51
+
52
+ However, in layout2img works, people will just resize raw image data into 256*256,
53
+ thus they pre-calculate box size and apply min_box_size before min/max_boxes_per_image.
54
+ And then they will remove images if does not follow the rule.
55
+
56
+ These two different methods will lead to different number of training/val images.
57
+ Thus this dataset here is only for layout2img.
58
+
59
+ """
60
+ def __init__(self,
61
+ image_root,
62
+ instances_json_path,
63
+ stuff_json_path,
64
+ category_embedding_path,
65
+ fake_caption_type = 'empty',
66
+ image_size=256,
67
+ max_samples=None,
68
+ min_box_size=0.02,
69
+ min_boxes_per_image=3,
70
+ max_boxes_per_image=8,
71
+ include_other=False,
72
+ random_flip=True
73
+ ):
74
+ super().__init__(random_crop=None, random_flip=None, image_size=None) # we only use vis_getitem func in BaseDataset, donot use the others.
75
+
76
+ assert fake_caption_type in ['empty', 'made']
77
+ self.image_root = image_root
78
+ self.instances_json_path = instances_json_path
79
+ self.stuff_json_path = stuff_json_path
80
+ self.category_embedding_path = category_embedding_path
81
+ self.fake_caption_type = fake_caption_type
82
+ self.image_size = image_size
83
+ self.max_samples = max_samples
84
+ self.min_box_size = min_box_size
85
+ self.min_boxes_per_image = min_boxes_per_image
86
+ self.max_boxes_per_image = max_boxes_per_image
87
+ self.include_other = include_other
88
+ self.random_flip = random_flip
89
+
90
+
91
+ self.transform = transforms.Compose([transforms.Resize( (image_size, image_size) ),
92
+ transforms.ToTensor(),
93
+ transforms.Lambda(lambda t: (t * 2) - 1) ])
94
+
95
+ # Load all jsons
96
+ with open(instances_json_path, 'r') as f:
97
+ instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
98
+ clean_annotations(instances_data["annotations"])
99
+ self.instances_data = instances_data
100
+
101
+ with open(stuff_json_path, 'r') as f:
102
+ stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
103
+ clean_annotations(stuff_data["annotations"])
104
+ self.stuff_data = stuff_data
105
+
106
+
107
+ # Load preprocessed name embedding
108
+ self.category_embeddings = torch.load(category_embedding_path)
109
+ self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
110
+
111
+
112
+ # Misc
113
+ self.image_ids = [] # main list for selecting images
114
+ self.image_id_to_filename = {} # file names used to read image
115
+ self.image_id_to_size = {} # original size of this image
116
+ assert instances_data['images'] == stuff_data["images"]
117
+ for image_data in instances_data['images']:
118
+ image_id = image_data['id']
119
+ filename = image_data['file_name']
120
+ width = image_data['width']
121
+ height = image_data['height']
122
+ self.image_ids.append(image_id)
123
+ self.image_id_to_filename[image_id] = filename
124
+ self.image_id_to_size[image_id] = (width, height)
125
+
126
+ # All category names (including things and stuff)
127
+ self.things_id_list = []
128
+ self.stuff_id_list = []
129
+ self.object_idx_to_name = {}
130
+ for category_data in instances_data['categories']:
131
+ self.things_id_list.append( category_data['id'] )
132
+ self.object_idx_to_name[category_data['id']] = category_data['name']
133
+ for category_data in stuff_data['categories']:
134
+ self.stuff_id_list.append( category_data['id'] )
135
+ self.object_idx_to_name[category_data['id']] = category_data['name']
136
+ self.all_categories = [ self.object_idx_to_name.get(k, None) for k in range(183+1) ]
137
+
138
+
139
+ # Add object data from instances and stuff
140
+ self.image_id_to_objects = defaultdict(list)
141
+ self.select_objects( instances_data['annotations'] )
142
+ self.select_objects( stuff_data['annotations'] )
143
+
144
+
145
+ # Prune images that have too few or too many objects
146
+ new_image_ids = []
147
+ for image_id in self.image_ids:
148
+ num_objs = len(self.image_id_to_objects[image_id])
149
+ if self.min_boxes_per_image <= num_objs <= self.max_boxes_per_image:
150
+ new_image_ids.append(image_id)
151
+ self.image_ids = new_image_ids
152
+
153
+
154
+ # Check if all filenames can be found in the zip file
155
+ all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids]
156
+ check_filenames_in_zipdata(all_filenames, image_root)
157
+
158
+
159
+
160
+ def select_objects(self, annotations):
161
+ for object_anno in annotations:
162
+ image_id = object_anno['image_id']
163
+ _, _, w, h = object_anno['bbox']
164
+ W, H = self.image_id_to_size[image_id]
165
+ box_area = (w * h) / (W * H)
166
+ box_ok = box_area > self.min_box_size
167
+ object_name = self.object_idx_to_name[object_anno['category_id']]
168
+ other_ok = object_name != 'other' or self.include_other
169
+ if box_ok and other_ok:
170
+ self.image_id_to_objects[image_id].append(object_anno)
171
+
172
+
173
+ def total_images(self):
174
+ return len(self)
175
+
176
+
177
+ def __getitem__(self, index):
178
+ if self.max_boxes_per_image > 99:
179
+ assert False, "Are you sure setting such large number of boxes?"
180
+
181
+ out = {}
182
+
183
+ image_id = self.image_ids[index]
184
+ out['id'] = image_id
185
+
186
+ flip = self.random_flip and random.random()<0.5
187
+
188
+ # Image
189
+ filename = self.image_id_to_filename[image_id]
190
+ zip_file = self.fetch_zipfile(self.image_root)
191
+ image = Image.open(BytesIO(zip_file.read(filename))).convert('RGB')
192
+ WW, HH = image.size
193
+ if flip:
194
+ image = ImageOps.mirror(image)
195
+ out["image"] = self.transform(image)
196
+
197
+ this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id])
198
+
199
+ # Make a sentence
200
+ obj_names = [] # used for make a sentence
201
+ boxes = torch.zeros(self.max_boxes_per_image, 4)
202
+ masks = torch.zeros(self.max_boxes_per_image)
203
+ positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len)
204
+ for idx, object_anno in enumerate(this_image_obj_annos):
205
+ obj_name = self.object_idx_to_name[ object_anno['category_id'] ]
206
+ obj_names.append(obj_name)
207
+ x, y, w, h = object_anno['bbox']
208
+ x0 = x / WW
209
+ y0 = y / HH
210
+ x1 = (x + w) / WW
211
+ y1 = (y + h) / HH
212
+ if flip:
213
+ x0, x1 = 1-x1, 1-x0
214
+ boxes[idx] = torch.tensor([x0,y0,x1,y1])
215
+ masks[idx] = 1
216
+ positive_embeddings[idx] = self.category_embeddings[obj_name]
217
+
218
+ if self.fake_caption_type == 'empty':
219
+ caption = ""
220
+ else:
221
+ caption = make_a_sentence(obj_names, clean=True)
222
+
223
+ out["caption"] = caption
224
+ out["boxes"] = boxes
225
+ out["masks"] = masks
226
+ out["positive_embeddings"] = positive_embeddings
227
+
228
+
229
+ return out
230
+
231
+
232
+ def __len__(self):
233
+ if self.max_samples is None:
234
+ return len(self.image_ids)
235
+ return min(len(self.image_ids), self.max_samples)
236
+
237
+
dataset/tsv.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as op
3
+ import gc
4
+ import json
5
+ from typing import List
6
+ import logging
7
+
8
+ try:
9
+ from .blob_storage import BlobStorage, disk_usage
10
+ except:
11
+ class BlobStorage:
12
+ pass
13
+
14
+
15
+ def generate_lineidx(filein: str, idxout: str) -> None:
16
+ idxout_tmp = idxout + '.tmp'
17
+ with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
18
+ fsize = os.fstat(tsvin.fileno()).st_size
19
+ fpos = 0
20
+ while fpos != fsize:
21
+ tsvout.write(str(fpos) + "\n")
22
+ tsvin.readline()
23
+ fpos = tsvin.tell()
24
+ os.rename(idxout_tmp, idxout)
25
+
26
+
27
+ def read_to_character(fp, c):
28
+ result = []
29
+ while True:
30
+ s = fp.read(32)
31
+ assert s != ''
32
+ if c in s:
33
+ result.append(s[: s.index(c)])
34
+ break
35
+ else:
36
+ result.append(s)
37
+ return ''.join(result)
38
+
39
+
40
+ class TSVFile(object):
41
+ def __init__(self,
42
+ tsv_file: str,
43
+ if_generate_lineidx: bool = False,
44
+ lineidx: str = None,
45
+ class_selector: List[str] = None,
46
+ blob_storage: BlobStorage = None):
47
+ self.tsv_file = tsv_file
48
+ self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
49
+ if not lineidx else lineidx
50
+ self.linelist = op.splitext(tsv_file)[0] + '.linelist'
51
+ self.chunks = op.splitext(tsv_file)[0] + '.chunks'
52
+ self._fp = None
53
+ self._lineidx = None
54
+ self._sample_indices = None
55
+ self._class_boundaries = None
56
+ self._class_selector = class_selector
57
+ self._blob_storage = blob_storage
58
+ self._len = None
59
+ # the process always keeps the process which opens the file.
60
+ # If the pid is not equal to the currrent pid, we will re-open the file.
61
+ self.pid = None
62
+ # generate lineidx if not exist
63
+ if not op.isfile(self.lineidx) and if_generate_lineidx:
64
+ generate_lineidx(self.tsv_file, self.lineidx)
65
+
66
+ def __del__(self):
67
+ self.gcidx()
68
+ if self._fp:
69
+ self._fp.close()
70
+ # physically remove the tsv file if it is retrieved by BlobStorage
71
+ if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
72
+ try:
73
+ original_usage = disk_usage('/')
74
+ os.remove(self.tsv_file)
75
+ logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
76
+ (self.tsv_file, original_usage, disk_usage('/') * 100))
77
+ except:
78
+ # Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
79
+ # TODO: try Threadling.Lock to better handle the race condition
80
+ pass
81
+
82
+ def __str__(self):
83
+ return "TSVFile(tsv_file='{}')".format(self.tsv_file)
84
+
85
+ def __repr__(self):
86
+ return str(self)
87
+
88
+ def gcidx(self):
89
+ logging.debug('Run gc collect')
90
+ self._lineidx = None
91
+ self._sample_indices = None
92
+ #self._class_boundaries = None
93
+ return gc.collect()
94
+
95
+ def get_class_boundaries(self):
96
+ return self._class_boundaries
97
+
98
+ def num_rows(self, gcf=False):
99
+ if (self._len is None):
100
+ self._ensure_lineidx_loaded()
101
+ retval = len(self._sample_indices)
102
+
103
+ if (gcf):
104
+ self.gcidx()
105
+
106
+ self._len = retval
107
+
108
+ return self._len
109
+
110
+ def seek(self, idx: int):
111
+ self._ensure_tsv_opened()
112
+ self._ensure_lineidx_loaded()
113
+ try:
114
+ pos = self._lineidx[self._sample_indices[idx]]
115
+ except:
116
+ logging.info('=> {}-{}'.format(self.tsv_file, idx))
117
+ raise
118
+ self._fp.seek(pos)
119
+ return [s.strip() for s in self._fp.readline().split('\t')]
120
+
121
+ def seek_first_column(self, idx: int):
122
+ self._ensure_tsv_opened()
123
+ self._ensure_lineidx_loaded()
124
+ pos = self._lineidx[idx]
125
+ self._fp.seek(pos)
126
+ return read_to_character(self._fp, '\t')
127
+
128
+ def get_key(self, idx: int):
129
+ return self.seek_first_column(idx)
130
+
131
+ def __getitem__(self, index: int):
132
+ return self.seek(index)
133
+
134
+ def __len__(self):
135
+ return self.num_rows()
136
+
137
+ def _ensure_lineidx_loaded(self):
138
+ if self._lineidx is None:
139
+ logging.debug('=> loading lineidx: {}'.format(self.lineidx))
140
+ with open(self.lineidx, 'r') as fp:
141
+ lines = fp.readlines()
142
+ lines = [line.strip() for line in lines]
143
+ self._lineidx = [int(line) for line in lines]
144
+
145
+ # read the line list if exists
146
+ linelist = None
147
+ if op.isfile(self.linelist):
148
+ with open(self.linelist, 'r') as fp:
149
+ linelist = sorted(
150
+ [
151
+ int(line.strip())
152
+ for line in fp.readlines()
153
+ ]
154
+ )
155
+
156
+ if op.isfile(self.chunks):
157
+ self._sample_indices = []
158
+ self._class_boundaries = []
159
+ class_boundaries = json.load(open(self.chunks, 'r'))
160
+ for class_name, boundary in class_boundaries.items():
161
+ start = len(self._sample_indices)
162
+ if class_name in self._class_selector:
163
+ for idx in range(boundary[0], boundary[1] + 1):
164
+ # NOTE: potentially slow when linelist is long, try to speed it up
165
+ if linelist and idx not in linelist:
166
+ continue
167
+ self._sample_indices.append(idx)
168
+ end = len(self._sample_indices)
169
+ self._class_boundaries.append((start, end))
170
+ else:
171
+ if linelist:
172
+ self._sample_indices = linelist
173
+ else:
174
+ self._sample_indices = list(range(len(self._lineidx)))
175
+
176
+ def _ensure_tsv_opened(self):
177
+ if self._fp is None:
178
+ if self._blob_storage:
179
+ self._fp = self._blob_storage.open(self.tsv_file)
180
+ else:
181
+ self._fp = open(self.tsv_file, 'r')
182
+ self.pid = os.getpid()
183
+
184
+ if self.pid != os.getpid():
185
+ logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
186
+ self._fp = open(self.tsv_file, 'r')
187
+ self.pid = os.getpid()
188
+
189
+
190
+ class TSVWriter(object):
191
+ def __init__(self, tsv_file):
192
+ self.tsv_file = tsv_file
193
+ self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
194
+ self.tsv_file_tmp = self.tsv_file + '.tmp'
195
+ self.lineidx_file_tmp = self.lineidx_file + '.tmp'
196
+
197
+ self.tsv_fp = open(self.tsv_file_tmp, 'w')
198
+ self.lineidx_fp = open(self.lineidx_file_tmp, 'w')
199
+
200
+ self.idx = 0
201
+
202
+ def write(self, values, sep='\t'):
203
+ v = '{0}\n'.format(sep.join(map(str, values)))
204
+ self.tsv_fp.write(v)
205
+ self.lineidx_fp.write(str(self.idx) + '\n')
206
+ self.idx = self.idx + len(v)
207
+
208
+ def close(self):
209
+ self.tsv_fp.close()
210
+ self.lineidx_fp.close()
211
+ os.rename(self.tsv_file_tmp, self.tsv_file)
212
+ os.rename(self.lineidx_file_tmp, self.lineidx_file)
dataset/tsv_dataset.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tkinter.messagebox import NO
2
+ import torch
3
+ import json
4
+ from collections import defaultdict
5
+ from PIL import Image, ImageDraw
6
+ from copy import deepcopy
7
+ import os
8
+ import torchvision.transforms as transforms
9
+ import torchvision
10
+ from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
11
+ from io import BytesIO
12
+ import random
13
+
14
+ from .tsv import TSVFile
15
+
16
+ from io import BytesIO
17
+ import base64
18
+ from PIL import Image
19
+ import numpy as np
20
+
21
+
22
+ def decode_base64_to_pillow(image_b64):
23
+ return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB')
24
+
25
+ def decode_tensor_from_string(arr_str, use_tensor=True):
26
+ arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32')
27
+ if use_tensor:
28
+ arr = torch.from_numpy(arr)
29
+ return arr
30
+
31
+ def decode_item(item):
32
+ item = json.loads(item)
33
+ item['image'] = decode_base64_to_pillow(item['image'])
34
+
35
+ for anno in item['annos']:
36
+ anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before'])
37
+ anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before'])
38
+ anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after'])
39
+ anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after'])
40
+ return item
41
+
42
+ def check_unique(images, fields):
43
+ for field in fields:
44
+ temp_list = []
45
+ for img_info in images:
46
+ temp_list.append(img_info[field])
47
+ assert len(set(temp_list)) == len(temp_list), field
48
+
49
+ def clean_data(data):
50
+ for data_info in data:
51
+ data_info.pop("original_img_id", None)
52
+ data_info.pop("original_id", None)
53
+ data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image)
54
+ data_info.pop("dataset_name", None)
55
+ data_info.pop("data_source", None)
56
+ data_info["data_id"] = data_info.pop("id")
57
+
58
+
59
+ def clean_annotations(annotations):
60
+ for anno_info in annotations:
61
+ anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco
62
+ anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation
63
+ anno_info.pop("area", None)
64
+ # anno_info.pop("id", None)
65
+ anno_info["data_id"] = anno_info.pop("image_id")
66
+
67
+
68
+ def draw_box(img, boxes):
69
+ draw = ImageDraw.Draw(img)
70
+ for box in boxes:
71
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
72
+ return img
73
+
74
+
75
+ def xyhw2xyxy(box):
76
+ x0, y0, w, h = box
77
+ return [ x0, y0, x0+w, y0+h ]
78
+
79
+
80
+ def make_a_sentence(obj_names, clean=False):
81
+
82
+ if clean:
83
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
84
+
85
+ caption = ""
86
+ tokens_positive = []
87
+ for obj_name in obj_names:
88
+ start_len = len(caption)
89
+ caption += obj_name
90
+ end_len = len(caption)
91
+ caption += ", "
92
+ tokens_positive.append(
93
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
94
+ )
95
+ caption = caption[:-2] # remove last ", "
96
+
97
+ return caption #, tokens_positive
98
+
99
+
100
+ def mask_for_random_drop_text_or_image_feature(masks, random_drop_embedding):
101
+ """
102
+ input masks tell how many valid grounding tokens for this image
103
+ e.g., 1,1,1,1,0,0,0,0,0,0...
104
+
105
+ If random_drop_embedding=both. we will random drop either image or
106
+ text feature for each token,
107
+ but we always make sure there is at least one feature used.
108
+ In other words, the following masks are not valid
109
+ (because for the second obj, no feature at all):
110
+ image: 1,0,1,1,0,0,0,0,0
111
+ text: 1,0,0,0,0,0,0,0,0
112
+
113
+ if random_drop_embedding=image. we will random drop image feature
114
+ and always keep the text one.
115
+
116
+ """
117
+ N = masks.shape[0]
118
+
119
+ if random_drop_embedding=='both':
120
+ temp_mask = torch.ones(2,N)
121
+ for i in range(N):
122
+ if random.uniform(0, 1) < 0.5: # else keep both features
123
+ idx = random.sample([0,1], 1)[0] # randomly choose to drop image or text feature
124
+ temp_mask[idx,i] = 0
125
+ image_masks = temp_mask[0]*masks
126
+ text_masks = temp_mask[1]*masks
127
+
128
+ if random_drop_embedding=='image':
129
+ image_masks = masks*(torch.rand(N)>0.5)*1
130
+ text_masks = masks
131
+
132
+ return image_masks, text_masks
133
+
134
+
135
+
136
+
137
+
138
+ def project(x, projection_matrix):
139
+ """
140
+ x (Batch*768) should be the penultimate feature of CLIP (before projection)
141
+ projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
142
+ defined in CLIP (out_dim, in_dim), thus we need to apply transpose below.
143
+ this function will return the CLIP feature (without normalziation)
144
+ """
145
+ return x@torch.transpose(projection_matrix, 0, 1)
146
+
147
+
148
+ def inv_project(y, projection_matrix):
149
+ """
150
+ y (Batch*768) should be the CLIP feature (after projection)
151
+ projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
152
+ defined in CLIP (out_dim, in_dim).
153
+ this function will return the CLIP penultimate feature.
154
+
155
+ Note: to make sure getting the correct penultimate feature, the input y should not be normalized.
156
+ If it is normalized, then the result will be scaled by CLIP feature norm, which is unknown.
157
+ """
158
+ return y@torch.transpose(torch.linalg.inv(projection_matrix), 0, 1)
159
+
160
+
161
+
162
+
163
+ class TSVDataset(BaseDataset):
164
+ def __init__(self,
165
+ tsv_path,
166
+ which_embedder='clip',
167
+ which_layer=['after','after'], # text and image
168
+ prob_use_caption=1,
169
+ random_drop_embedding='none',
170
+ image_size=256,
171
+ min_box_size=0.01,
172
+ max_boxes_per_data=8,
173
+ max_images=None, # set as 30K used to eval
174
+ random_crop = False,
175
+ random_flip = True,
176
+ ):
177
+ image_root = "a placeholder path as we are using tsv here"
178
+ super().__init__(image_root, random_crop, random_flip, image_size)
179
+ self.tsv_path = tsv_path
180
+ self.which_embedder = which_embedder
181
+ self.prob_use_caption = prob_use_caption
182
+ self.random_drop_embedding = random_drop_embedding
183
+ self.min_box_size = min_box_size
184
+ self.max_boxes_per_data = max_boxes_per_data
185
+ self.max_images = max_images
186
+
187
+ assert which_layer in [ ['after','after'], ['before','after_renorm'], ['before','after_reproject'] ]
188
+ assert random_drop_embedding in ['none', 'both', 'image']
189
+ self.which_layer_text = which_layer[0]
190
+ self.which_layer_image = which_layer[1]
191
+
192
+ #self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
193
+ self.projection_matrix = torch.load('projection_matrix.pth')
194
+
195
+ # Load tsv data
196
+ self.tsv_file = TSVFile(self.tsv_path)
197
+
198
+
199
+ # Load preprocessed name embedding
200
+ if which_embedder == 'bert':
201
+ self.embedding_len = 1280
202
+ elif which_embedder == 'clip':
203
+ self.embedding_len = 768
204
+ else:
205
+ assert False
206
+
207
+ def total_images(self):
208
+ return len(self)
209
+
210
+ def get_item_from_tsv(self, index):
211
+ _, item = self.tsv_file[index]
212
+ item = decode_item(item)
213
+ return item
214
+
215
+
216
+ def mapping(self, image_embedding):
217
+ if self.which_layer_image == 'after':
218
+ # both use CLIP aligned feature
219
+ return image_embedding
220
+ elif self.which_layer_image == 'after_renorm':
221
+ # text use before, but image use after projection but normalize to 28.7
222
+ return image_embedding*28.7
223
+ elif self.which_layer_image == 'after_reproject':
224
+ image_embedding = project( image_embedding.unsqueeze(0), self.projection_matrix.T )
225
+ image_embedding = image_embedding.squeeze(0)
226
+ image_embedding = image_embedding / image_embedding.norm()
227
+ image_embedding = image_embedding * 28.7
228
+ return image_embedding
229
+
230
+
231
+
232
+ def __getitem__(self, index):
233
+ if self.max_boxes_per_data > 99:
234
+ assert False, "Are you sure setting such large number of boxes?"
235
+
236
+ raw_item = self.get_item_from_tsv(index)
237
+ is_det = raw_item.get('is_det', False) # if it is from detection (such as o365), then we will make a caption
238
+
239
+ out = {}
240
+
241
+ # -------------------- id and image ------------------- #
242
+ out['id'] = raw_item['data_id']
243
+ image = raw_item['image']
244
+ image_tensor, trans_info = self.transform_image(image)
245
+ out["image"] = image_tensor
246
+
247
+
248
+
249
+ # -------------------- grounding token ------------------- #
250
+ annos = raw_item['annos']
251
+
252
+ areas = []
253
+ all_boxes = []
254
+ all_masks = []
255
+ all_text_embeddings = []
256
+ all_image_embeddings = []
257
+ if is_det:
258
+ all_category_names = []
259
+
260
+ text_embedding_name = 'text_embedding_before' if self.which_layer_text == 'before' else 'text_embedding_after'
261
+ image_embedding_name = 'image_embedding_after'
262
+
263
+ for anno in annos:
264
+ x, y, w, h = anno['bbox']
265
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
266
+
267
+ if valid:
268
+ areas.append( (x1-x0)*(y1-y0) )
269
+ all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
270
+ all_masks.append(1)
271
+ all_text_embeddings.append(anno[text_embedding_name])
272
+ all_image_embeddings.append( self.mapping(anno[image_embedding_name]) )
273
+ if is_det:
274
+ all_category_names.append(anno["category_name"])
275
+
276
+
277
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
278
+ wanted_idxs = wanted_idxs[0:self.max_boxes_per_data]
279
+
280
+ boxes = torch.zeros(self.max_boxes_per_data, 4)
281
+ masks = torch.zeros(self.max_boxes_per_data)
282
+ text_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
283
+ image_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
284
+ if is_det:
285
+ category_names = []
286
+ for i, idx in enumerate(wanted_idxs):
287
+ boxes[i] = all_boxes[idx]
288
+ masks[i] = all_masks[idx]
289
+ text_embeddings[i] = all_text_embeddings[idx]
290
+ image_embeddings[i] = all_image_embeddings[idx]
291
+ if is_det:
292
+ category_names.append(all_category_names[idx])
293
+
294
+ if self.random_drop_embedding != 'none':
295
+ image_masks, text_masks = mask_for_random_drop_text_or_image_feature(masks, self.random_drop_embedding)
296
+ else:
297
+ image_masks = masks
298
+ text_masks = masks
299
+
300
+
301
+ out["boxes"] = boxes
302
+ out["masks"] = masks
303
+ out["image_masks"] = image_masks
304
+ out["text_masks"] = text_masks
305
+ out["text_embeddings"] = text_embeddings
306
+ out["image_embeddings"] = image_embeddings
307
+
308
+
309
+
310
+ # -------------------- caption ------------------- #
311
+ if random.uniform(0, 1) < self.prob_use_caption:
312
+ if is_det:
313
+ out["caption"] = make_a_sentence(category_names)
314
+ else:
315
+ out["caption"] = raw_item["caption"]
316
+ else:
317
+ out["caption"] = ""
318
+
319
+ return out
320
+
321
+
322
+
323
+ def __len__(self):
324
+ return len(self.tsv_file)
325
+
326
+
dataset/utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # Copyright 2018 Google LLC
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import PIL
18
+ import torch
19
+ import torchvision.transforms as T
20
+
21
+
22
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
23
+ IMAGENET_STD = [0.229, 0.224, 0.225]
24
+
25
+ INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN]
26
+ INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD]
27
+
28
+
29
+ def imagenet_preprocess():
30
+ return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
31
+
32
+
33
+ def rescale(x):
34
+ lo, hi = x.min(), x.max()
35
+ return x.sub(lo).div(hi - lo)
36
+
37
+
38
+ def imagenet_deprocess(rescale_image=True):
39
+ transforms = [
40
+ T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD),
41
+ T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]),
42
+ ]
43
+ if rescale_image:
44
+ transforms.append(rescale)
45
+ return T.Compose(transforms)
46
+
47
+
48
+ def imagenet_deprocess_batch(imgs, rescale=True):
49
+ """
50
+ Input:
51
+ - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images
52
+
53
+ Output:
54
+ - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images
55
+ in the range [0, 255]
56
+ """
57
+ if isinstance(imgs, torch.autograd.Variable):
58
+ imgs = imgs.data
59
+ imgs = imgs.cpu().clone()
60
+ deprocess_fn = imagenet_deprocess(rescale_image=rescale)
61
+ imgs_de = []
62
+ for i in range(imgs.size(0)):
63
+ img_de = deprocess_fn(imgs[i])[None]
64
+ img_de = img_de.mul(255).clamp(0, 255).byte()
65
+ imgs_de.append(img_de)
66
+ imgs_de = torch.cat(imgs_de, dim=0)
67
+ return imgs_de
68
+
69
+
70
+ class Resize(object):
71
+ def __init__(self, size, interp=PIL.Image.BILINEAR):
72
+ if isinstance(size, tuple):
73
+ H, W = size
74
+ self.size = (W, H)
75
+ else:
76
+ self.size = (size, size)
77
+ self.interp = interp
78
+
79
+ def __call__(self, img):
80
+ return img.resize(self.size, self.interp)
81
+
82
+
83
+ def unpack_var(v):
84
+ if isinstance(v, torch.autograd.Variable):
85
+ return v.data
86
+ return v
87
+
88
+
89
+ def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img):
90
+ triples = unpack_var(triples)
91
+ obj_data = [unpack_var(o) for o in obj_data]
92
+ obj_to_img = unpack_var(obj_to_img)
93
+ triple_to_img = unpack_var(triple_to_img)
94
+
95
+ triples_out = []
96
+ obj_data_out = [[] for _ in obj_data]
97
+ obj_offset = 0
98
+ N = obj_to_img.max() + 1
99
+ for i in range(N):
100
+ o_idxs = (obj_to_img == i).nonzero().view(-1)
101
+ t_idxs = (triple_to_img == i).nonzero().view(-1)
102
+
103
+ cur_triples = triples[t_idxs].clone()
104
+ cur_triples[:, 0] -= obj_offset
105
+ cur_triples[:, 2] -= obj_offset
106
+ triples_out.append(cur_triples)
107
+
108
+ for j, o_data in enumerate(obj_data):
109
+ cur_o_data = None
110
+ if o_data is not None:
111
+ cur_o_data = o_data[o_idxs]
112
+ obj_data_out[j].append(cur_o_data)
113
+
114
+ obj_offset += o_idxs.size(0)
115
+
116
+ return triples_out, obj_data_out
gligen/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (345 Bytes). View file
 
gligen/__pycache__/distributed.cpython-38.pyc ADDED
Binary file (2.91 kB). View file
 
gligen/__pycache__/evaluator.cpython-38.pyc ADDED
Binary file (5.9 kB). View file
 
gligen/__pycache__/task_grounded_generation.cpython-38.pyc ADDED
Binary file (9.11 kB). View file
 
gligen/__pycache__/trainer.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
gligen/ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (3.2 kB). View file
 
gligen/ldm/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
gligen/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ #import pytorch_lightning as pl
4
+ import torch.nn.functional as F
5
+ from contextlib import contextmanager
6
+
7
+ # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8
+
9
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
10
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
11
+
12
+ from ldm.util import instantiate_from_config
13
+
14
+
15
+
16
+
17
+ class AutoencoderKL(nn.Module):
18
+ def __init__(self,
19
+ ddconfig,
20
+ embed_dim,
21
+ scale_factor=1
22
+ ):
23
+ super().__init__()
24
+ self.encoder = Encoder(**ddconfig)
25
+ self.decoder = Decoder(**ddconfig)
26
+ assert ddconfig["double_z"]
27
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
28
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
29
+ self.embed_dim = embed_dim
30
+ self.scale_factor = scale_factor
31
+
32
+
33
+
34
+ def encode(self, x):
35
+ h = self.encoder(x)
36
+ moments = self.quant_conv(h)
37
+ posterior = DiagonalGaussianDistribution(moments)
38
+ return posterior.sample() * self.scale_factor
39
+
40
+ def decode(self, z):
41
+ z = 1. / self.scale_factor * z
42
+ z = self.post_quant_conv(z)
43
+ dec = self.decoder(z)
44
+ return dec
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
gligen/ldm/models/diffusion/__init__.py ADDED
File without changes
gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (159 Bytes). View file
 
gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (4.57 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc ADDED
Binary file (2.12 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc ADDED
Binary file (4.11 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc ADDED
Binary file (1.21 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc ADDED
Binary file (4.23 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc ADDED
Binary file (8.71 kB). View file
 
gligen/ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
gligen/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from functools import partial
5
+
6
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
+
8
+
9
+ class DDIMSampler(object):
10
+ def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
+ super().__init__()
12
+ self.diffusion = diffusion
13
+ self.model = model
14
+ self.device = diffusion.betas.device
15
+ self.ddpm_num_timesteps = diffusion.num_timesteps
16
+ self.schedule = schedule
17
+ self.alpha_generator_func = alpha_generator_func
18
+ self.set_alpha_scale = set_alpha_scale
19
+
20
+
21
+ def register_buffer(self, name, attr):
22
+ if type(attr) == torch.Tensor:
23
+ attr = attr.to(self.device)
24
+ setattr(self, name, attr)
25
+
26
+
27
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.):
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=False)
30
+ alphas_cumprod = self.diffusion.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.diffusion.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=False)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+
59
+ @torch.no_grad()
60
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None):
61
+ self.make_schedule(ddim_num_steps=S)
62
+ return self.ddim_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0)
63
+
64
+
65
+ @torch.no_grad()
66
+ def ddim_sampling(self, shape, input, uc, guidance_scale=1, mask=None, x0=None):
67
+ b = shape[0]
68
+
69
+ img = input["x"]
70
+ if img == None:
71
+ img = torch.randn(shape, device=self.device)
72
+ input["x"] = img
73
+
74
+
75
+ time_range = np.flip(self.ddim_timesteps)
76
+ total_steps = self.ddim_timesteps.shape[0]
77
+
78
+ #iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
79
+ iterator = time_range
80
+
81
+ if self.alpha_generator_func != None:
82
+ alphas = self.alpha_generator_func(len(iterator))
83
+
84
+
85
+ for i, step in enumerate(iterator):
86
+
87
+ # set alpha
88
+ if self.alpha_generator_func != None:
89
+ self.set_alpha_scale(self.model, alphas[i])
90
+ if alphas[i] == 0:
91
+ self.model.restore_first_conv_from_SD()
92
+
93
+ # run
94
+ index = total_steps - i - 1
95
+ input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long)
96
+
97
+ if mask is not None:
98
+ assert x0 is not None
99
+ img_orig = self.diffusion.q_sample( x0, input["timesteps"] )
100
+ img = img_orig * mask + (1. - mask) * img
101
+ input["x"] = img
102
+
103
+ img, pred_x0 = self.p_sample_ddim(input, index=index, uc=uc, guidance_scale=guidance_scale)
104
+ input["x"] = img
105
+
106
+ return img
107
+
108
+
109
+ @torch.no_grad()
110
+ def p_sample_ddim(self, input, index, uc=None, guidance_scale=1):
111
+
112
+
113
+ e_t = self.model(input)
114
+ if uc is not None and guidance_scale != 1:
115
+ unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input'])
116
+ e_t_uncond = self.model( unconditional_input )
117
+ e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
118
+
119
+ # select parameters corresponding to the currently considered timestep
120
+ b = input["x"].shape[0]
121
+ a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
122
+ a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
123
+ sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
124
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
125
+
126
+ # current prediction for x_0
127
+ pred_x0 = (input["x"] - sqrt_one_minus_at * e_t) / a_t.sqrt()
128
+
129
+ # direction pointing to x_t
130
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
131
+ noise = sigma_t * torch.randn_like( input["x"] )
132
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
133
+
134
+ return x_prev, pred_x0
gligen/ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+ from ldm.modules.diffusionmodules.util import make_beta_schedule
6
+
7
+
8
+
9
+
10
+
11
+ class DDPM(nn.Module):
12
+ def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
13
+ super().__init__()
14
+
15
+ self.v_posterior = 0
16
+ self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s)
17
+
18
+
19
+ def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
20
+
21
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
22
+ alphas = 1. - betas
23
+ alphas_cumprod = np.cumprod(alphas, axis=0)
24
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
25
+
26
+ timesteps, = betas.shape
27
+ self.num_timesteps = int(timesteps)
28
+ self.linear_start = linear_start
29
+ self.linear_end = linear_end
30
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
31
+
32
+ to_torch = partial(torch.tensor, dtype=torch.float32)
33
+
34
+ self.register_buffer('betas', to_torch(betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
44
+
45
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
46
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas
47
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
48
+
49
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
50
+
51
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
52
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
53
+ self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
54
+ self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
gligen/ldm/models/diffusion/gaussian_smoothing.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class GaussianSmoothing(nn.Module):
9
+ """
10
+ Apply gaussian smoothing on a
11
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
12
+ in the input using a depthwise convolution.
13
+ Arguments:
14
+ channels (int, sequence): Number of channels of the input tensors. Output will
15
+ have this number of channels as well.
16
+ kernel_size (int, sequence): Size of the gaussian kernel.
17
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
18
+ dim (int, optional): The number of dimensions of the data.
19
+ Default value is 2 (spatial).
20
+ """
21
+ def __init__(self, channels, kernel_size, sigma, dim=2):
22
+ super(GaussianSmoothing, self).__init__()
23
+ if isinstance(kernel_size, numbers.Number):
24
+ kernel_size = [kernel_size] * dim
25
+ if isinstance(sigma, numbers.Number):
26
+ sigma = [sigma] * dim
27
+
28
+ # The gaussian kernel is the product of the
29
+ # gaussian function of each dimension.
30
+ kernel = 1
31
+ meshgrids = torch.meshgrid(
32
+ [
33
+ torch.arange(size, dtype=torch.float32)
34
+ for size in kernel_size
35
+ ]
36
+ )
37
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
38
+ mean = (size - 1) / 2
39
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
40
+ torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
41
+
42
+ # Make sure sum of values in gaussian kernel equals 1.
43
+ kernel = kernel / torch.sum(kernel)
44
+
45
+ # Reshape to depthwise convolutional weight
46
+ kernel = kernel.view(1, 1, *kernel.size())
47
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
48
+
49
+ self.register_buffer('weight', kernel)
50
+ self.groups = channels
51
+
52
+ if dim == 1:
53
+ self.conv = F.conv1d
54
+ elif dim == 2:
55
+ self.conv = F.conv2d
56
+ elif dim == 3:
57
+ self.conv = F.conv3d
58
+ else:
59
+ raise RuntimeError(
60
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
61
+ )
62
+
63
+ def forward(self, input):
64
+ """
65
+ Apply gaussian filter to input.
66
+ Arguments:
67
+ input (torch.Tensor): Input to apply gaussian filter on.
68
+ Returns:
69
+ filtered (torch.Tensor): Filtered output.
70
+ """
71
+ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
72
+
73
+
74
+ class AverageSmoothing(nn.Module):
75
+ """
76
+ Apply average smoothing on a
77
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
78
+ in the input using a depthwise convolution.
79
+ Arguments:
80
+ channels (int, sequence): Number of channels of the input tensors. Output will
81
+ have this number of channels as well.
82
+ kernel_size (int, sequence): Size of the average kernel.
83
+ sigma (float, sequence): Standard deviation of the rage kernel.
84
+ dim (int, optional): The number of dimensions of the data.
85
+ Default value is 2 (spatial).
86
+ """
87
+ def __init__(self, channels, kernel_size, dim=2):
88
+ super(AverageSmoothing, self).__init__()
89
+
90
+ # Make sure sum of values in gaussian kernel equals 1.
91
+ kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size)
92
+
93
+ # Reshape to depthwise convolutional weight
94
+ kernel = kernel.view(1, 1, *kernel.size())
95
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
96
+
97
+ self.register_buffer('weight', kernel)
98
+ self.groups = channels
99
+
100
+ if dim == 1:
101
+ self.conv = F.conv1d
102
+ elif dim == 2:
103
+ self.conv = F.conv2d
104
+ elif dim == 3:
105
+ self.conv = F.conv3d
106
+ else:
107
+ raise RuntimeError(
108
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
109
+ )
110
+
111
+ def forward(self, input):
112
+ """
113
+ Apply average filter to input.
114
+ Arguments:
115
+ input (torch.Tensor): Input to apply average filter on.
116
+ Returns:
117
+ filtered (torch.Tensor): Filtered output.
118
+ """
119
+ return self.conv(input, weight=self.weight, groups=self.groups)
gligen/ldm/models/diffusion/ldm.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from ldm.util import default
6
+ from ldm.modules.diffusionmodules.util import extract_into_tensor
7
+ from .ddpm import DDPM
8
+
9
+
10
+
11
+ class LatentDiffusion(DDPM):
12
+ def __init__(self, *args, **kwargs):
13
+ super().__init__(*args, **kwargs)
14
+ # hardcoded
15
+ self.clip_denoised = False
16
+
17
+
18
+
19
+ def q_sample(self, x_start, t, noise=None):
20
+ noise = default(noise, lambda: torch.randn_like(x_start))
21
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
22
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
23
+
24
+
25
+ "Does not support DDPM sampling anymore. Only do DDIM or PLMS"
26
+
27
+ # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = #
28
+
29
+ # def predict_start_from_noise(self, x_t, t, noise):
30
+ # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
31
+ # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise )
32
+
33
+ # def q_posterior(self, x_start, x_t, t):
34
+ # posterior_mean = (
35
+ # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
36
+ # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
37
+ # )
38
+ # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
39
+ # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
40
+ # return posterior_mean, posterior_variance, posterior_log_variance_clipped
41
+
42
+
43
+ # def p_mean_variance(self, model, x, c, t):
44
+
45
+ # model_out = model(x, t, c)
46
+ # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
47
+
48
+ # if self.clip_denoised:
49
+ # x_recon.clamp_(-1., 1.)
50
+
51
+ # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
52
+ # return model_mean, posterior_variance, posterior_log_variance, x_recon
53
+
54
+
55
+ # @torch.no_grad()
56
+ # def p_sample(self, model, x, c, t):
57
+ # b, *_, device = *x.shape, x.device
58
+ # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, )
59
+ # noise = torch.randn_like(x)
60
+
61
+ # # no noise when t == 0
62
+ # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
63
+
64
+ # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
65
+
66
+
67
+ # @torch.no_grad()
68
+ # def p_sample_loop(self, model, shape, c):
69
+ # device = self.betas.device
70
+ # b = shape[0]
71
+ # img = torch.randn(shape, device=device)
72
+
73
+ # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps)
74
+ # for i in iterator:
75
+ # ts = torch.full((b,), i, device=device, dtype=torch.long)
76
+ # img, x0 = self.p_sample(model, img, c, ts)
77
+
78
+ # return img
79
+
80
+
81
+ # @torch.no_grad()
82
+ # def sample(self, model, shape, c, uc=None, guidance_scale=None):
83
+ # return self.p_sample_loop(model, shape, c)
84
+
85
+
86
+
87
+
88
+
gligen/ldm/models/diffusion/loss.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from ldm.models.diffusion.gaussian_smoothing import GaussianSmoothing
4
+ from torch.nn import functional as F
5
+ from torchvision.utils import save_image
6
+
7
+
8
+
9
+
10
+
11
+
12
+ def loss_one_att_outside(attn_map,bboxes, object_positions,t):
13
+ # loss = torch.tensor(0).to('cuda')
14
+ loss = 0
15
+ object_number = len(bboxes)
16
+ b, i, j = attn_map.shape
17
+ H = W = int(math.sqrt(i))
18
+
19
+
20
+ # if t== 20: import pdb; pdb.set_trace()
21
+
22
+ for obj_idx in range(object_number):
23
+
24
+ for obj_box in bboxes[obj_idx]:
25
+ mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
26
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
27
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
28
+ mask[y_min: y_max, x_min: x_max] = 1.
29
+ mask_out = 1. - mask
30
+ index = (mask == 1.).nonzero(as_tuple=False)
31
+ index_in_key = index[:,0]* H + index[:, 1]
32
+ att_box = torch.zeros_like(attn_map)
33
+ att_box[:,index_in_key,:] = attn_map[:,index_in_key,:]
34
+
35
+ att_box = att_box.sum(axis=1) / index_in_key.shape[0]
36
+ att_box = att_box.reshape(-1, H, H)
37
+ activation_value = (att_box* mask_out).reshape(b, -1).sum(dim=-1) #/ att_box.reshape(b, -1).sum(dim=-1)
38
+ loss += torch.mean(activation_value)
39
+
40
+ return loss / object_number
41
+
42
+ def caculate_loss_self_att(self_first, self_second, self_third, bboxes, object_positions, t, list_res=[256], smooth_att = True,sigma=0.5,kernel_size=3 ):
43
+ all_attn = get_all_self_att(self_first, self_second, self_third)
44
+ cnt = 0
45
+ total_loss = 0
46
+ for res in list_res:
47
+ attn_maps = all_attn[res]
48
+ for attn in attn_maps:
49
+ total_loss += loss_one_att_outside(attn, bboxes, object_positions,t)
50
+ cnt += 1
51
+
52
+ return total_loss /cnt
53
+
54
+
55
+ def get_all_self_att(self_first, self_second, self_third):
56
+ result = {256:[], 1024:[], 4096:[], 64:[], 94:[],1054:[] ,286:[],4126:[] }
57
+ # import pdb; pdb.set_trace()
58
+ all_att = [self_first, self_second, self_third]
59
+ for self_att in all_att:
60
+ for att in self_att:
61
+ if att != []:
62
+ temp = att[0]
63
+ for attn_map in temp:
64
+ current_res = attn_map.shape[1]
65
+ # print(current_res)
66
+ result[current_res].append(attn_map)
67
+ return result
68
+
69
+ def get_all_attention(attn_maps_mid, attn_maps_up , attn_maps_down, res):
70
+ result = []
71
+
72
+ for attn_map_integrated in attn_maps_up:
73
+ if attn_map_integrated == []: continue
74
+ attn_map = attn_map_integrated[0][0]
75
+ b, i, j = attn_map.shape
76
+ H = W = int(math.sqrt(i))
77
+ # print(H)
78
+ if H == res:
79
+ result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
80
+ for attn_map_integrated in attn_maps_mid:
81
+
82
+ # for attn_map_integrated in attn_maps_mid:
83
+ attn_map = attn_map_integrated[0]
84
+ b, i, j = attn_map.shape
85
+ H = W = int(math.sqrt(i))
86
+ # print(H)
87
+ if (H==res):
88
+ result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
89
+ # import pdb; pdb.set_trace()
90
+ for attn_map_integrated in attn_maps_down:
91
+ if attn_map_integrated == []: continue
92
+ attn_map = attn_map_integrated[0][0]
93
+ if attn_map == []: continue
94
+ b, i, j = attn_map.shape
95
+ H = W = int(math.sqrt(i))
96
+ # print(H)
97
+ if (H==res):
98
+ result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
99
+
100
+ result = torch.cat(result, dim=0)
101
+ result = result.sum(0) / result.shape[0]
102
+ return result
103
+
104
+
105
+ def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
106
+ attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
107
+ # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
108
+ # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
109
+ # attn8 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 8)
110
+ all_attn = [attn16]
111
+ obj_number = len(bboxes)
112
+ total_loss = 0
113
+ # import pdb; pdb.set_trace()
114
+ for attn in all_attn[0:1]:
115
+ attn_text = attn[:, :, 1:-1]
116
+ attn_text *= 100
117
+ attn_text = torch.nn.functional.softmax(attn_text, dim=-1)
118
+ current_res = attn.shape[0]
119
+ H = W = current_res
120
+
121
+ # if t == 49: import pdb; pdb.set_trace()
122
+ for obj_idx in range(obj_number):
123
+ num_boxes= 0
124
+
125
+ for obj_position in object_positions[obj_idx]:
126
+ true_obj_position = obj_position - 1
127
+ att_map_obj = attn_text[:,:, true_obj_position]
128
+ if smooth_att:
129
+ smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
130
+ input = F.pad(att_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
131
+ att_map_obj = smoothing(input).squeeze(0).squeeze(0)
132
+ other_att_map_obj = att_map_obj.clone()
133
+ att_copy = att_map_obj.clone()
134
+
135
+ for obj_box in bboxes[obj_idx]:
136
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
137
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
138
+
139
+
140
+ if att_map_obj[y_min: y_max, x_min: x_max].numel() == 0:
141
+ max_inside=1.
142
+
143
+ else:
144
+ max_inside = att_map_obj[y_min: y_max, x_min: x_max].max()
145
+ total_loss += 1. - max_inside
146
+
147
+ # find max outside the box, find in the other boxes
148
+
149
+ att_copy[y_min: y_max, x_min: x_max] = 0.
150
+ other_att_map_obj[y_min: y_max, x_min: x_max] = 0.
151
+
152
+ for obj_outside in range(obj_number):
153
+ if obj_outside != obj_idx:
154
+ for obj_out_box in bboxes[obj_outside]:
155
+ x_min_out, y_min_out, x_max_out, y_max_out = int(obj_out_box[0] * W), \
156
+ int(obj_out_box[1] * H), int(obj_out_box[2] * W), int(obj_out_box[3] * H)
157
+
158
+ # att_copy[y_min: y_max, x_min: x_max] = 0.
159
+ if other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].numel() == 0:
160
+ max_outside_one= 0
161
+ else:
162
+ max_outside_one = other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].max()
163
+ # max_outside = max(max_outside,max_outside_one )
164
+ att_copy[y_min_out: y_max_out, x_min_out: x_max_out] = 0.
165
+ total_loss += max_outside_one
166
+ max_background = att_copy.max()
167
+ total_loss += len(bboxes[obj_idx]) *max_background /2.
168
+
169
+ return total_loss/obj_number
170
+
gligen/ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from functools import partial
5
+ from copy import deepcopy
6
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
+ import math
8
+ from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att
9
+ class PLMSSampler(object):
10
+ def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
+ super().__init__()
12
+ self.diffusion = diffusion
13
+ self.model = model
14
+ self.device = diffusion.betas.device
15
+ self.ddpm_num_timesteps = diffusion.num_timesteps
16
+ self.schedule = schedule
17
+ self.alpha_generator_func = alpha_generator_func
18
+ self.set_alpha_scale = set_alpha_scale
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ attr = attr.to(self.device)
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
26
+ if ddim_eta != 0:
27
+ raise ValueError('ddim_eta must be 0 for PLMS')
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
+ alphas_cumprod = self.diffusion.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.diffusion.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+
59
+ # @torch.no_grad()
60
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
61
+ self.make_schedule(ddim_num_steps=S)
62
+ # import pdb; pdb.set_trace()
63
+ return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
64
+
65
+
66
+ # @torch.no_grad()
67
+ def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
68
+
69
+ b = shape[0]
70
+
71
+ img = input["x"]
72
+ if img == None:
73
+ img = torch.randn(shape, device=self.device)
74
+ input["x"] = img
75
+
76
+ time_range = np.flip(self.ddim_timesteps)
77
+ total_steps = self.ddim_timesteps.shape[0]
78
+
79
+ old_eps = []
80
+
81
+ if self.alpha_generator_func != None:
82
+ alphas = self.alpha_generator_func(len(time_range))
83
+
84
+ for i, step in enumerate(time_range):
85
+
86
+ # set alpha and restore first conv layer
87
+ if self.alpha_generator_func != None:
88
+ self.set_alpha_scale(self.model, alphas[i])
89
+ if alphas[i] == 0:
90
+ self.model.restore_first_conv_from_SD()
91
+
92
+ # run
93
+ index = total_steps - i - 1
94
+ ts = torch.full((b,), step, device=self.device, dtype=torch.long)
95
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long)
96
+
97
+ if mask is not None:
98
+ assert x0 is not None
99
+ img_orig = self.diffusion.q_sample(x0, ts)
100
+ img = img_orig * mask + (1. - mask) * img
101
+ input["x"] = img
102
+ # three loss types
103
+ if loss_type !=None and loss_type!='standard':
104
+ if input['object_position'] != []:
105
+ if loss_type=='SAR_CAR':
106
+ x = self.update_loss_self_cross( input,i, index, ts )
107
+ elif loss_type=='SAR':
108
+ x = self.update_only_self( input,i, index, ts )
109
+ elif loss_type=='CAR':
110
+ x = self.update_loss_only_cross( input,i, index, ts )
111
+ input["x"] = x
112
+ img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
113
+ input["x"] = img
114
+ old_eps.append(e_t)
115
+ if len(old_eps) >= 4:
116
+ old_eps.pop(0)
117
+
118
+ return img
119
+
120
+ def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
121
+ if index1 < 10:
122
+ loss_scale = 3
123
+ max_iter = 5
124
+ elif index1 < 20:
125
+ loss_scale = 2
126
+ max_iter = 5
127
+ else:
128
+ loss_scale = 0.8
129
+ max_iter = 1
130
+
131
+ loss_threshold = 0.1
132
+ max_index = 20
133
+ x = deepcopy(input["x"])
134
+ iteration = 0
135
+ loss = torch.tensor(10000)
136
+ input["timesteps"] = ts
137
+
138
+ print("optimize", index1)
139
+ self.model.train()
140
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
141
+ print('iter', iteration)
142
+ # import pdb; pdb.set_trace()
143
+ x = x.requires_grad_(True)
144
+ input['x'] = x
145
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
146
+ bboxes = input['boxes_att']
147
+ object_positions = input['object_position']
148
+ loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
149
+ object_positions=object_positions, t = index1)*loss_scale
150
+ loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
151
+ object_positions=object_positions, t = index1)*loss_scale
152
+ loss = loss1 + loss2
153
+ print('loss', loss, loss1, loss2)
154
+ hh = torch.autograd.backward(loss, retain_graph=True)
155
+ grad_cond = x.grad
156
+ x = x - grad_cond
157
+ x = x.detach()
158
+ iteration += 1
159
+ torch.cuda.empty_cache()
160
+ return x
161
+
162
+ def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
163
+
164
+ if index1 < 10:
165
+ loss_scale = 3
166
+ max_iter = 5
167
+ elif index1 < 20:
168
+ loss_scale = 2
169
+ max_iter = 5
170
+ else:
171
+ loss_scale = 1
172
+ max_iter = 1
173
+ loss_threshold = 0.1
174
+
175
+ max_index = 30
176
+ x = deepcopy(input["x"])
177
+ iteration = 0
178
+ loss = torch.tensor(10000)
179
+ input["timesteps"] = ts
180
+
181
+ print("optimize", index1)
182
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
183
+ print('iter', iteration)
184
+ x = x.requires_grad_(True)
185
+ input['x'] = x
186
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
187
+
188
+ bboxes = input['boxes']
189
+ object_positions = input['object_position']
190
+ loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
191
+ object_positions=object_positions, t = index1)*loss_scale
192
+ loss = loss2
193
+ print('loss', loss)
194
+ hh = torch.autograd.backward(loss)
195
+ grad_cond = x.grad
196
+ x = x - grad_cond
197
+ x = x.detach()
198
+ iteration += 1
199
+ torch.cuda.empty_cache()
200
+ return x
201
+
202
+ def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ):
203
+ if index1 < 10:
204
+ loss_scale = 4
205
+ max_iter = 5
206
+ elif index1 < 20:
207
+ loss_scale = 3
208
+ max_iter = 5
209
+ else:
210
+ loss_scale = 1
211
+ max_iter = 1
212
+ loss_threshold = 0.1
213
+
214
+ max_index = 30
215
+ x = deepcopy(input["x"])
216
+ iteration = 0
217
+ loss = torch.tensor(10000)
218
+ input["timesteps"] = ts
219
+
220
+ print("optimize", index1)
221
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
222
+ print('iter', iteration)
223
+ x = x.requires_grad_(True)
224
+ input['x'] = x
225
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
226
+
227
+ bboxes = input['boxes']
228
+ object_positions = input['object_position']
229
+ loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
230
+ object_positions=object_positions, t = index1)*loss_scale
231
+ print('loss', loss)
232
+ hh = torch.autograd.backward(loss)
233
+ grad_cond = x.grad
234
+
235
+ x = x - grad_cond
236
+ x = x.detach()
237
+ iteration += 1
238
+ torch.cuda.empty_cache()
239
+ return x
240
+
241
+ @torch.no_grad()
242
+ def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
243
+ x = deepcopy(input["x"])
244
+ b = x.shape[0]
245
+ self.model.eval()
246
+ def get_model_output(input):
247
+ e_t, first, second, third,_,_,_ = self.model(input)
248
+ if uc is not None and guidance_scale != 1:
249
+ unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=None, grounding_extra_input=None)
250
+ # unconditional_input=input
251
+ e_t_uncond, _, _, _, _, _, _ = self.model( unconditional_input)
252
+ e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
253
+ return e_t
254
+
255
+
256
+ def get_x_prev_and_pred_x0(e_t, index):
257
+ # select parameters corresponding to the currently considered timestep
258
+ a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
259
+ a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
260
+ sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
261
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
262
+
263
+ # current prediction for x_0
264
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
265
+
266
+ # direction pointing to x_t
267
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
268
+ noise = sigma_t * torch.randn_like(x)
269
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
270
+ return x_prev, pred_x0
271
+
272
+ input["timesteps"] = t
273
+ e_t = get_model_output(input)
274
+ if len(old_eps) == 0:
275
+ # Pseudo Improved Euler (2nd order)
276
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
277
+ input["x"] = x_prev
278
+ input["timesteps"] = t_next
279
+ e_t_next = get_model_output(input)
280
+ e_t_prime = (e_t + e_t_next) / 2
281
+ elif len(old_eps) == 1:
282
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
283
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
284
+ elif len(old_eps) == 2:
285
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
286
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
287
+ elif len(old_eps) >= 3:
288
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
289
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
290
+
291
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
292
+
293
+ return x_prev, pred_x0, e_t
294
+
295
+
gligen/ldm/modules/__pycache__/attention.cpython-38.pyc ADDED
Binary file (13 kB). View file
 
gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc ADDED
Binary file (18.3 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (188 Bytes). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (167 Bytes). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc ADDED
Binary file (8.43 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc ADDED
Binary file (13 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc ADDED
Binary file (1.66 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc ADDED
Binary file (10.1 kB). View file
 
gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc ADDED
Binary file (10.2 kB). View file