Spaces:
Runtime error
Runtime error
Quα»³nh PhΓΉng
commited on
Commit
Β·
ce7c64a
1
Parent(s):
589b7f1
update
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- __pycache__/app.cpython-38.pyc +0 -0
- __pycache__/example_component.cpython-38.pyc +0 -0
- dataset/__init__.py +0 -0
- dataset/__pycache__/__init__.cpython-38.pyc +0 -0
- dataset/__pycache__/catalog.cpython-38.pyc +0 -0
- dataset/__pycache__/concat_dataset.cpython-38.pyc +0 -0
- dataset/base_dataset.py +220 -0
- dataset/catalog.py +72 -0
- dataset/cd_dataset.py +250 -0
- dataset/concat_dataset.py +65 -0
- dataset/grounding_dataset.py +205 -0
- dataset/layout_dataset.py +237 -0
- dataset/tsv.py +212 -0
- dataset/tsv_dataset.py +326 -0
- dataset/utils.py +116 -0
- gligen/__pycache__/__init__.cpython-38.pyc +0 -0
- gligen/__pycache__/distributed.cpython-38.pyc +0 -0
- gligen/__pycache__/evaluator.cpython-38.pyc +0 -0
- gligen/__pycache__/task_grounded_generation.cpython-38.pyc +0 -0
- gligen/__pycache__/trainer.cpython-38.pyc +0 -0
- gligen/ldm/__pycache__/util.cpython-38.pyc +0 -0
- gligen/ldm/models/.DS_Store +0 -0
- gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
- gligen/ldm/models/autoencoder.py +52 -0
- gligen/ldm/models/diffusion/__init__.py +0 -0
- gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc +0 -0
- gligen/ldm/models/diffusion/classifier.py +267 -0
- gligen/ldm/models/diffusion/ddim.py +134 -0
- gligen/ldm/models/diffusion/ddpm.py +72 -0
- gligen/ldm/models/diffusion/gaussian_smoothing.py +119 -0
- gligen/ldm/models/diffusion/ldm.py +88 -0
- gligen/ldm/models/diffusion/loss.py +170 -0
- gligen/ldm/models/diffusion/plms.py +295 -0
- gligen/ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
- gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc +0 -0
- gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc +0 -0
__pycache__/app.cpython-38.pyc
ADDED
Binary file (25.8 kB). View file
|
|
__pycache__/example_component.cpython-38.pyc
ADDED
Binary file (26.6 kB). View file
|
|
dataset/__init__.py
ADDED
File without changes
|
dataset/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (139 Bytes). View file
|
|
dataset/__pycache__/catalog.cpython-38.pyc
ADDED
Binary file (1.11 kB). View file
|
|
dataset/__pycache__/concat_dataset.cpython-38.pyc
ADDED
Binary file (1.88 kB). View file
|
|
dataset/base_dataset.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image, ImageDraw
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torchvision
|
5 |
+
from zipfile import ZipFile
|
6 |
+
import os
|
7 |
+
import multiprocessing
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
from io import BytesIO
|
12 |
+
|
13 |
+
VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
|
14 |
+
|
15 |
+
|
16 |
+
def check_filenames_in_zipdata(filenames, ziproot):
|
17 |
+
samples = []
|
18 |
+
for fst in ZipFile(ziproot).infolist():
|
19 |
+
fname = fst.filename
|
20 |
+
if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
|
21 |
+
continue
|
22 |
+
if os.path.splitext(fname)[1].lower() in VALID_IMAGE_TYPES:
|
23 |
+
samples.append((fname))
|
24 |
+
filenames = set(filenames)
|
25 |
+
samples = set(samples)
|
26 |
+
assert filenames.issubset(samples), 'Something wrong with your zip data'
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def draw_box(img, boxes):
|
31 |
+
colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
|
32 |
+
draw = ImageDraw.Draw(img)
|
33 |
+
for bid, box in enumerate(boxes):
|
34 |
+
draw.rectangle([box[0], box[1], box[2], box[3]], outline =colors[bid % len(colors)], width=4)
|
35 |
+
# draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
|
36 |
+
return img
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def to_valid(x0, y0, x1, y1, image_size, min_box_size):
|
41 |
+
valid = True
|
42 |
+
|
43 |
+
if x0>image_size or y0>image_size or x1<0 or y1<0:
|
44 |
+
valid = False # no way to make this box vide, it is completely cropped out
|
45 |
+
return valid, (None, None, None, None)
|
46 |
+
|
47 |
+
x0 = max(x0, 0)
|
48 |
+
y0 = max(y0, 0)
|
49 |
+
x1 = min(x1, image_size)
|
50 |
+
y1 = min(y1, image_size)
|
51 |
+
|
52 |
+
if (x1-x0)*(y1-y0) / (image_size*image_size) < min_box_size:
|
53 |
+
valid = False
|
54 |
+
return valid, (None, None, None, None)
|
55 |
+
|
56 |
+
return valid, (x0, y0, x1, y1)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, image_size, min_box_size):
|
63 |
+
"""
|
64 |
+
x,y,w,h: the original annotation corresponding to the raw image size.
|
65 |
+
trans_info: what resizing and cropping have been applied to the raw image
|
66 |
+
image_size: what is the final image size
|
67 |
+
"""
|
68 |
+
|
69 |
+
x0 = x * trans_info["performed_scale"] - trans_info['crop_x']
|
70 |
+
y0 = y * trans_info["performed_scale"] - trans_info['crop_y']
|
71 |
+
x1 = (x + w) * trans_info["performed_scale"] - trans_info['crop_x']
|
72 |
+
y1 = (y + h) * trans_info["performed_scale"] - trans_info['crop_y']
|
73 |
+
|
74 |
+
|
75 |
+
# at this point, box annotation has been recalculated based on scaling and cropping
|
76 |
+
# but some point may fall off the image_size region (e.g., negative value), thus we
|
77 |
+
# need to clamp them into 0-image_size. But if all points falling outsize of image
|
78 |
+
# region, then we will consider this is an invalid box.
|
79 |
+
valid, (x0, y0, x1, y1) = to_valid(x0, y0, x1, y1, image_size, min_box_size)
|
80 |
+
|
81 |
+
if valid:
|
82 |
+
# we also perform random flip.
|
83 |
+
# Here boxes are valid, and are based on image_size
|
84 |
+
if trans_info["performed_flip"]:
|
85 |
+
x0, x1 = image_size-x1, image_size-x0
|
86 |
+
|
87 |
+
return valid, (x0, y0, x1, y1)
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
class BaseDataset(torch.utils.data.Dataset):
|
92 |
+
def __init__(self, image_root, random_crop, random_flip, image_size):
|
93 |
+
super().__init__()
|
94 |
+
self.image_root = image_root
|
95 |
+
self.random_crop = random_crop
|
96 |
+
self.random_flip = random_flip
|
97 |
+
self.image_size = image_size
|
98 |
+
self.use_zip = False
|
99 |
+
|
100 |
+
if image_root[-4::] == 'zip':
|
101 |
+
self.use_zip = True
|
102 |
+
self.zip_dict = {}
|
103 |
+
|
104 |
+
if self.random_crop:
|
105 |
+
assert False, 'NOT IMPLEMENTED'
|
106 |
+
|
107 |
+
|
108 |
+
def fetch_zipfile(self, ziproot):
|
109 |
+
pid = multiprocessing.current_process().pid # get pid of this process.
|
110 |
+
if pid not in self.zip_dict:
|
111 |
+
self.zip_dict[pid] = ZipFile(ziproot)
|
112 |
+
zip_file = self.zip_dict[pid]
|
113 |
+
return zip_file
|
114 |
+
|
115 |
+
def fetch_image(self, filename):
|
116 |
+
if self.use_zip:
|
117 |
+
zip_file = self.fetch_zipfile(self.image_root)
|
118 |
+
image = Image.open( BytesIO(zip_file.read(filename)) ).convert('RGB')
|
119 |
+
return image
|
120 |
+
else:
|
121 |
+
image = Image.open( os.path.join(self.image_root,filename) ).convert('RGB')
|
122 |
+
return image
|
123 |
+
|
124 |
+
|
125 |
+
def vis_getitem_data(self, index=None, out=None, return_tensor=False, name="res.jpg", print_caption=True):
|
126 |
+
|
127 |
+
if out is None:
|
128 |
+
out = self[index]
|
129 |
+
|
130 |
+
img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 )
|
131 |
+
canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) )
|
132 |
+
W, H = img.size
|
133 |
+
|
134 |
+
if print_caption:
|
135 |
+
caption = out["caption"]
|
136 |
+
print(caption)
|
137 |
+
print(" ")
|
138 |
+
|
139 |
+
boxes = []
|
140 |
+
for box in out["boxes"]:
|
141 |
+
x0,y0,x1,y1 = box
|
142 |
+
boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
|
143 |
+
img = draw_box(img, boxes)
|
144 |
+
|
145 |
+
if return_tensor:
|
146 |
+
return torchvision.transforms.functional.to_tensor(img)
|
147 |
+
else:
|
148 |
+
img.save(name)
|
149 |
+
|
150 |
+
|
151 |
+
def transform_image(self, pil_image):
|
152 |
+
if self.random_crop:
|
153 |
+
assert False
|
154 |
+
arr = random_crop_arr(pil_image, self.image_size)
|
155 |
+
else:
|
156 |
+
arr, info = center_crop_arr(pil_image, self.image_size)
|
157 |
+
|
158 |
+
info["performed_flip"] = False
|
159 |
+
if self.random_flip and random.random()<0.5:
|
160 |
+
arr = arr[:, ::-1]
|
161 |
+
info["performed_flip"] = True
|
162 |
+
|
163 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
164 |
+
arr = np.transpose(arr, [2,0,1])
|
165 |
+
|
166 |
+
return torch.tensor(arr), info
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
def center_crop_arr(pil_image, image_size):
|
171 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
172 |
+
# argument, which uses BOX downsampling at powers of two first.
|
173 |
+
# Thus, we do it by hand to improve downsample quality.
|
174 |
+
WW, HH = pil_image.size
|
175 |
+
|
176 |
+
while min(*pil_image.size) >= 2 * image_size:
|
177 |
+
pil_image = pil_image.resize(
|
178 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
179 |
+
)
|
180 |
+
|
181 |
+
scale = image_size / min(*pil_image.size)
|
182 |
+
|
183 |
+
pil_image = pil_image.resize(
|
184 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
185 |
+
)
|
186 |
+
|
187 |
+
# at this point, the min of pil_image side is desired image_size
|
188 |
+
performed_scale = image_size / min(WW, HH)
|
189 |
+
|
190 |
+
arr = np.array(pil_image)
|
191 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
192 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
193 |
+
|
194 |
+
info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH}
|
195 |
+
|
196 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info
|
197 |
+
|
198 |
+
|
199 |
+
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
|
200 |
+
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
|
201 |
+
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
|
202 |
+
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
|
203 |
+
|
204 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
205 |
+
# argument, which uses BOX downsampling at powers of two first.
|
206 |
+
# Thus, we do it by hand to improve downsample quality.
|
207 |
+
while min(*pil_image.size) >= 2 * smaller_dim_size:
|
208 |
+
pil_image = pil_image.resize(
|
209 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
210 |
+
)
|
211 |
+
|
212 |
+
scale = smaller_dim_size / min(*pil_image.size)
|
213 |
+
pil_image = pil_image.resize(
|
214 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
215 |
+
)
|
216 |
+
|
217 |
+
arr = np.array(pil_image)
|
218 |
+
crop_y = random.randrange(arr.shape[0] - image_size + 1)
|
219 |
+
crop_x = random.randrange(arr.shape[1] - image_size + 1)
|
220 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
|
dataset/catalog.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
class DatasetCatalog:
|
4 |
+
def __init__(self, ROOT, which_embedder):
|
5 |
+
assert which_embedder in ['clip', 'bert']
|
6 |
+
|
7 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
8 |
+
|
9 |
+
|
10 |
+
self.VGGrounding = {
|
11 |
+
"target": "dataset.tsv_dataset.TSVDataset",
|
12 |
+
"train_params": dict(
|
13 |
+
tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'),
|
14 |
+
)
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
19 |
+
|
20 |
+
|
21 |
+
self.FlickrGrounding = {
|
22 |
+
"target": "dataset.tsv_dataset.TSVDataset",
|
23 |
+
"train_params":dict(
|
24 |
+
tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'),
|
25 |
+
)
|
26 |
+
}
|
27 |
+
|
28 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
29 |
+
|
30 |
+
self.SBUGrounding = {
|
31 |
+
"target": "dataset.tsv_dataset.TSVDataset",
|
32 |
+
"train_params":dict(
|
33 |
+
tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'),
|
34 |
+
)
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
39 |
+
|
40 |
+
|
41 |
+
self.CC3MGrounding = {
|
42 |
+
"target": "dataset.tsv_dataset.TSVDataset",
|
43 |
+
"train_params":dict(
|
44 |
+
tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'),
|
45 |
+
)
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
50 |
+
|
51 |
+
|
52 |
+
self.CC12MGrounding = {
|
53 |
+
"target": "dataset.tsv_dataset.TSVDataset",
|
54 |
+
"train_params":dict(
|
55 |
+
tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'),
|
56 |
+
)
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
61 |
+
|
62 |
+
# temp = 'category_embedding_clip.pth' if which_embedder == 'clip' else 'category_embedding_bert.pth'
|
63 |
+
# obj365_category_embedding_path = os.path.join(ROOT, 'OBJECTS365', temp)
|
64 |
+
|
65 |
+
self.Obj365Detection = {
|
66 |
+
"target": "dataset.tsv_dataset.TSVDataset",
|
67 |
+
"train_params":dict(
|
68 |
+
tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'),
|
69 |
+
),
|
70 |
+
}
|
71 |
+
|
72 |
+
|
dataset/cd_dataset.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, os, random, math
|
2 |
+
from collections import defaultdict
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def not_in_at_all(list1, list2):
|
17 |
+
for a in list1:
|
18 |
+
if a in list2:
|
19 |
+
return False
|
20 |
+
return True
|
21 |
+
|
22 |
+
|
23 |
+
def clean_annotations(annotations):
|
24 |
+
for anno in annotations:
|
25 |
+
anno.pop("segmentation", None)
|
26 |
+
anno.pop("area", None)
|
27 |
+
anno.pop("iscrowd", None)
|
28 |
+
# anno.pop("id", None)
|
29 |
+
|
30 |
+
|
31 |
+
def make_a_sentence(obj_names, clean=False):
|
32 |
+
|
33 |
+
if clean:
|
34 |
+
obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
|
35 |
+
|
36 |
+
caption = ""
|
37 |
+
tokens_positive = []
|
38 |
+
for obj_name in obj_names:
|
39 |
+
start_len = len(caption)
|
40 |
+
caption += obj_name
|
41 |
+
end_len = len(caption)
|
42 |
+
caption += ", "
|
43 |
+
tokens_positive.append(
|
44 |
+
[[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
|
45 |
+
)
|
46 |
+
caption = caption[:-2] # remove last ", "
|
47 |
+
|
48 |
+
return caption #, tokens_positive
|
49 |
+
|
50 |
+
|
51 |
+
def check_all_have_same_images(instances_data, stuff_data, caption_data):
|
52 |
+
if stuff_data is not None:
|
53 |
+
assert instances_data["images"] == stuff_data["images"]
|
54 |
+
if caption_data is not None:
|
55 |
+
assert instances_data["images"] == caption_data["images"]
|
56 |
+
|
57 |
+
|
58 |
+
class CDDataset(BaseDataset):
|
59 |
+
"CD: Caption Detection"
|
60 |
+
def __init__(self,
|
61 |
+
image_root,
|
62 |
+
category_embedding_path,
|
63 |
+
instances_json_path = None,
|
64 |
+
stuff_json_path = None,
|
65 |
+
caption_json_path = None,
|
66 |
+
prob_real_caption = 0,
|
67 |
+
fake_caption_type = 'empty',
|
68 |
+
image_size=256,
|
69 |
+
max_images=None,
|
70 |
+
min_box_size=0.01,
|
71 |
+
max_boxes_per_image=8,
|
72 |
+
include_other=False,
|
73 |
+
random_crop = False,
|
74 |
+
random_flip = True,
|
75 |
+
):
|
76 |
+
super().__init__(random_crop, random_flip, image_size)
|
77 |
+
|
78 |
+
self.image_root = image_root
|
79 |
+
self.category_embedding_path = category_embedding_path
|
80 |
+
self.instances_json_path = instances_json_path
|
81 |
+
self.stuff_json_path = stuff_json_path
|
82 |
+
self.caption_json_path = caption_json_path
|
83 |
+
self.prob_real_caption = prob_real_caption
|
84 |
+
self.fake_caption_type = fake_caption_type
|
85 |
+
self.max_images = max_images
|
86 |
+
self.min_box_size = min_box_size
|
87 |
+
self.max_boxes_per_image = max_boxes_per_image
|
88 |
+
self.include_other = include_other
|
89 |
+
|
90 |
+
|
91 |
+
assert fake_caption_type in ["empty", "made"]
|
92 |
+
if prob_real_caption > 0:
|
93 |
+
assert caption_json_path is not None, "caption json must be given"
|
94 |
+
|
95 |
+
|
96 |
+
# Load all jsons
|
97 |
+
with open(instances_json_path, 'r') as f:
|
98 |
+
instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
|
99 |
+
clean_annotations(instances_data["annotations"])
|
100 |
+
self.instances_data = instances_data
|
101 |
+
|
102 |
+
self.stuff_data = None
|
103 |
+
if stuff_json_path is not None:
|
104 |
+
with open(stuff_json_path, 'r') as f:
|
105 |
+
stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
|
106 |
+
clean_annotations(stuff_data["annotations"])
|
107 |
+
self.stuff_data = stuff_data
|
108 |
+
|
109 |
+
self.captions_data = None
|
110 |
+
if caption_json_path is not None:
|
111 |
+
with open(caption_json_path, 'r') as f:
|
112 |
+
captions_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
|
113 |
+
clean_annotations(captions_data["annotations"])
|
114 |
+
self.captions_data = captions_data
|
115 |
+
|
116 |
+
|
117 |
+
# Load preprocessed name embedding
|
118 |
+
self.category_embeddings = torch.load(category_embedding_path)
|
119 |
+
self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
|
120 |
+
|
121 |
+
|
122 |
+
# Misc
|
123 |
+
self.image_ids = [] # main list for selecting images
|
124 |
+
self.image_id_to_filename = {} # file names used to read image
|
125 |
+
check_all_have_same_images(self.instances_data, self.stuff_data, self.captions_data)
|
126 |
+
for image_data in self.instances_data['images']:
|
127 |
+
image_id = image_data['id']
|
128 |
+
filename = image_data['file_name']
|
129 |
+
self.image_ids.append(image_id)
|
130 |
+
self.image_id_to_filename[image_id] = filename
|
131 |
+
|
132 |
+
|
133 |
+
# All category names (including things and stuff)
|
134 |
+
self.object_idx_to_name = {}
|
135 |
+
for category_data in self.instances_data['categories']:
|
136 |
+
self.object_idx_to_name[category_data['id']] = category_data['name']
|
137 |
+
if self.stuff_data is not None:
|
138 |
+
for category_data in self.stuff_data['categories']:
|
139 |
+
self.object_idx_to_name[category_data['id']] = category_data['name']
|
140 |
+
|
141 |
+
|
142 |
+
# Add object data from instances and stuff
|
143 |
+
self.image_id_to_objects = defaultdict(list)
|
144 |
+
self.select_objects( self.instances_data['annotations'] )
|
145 |
+
if self.stuff_data is not None:
|
146 |
+
self.select_objects( self.stuff_data['annotations'] )
|
147 |
+
|
148 |
+
# Add caption data
|
149 |
+
if self.captions_data is not None:
|
150 |
+
self.image_id_to_captions = defaultdict(list)
|
151 |
+
self.select_captions( self.captions_data['annotations'] )
|
152 |
+
|
153 |
+
# Check if all filenames can be found in the zip file
|
154 |
+
# all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids]
|
155 |
+
# check_filenames_in_zipdata(all_filenames, image_root)
|
156 |
+
|
157 |
+
|
158 |
+
def select_objects(self, annotations):
|
159 |
+
for object_anno in annotations:
|
160 |
+
image_id = object_anno['image_id']
|
161 |
+
object_name = self.object_idx_to_name[object_anno['category_id']]
|
162 |
+
other_ok = object_name != 'other' or self.include_other
|
163 |
+
if other_ok:
|
164 |
+
self.image_id_to_objects[image_id].append(object_anno)
|
165 |
+
|
166 |
+
|
167 |
+
def select_captions(self, annotations):
|
168 |
+
for caption_data in annotations:
|
169 |
+
image_id = caption_data['image_id']
|
170 |
+
self.image_id_to_captions[image_id].append(caption_data)
|
171 |
+
|
172 |
+
|
173 |
+
def total_images(self):
|
174 |
+
return len(self)
|
175 |
+
|
176 |
+
|
177 |
+
def __getitem__(self, index):
|
178 |
+
if self.max_boxes_per_image > 99:
|
179 |
+
assert False, "Are you sure setting such large number of boxes?"
|
180 |
+
|
181 |
+
out = {}
|
182 |
+
|
183 |
+
image_id = self.image_ids[index]
|
184 |
+
out['id'] = image_id
|
185 |
+
|
186 |
+
# Image
|
187 |
+
filename = self.image_id_to_filename[image_id]
|
188 |
+
image = self.fetch_image(filename)
|
189 |
+
#WW, HH = image.size
|
190 |
+
image_tensor, trans_info = self.transform_image(image)
|
191 |
+
out["image"] = image_tensor
|
192 |
+
|
193 |
+
|
194 |
+
# Select valid boxes after cropping (center or random)
|
195 |
+
this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id])
|
196 |
+
areas = []
|
197 |
+
all_obj_names = []
|
198 |
+
all_boxes = []
|
199 |
+
all_masks = []
|
200 |
+
all_positive_embeddings = []
|
201 |
+
for object_anno in this_image_obj_annos:
|
202 |
+
|
203 |
+
x, y, w, h = object_anno['bbox']
|
204 |
+
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
|
205 |
+
|
206 |
+
if valid:
|
207 |
+
areas.append( (x1-x0)*(y1-y0) )
|
208 |
+
obj_name = self.object_idx_to_name[ object_anno['category_id'] ]
|
209 |
+
all_obj_names.append(obj_name)
|
210 |
+
all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
|
211 |
+
all_masks.append(1)
|
212 |
+
all_positive_embeddings.append( self.category_embeddings[obj_name] )
|
213 |
+
|
214 |
+
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
|
215 |
+
wanted_idxs = wanted_idxs[0:self.max_boxes_per_image]
|
216 |
+
obj_names = [] # used for making a sentence
|
217 |
+
boxes = torch.zeros(self.max_boxes_per_image, 4)
|
218 |
+
masks = torch.zeros(self.max_boxes_per_image)
|
219 |
+
positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len)
|
220 |
+
for i, idx in enumerate(wanted_idxs):
|
221 |
+
obj_names.append( all_obj_names[idx] )
|
222 |
+
boxes[i] = all_boxes[idx]
|
223 |
+
masks[i] = all_masks[idx]
|
224 |
+
positive_embeddings[i] = all_positive_embeddings[idx]
|
225 |
+
|
226 |
+
# Caption
|
227 |
+
if random.uniform(0, 1) < self.prob_real_caption:
|
228 |
+
caption_data = self.image_id_to_captions[image_id]
|
229 |
+
idx = random.randint(0, len(caption_data)-1 )
|
230 |
+
caption = caption_data[idx]["caption"]
|
231 |
+
else:
|
232 |
+
if self.fake_caption_type == "empty":
|
233 |
+
caption = ""
|
234 |
+
else:
|
235 |
+
caption = make_a_sentence(obj_names, clean=True)
|
236 |
+
|
237 |
+
|
238 |
+
out["caption"] = caption
|
239 |
+
out["boxes"] = boxes
|
240 |
+
out["masks"] = masks
|
241 |
+
out["positive_embeddings"] = positive_embeddings
|
242 |
+
|
243 |
+
return out
|
244 |
+
|
245 |
+
|
246 |
+
def __len__(self):
|
247 |
+
if self.max_images is None:
|
248 |
+
return len(self.image_ids)
|
249 |
+
return min(len(self.image_ids), self.max_images)
|
250 |
+
|
dataset/concat_dataset.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .catalog import DatasetCatalog
|
2 |
+
from ldm.util import instantiate_from_config
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class ConCatDataset():
|
9 |
+
def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None):
|
10 |
+
self.datasets = []
|
11 |
+
cul_previous_dataset_length = 0
|
12 |
+
offset_map = []
|
13 |
+
which_dataset = []
|
14 |
+
|
15 |
+
if repeats is None:
|
16 |
+
repeats = [1] * len(dataset_name_list)
|
17 |
+
else:
|
18 |
+
assert len(repeats) == len(dataset_name_list)
|
19 |
+
|
20 |
+
|
21 |
+
Catalog = DatasetCatalog(ROOT, which_embedder)
|
22 |
+
for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()):
|
23 |
+
repeat = repeats[dataset_idx]
|
24 |
+
|
25 |
+
dataset_dict = getattr(Catalog, dataset_name)
|
26 |
+
|
27 |
+
target = dataset_dict['target']
|
28 |
+
params = dataset_dict['train_params'] if train else dataset_dict['val_params']
|
29 |
+
if yaml_params is not None:
|
30 |
+
params.update(yaml_params)
|
31 |
+
dataset = instantiate_from_config( dict(target=target, params=params) )
|
32 |
+
|
33 |
+
self.datasets.append(dataset)
|
34 |
+
for _ in range(repeat):
|
35 |
+
offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length )
|
36 |
+
which_dataset.append( torch.ones(len(dataset))*dataset_idx )
|
37 |
+
cul_previous_dataset_length += len(dataset)
|
38 |
+
offset_map = torch.cat(offset_map, dim=0).long()
|
39 |
+
self.total_length = cul_previous_dataset_length
|
40 |
+
|
41 |
+
self.mapping = torch.arange(self.total_length) - offset_map
|
42 |
+
self.which_dataset = torch.cat(which_dataset, dim=0).long()
|
43 |
+
|
44 |
+
|
45 |
+
def total_images(self):
|
46 |
+
count = 0
|
47 |
+
for dataset in self.datasets:
|
48 |
+
print(dataset.total_images())
|
49 |
+
count += dataset.total_images()
|
50 |
+
return count
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
dataset = self.datasets[ self.which_dataset[idx] ]
|
56 |
+
return dataset[ self.mapping[idx] ]
|
57 |
+
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return self.total_length
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
dataset/grounding_dataset.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tkinter.messagebox import NO
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from collections import defaultdict
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
from copy import deepcopy
|
7 |
+
import os
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import torchvision
|
10 |
+
from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
|
11 |
+
from io import BytesIO
|
12 |
+
import random
|
13 |
+
|
14 |
+
def check_unique(images, fields):
|
15 |
+
for field in fields:
|
16 |
+
temp_list = []
|
17 |
+
for img_info in images:
|
18 |
+
temp_list.append(img_info[field])
|
19 |
+
assert len(set(temp_list)) == len(temp_list), field
|
20 |
+
|
21 |
+
def clean_data(data):
|
22 |
+
for data_info in data:
|
23 |
+
data_info.pop("original_img_id", None)
|
24 |
+
data_info.pop("original_id", None)
|
25 |
+
data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image)
|
26 |
+
data_info.pop("dataset_name", None)
|
27 |
+
data_info.pop("data_source", None)
|
28 |
+
data_info["data_id"] = data_info.pop("id")
|
29 |
+
|
30 |
+
|
31 |
+
def clean_annotations(annotations):
|
32 |
+
for anno_info in annotations:
|
33 |
+
anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco
|
34 |
+
anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation
|
35 |
+
anno_info.pop("area", None)
|
36 |
+
# anno_info.pop("id", None)
|
37 |
+
anno_info["data_id"] = anno_info.pop("image_id")
|
38 |
+
|
39 |
+
|
40 |
+
def draw_box(img, boxes):
|
41 |
+
draw = ImageDraw.Draw(img)
|
42 |
+
for box in boxes:
|
43 |
+
draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
|
44 |
+
return img
|
45 |
+
|
46 |
+
|
47 |
+
def xyhw2xyxy(box):
|
48 |
+
x0, y0, w, h = box
|
49 |
+
return [ x0, y0, x0+w, y0+h ]
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
class GroundingDataset(BaseDataset):
|
54 |
+
def __init__(self,
|
55 |
+
image_root,
|
56 |
+
json_path,
|
57 |
+
annotation_embedding_path,
|
58 |
+
prob_real_caption=1,
|
59 |
+
image_size=256,
|
60 |
+
min_box_size=0.01,
|
61 |
+
max_boxes_per_data=8,
|
62 |
+
max_images=None, # set as 30K used to eval
|
63 |
+
random_crop = False,
|
64 |
+
random_flip = True,
|
65 |
+
):
|
66 |
+
super().__init__(image_root, random_crop, random_flip, image_size)
|
67 |
+
self.image_root = image_root
|
68 |
+
self.json_path = json_path
|
69 |
+
self.annotation_embedding_path = annotation_embedding_path
|
70 |
+
self.prob_real_caption = prob_real_caption
|
71 |
+
self.min_box_size = min_box_size
|
72 |
+
self.max_boxes_per_data = max_boxes_per_data
|
73 |
+
self.max_images = max_images
|
74 |
+
|
75 |
+
|
76 |
+
# Load raw data
|
77 |
+
with open(json_path, 'r') as f:
|
78 |
+
json_raw = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
|
79 |
+
self.data = json_raw["images"] # donot name it images, which is misleading
|
80 |
+
self.annotations = json_raw["annotations"]
|
81 |
+
|
82 |
+
|
83 |
+
# Load preprocessed name embedding
|
84 |
+
if 'bert' in annotation_embedding_path:
|
85 |
+
self.embedding_len = 1280
|
86 |
+
elif 'clip' in annotation_embedding_path:
|
87 |
+
self.embedding_len = 768
|
88 |
+
else:
|
89 |
+
assert False
|
90 |
+
|
91 |
+
|
92 |
+
# clean data and annotation
|
93 |
+
check_unique( self.data, ['id'] )
|
94 |
+
check_unique( self.annotations, ['id'] )
|
95 |
+
clean_data(self.data)
|
96 |
+
clean_annotations(self.annotations)
|
97 |
+
self.data_id_list = [ datum['data_id'] for datum in self.data ]
|
98 |
+
self.data = { datum['data_id']:datum for datum in self.data } # map self.data from a list into a dict
|
99 |
+
|
100 |
+
|
101 |
+
# data point to its annotation mapping
|
102 |
+
self.data_id_to_annos = defaultdict(list)
|
103 |
+
for anno in self.annotations:
|
104 |
+
self.data_id_to_annos[ anno["data_id"] ].append(anno)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
# These are not used that offen, but are useful in some cases
|
109 |
+
self.file_names = [] # all training images
|
110 |
+
self.file_name_to_data_ids = defaultdict(list) # for each image, there are multiple data points (captions)
|
111 |
+
for data_id in self.data_id_list:
|
112 |
+
fine_name = self.data[data_id]["file_name"]
|
113 |
+
self.file_names.append(fine_name)
|
114 |
+
self.file_name_to_data_ids[fine_name].append(data_id)
|
115 |
+
self.file_names = list(set(self.file_names))
|
116 |
+
|
117 |
+
|
118 |
+
if self.max_images is not None:
|
119 |
+
"This is only used as COCO2017P evulation, when we set max_images as 30k"
|
120 |
+
assert False, 'I have commented out the following code to save cpu memory'
|
121 |
+
# new_data_id_list = []
|
122 |
+
# new_file_name_to_data_ids = defaultdict(list)
|
123 |
+
# self.file_names = self.file_names[0:self.max_images]
|
124 |
+
# for file_name in self.file_names:
|
125 |
+
# data_id = self.file_name_to_data_ids[file_name][0]
|
126 |
+
# new_data_id_list.append(data_id)
|
127 |
+
# new_file_name_to_data_ids[file_name].append(data_id)
|
128 |
+
# self.data_id_list = new_data_id_list
|
129 |
+
# self.file_name_to_data_ids = new_file_name_to_data_ids
|
130 |
+
|
131 |
+
|
132 |
+
# Check if all filenames can be found in the zip file
|
133 |
+
# all_filenames = [self.data[idx]['file_name'] for idx in self.data_id_list ]
|
134 |
+
# check_filenames_in_zipdata(all_filenames, image_root)
|
135 |
+
|
136 |
+
|
137 |
+
def total_images(self):
|
138 |
+
return len(self.file_names)
|
139 |
+
|
140 |
+
|
141 |
+
def __getitem__(self, index):
|
142 |
+
if self.max_boxes_per_data > 99:
|
143 |
+
assert False, "Are you sure setting such large number of boxes?"
|
144 |
+
|
145 |
+
out = {}
|
146 |
+
|
147 |
+
data_id = self.data_id_list[index]
|
148 |
+
out['id'] = data_id
|
149 |
+
|
150 |
+
|
151 |
+
# Image and caption
|
152 |
+
file_name = self.data[data_id]['file_name']
|
153 |
+
image = self.fetch_image(file_name)
|
154 |
+
image_tensor, trans_info = self.transform_image(image)
|
155 |
+
out["image"] = image_tensor
|
156 |
+
|
157 |
+
if random.uniform(0, 1) < self.prob_real_caption:
|
158 |
+
out["caption"] = self.data[data_id]["caption"]
|
159 |
+
else:
|
160 |
+
out["caption"] = ""
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
annos = deepcopy(self.data_id_to_annos[data_id])
|
165 |
+
areas = []
|
166 |
+
all_boxes = []
|
167 |
+
all_masks = []
|
168 |
+
all_positive_embeddings = []
|
169 |
+
|
170 |
+
|
171 |
+
for anno in annos:
|
172 |
+
|
173 |
+
x, y, w, h = anno['bbox']
|
174 |
+
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
|
175 |
+
|
176 |
+
if valid:
|
177 |
+
areas.append( (x1-x0)*(y1-y0) )
|
178 |
+
all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
|
179 |
+
all_masks.append(1)
|
180 |
+
all_positive_embeddings.append( torch.load(os.path.join(self.annotation_embedding_path,str(anno["id"])), map_location='cpu' ) )
|
181 |
+
|
182 |
+
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
|
183 |
+
wanted_idxs = wanted_idxs[0:self.max_boxes_per_data]
|
184 |
+
|
185 |
+
boxes = torch.zeros(self.max_boxes_per_data, 4)
|
186 |
+
masks = torch.zeros(self.max_boxes_per_data)
|
187 |
+
positive_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
|
188 |
+
for i, idx in enumerate(wanted_idxs):
|
189 |
+
boxes[i] = all_boxes[idx]
|
190 |
+
masks[i] = all_masks[idx]
|
191 |
+
positive_embeddings[i] = all_positive_embeddings[idx]
|
192 |
+
|
193 |
+
|
194 |
+
out["boxes"] = boxes
|
195 |
+
out["masks"] = masks
|
196 |
+
out["positive_embeddings"] = positive_embeddings
|
197 |
+
|
198 |
+
return out
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
def __len__(self):
|
203 |
+
return len(self.data_id_list)
|
204 |
+
|
205 |
+
|
dataset/layout_dataset.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, os, random, math
|
2 |
+
from collections import defaultdict
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
from .base_dataset import BaseDataset, check_filenames_in_zipdata
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def clean_annotations(annotations):
|
18 |
+
for anno in annotations:
|
19 |
+
anno.pop("segmentation", None)
|
20 |
+
anno.pop("area", None)
|
21 |
+
anno.pop("iscrowd", None)
|
22 |
+
anno.pop("id", None)
|
23 |
+
|
24 |
+
|
25 |
+
def make_a_sentence(obj_names, clean=False):
|
26 |
+
|
27 |
+
if clean:
|
28 |
+
obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
|
29 |
+
|
30 |
+
caption = ""
|
31 |
+
tokens_positive = []
|
32 |
+
for obj_name in obj_names:
|
33 |
+
start_len = len(caption)
|
34 |
+
caption += obj_name
|
35 |
+
end_len = len(caption)
|
36 |
+
caption += ", "
|
37 |
+
tokens_positive.append(
|
38 |
+
[[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
|
39 |
+
)
|
40 |
+
caption = caption[:-2] # remove last ", "
|
41 |
+
|
42 |
+
return caption #, tokens_positive
|
43 |
+
|
44 |
+
|
45 |
+
class LayoutDataset(BaseDataset):
|
46 |
+
"""
|
47 |
+
Note: this dataset can somehow be achieved in cd_dataset.CDDataset
|
48 |
+
Since if you donot set prob_real_caption=0 in CDDataset, then that
|
49 |
+
dataset will only use detection annotations. However, in that dataset,
|
50 |
+
we do not remove images but remove boxes.
|
51 |
+
|
52 |
+
However, in layout2img works, people will just resize raw image data into 256*256,
|
53 |
+
thus they pre-calculate box size and apply min_box_size before min/max_boxes_per_image.
|
54 |
+
And then they will remove images if does not follow the rule.
|
55 |
+
|
56 |
+
These two different methods will lead to different number of training/val images.
|
57 |
+
Thus this dataset here is only for layout2img.
|
58 |
+
|
59 |
+
"""
|
60 |
+
def __init__(self,
|
61 |
+
image_root,
|
62 |
+
instances_json_path,
|
63 |
+
stuff_json_path,
|
64 |
+
category_embedding_path,
|
65 |
+
fake_caption_type = 'empty',
|
66 |
+
image_size=256,
|
67 |
+
max_samples=None,
|
68 |
+
min_box_size=0.02,
|
69 |
+
min_boxes_per_image=3,
|
70 |
+
max_boxes_per_image=8,
|
71 |
+
include_other=False,
|
72 |
+
random_flip=True
|
73 |
+
):
|
74 |
+
super().__init__(random_crop=None, random_flip=None, image_size=None) # we only use vis_getitem func in BaseDataset, donot use the others.
|
75 |
+
|
76 |
+
assert fake_caption_type in ['empty', 'made']
|
77 |
+
self.image_root = image_root
|
78 |
+
self.instances_json_path = instances_json_path
|
79 |
+
self.stuff_json_path = stuff_json_path
|
80 |
+
self.category_embedding_path = category_embedding_path
|
81 |
+
self.fake_caption_type = fake_caption_type
|
82 |
+
self.image_size = image_size
|
83 |
+
self.max_samples = max_samples
|
84 |
+
self.min_box_size = min_box_size
|
85 |
+
self.min_boxes_per_image = min_boxes_per_image
|
86 |
+
self.max_boxes_per_image = max_boxes_per_image
|
87 |
+
self.include_other = include_other
|
88 |
+
self.random_flip = random_flip
|
89 |
+
|
90 |
+
|
91 |
+
self.transform = transforms.Compose([transforms.Resize( (image_size, image_size) ),
|
92 |
+
transforms.ToTensor(),
|
93 |
+
transforms.Lambda(lambda t: (t * 2) - 1) ])
|
94 |
+
|
95 |
+
# Load all jsons
|
96 |
+
with open(instances_json_path, 'r') as f:
|
97 |
+
instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
|
98 |
+
clean_annotations(instances_data["annotations"])
|
99 |
+
self.instances_data = instances_data
|
100 |
+
|
101 |
+
with open(stuff_json_path, 'r') as f:
|
102 |
+
stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
|
103 |
+
clean_annotations(stuff_data["annotations"])
|
104 |
+
self.stuff_data = stuff_data
|
105 |
+
|
106 |
+
|
107 |
+
# Load preprocessed name embedding
|
108 |
+
self.category_embeddings = torch.load(category_embedding_path)
|
109 |
+
self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
|
110 |
+
|
111 |
+
|
112 |
+
# Misc
|
113 |
+
self.image_ids = [] # main list for selecting images
|
114 |
+
self.image_id_to_filename = {} # file names used to read image
|
115 |
+
self.image_id_to_size = {} # original size of this image
|
116 |
+
assert instances_data['images'] == stuff_data["images"]
|
117 |
+
for image_data in instances_data['images']:
|
118 |
+
image_id = image_data['id']
|
119 |
+
filename = image_data['file_name']
|
120 |
+
width = image_data['width']
|
121 |
+
height = image_data['height']
|
122 |
+
self.image_ids.append(image_id)
|
123 |
+
self.image_id_to_filename[image_id] = filename
|
124 |
+
self.image_id_to_size[image_id] = (width, height)
|
125 |
+
|
126 |
+
# All category names (including things and stuff)
|
127 |
+
self.things_id_list = []
|
128 |
+
self.stuff_id_list = []
|
129 |
+
self.object_idx_to_name = {}
|
130 |
+
for category_data in instances_data['categories']:
|
131 |
+
self.things_id_list.append( category_data['id'] )
|
132 |
+
self.object_idx_to_name[category_data['id']] = category_data['name']
|
133 |
+
for category_data in stuff_data['categories']:
|
134 |
+
self.stuff_id_list.append( category_data['id'] )
|
135 |
+
self.object_idx_to_name[category_data['id']] = category_data['name']
|
136 |
+
self.all_categories = [ self.object_idx_to_name.get(k, None) for k in range(183+1) ]
|
137 |
+
|
138 |
+
|
139 |
+
# Add object data from instances and stuff
|
140 |
+
self.image_id_to_objects = defaultdict(list)
|
141 |
+
self.select_objects( instances_data['annotations'] )
|
142 |
+
self.select_objects( stuff_data['annotations'] )
|
143 |
+
|
144 |
+
|
145 |
+
# Prune images that have too few or too many objects
|
146 |
+
new_image_ids = []
|
147 |
+
for image_id in self.image_ids:
|
148 |
+
num_objs = len(self.image_id_to_objects[image_id])
|
149 |
+
if self.min_boxes_per_image <= num_objs <= self.max_boxes_per_image:
|
150 |
+
new_image_ids.append(image_id)
|
151 |
+
self.image_ids = new_image_ids
|
152 |
+
|
153 |
+
|
154 |
+
# Check if all filenames can be found in the zip file
|
155 |
+
all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids]
|
156 |
+
check_filenames_in_zipdata(all_filenames, image_root)
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def select_objects(self, annotations):
|
161 |
+
for object_anno in annotations:
|
162 |
+
image_id = object_anno['image_id']
|
163 |
+
_, _, w, h = object_anno['bbox']
|
164 |
+
W, H = self.image_id_to_size[image_id]
|
165 |
+
box_area = (w * h) / (W * H)
|
166 |
+
box_ok = box_area > self.min_box_size
|
167 |
+
object_name = self.object_idx_to_name[object_anno['category_id']]
|
168 |
+
other_ok = object_name != 'other' or self.include_other
|
169 |
+
if box_ok and other_ok:
|
170 |
+
self.image_id_to_objects[image_id].append(object_anno)
|
171 |
+
|
172 |
+
|
173 |
+
def total_images(self):
|
174 |
+
return len(self)
|
175 |
+
|
176 |
+
|
177 |
+
def __getitem__(self, index):
|
178 |
+
if self.max_boxes_per_image > 99:
|
179 |
+
assert False, "Are you sure setting such large number of boxes?"
|
180 |
+
|
181 |
+
out = {}
|
182 |
+
|
183 |
+
image_id = self.image_ids[index]
|
184 |
+
out['id'] = image_id
|
185 |
+
|
186 |
+
flip = self.random_flip and random.random()<0.5
|
187 |
+
|
188 |
+
# Image
|
189 |
+
filename = self.image_id_to_filename[image_id]
|
190 |
+
zip_file = self.fetch_zipfile(self.image_root)
|
191 |
+
image = Image.open(BytesIO(zip_file.read(filename))).convert('RGB')
|
192 |
+
WW, HH = image.size
|
193 |
+
if flip:
|
194 |
+
image = ImageOps.mirror(image)
|
195 |
+
out["image"] = self.transform(image)
|
196 |
+
|
197 |
+
this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id])
|
198 |
+
|
199 |
+
# Make a sentence
|
200 |
+
obj_names = [] # used for make a sentence
|
201 |
+
boxes = torch.zeros(self.max_boxes_per_image, 4)
|
202 |
+
masks = torch.zeros(self.max_boxes_per_image)
|
203 |
+
positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len)
|
204 |
+
for idx, object_anno in enumerate(this_image_obj_annos):
|
205 |
+
obj_name = self.object_idx_to_name[ object_anno['category_id'] ]
|
206 |
+
obj_names.append(obj_name)
|
207 |
+
x, y, w, h = object_anno['bbox']
|
208 |
+
x0 = x / WW
|
209 |
+
y0 = y / HH
|
210 |
+
x1 = (x + w) / WW
|
211 |
+
y1 = (y + h) / HH
|
212 |
+
if flip:
|
213 |
+
x0, x1 = 1-x1, 1-x0
|
214 |
+
boxes[idx] = torch.tensor([x0,y0,x1,y1])
|
215 |
+
masks[idx] = 1
|
216 |
+
positive_embeddings[idx] = self.category_embeddings[obj_name]
|
217 |
+
|
218 |
+
if self.fake_caption_type == 'empty':
|
219 |
+
caption = ""
|
220 |
+
else:
|
221 |
+
caption = make_a_sentence(obj_names, clean=True)
|
222 |
+
|
223 |
+
out["caption"] = caption
|
224 |
+
out["boxes"] = boxes
|
225 |
+
out["masks"] = masks
|
226 |
+
out["positive_embeddings"] = positive_embeddings
|
227 |
+
|
228 |
+
|
229 |
+
return out
|
230 |
+
|
231 |
+
|
232 |
+
def __len__(self):
|
233 |
+
if self.max_samples is None:
|
234 |
+
return len(self.image_ids)
|
235 |
+
return min(len(self.image_ids), self.max_samples)
|
236 |
+
|
237 |
+
|
dataset/tsv.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as op
|
3 |
+
import gc
|
4 |
+
import json
|
5 |
+
from typing import List
|
6 |
+
import logging
|
7 |
+
|
8 |
+
try:
|
9 |
+
from .blob_storage import BlobStorage, disk_usage
|
10 |
+
except:
|
11 |
+
class BlobStorage:
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def generate_lineidx(filein: str, idxout: str) -> None:
|
16 |
+
idxout_tmp = idxout + '.tmp'
|
17 |
+
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
|
18 |
+
fsize = os.fstat(tsvin.fileno()).st_size
|
19 |
+
fpos = 0
|
20 |
+
while fpos != fsize:
|
21 |
+
tsvout.write(str(fpos) + "\n")
|
22 |
+
tsvin.readline()
|
23 |
+
fpos = tsvin.tell()
|
24 |
+
os.rename(idxout_tmp, idxout)
|
25 |
+
|
26 |
+
|
27 |
+
def read_to_character(fp, c):
|
28 |
+
result = []
|
29 |
+
while True:
|
30 |
+
s = fp.read(32)
|
31 |
+
assert s != ''
|
32 |
+
if c in s:
|
33 |
+
result.append(s[: s.index(c)])
|
34 |
+
break
|
35 |
+
else:
|
36 |
+
result.append(s)
|
37 |
+
return ''.join(result)
|
38 |
+
|
39 |
+
|
40 |
+
class TSVFile(object):
|
41 |
+
def __init__(self,
|
42 |
+
tsv_file: str,
|
43 |
+
if_generate_lineidx: bool = False,
|
44 |
+
lineidx: str = None,
|
45 |
+
class_selector: List[str] = None,
|
46 |
+
blob_storage: BlobStorage = None):
|
47 |
+
self.tsv_file = tsv_file
|
48 |
+
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
|
49 |
+
if not lineidx else lineidx
|
50 |
+
self.linelist = op.splitext(tsv_file)[0] + '.linelist'
|
51 |
+
self.chunks = op.splitext(tsv_file)[0] + '.chunks'
|
52 |
+
self._fp = None
|
53 |
+
self._lineidx = None
|
54 |
+
self._sample_indices = None
|
55 |
+
self._class_boundaries = None
|
56 |
+
self._class_selector = class_selector
|
57 |
+
self._blob_storage = blob_storage
|
58 |
+
self._len = None
|
59 |
+
# the process always keeps the process which opens the file.
|
60 |
+
# If the pid is not equal to the currrent pid, we will re-open the file.
|
61 |
+
self.pid = None
|
62 |
+
# generate lineidx if not exist
|
63 |
+
if not op.isfile(self.lineidx) and if_generate_lineidx:
|
64 |
+
generate_lineidx(self.tsv_file, self.lineidx)
|
65 |
+
|
66 |
+
def __del__(self):
|
67 |
+
self.gcidx()
|
68 |
+
if self._fp:
|
69 |
+
self._fp.close()
|
70 |
+
# physically remove the tsv file if it is retrieved by BlobStorage
|
71 |
+
if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
|
72 |
+
try:
|
73 |
+
original_usage = disk_usage('/')
|
74 |
+
os.remove(self.tsv_file)
|
75 |
+
logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
|
76 |
+
(self.tsv_file, original_usage, disk_usage('/') * 100))
|
77 |
+
except:
|
78 |
+
# Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
|
79 |
+
# TODO: try Threadling.Lock to better handle the race condition
|
80 |
+
pass
|
81 |
+
|
82 |
+
def __str__(self):
|
83 |
+
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
|
84 |
+
|
85 |
+
def __repr__(self):
|
86 |
+
return str(self)
|
87 |
+
|
88 |
+
def gcidx(self):
|
89 |
+
logging.debug('Run gc collect')
|
90 |
+
self._lineidx = None
|
91 |
+
self._sample_indices = None
|
92 |
+
#self._class_boundaries = None
|
93 |
+
return gc.collect()
|
94 |
+
|
95 |
+
def get_class_boundaries(self):
|
96 |
+
return self._class_boundaries
|
97 |
+
|
98 |
+
def num_rows(self, gcf=False):
|
99 |
+
if (self._len is None):
|
100 |
+
self._ensure_lineidx_loaded()
|
101 |
+
retval = len(self._sample_indices)
|
102 |
+
|
103 |
+
if (gcf):
|
104 |
+
self.gcidx()
|
105 |
+
|
106 |
+
self._len = retval
|
107 |
+
|
108 |
+
return self._len
|
109 |
+
|
110 |
+
def seek(self, idx: int):
|
111 |
+
self._ensure_tsv_opened()
|
112 |
+
self._ensure_lineidx_loaded()
|
113 |
+
try:
|
114 |
+
pos = self._lineidx[self._sample_indices[idx]]
|
115 |
+
except:
|
116 |
+
logging.info('=> {}-{}'.format(self.tsv_file, idx))
|
117 |
+
raise
|
118 |
+
self._fp.seek(pos)
|
119 |
+
return [s.strip() for s in self._fp.readline().split('\t')]
|
120 |
+
|
121 |
+
def seek_first_column(self, idx: int):
|
122 |
+
self._ensure_tsv_opened()
|
123 |
+
self._ensure_lineidx_loaded()
|
124 |
+
pos = self._lineidx[idx]
|
125 |
+
self._fp.seek(pos)
|
126 |
+
return read_to_character(self._fp, '\t')
|
127 |
+
|
128 |
+
def get_key(self, idx: int):
|
129 |
+
return self.seek_first_column(idx)
|
130 |
+
|
131 |
+
def __getitem__(self, index: int):
|
132 |
+
return self.seek(index)
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
return self.num_rows()
|
136 |
+
|
137 |
+
def _ensure_lineidx_loaded(self):
|
138 |
+
if self._lineidx is None:
|
139 |
+
logging.debug('=> loading lineidx: {}'.format(self.lineidx))
|
140 |
+
with open(self.lineidx, 'r') as fp:
|
141 |
+
lines = fp.readlines()
|
142 |
+
lines = [line.strip() for line in lines]
|
143 |
+
self._lineidx = [int(line) for line in lines]
|
144 |
+
|
145 |
+
# read the line list if exists
|
146 |
+
linelist = None
|
147 |
+
if op.isfile(self.linelist):
|
148 |
+
with open(self.linelist, 'r') as fp:
|
149 |
+
linelist = sorted(
|
150 |
+
[
|
151 |
+
int(line.strip())
|
152 |
+
for line in fp.readlines()
|
153 |
+
]
|
154 |
+
)
|
155 |
+
|
156 |
+
if op.isfile(self.chunks):
|
157 |
+
self._sample_indices = []
|
158 |
+
self._class_boundaries = []
|
159 |
+
class_boundaries = json.load(open(self.chunks, 'r'))
|
160 |
+
for class_name, boundary in class_boundaries.items():
|
161 |
+
start = len(self._sample_indices)
|
162 |
+
if class_name in self._class_selector:
|
163 |
+
for idx in range(boundary[0], boundary[1] + 1):
|
164 |
+
# NOTE: potentially slow when linelist is long, try to speed it up
|
165 |
+
if linelist and idx not in linelist:
|
166 |
+
continue
|
167 |
+
self._sample_indices.append(idx)
|
168 |
+
end = len(self._sample_indices)
|
169 |
+
self._class_boundaries.append((start, end))
|
170 |
+
else:
|
171 |
+
if linelist:
|
172 |
+
self._sample_indices = linelist
|
173 |
+
else:
|
174 |
+
self._sample_indices = list(range(len(self._lineidx)))
|
175 |
+
|
176 |
+
def _ensure_tsv_opened(self):
|
177 |
+
if self._fp is None:
|
178 |
+
if self._blob_storage:
|
179 |
+
self._fp = self._blob_storage.open(self.tsv_file)
|
180 |
+
else:
|
181 |
+
self._fp = open(self.tsv_file, 'r')
|
182 |
+
self.pid = os.getpid()
|
183 |
+
|
184 |
+
if self.pid != os.getpid():
|
185 |
+
logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
|
186 |
+
self._fp = open(self.tsv_file, 'r')
|
187 |
+
self.pid = os.getpid()
|
188 |
+
|
189 |
+
|
190 |
+
class TSVWriter(object):
|
191 |
+
def __init__(self, tsv_file):
|
192 |
+
self.tsv_file = tsv_file
|
193 |
+
self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
|
194 |
+
self.tsv_file_tmp = self.tsv_file + '.tmp'
|
195 |
+
self.lineidx_file_tmp = self.lineidx_file + '.tmp'
|
196 |
+
|
197 |
+
self.tsv_fp = open(self.tsv_file_tmp, 'w')
|
198 |
+
self.lineidx_fp = open(self.lineidx_file_tmp, 'w')
|
199 |
+
|
200 |
+
self.idx = 0
|
201 |
+
|
202 |
+
def write(self, values, sep='\t'):
|
203 |
+
v = '{0}\n'.format(sep.join(map(str, values)))
|
204 |
+
self.tsv_fp.write(v)
|
205 |
+
self.lineidx_fp.write(str(self.idx) + '\n')
|
206 |
+
self.idx = self.idx + len(v)
|
207 |
+
|
208 |
+
def close(self):
|
209 |
+
self.tsv_fp.close()
|
210 |
+
self.lineidx_fp.close()
|
211 |
+
os.rename(self.tsv_file_tmp, self.tsv_file)
|
212 |
+
os.rename(self.lineidx_file_tmp, self.lineidx_file)
|
dataset/tsv_dataset.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tkinter.messagebox import NO
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from collections import defaultdict
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
from copy import deepcopy
|
7 |
+
import os
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import torchvision
|
10 |
+
from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
|
11 |
+
from io import BytesIO
|
12 |
+
import random
|
13 |
+
|
14 |
+
from .tsv import TSVFile
|
15 |
+
|
16 |
+
from io import BytesIO
|
17 |
+
import base64
|
18 |
+
from PIL import Image
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
def decode_base64_to_pillow(image_b64):
|
23 |
+
return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB')
|
24 |
+
|
25 |
+
def decode_tensor_from_string(arr_str, use_tensor=True):
|
26 |
+
arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32')
|
27 |
+
if use_tensor:
|
28 |
+
arr = torch.from_numpy(arr)
|
29 |
+
return arr
|
30 |
+
|
31 |
+
def decode_item(item):
|
32 |
+
item = json.loads(item)
|
33 |
+
item['image'] = decode_base64_to_pillow(item['image'])
|
34 |
+
|
35 |
+
for anno in item['annos']:
|
36 |
+
anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before'])
|
37 |
+
anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before'])
|
38 |
+
anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after'])
|
39 |
+
anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after'])
|
40 |
+
return item
|
41 |
+
|
42 |
+
def check_unique(images, fields):
|
43 |
+
for field in fields:
|
44 |
+
temp_list = []
|
45 |
+
for img_info in images:
|
46 |
+
temp_list.append(img_info[field])
|
47 |
+
assert len(set(temp_list)) == len(temp_list), field
|
48 |
+
|
49 |
+
def clean_data(data):
|
50 |
+
for data_info in data:
|
51 |
+
data_info.pop("original_img_id", None)
|
52 |
+
data_info.pop("original_id", None)
|
53 |
+
data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image)
|
54 |
+
data_info.pop("dataset_name", None)
|
55 |
+
data_info.pop("data_source", None)
|
56 |
+
data_info["data_id"] = data_info.pop("id")
|
57 |
+
|
58 |
+
|
59 |
+
def clean_annotations(annotations):
|
60 |
+
for anno_info in annotations:
|
61 |
+
anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco
|
62 |
+
anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation
|
63 |
+
anno_info.pop("area", None)
|
64 |
+
# anno_info.pop("id", None)
|
65 |
+
anno_info["data_id"] = anno_info.pop("image_id")
|
66 |
+
|
67 |
+
|
68 |
+
def draw_box(img, boxes):
|
69 |
+
draw = ImageDraw.Draw(img)
|
70 |
+
for box in boxes:
|
71 |
+
draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
|
72 |
+
return img
|
73 |
+
|
74 |
+
|
75 |
+
def xyhw2xyxy(box):
|
76 |
+
x0, y0, w, h = box
|
77 |
+
return [ x0, y0, x0+w, y0+h ]
|
78 |
+
|
79 |
+
|
80 |
+
def make_a_sentence(obj_names, clean=False):
|
81 |
+
|
82 |
+
if clean:
|
83 |
+
obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
|
84 |
+
|
85 |
+
caption = ""
|
86 |
+
tokens_positive = []
|
87 |
+
for obj_name in obj_names:
|
88 |
+
start_len = len(caption)
|
89 |
+
caption += obj_name
|
90 |
+
end_len = len(caption)
|
91 |
+
caption += ", "
|
92 |
+
tokens_positive.append(
|
93 |
+
[[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
|
94 |
+
)
|
95 |
+
caption = caption[:-2] # remove last ", "
|
96 |
+
|
97 |
+
return caption #, tokens_positive
|
98 |
+
|
99 |
+
|
100 |
+
def mask_for_random_drop_text_or_image_feature(masks, random_drop_embedding):
|
101 |
+
"""
|
102 |
+
input masks tell how many valid grounding tokens for this image
|
103 |
+
e.g., 1,1,1,1,0,0,0,0,0,0...
|
104 |
+
|
105 |
+
If random_drop_embedding=both. we will random drop either image or
|
106 |
+
text feature for each token,
|
107 |
+
but we always make sure there is at least one feature used.
|
108 |
+
In other words, the following masks are not valid
|
109 |
+
(because for the second obj, no feature at all):
|
110 |
+
image: 1,0,1,1,0,0,0,0,0
|
111 |
+
text: 1,0,0,0,0,0,0,0,0
|
112 |
+
|
113 |
+
if random_drop_embedding=image. we will random drop image feature
|
114 |
+
and always keep the text one.
|
115 |
+
|
116 |
+
"""
|
117 |
+
N = masks.shape[0]
|
118 |
+
|
119 |
+
if random_drop_embedding=='both':
|
120 |
+
temp_mask = torch.ones(2,N)
|
121 |
+
for i in range(N):
|
122 |
+
if random.uniform(0, 1) < 0.5: # else keep both features
|
123 |
+
idx = random.sample([0,1], 1)[0] # randomly choose to drop image or text feature
|
124 |
+
temp_mask[idx,i] = 0
|
125 |
+
image_masks = temp_mask[0]*masks
|
126 |
+
text_masks = temp_mask[1]*masks
|
127 |
+
|
128 |
+
if random_drop_embedding=='image':
|
129 |
+
image_masks = masks*(torch.rand(N)>0.5)*1
|
130 |
+
text_masks = masks
|
131 |
+
|
132 |
+
return image_masks, text_masks
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def project(x, projection_matrix):
|
139 |
+
"""
|
140 |
+
x (Batch*768) should be the penultimate feature of CLIP (before projection)
|
141 |
+
projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
|
142 |
+
defined in CLIP (out_dim, in_dim), thus we need to apply transpose below.
|
143 |
+
this function will return the CLIP feature (without normalziation)
|
144 |
+
"""
|
145 |
+
return x@torch.transpose(projection_matrix, 0, 1)
|
146 |
+
|
147 |
+
|
148 |
+
def inv_project(y, projection_matrix):
|
149 |
+
"""
|
150 |
+
y (Batch*768) should be the CLIP feature (after projection)
|
151 |
+
projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
|
152 |
+
defined in CLIP (out_dim, in_dim).
|
153 |
+
this function will return the CLIP penultimate feature.
|
154 |
+
|
155 |
+
Note: to make sure getting the correct penultimate feature, the input y should not be normalized.
|
156 |
+
If it is normalized, then the result will be scaled by CLIP feature norm, which is unknown.
|
157 |
+
"""
|
158 |
+
return y@torch.transpose(torch.linalg.inv(projection_matrix), 0, 1)
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
class TSVDataset(BaseDataset):
|
164 |
+
def __init__(self,
|
165 |
+
tsv_path,
|
166 |
+
which_embedder='clip',
|
167 |
+
which_layer=['after','after'], # text and image
|
168 |
+
prob_use_caption=1,
|
169 |
+
random_drop_embedding='none',
|
170 |
+
image_size=256,
|
171 |
+
min_box_size=0.01,
|
172 |
+
max_boxes_per_data=8,
|
173 |
+
max_images=None, # set as 30K used to eval
|
174 |
+
random_crop = False,
|
175 |
+
random_flip = True,
|
176 |
+
):
|
177 |
+
image_root = "a placeholder path as we are using tsv here"
|
178 |
+
super().__init__(image_root, random_crop, random_flip, image_size)
|
179 |
+
self.tsv_path = tsv_path
|
180 |
+
self.which_embedder = which_embedder
|
181 |
+
self.prob_use_caption = prob_use_caption
|
182 |
+
self.random_drop_embedding = random_drop_embedding
|
183 |
+
self.min_box_size = min_box_size
|
184 |
+
self.max_boxes_per_data = max_boxes_per_data
|
185 |
+
self.max_images = max_images
|
186 |
+
|
187 |
+
assert which_layer in [ ['after','after'], ['before','after_renorm'], ['before','after_reproject'] ]
|
188 |
+
assert random_drop_embedding in ['none', 'both', 'image']
|
189 |
+
self.which_layer_text = which_layer[0]
|
190 |
+
self.which_layer_image = which_layer[1]
|
191 |
+
|
192 |
+
#self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
|
193 |
+
self.projection_matrix = torch.load('projection_matrix.pth')
|
194 |
+
|
195 |
+
# Load tsv data
|
196 |
+
self.tsv_file = TSVFile(self.tsv_path)
|
197 |
+
|
198 |
+
|
199 |
+
# Load preprocessed name embedding
|
200 |
+
if which_embedder == 'bert':
|
201 |
+
self.embedding_len = 1280
|
202 |
+
elif which_embedder == 'clip':
|
203 |
+
self.embedding_len = 768
|
204 |
+
else:
|
205 |
+
assert False
|
206 |
+
|
207 |
+
def total_images(self):
|
208 |
+
return len(self)
|
209 |
+
|
210 |
+
def get_item_from_tsv(self, index):
|
211 |
+
_, item = self.tsv_file[index]
|
212 |
+
item = decode_item(item)
|
213 |
+
return item
|
214 |
+
|
215 |
+
|
216 |
+
def mapping(self, image_embedding):
|
217 |
+
if self.which_layer_image == 'after':
|
218 |
+
# both use CLIP aligned feature
|
219 |
+
return image_embedding
|
220 |
+
elif self.which_layer_image == 'after_renorm':
|
221 |
+
# text use before, but image use after projection but normalize to 28.7
|
222 |
+
return image_embedding*28.7
|
223 |
+
elif self.which_layer_image == 'after_reproject':
|
224 |
+
image_embedding = project( image_embedding.unsqueeze(0), self.projection_matrix.T )
|
225 |
+
image_embedding = image_embedding.squeeze(0)
|
226 |
+
image_embedding = image_embedding / image_embedding.norm()
|
227 |
+
image_embedding = image_embedding * 28.7
|
228 |
+
return image_embedding
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
def __getitem__(self, index):
|
233 |
+
if self.max_boxes_per_data > 99:
|
234 |
+
assert False, "Are you sure setting such large number of boxes?"
|
235 |
+
|
236 |
+
raw_item = self.get_item_from_tsv(index)
|
237 |
+
is_det = raw_item.get('is_det', False) # if it is from detection (such as o365), then we will make a caption
|
238 |
+
|
239 |
+
out = {}
|
240 |
+
|
241 |
+
# -------------------- id and image ------------------- #
|
242 |
+
out['id'] = raw_item['data_id']
|
243 |
+
image = raw_item['image']
|
244 |
+
image_tensor, trans_info = self.transform_image(image)
|
245 |
+
out["image"] = image_tensor
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
# -------------------- grounding token ------------------- #
|
250 |
+
annos = raw_item['annos']
|
251 |
+
|
252 |
+
areas = []
|
253 |
+
all_boxes = []
|
254 |
+
all_masks = []
|
255 |
+
all_text_embeddings = []
|
256 |
+
all_image_embeddings = []
|
257 |
+
if is_det:
|
258 |
+
all_category_names = []
|
259 |
+
|
260 |
+
text_embedding_name = 'text_embedding_before' if self.which_layer_text == 'before' else 'text_embedding_after'
|
261 |
+
image_embedding_name = 'image_embedding_after'
|
262 |
+
|
263 |
+
for anno in annos:
|
264 |
+
x, y, w, h = anno['bbox']
|
265 |
+
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
|
266 |
+
|
267 |
+
if valid:
|
268 |
+
areas.append( (x1-x0)*(y1-y0) )
|
269 |
+
all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
|
270 |
+
all_masks.append(1)
|
271 |
+
all_text_embeddings.append(anno[text_embedding_name])
|
272 |
+
all_image_embeddings.append( self.mapping(anno[image_embedding_name]) )
|
273 |
+
if is_det:
|
274 |
+
all_category_names.append(anno["category_name"])
|
275 |
+
|
276 |
+
|
277 |
+
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
|
278 |
+
wanted_idxs = wanted_idxs[0:self.max_boxes_per_data]
|
279 |
+
|
280 |
+
boxes = torch.zeros(self.max_boxes_per_data, 4)
|
281 |
+
masks = torch.zeros(self.max_boxes_per_data)
|
282 |
+
text_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
|
283 |
+
image_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
|
284 |
+
if is_det:
|
285 |
+
category_names = []
|
286 |
+
for i, idx in enumerate(wanted_idxs):
|
287 |
+
boxes[i] = all_boxes[idx]
|
288 |
+
masks[i] = all_masks[idx]
|
289 |
+
text_embeddings[i] = all_text_embeddings[idx]
|
290 |
+
image_embeddings[i] = all_image_embeddings[idx]
|
291 |
+
if is_det:
|
292 |
+
category_names.append(all_category_names[idx])
|
293 |
+
|
294 |
+
if self.random_drop_embedding != 'none':
|
295 |
+
image_masks, text_masks = mask_for_random_drop_text_or_image_feature(masks, self.random_drop_embedding)
|
296 |
+
else:
|
297 |
+
image_masks = masks
|
298 |
+
text_masks = masks
|
299 |
+
|
300 |
+
|
301 |
+
out["boxes"] = boxes
|
302 |
+
out["masks"] = masks
|
303 |
+
out["image_masks"] = image_masks
|
304 |
+
out["text_masks"] = text_masks
|
305 |
+
out["text_embeddings"] = text_embeddings
|
306 |
+
out["image_embeddings"] = image_embeddings
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
# -------------------- caption ------------------- #
|
311 |
+
if random.uniform(0, 1) < self.prob_use_caption:
|
312 |
+
if is_det:
|
313 |
+
out["caption"] = make_a_sentence(category_names)
|
314 |
+
else:
|
315 |
+
out["caption"] = raw_item["caption"]
|
316 |
+
else:
|
317 |
+
out["caption"] = ""
|
318 |
+
|
319 |
+
return out
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
def __len__(self):
|
324 |
+
return len(self.tsv_file)
|
325 |
+
|
326 |
+
|
dataset/utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#
|
3 |
+
# Copyright 2018 Google LLC
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import PIL
|
18 |
+
import torch
|
19 |
+
import torchvision.transforms as T
|
20 |
+
|
21 |
+
|
22 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
23 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
24 |
+
|
25 |
+
INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN]
|
26 |
+
INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD]
|
27 |
+
|
28 |
+
|
29 |
+
def imagenet_preprocess():
|
30 |
+
return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
31 |
+
|
32 |
+
|
33 |
+
def rescale(x):
|
34 |
+
lo, hi = x.min(), x.max()
|
35 |
+
return x.sub(lo).div(hi - lo)
|
36 |
+
|
37 |
+
|
38 |
+
def imagenet_deprocess(rescale_image=True):
|
39 |
+
transforms = [
|
40 |
+
T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD),
|
41 |
+
T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]),
|
42 |
+
]
|
43 |
+
if rescale_image:
|
44 |
+
transforms.append(rescale)
|
45 |
+
return T.Compose(transforms)
|
46 |
+
|
47 |
+
|
48 |
+
def imagenet_deprocess_batch(imgs, rescale=True):
|
49 |
+
"""
|
50 |
+
Input:
|
51 |
+
- imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images
|
52 |
+
|
53 |
+
Output:
|
54 |
+
- imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images
|
55 |
+
in the range [0, 255]
|
56 |
+
"""
|
57 |
+
if isinstance(imgs, torch.autograd.Variable):
|
58 |
+
imgs = imgs.data
|
59 |
+
imgs = imgs.cpu().clone()
|
60 |
+
deprocess_fn = imagenet_deprocess(rescale_image=rescale)
|
61 |
+
imgs_de = []
|
62 |
+
for i in range(imgs.size(0)):
|
63 |
+
img_de = deprocess_fn(imgs[i])[None]
|
64 |
+
img_de = img_de.mul(255).clamp(0, 255).byte()
|
65 |
+
imgs_de.append(img_de)
|
66 |
+
imgs_de = torch.cat(imgs_de, dim=0)
|
67 |
+
return imgs_de
|
68 |
+
|
69 |
+
|
70 |
+
class Resize(object):
|
71 |
+
def __init__(self, size, interp=PIL.Image.BILINEAR):
|
72 |
+
if isinstance(size, tuple):
|
73 |
+
H, W = size
|
74 |
+
self.size = (W, H)
|
75 |
+
else:
|
76 |
+
self.size = (size, size)
|
77 |
+
self.interp = interp
|
78 |
+
|
79 |
+
def __call__(self, img):
|
80 |
+
return img.resize(self.size, self.interp)
|
81 |
+
|
82 |
+
|
83 |
+
def unpack_var(v):
|
84 |
+
if isinstance(v, torch.autograd.Variable):
|
85 |
+
return v.data
|
86 |
+
return v
|
87 |
+
|
88 |
+
|
89 |
+
def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img):
|
90 |
+
triples = unpack_var(triples)
|
91 |
+
obj_data = [unpack_var(o) for o in obj_data]
|
92 |
+
obj_to_img = unpack_var(obj_to_img)
|
93 |
+
triple_to_img = unpack_var(triple_to_img)
|
94 |
+
|
95 |
+
triples_out = []
|
96 |
+
obj_data_out = [[] for _ in obj_data]
|
97 |
+
obj_offset = 0
|
98 |
+
N = obj_to_img.max() + 1
|
99 |
+
for i in range(N):
|
100 |
+
o_idxs = (obj_to_img == i).nonzero().view(-1)
|
101 |
+
t_idxs = (triple_to_img == i).nonzero().view(-1)
|
102 |
+
|
103 |
+
cur_triples = triples[t_idxs].clone()
|
104 |
+
cur_triples[:, 0] -= obj_offset
|
105 |
+
cur_triples[:, 2] -= obj_offset
|
106 |
+
triples_out.append(cur_triples)
|
107 |
+
|
108 |
+
for j, o_data in enumerate(obj_data):
|
109 |
+
cur_o_data = None
|
110 |
+
if o_data is not None:
|
111 |
+
cur_o_data = o_data[o_idxs]
|
112 |
+
obj_data_out[j].append(cur_o_data)
|
113 |
+
|
114 |
+
obj_offset += o_idxs.size(0)
|
115 |
+
|
116 |
+
return triples_out, obj_data_out
|
gligen/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (345 Bytes). View file
|
|
gligen/__pycache__/distributed.cpython-38.pyc
ADDED
Binary file (2.91 kB). View file
|
|
gligen/__pycache__/evaluator.cpython-38.pyc
ADDED
Binary file (5.9 kB). View file
|
|
gligen/__pycache__/task_grounded_generation.cpython-38.pyc
ADDED
Binary file (9.11 kB). View file
|
|
gligen/__pycache__/trainer.cpython-38.pyc
ADDED
Binary file (11.7 kB). View file
|
|
gligen/ldm/__pycache__/util.cpython-38.pyc
ADDED
Binary file (3.2 kB). View file
|
|
gligen/ldm/models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc
ADDED
Binary file (1.58 kB). View file
|
|
gligen/ldm/models/autoencoder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
#import pytorch_lightning as pl
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from contextlib import contextmanager
|
6 |
+
|
7 |
+
# from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
8 |
+
|
9 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
10 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
11 |
+
|
12 |
+
from ldm.util import instantiate_from_config
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class AutoencoderKL(nn.Module):
|
18 |
+
def __init__(self,
|
19 |
+
ddconfig,
|
20 |
+
embed_dim,
|
21 |
+
scale_factor=1
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.encoder = Encoder(**ddconfig)
|
25 |
+
self.decoder = Decoder(**ddconfig)
|
26 |
+
assert ddconfig["double_z"]
|
27 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
28 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
29 |
+
self.embed_dim = embed_dim
|
30 |
+
self.scale_factor = scale_factor
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def encode(self, x):
|
35 |
+
h = self.encoder(x)
|
36 |
+
moments = self.quant_conv(h)
|
37 |
+
posterior = DiagonalGaussianDistribution(moments)
|
38 |
+
return posterior.sample() * self.scale_factor
|
39 |
+
|
40 |
+
def decode(self, z):
|
41 |
+
z = 1. / self.scale_factor * z
|
42 |
+
z = self.post_quant_conv(z)
|
43 |
+
dec = self.decoder(z)
|
44 |
+
return dec
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
gligen/ldm/models/diffusion/__init__.py
ADDED
File without changes
|
gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (159 Bytes). View file
|
|
gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
ADDED
Binary file (4.57 kB). View file
|
|
gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
ADDED
Binary file (2.12 kB). View file
|
|
gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc
ADDED
Binary file (4.11 kB). View file
|
|
gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc
ADDED
Binary file (1.21 kB). View file
|
|
gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc
ADDED
Binary file (4.23 kB). View file
|
|
gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
ADDED
Binary file (8.71 kB). View file
|
|
gligen/ldm/models/diffusion/classifier.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
8 |
+
from copy import deepcopy
|
9 |
+
from einops import rearrange
|
10 |
+
from glob import glob
|
11 |
+
from natsort import natsorted
|
12 |
+
|
13 |
+
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
14 |
+
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
15 |
+
|
16 |
+
__models__ = {
|
17 |
+
'class_label': EncoderUNetModel,
|
18 |
+
'segmentation': UNetModel
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def disabled_train(self, mode=True):
|
23 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
24 |
+
does not change anymore."""
|
25 |
+
return self
|
26 |
+
|
27 |
+
|
28 |
+
class NoisyLatentImageClassifier(pl.LightningModule):
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
diffusion_path,
|
32 |
+
num_classes,
|
33 |
+
ckpt_path=None,
|
34 |
+
pool='attention',
|
35 |
+
label_key=None,
|
36 |
+
diffusion_ckpt_path=None,
|
37 |
+
scheduler_config=None,
|
38 |
+
weight_decay=1.e-2,
|
39 |
+
log_steps=10,
|
40 |
+
monitor='val/loss',
|
41 |
+
*args,
|
42 |
+
**kwargs):
|
43 |
+
super().__init__(*args, **kwargs)
|
44 |
+
self.num_classes = num_classes
|
45 |
+
# get latest config of diffusion model
|
46 |
+
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
47 |
+
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
48 |
+
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
49 |
+
self.load_diffusion()
|
50 |
+
|
51 |
+
self.monitor = monitor
|
52 |
+
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
53 |
+
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
54 |
+
self.log_steps = log_steps
|
55 |
+
|
56 |
+
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
57 |
+
else self.diffusion_model.cond_stage_key
|
58 |
+
|
59 |
+
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
60 |
+
|
61 |
+
if self.label_key not in __models__:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
self.load_classifier(ckpt_path, pool)
|
65 |
+
|
66 |
+
self.scheduler_config = scheduler_config
|
67 |
+
self.use_scheduler = self.scheduler_config is not None
|
68 |
+
self.weight_decay = weight_decay
|
69 |
+
|
70 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
71 |
+
sd = torch.load(path, map_location="cpu")
|
72 |
+
if "state_dict" in list(sd.keys()):
|
73 |
+
sd = sd["state_dict"]
|
74 |
+
keys = list(sd.keys())
|
75 |
+
for k in keys:
|
76 |
+
for ik in ignore_keys:
|
77 |
+
if k.startswith(ik):
|
78 |
+
print("Deleting key {} from state_dict.".format(k))
|
79 |
+
del sd[k]
|
80 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
81 |
+
sd, strict=False)
|
82 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
83 |
+
if len(missing) > 0:
|
84 |
+
print(f"Missing Keys: {missing}")
|
85 |
+
if len(unexpected) > 0:
|
86 |
+
print(f"Unexpected Keys: {unexpected}")
|
87 |
+
|
88 |
+
def load_diffusion(self):
|
89 |
+
model = instantiate_from_config(self.diffusion_config)
|
90 |
+
self.diffusion_model = model.eval()
|
91 |
+
self.diffusion_model.train = disabled_train
|
92 |
+
for param in self.diffusion_model.parameters():
|
93 |
+
param.requires_grad = False
|
94 |
+
|
95 |
+
def load_classifier(self, ckpt_path, pool):
|
96 |
+
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
97 |
+
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
98 |
+
model_config.out_channels = self.num_classes
|
99 |
+
if self.label_key == 'class_label':
|
100 |
+
model_config.pool = pool
|
101 |
+
|
102 |
+
self.model = __models__[self.label_key](**model_config)
|
103 |
+
if ckpt_path is not None:
|
104 |
+
print('#####################################################################')
|
105 |
+
print(f'load from ckpt "{ckpt_path}"')
|
106 |
+
print('#####################################################################')
|
107 |
+
self.init_from_ckpt(ckpt_path)
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def get_x_noisy(self, x, t, noise=None):
|
111 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
112 |
+
continuous_sqrt_alpha_cumprod = None
|
113 |
+
if self.diffusion_model.use_continuous_noise:
|
114 |
+
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
115 |
+
# todo: make sure t+1 is correct here
|
116 |
+
|
117 |
+
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
118 |
+
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
119 |
+
|
120 |
+
def forward(self, x_noisy, t, *args, **kwargs):
|
121 |
+
return self.model(x_noisy, t)
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def get_input(self, batch, k):
|
125 |
+
x = batch[k]
|
126 |
+
if len(x.shape) == 3:
|
127 |
+
x = x[..., None]
|
128 |
+
x = rearrange(x, 'b h w c -> b c h w')
|
129 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
130 |
+
return x
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def get_conditioning(self, batch, k=None):
|
134 |
+
if k is None:
|
135 |
+
k = self.label_key
|
136 |
+
assert k is not None, 'Needs to provide label key'
|
137 |
+
|
138 |
+
targets = batch[k].to(self.device)
|
139 |
+
|
140 |
+
if self.label_key == 'segmentation':
|
141 |
+
targets = rearrange(targets, 'b h w c -> b c h w')
|
142 |
+
for down in range(self.numd):
|
143 |
+
h, w = targets.shape[-2:]
|
144 |
+
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
145 |
+
|
146 |
+
# targets = rearrange(targets,'b c h w -> b h w c')
|
147 |
+
|
148 |
+
return targets
|
149 |
+
|
150 |
+
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
151 |
+
_, top_ks = torch.topk(logits, k, dim=1)
|
152 |
+
if reduction == "mean":
|
153 |
+
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
154 |
+
elif reduction == "none":
|
155 |
+
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
156 |
+
|
157 |
+
def on_train_epoch_start(self):
|
158 |
+
# save some memory
|
159 |
+
self.diffusion_model.model.to('cpu')
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def write_logs(self, loss, logits, targets):
|
163 |
+
log_prefix = 'train' if self.training else 'val'
|
164 |
+
log = {}
|
165 |
+
log[f"{log_prefix}/loss"] = loss.mean()
|
166 |
+
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
167 |
+
logits, targets, k=1, reduction="mean"
|
168 |
+
)
|
169 |
+
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
170 |
+
logits, targets, k=5, reduction="mean"
|
171 |
+
)
|
172 |
+
|
173 |
+
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
174 |
+
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
175 |
+
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
176 |
+
lr = self.optimizers().param_groups[0]['lr']
|
177 |
+
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
178 |
+
|
179 |
+
def shared_step(self, batch, t=None):
|
180 |
+
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
181 |
+
targets = self.get_conditioning(batch)
|
182 |
+
if targets.dim() == 4:
|
183 |
+
targets = targets.argmax(dim=1)
|
184 |
+
if t is None:
|
185 |
+
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
186 |
+
else:
|
187 |
+
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
188 |
+
x_noisy = self.get_x_noisy(x, t)
|
189 |
+
logits = self(x_noisy, t)
|
190 |
+
|
191 |
+
loss = F.cross_entropy(logits, targets, reduction='none')
|
192 |
+
|
193 |
+
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
194 |
+
|
195 |
+
loss = loss.mean()
|
196 |
+
return loss, logits, x_noisy, targets
|
197 |
+
|
198 |
+
def training_step(self, batch, batch_idx):
|
199 |
+
loss, *_ = self.shared_step(batch)
|
200 |
+
return loss
|
201 |
+
|
202 |
+
def reset_noise_accs(self):
|
203 |
+
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
204 |
+
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
205 |
+
|
206 |
+
def on_validation_start(self):
|
207 |
+
self.reset_noise_accs()
|
208 |
+
|
209 |
+
@torch.no_grad()
|
210 |
+
def validation_step(self, batch, batch_idx):
|
211 |
+
loss, *_ = self.shared_step(batch)
|
212 |
+
|
213 |
+
for t in self.noisy_acc:
|
214 |
+
_, logits, _, targets = self.shared_step(batch, t)
|
215 |
+
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
216 |
+
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
217 |
+
|
218 |
+
return loss
|
219 |
+
|
220 |
+
def configure_optimizers(self):
|
221 |
+
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
222 |
+
|
223 |
+
if self.use_scheduler:
|
224 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
225 |
+
|
226 |
+
print("Setting up LambdaLR scheduler...")
|
227 |
+
scheduler = [
|
228 |
+
{
|
229 |
+
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
230 |
+
'interval': 'step',
|
231 |
+
'frequency': 1
|
232 |
+
}]
|
233 |
+
return [optimizer], scheduler
|
234 |
+
|
235 |
+
return optimizer
|
236 |
+
|
237 |
+
@torch.no_grad()
|
238 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
239 |
+
log = dict()
|
240 |
+
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
241 |
+
log['inputs'] = x
|
242 |
+
|
243 |
+
y = self.get_conditioning(batch)
|
244 |
+
|
245 |
+
if self.label_key == 'class_label':
|
246 |
+
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
247 |
+
log['labels'] = y
|
248 |
+
|
249 |
+
if ismap(y):
|
250 |
+
log['labels'] = self.diffusion_model.to_rgb(y)
|
251 |
+
|
252 |
+
for step in range(self.log_steps):
|
253 |
+
current_time = step * self.log_time_interval
|
254 |
+
|
255 |
+
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
256 |
+
|
257 |
+
log[f'inputs@t{current_time}'] = x_noisy
|
258 |
+
|
259 |
+
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
260 |
+
pred = rearrange(pred, 'b h w c -> b c h w')
|
261 |
+
|
262 |
+
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
263 |
+
|
264 |
+
for key in log:
|
265 |
+
log[key] = log[key][:N]
|
266 |
+
|
267 |
+
return log
|
gligen/ldm/models/diffusion/ddim.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
7 |
+
|
8 |
+
|
9 |
+
class DDIMSampler(object):
|
10 |
+
def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
|
11 |
+
super().__init__()
|
12 |
+
self.diffusion = diffusion
|
13 |
+
self.model = model
|
14 |
+
self.device = diffusion.betas.device
|
15 |
+
self.ddpm_num_timesteps = diffusion.num_timesteps
|
16 |
+
self.schedule = schedule
|
17 |
+
self.alpha_generator_func = alpha_generator_func
|
18 |
+
self.set_alpha_scale = set_alpha_scale
|
19 |
+
|
20 |
+
|
21 |
+
def register_buffer(self, name, attr):
|
22 |
+
if type(attr) == torch.Tensor:
|
23 |
+
attr = attr.to(self.device)
|
24 |
+
setattr(self, name, attr)
|
25 |
+
|
26 |
+
|
27 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.):
|
28 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
29 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=False)
|
30 |
+
alphas_cumprod = self.diffusion.alphas_cumprod
|
31 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
32 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
33 |
+
|
34 |
+
self.register_buffer('betas', to_torch(self.diffusion.betas))
|
35 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
36 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
|
37 |
+
|
38 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
39 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
40 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
41 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
42 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
43 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
44 |
+
|
45 |
+
# ddim sampling parameters
|
46 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
47 |
+
ddim_timesteps=self.ddim_timesteps,
|
48 |
+
eta=ddim_eta,verbose=False)
|
49 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
50 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
51 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
52 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
53 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
54 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
55 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
56 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None):
|
61 |
+
self.make_schedule(ddim_num_steps=S)
|
62 |
+
return self.ddim_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0)
|
63 |
+
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def ddim_sampling(self, shape, input, uc, guidance_scale=1, mask=None, x0=None):
|
67 |
+
b = shape[0]
|
68 |
+
|
69 |
+
img = input["x"]
|
70 |
+
if img == None:
|
71 |
+
img = torch.randn(shape, device=self.device)
|
72 |
+
input["x"] = img
|
73 |
+
|
74 |
+
|
75 |
+
time_range = np.flip(self.ddim_timesteps)
|
76 |
+
total_steps = self.ddim_timesteps.shape[0]
|
77 |
+
|
78 |
+
#iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
79 |
+
iterator = time_range
|
80 |
+
|
81 |
+
if self.alpha_generator_func != None:
|
82 |
+
alphas = self.alpha_generator_func(len(iterator))
|
83 |
+
|
84 |
+
|
85 |
+
for i, step in enumerate(iterator):
|
86 |
+
|
87 |
+
# set alpha
|
88 |
+
if self.alpha_generator_func != None:
|
89 |
+
self.set_alpha_scale(self.model, alphas[i])
|
90 |
+
if alphas[i] == 0:
|
91 |
+
self.model.restore_first_conv_from_SD()
|
92 |
+
|
93 |
+
# run
|
94 |
+
index = total_steps - i - 1
|
95 |
+
input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long)
|
96 |
+
|
97 |
+
if mask is not None:
|
98 |
+
assert x0 is not None
|
99 |
+
img_orig = self.diffusion.q_sample( x0, input["timesteps"] )
|
100 |
+
img = img_orig * mask + (1. - mask) * img
|
101 |
+
input["x"] = img
|
102 |
+
|
103 |
+
img, pred_x0 = self.p_sample_ddim(input, index=index, uc=uc, guidance_scale=guidance_scale)
|
104 |
+
input["x"] = img
|
105 |
+
|
106 |
+
return img
|
107 |
+
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def p_sample_ddim(self, input, index, uc=None, guidance_scale=1):
|
111 |
+
|
112 |
+
|
113 |
+
e_t = self.model(input)
|
114 |
+
if uc is not None and guidance_scale != 1:
|
115 |
+
unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input'])
|
116 |
+
e_t_uncond = self.model( unconditional_input )
|
117 |
+
e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
|
118 |
+
|
119 |
+
# select parameters corresponding to the currently considered timestep
|
120 |
+
b = input["x"].shape[0]
|
121 |
+
a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
|
122 |
+
a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
|
123 |
+
sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
|
124 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
|
125 |
+
|
126 |
+
# current prediction for x_0
|
127 |
+
pred_x0 = (input["x"] - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
128 |
+
|
129 |
+
# direction pointing to x_t
|
130 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
131 |
+
noise = sigma_t * torch.randn_like( input["x"] )
|
132 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
133 |
+
|
134 |
+
return x_prev, pred_x0
|
gligen/ldm/models/diffusion/ddpm.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from functools import partial
|
5 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class DDPM(nn.Module):
|
12 |
+
def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.v_posterior = 0
|
16 |
+
self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
17 |
+
|
18 |
+
|
19 |
+
def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
20 |
+
|
21 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
22 |
+
alphas = 1. - betas
|
23 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
24 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
25 |
+
|
26 |
+
timesteps, = betas.shape
|
27 |
+
self.num_timesteps = int(timesteps)
|
28 |
+
self.linear_start = linear_start
|
29 |
+
self.linear_end = linear_end
|
30 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
31 |
+
|
32 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
33 |
+
|
34 |
+
self.register_buffer('betas', to_torch(betas))
|
35 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
36 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
37 |
+
|
38 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
39 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
40 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
41 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
42 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
43 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
44 |
+
|
45 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
46 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas
|
47 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
48 |
+
|
49 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
50 |
+
|
51 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
52 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
53 |
+
self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
54 |
+
self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
gligen/ldm/models/diffusion/gaussian_smoothing.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numbers
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class GaussianSmoothing(nn.Module):
|
9 |
+
"""
|
10 |
+
Apply gaussian smoothing on a
|
11 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
12 |
+
in the input using a depthwise convolution.
|
13 |
+
Arguments:
|
14 |
+
channels (int, sequence): Number of channels of the input tensors. Output will
|
15 |
+
have this number of channels as well.
|
16 |
+
kernel_size (int, sequence): Size of the gaussian kernel.
|
17 |
+
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
18 |
+
dim (int, optional): The number of dimensions of the data.
|
19 |
+
Default value is 2 (spatial).
|
20 |
+
"""
|
21 |
+
def __init__(self, channels, kernel_size, sigma, dim=2):
|
22 |
+
super(GaussianSmoothing, self).__init__()
|
23 |
+
if isinstance(kernel_size, numbers.Number):
|
24 |
+
kernel_size = [kernel_size] * dim
|
25 |
+
if isinstance(sigma, numbers.Number):
|
26 |
+
sigma = [sigma] * dim
|
27 |
+
|
28 |
+
# The gaussian kernel is the product of the
|
29 |
+
# gaussian function of each dimension.
|
30 |
+
kernel = 1
|
31 |
+
meshgrids = torch.meshgrid(
|
32 |
+
[
|
33 |
+
torch.arange(size, dtype=torch.float32)
|
34 |
+
for size in kernel_size
|
35 |
+
]
|
36 |
+
)
|
37 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
38 |
+
mean = (size - 1) / 2
|
39 |
+
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
40 |
+
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
|
41 |
+
|
42 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
43 |
+
kernel = kernel / torch.sum(kernel)
|
44 |
+
|
45 |
+
# Reshape to depthwise convolutional weight
|
46 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
47 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
48 |
+
|
49 |
+
self.register_buffer('weight', kernel)
|
50 |
+
self.groups = channels
|
51 |
+
|
52 |
+
if dim == 1:
|
53 |
+
self.conv = F.conv1d
|
54 |
+
elif dim == 2:
|
55 |
+
self.conv = F.conv2d
|
56 |
+
elif dim == 3:
|
57 |
+
self.conv = F.conv3d
|
58 |
+
else:
|
59 |
+
raise RuntimeError(
|
60 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, input):
|
64 |
+
"""
|
65 |
+
Apply gaussian filter to input.
|
66 |
+
Arguments:
|
67 |
+
input (torch.Tensor): Input to apply gaussian filter on.
|
68 |
+
Returns:
|
69 |
+
filtered (torch.Tensor): Filtered output.
|
70 |
+
"""
|
71 |
+
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
|
72 |
+
|
73 |
+
|
74 |
+
class AverageSmoothing(nn.Module):
|
75 |
+
"""
|
76 |
+
Apply average smoothing on a
|
77 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
78 |
+
in the input using a depthwise convolution.
|
79 |
+
Arguments:
|
80 |
+
channels (int, sequence): Number of channels of the input tensors. Output will
|
81 |
+
have this number of channels as well.
|
82 |
+
kernel_size (int, sequence): Size of the average kernel.
|
83 |
+
sigma (float, sequence): Standard deviation of the rage kernel.
|
84 |
+
dim (int, optional): The number of dimensions of the data.
|
85 |
+
Default value is 2 (spatial).
|
86 |
+
"""
|
87 |
+
def __init__(self, channels, kernel_size, dim=2):
|
88 |
+
super(AverageSmoothing, self).__init__()
|
89 |
+
|
90 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
91 |
+
kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size)
|
92 |
+
|
93 |
+
# Reshape to depthwise convolutional weight
|
94 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
95 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
96 |
+
|
97 |
+
self.register_buffer('weight', kernel)
|
98 |
+
self.groups = channels
|
99 |
+
|
100 |
+
if dim == 1:
|
101 |
+
self.conv = F.conv1d
|
102 |
+
elif dim == 2:
|
103 |
+
self.conv = F.conv2d
|
104 |
+
elif dim == 3:
|
105 |
+
self.conv = F.conv3d
|
106 |
+
else:
|
107 |
+
raise RuntimeError(
|
108 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, input):
|
112 |
+
"""
|
113 |
+
Apply average filter to input.
|
114 |
+
Arguments:
|
115 |
+
input (torch.Tensor): Input to apply average filter on.
|
116 |
+
Returns:
|
117 |
+
filtered (torch.Tensor): Filtered output.
|
118 |
+
"""
|
119 |
+
return self.conv(input, weight=self.weight, groups=self.groups)
|
gligen/ldm/models/diffusion/ldm.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from ldm.util import default
|
6 |
+
from ldm.modules.diffusionmodules.util import extract_into_tensor
|
7 |
+
from .ddpm import DDPM
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class LatentDiffusion(DDPM):
|
12 |
+
def __init__(self, *args, **kwargs):
|
13 |
+
super().__init__(*args, **kwargs)
|
14 |
+
# hardcoded
|
15 |
+
self.clip_denoised = False
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def q_sample(self, x_start, t, noise=None):
|
20 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
21 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
22 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
23 |
+
|
24 |
+
|
25 |
+
"Does not support DDPM sampling anymore. Only do DDIM or PLMS"
|
26 |
+
|
27 |
+
# = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = #
|
28 |
+
|
29 |
+
# def predict_start_from_noise(self, x_t, t, noise):
|
30 |
+
# return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
31 |
+
# extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise )
|
32 |
+
|
33 |
+
# def q_posterior(self, x_start, x_t, t):
|
34 |
+
# posterior_mean = (
|
35 |
+
# extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
36 |
+
# extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
37 |
+
# )
|
38 |
+
# posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
39 |
+
# posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
40 |
+
# return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
41 |
+
|
42 |
+
|
43 |
+
# def p_mean_variance(self, model, x, c, t):
|
44 |
+
|
45 |
+
# model_out = model(x, t, c)
|
46 |
+
# x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
47 |
+
|
48 |
+
# if self.clip_denoised:
|
49 |
+
# x_recon.clamp_(-1., 1.)
|
50 |
+
|
51 |
+
# model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
52 |
+
# return model_mean, posterior_variance, posterior_log_variance, x_recon
|
53 |
+
|
54 |
+
|
55 |
+
# @torch.no_grad()
|
56 |
+
# def p_sample(self, model, x, c, t):
|
57 |
+
# b, *_, device = *x.shape, x.device
|
58 |
+
# model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, )
|
59 |
+
# noise = torch.randn_like(x)
|
60 |
+
|
61 |
+
# # no noise when t == 0
|
62 |
+
# nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
63 |
+
|
64 |
+
# return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
65 |
+
|
66 |
+
|
67 |
+
# @torch.no_grad()
|
68 |
+
# def p_sample_loop(self, model, shape, c):
|
69 |
+
# device = self.betas.device
|
70 |
+
# b = shape[0]
|
71 |
+
# img = torch.randn(shape, device=device)
|
72 |
+
|
73 |
+
# iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps)
|
74 |
+
# for i in iterator:
|
75 |
+
# ts = torch.full((b,), i, device=device, dtype=torch.long)
|
76 |
+
# img, x0 = self.p_sample(model, img, c, ts)
|
77 |
+
|
78 |
+
# return img
|
79 |
+
|
80 |
+
|
81 |
+
# @torch.no_grad()
|
82 |
+
# def sample(self, model, shape, c, uc=None, guidance_scale=None):
|
83 |
+
# return self.p_sample_loop(model, shape, c)
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
gligen/ldm/models/diffusion/loss.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from ldm.models.diffusion.gaussian_smoothing import GaussianSmoothing
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torchvision.utils import save_image
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def loss_one_att_outside(attn_map,bboxes, object_positions,t):
|
13 |
+
# loss = torch.tensor(0).to('cuda')
|
14 |
+
loss = 0
|
15 |
+
object_number = len(bboxes)
|
16 |
+
b, i, j = attn_map.shape
|
17 |
+
H = W = int(math.sqrt(i))
|
18 |
+
|
19 |
+
|
20 |
+
# if t== 20: import pdb; pdb.set_trace()
|
21 |
+
|
22 |
+
for obj_idx in range(object_number):
|
23 |
+
|
24 |
+
for obj_box in bboxes[obj_idx]:
|
25 |
+
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
|
26 |
+
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
|
27 |
+
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
28 |
+
mask[y_min: y_max, x_min: x_max] = 1.
|
29 |
+
mask_out = 1. - mask
|
30 |
+
index = (mask == 1.).nonzero(as_tuple=False)
|
31 |
+
index_in_key = index[:,0]* H + index[:, 1]
|
32 |
+
att_box = torch.zeros_like(attn_map)
|
33 |
+
att_box[:,index_in_key,:] = attn_map[:,index_in_key,:]
|
34 |
+
|
35 |
+
att_box = att_box.sum(axis=1) / index_in_key.shape[0]
|
36 |
+
att_box = att_box.reshape(-1, H, H)
|
37 |
+
activation_value = (att_box* mask_out).reshape(b, -1).sum(dim=-1) #/ att_box.reshape(b, -1).sum(dim=-1)
|
38 |
+
loss += torch.mean(activation_value)
|
39 |
+
|
40 |
+
return loss / object_number
|
41 |
+
|
42 |
+
def caculate_loss_self_att(self_first, self_second, self_third, bboxes, object_positions, t, list_res=[256], smooth_att = True,sigma=0.5,kernel_size=3 ):
|
43 |
+
all_attn = get_all_self_att(self_first, self_second, self_third)
|
44 |
+
cnt = 0
|
45 |
+
total_loss = 0
|
46 |
+
for res in list_res:
|
47 |
+
attn_maps = all_attn[res]
|
48 |
+
for attn in attn_maps:
|
49 |
+
total_loss += loss_one_att_outside(attn, bboxes, object_positions,t)
|
50 |
+
cnt += 1
|
51 |
+
|
52 |
+
return total_loss /cnt
|
53 |
+
|
54 |
+
|
55 |
+
def get_all_self_att(self_first, self_second, self_third):
|
56 |
+
result = {256:[], 1024:[], 4096:[], 64:[], 94:[],1054:[] ,286:[],4126:[] }
|
57 |
+
# import pdb; pdb.set_trace()
|
58 |
+
all_att = [self_first, self_second, self_third]
|
59 |
+
for self_att in all_att:
|
60 |
+
for att in self_att:
|
61 |
+
if att != []:
|
62 |
+
temp = att[0]
|
63 |
+
for attn_map in temp:
|
64 |
+
current_res = attn_map.shape[1]
|
65 |
+
# print(current_res)
|
66 |
+
result[current_res].append(attn_map)
|
67 |
+
return result
|
68 |
+
|
69 |
+
def get_all_attention(attn_maps_mid, attn_maps_up , attn_maps_down, res):
|
70 |
+
result = []
|
71 |
+
|
72 |
+
for attn_map_integrated in attn_maps_up:
|
73 |
+
if attn_map_integrated == []: continue
|
74 |
+
attn_map = attn_map_integrated[0][0]
|
75 |
+
b, i, j = attn_map.shape
|
76 |
+
H = W = int(math.sqrt(i))
|
77 |
+
# print(H)
|
78 |
+
if H == res:
|
79 |
+
result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
|
80 |
+
for attn_map_integrated in attn_maps_mid:
|
81 |
+
|
82 |
+
# for attn_map_integrated in attn_maps_mid:
|
83 |
+
attn_map = attn_map_integrated[0]
|
84 |
+
b, i, j = attn_map.shape
|
85 |
+
H = W = int(math.sqrt(i))
|
86 |
+
# print(H)
|
87 |
+
if (H==res):
|
88 |
+
result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
|
89 |
+
# import pdb; pdb.set_trace()
|
90 |
+
for attn_map_integrated in attn_maps_down:
|
91 |
+
if attn_map_integrated == []: continue
|
92 |
+
attn_map = attn_map_integrated[0][0]
|
93 |
+
if attn_map == []: continue
|
94 |
+
b, i, j = attn_map.shape
|
95 |
+
H = W = int(math.sqrt(i))
|
96 |
+
# print(H)
|
97 |
+
if (H==res):
|
98 |
+
result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
|
99 |
+
|
100 |
+
result = torch.cat(result, dim=0)
|
101 |
+
result = result.sum(0) / result.shape[0]
|
102 |
+
return result
|
103 |
+
|
104 |
+
|
105 |
+
def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
|
106 |
+
attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
|
107 |
+
# attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
|
108 |
+
# attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
|
109 |
+
# attn8 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 8)
|
110 |
+
all_attn = [attn16]
|
111 |
+
obj_number = len(bboxes)
|
112 |
+
total_loss = 0
|
113 |
+
# import pdb; pdb.set_trace()
|
114 |
+
for attn in all_attn[0:1]:
|
115 |
+
attn_text = attn[:, :, 1:-1]
|
116 |
+
attn_text *= 100
|
117 |
+
attn_text = torch.nn.functional.softmax(attn_text, dim=-1)
|
118 |
+
current_res = attn.shape[0]
|
119 |
+
H = W = current_res
|
120 |
+
|
121 |
+
# if t == 49: import pdb; pdb.set_trace()
|
122 |
+
for obj_idx in range(obj_number):
|
123 |
+
num_boxes= 0
|
124 |
+
|
125 |
+
for obj_position in object_positions[obj_idx]:
|
126 |
+
true_obj_position = obj_position - 1
|
127 |
+
att_map_obj = attn_text[:,:, true_obj_position]
|
128 |
+
if smooth_att:
|
129 |
+
smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
|
130 |
+
input = F.pad(att_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
|
131 |
+
att_map_obj = smoothing(input).squeeze(0).squeeze(0)
|
132 |
+
other_att_map_obj = att_map_obj.clone()
|
133 |
+
att_copy = att_map_obj.clone()
|
134 |
+
|
135 |
+
for obj_box in bboxes[obj_idx]:
|
136 |
+
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
|
137 |
+
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
138 |
+
|
139 |
+
|
140 |
+
if att_map_obj[y_min: y_max, x_min: x_max].numel() == 0:
|
141 |
+
max_inside=1.
|
142 |
+
|
143 |
+
else:
|
144 |
+
max_inside = att_map_obj[y_min: y_max, x_min: x_max].max()
|
145 |
+
total_loss += 1. - max_inside
|
146 |
+
|
147 |
+
# find max outside the box, find in the other boxes
|
148 |
+
|
149 |
+
att_copy[y_min: y_max, x_min: x_max] = 0.
|
150 |
+
other_att_map_obj[y_min: y_max, x_min: x_max] = 0.
|
151 |
+
|
152 |
+
for obj_outside in range(obj_number):
|
153 |
+
if obj_outside != obj_idx:
|
154 |
+
for obj_out_box in bboxes[obj_outside]:
|
155 |
+
x_min_out, y_min_out, x_max_out, y_max_out = int(obj_out_box[0] * W), \
|
156 |
+
int(obj_out_box[1] * H), int(obj_out_box[2] * W), int(obj_out_box[3] * H)
|
157 |
+
|
158 |
+
# att_copy[y_min: y_max, x_min: x_max] = 0.
|
159 |
+
if other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].numel() == 0:
|
160 |
+
max_outside_one= 0
|
161 |
+
else:
|
162 |
+
max_outside_one = other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].max()
|
163 |
+
# max_outside = max(max_outside,max_outside_one )
|
164 |
+
att_copy[y_min_out: y_max_out, x_min_out: x_max_out] = 0.
|
165 |
+
total_loss += max_outside_one
|
166 |
+
max_background = att_copy.max()
|
167 |
+
total_loss += len(bboxes[obj_idx]) *max_background /2.
|
168 |
+
|
169 |
+
return total_loss/obj_number
|
170 |
+
|
gligen/ldm/models/diffusion/plms.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from functools import partial
|
5 |
+
from copy import deepcopy
|
6 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
7 |
+
import math
|
8 |
+
from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att
|
9 |
+
class PLMSSampler(object):
|
10 |
+
def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
|
11 |
+
super().__init__()
|
12 |
+
self.diffusion = diffusion
|
13 |
+
self.model = model
|
14 |
+
self.device = diffusion.betas.device
|
15 |
+
self.ddpm_num_timesteps = diffusion.num_timesteps
|
16 |
+
self.schedule = schedule
|
17 |
+
self.alpha_generator_func = alpha_generator_func
|
18 |
+
self.set_alpha_scale = set_alpha_scale
|
19 |
+
|
20 |
+
def register_buffer(self, name, attr):
|
21 |
+
if type(attr) == torch.Tensor:
|
22 |
+
attr = attr.to(self.device)
|
23 |
+
setattr(self, name, attr)
|
24 |
+
|
25 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
|
26 |
+
if ddim_eta != 0:
|
27 |
+
raise ValueError('ddim_eta must be 0 for PLMS')
|
28 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
29 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
30 |
+
alphas_cumprod = self.diffusion.alphas_cumprod
|
31 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
32 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
33 |
+
|
34 |
+
self.register_buffer('betas', to_torch(self.diffusion.betas))
|
35 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
36 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
|
37 |
+
|
38 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
39 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
40 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
41 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
42 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
43 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
44 |
+
|
45 |
+
# ddim sampling parameters
|
46 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
47 |
+
ddim_timesteps=self.ddim_timesteps,
|
48 |
+
eta=ddim_eta,verbose=verbose)
|
49 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
50 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
51 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
52 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
53 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
54 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
55 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
56 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
57 |
+
|
58 |
+
|
59 |
+
# @torch.no_grad()
|
60 |
+
def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
|
61 |
+
self.make_schedule(ddim_num_steps=S)
|
62 |
+
# import pdb; pdb.set_trace()
|
63 |
+
return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
|
64 |
+
|
65 |
+
|
66 |
+
# @torch.no_grad()
|
67 |
+
def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
|
68 |
+
|
69 |
+
b = shape[0]
|
70 |
+
|
71 |
+
img = input["x"]
|
72 |
+
if img == None:
|
73 |
+
img = torch.randn(shape, device=self.device)
|
74 |
+
input["x"] = img
|
75 |
+
|
76 |
+
time_range = np.flip(self.ddim_timesteps)
|
77 |
+
total_steps = self.ddim_timesteps.shape[0]
|
78 |
+
|
79 |
+
old_eps = []
|
80 |
+
|
81 |
+
if self.alpha_generator_func != None:
|
82 |
+
alphas = self.alpha_generator_func(len(time_range))
|
83 |
+
|
84 |
+
for i, step in enumerate(time_range):
|
85 |
+
|
86 |
+
# set alpha and restore first conv layer
|
87 |
+
if self.alpha_generator_func != None:
|
88 |
+
self.set_alpha_scale(self.model, alphas[i])
|
89 |
+
if alphas[i] == 0:
|
90 |
+
self.model.restore_first_conv_from_SD()
|
91 |
+
|
92 |
+
# run
|
93 |
+
index = total_steps - i - 1
|
94 |
+
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
95 |
+
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long)
|
96 |
+
|
97 |
+
if mask is not None:
|
98 |
+
assert x0 is not None
|
99 |
+
img_orig = self.diffusion.q_sample(x0, ts)
|
100 |
+
img = img_orig * mask + (1. - mask) * img
|
101 |
+
input["x"] = img
|
102 |
+
# three loss types
|
103 |
+
if loss_type !=None and loss_type!='standard':
|
104 |
+
if input['object_position'] != []:
|
105 |
+
if loss_type=='SAR_CAR':
|
106 |
+
x = self.update_loss_self_cross( input,i, index, ts )
|
107 |
+
elif loss_type=='SAR':
|
108 |
+
x = self.update_only_self( input,i, index, ts )
|
109 |
+
elif loss_type=='CAR':
|
110 |
+
x = self.update_loss_only_cross( input,i, index, ts )
|
111 |
+
input["x"] = x
|
112 |
+
img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
|
113 |
+
input["x"] = img
|
114 |
+
old_eps.append(e_t)
|
115 |
+
if len(old_eps) >= 4:
|
116 |
+
old_eps.pop(0)
|
117 |
+
|
118 |
+
return img
|
119 |
+
|
120 |
+
def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
|
121 |
+
if index1 < 10:
|
122 |
+
loss_scale = 3
|
123 |
+
max_iter = 5
|
124 |
+
elif index1 < 20:
|
125 |
+
loss_scale = 2
|
126 |
+
max_iter = 5
|
127 |
+
else:
|
128 |
+
loss_scale = 0.8
|
129 |
+
max_iter = 1
|
130 |
+
|
131 |
+
loss_threshold = 0.1
|
132 |
+
max_index = 20
|
133 |
+
x = deepcopy(input["x"])
|
134 |
+
iteration = 0
|
135 |
+
loss = torch.tensor(10000)
|
136 |
+
input["timesteps"] = ts
|
137 |
+
|
138 |
+
print("optimize", index1)
|
139 |
+
self.model.train()
|
140 |
+
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
141 |
+
print('iter', iteration)
|
142 |
+
# import pdb; pdb.set_trace()
|
143 |
+
x = x.requires_grad_(True)
|
144 |
+
input['x'] = x
|
145 |
+
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
146 |
+
bboxes = input['boxes_att']
|
147 |
+
object_positions = input['object_position']
|
148 |
+
loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
|
149 |
+
object_positions=object_positions, t = index1)*loss_scale
|
150 |
+
loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
|
151 |
+
object_positions=object_positions, t = index1)*loss_scale
|
152 |
+
loss = loss1 + loss2
|
153 |
+
print('loss', loss, loss1, loss2)
|
154 |
+
hh = torch.autograd.backward(loss, retain_graph=True)
|
155 |
+
grad_cond = x.grad
|
156 |
+
x = x - grad_cond
|
157 |
+
x = x.detach()
|
158 |
+
iteration += 1
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
return x
|
161 |
+
|
162 |
+
def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
|
163 |
+
|
164 |
+
if index1 < 10:
|
165 |
+
loss_scale = 3
|
166 |
+
max_iter = 5
|
167 |
+
elif index1 < 20:
|
168 |
+
loss_scale = 2
|
169 |
+
max_iter = 5
|
170 |
+
else:
|
171 |
+
loss_scale = 1
|
172 |
+
max_iter = 1
|
173 |
+
loss_threshold = 0.1
|
174 |
+
|
175 |
+
max_index = 30
|
176 |
+
x = deepcopy(input["x"])
|
177 |
+
iteration = 0
|
178 |
+
loss = torch.tensor(10000)
|
179 |
+
input["timesteps"] = ts
|
180 |
+
|
181 |
+
print("optimize", index1)
|
182 |
+
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
183 |
+
print('iter', iteration)
|
184 |
+
x = x.requires_grad_(True)
|
185 |
+
input['x'] = x
|
186 |
+
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
187 |
+
|
188 |
+
bboxes = input['boxes']
|
189 |
+
object_positions = input['object_position']
|
190 |
+
loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
|
191 |
+
object_positions=object_positions, t = index1)*loss_scale
|
192 |
+
loss = loss2
|
193 |
+
print('loss', loss)
|
194 |
+
hh = torch.autograd.backward(loss)
|
195 |
+
grad_cond = x.grad
|
196 |
+
x = x - grad_cond
|
197 |
+
x = x.detach()
|
198 |
+
iteration += 1
|
199 |
+
torch.cuda.empty_cache()
|
200 |
+
return x
|
201 |
+
|
202 |
+
def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ):
|
203 |
+
if index1 < 10:
|
204 |
+
loss_scale = 4
|
205 |
+
max_iter = 5
|
206 |
+
elif index1 < 20:
|
207 |
+
loss_scale = 3
|
208 |
+
max_iter = 5
|
209 |
+
else:
|
210 |
+
loss_scale = 1
|
211 |
+
max_iter = 1
|
212 |
+
loss_threshold = 0.1
|
213 |
+
|
214 |
+
max_index = 30
|
215 |
+
x = deepcopy(input["x"])
|
216 |
+
iteration = 0
|
217 |
+
loss = torch.tensor(10000)
|
218 |
+
input["timesteps"] = ts
|
219 |
+
|
220 |
+
print("optimize", index1)
|
221 |
+
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
222 |
+
print('iter', iteration)
|
223 |
+
x = x.requires_grad_(True)
|
224 |
+
input['x'] = x
|
225 |
+
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
226 |
+
|
227 |
+
bboxes = input['boxes']
|
228 |
+
object_positions = input['object_position']
|
229 |
+
loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
|
230 |
+
object_positions=object_positions, t = index1)*loss_scale
|
231 |
+
print('loss', loss)
|
232 |
+
hh = torch.autograd.backward(loss)
|
233 |
+
grad_cond = x.grad
|
234 |
+
|
235 |
+
x = x - grad_cond
|
236 |
+
x = x.detach()
|
237 |
+
iteration += 1
|
238 |
+
torch.cuda.empty_cache()
|
239 |
+
return x
|
240 |
+
|
241 |
+
@torch.no_grad()
|
242 |
+
def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
|
243 |
+
x = deepcopy(input["x"])
|
244 |
+
b = x.shape[0]
|
245 |
+
self.model.eval()
|
246 |
+
def get_model_output(input):
|
247 |
+
e_t, first, second, third,_,_,_ = self.model(input)
|
248 |
+
if uc is not None and guidance_scale != 1:
|
249 |
+
unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=None, grounding_extra_input=None)
|
250 |
+
# unconditional_input=input
|
251 |
+
e_t_uncond, _, _, _, _, _, _ = self.model( unconditional_input)
|
252 |
+
e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
|
253 |
+
return e_t
|
254 |
+
|
255 |
+
|
256 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
257 |
+
# select parameters corresponding to the currently considered timestep
|
258 |
+
a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
|
259 |
+
a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
|
260 |
+
sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
|
261 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
|
262 |
+
|
263 |
+
# current prediction for x_0
|
264 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
265 |
+
|
266 |
+
# direction pointing to x_t
|
267 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
268 |
+
noise = sigma_t * torch.randn_like(x)
|
269 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
270 |
+
return x_prev, pred_x0
|
271 |
+
|
272 |
+
input["timesteps"] = t
|
273 |
+
e_t = get_model_output(input)
|
274 |
+
if len(old_eps) == 0:
|
275 |
+
# Pseudo Improved Euler (2nd order)
|
276 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
277 |
+
input["x"] = x_prev
|
278 |
+
input["timesteps"] = t_next
|
279 |
+
e_t_next = get_model_output(input)
|
280 |
+
e_t_prime = (e_t + e_t_next) / 2
|
281 |
+
elif len(old_eps) == 1:
|
282 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
283 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
284 |
+
elif len(old_eps) == 2:
|
285 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
286 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
287 |
+
elif len(old_eps) >= 3:
|
288 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
289 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
290 |
+
|
291 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
292 |
+
|
293 |
+
return x_prev, pred_x0, e_t
|
294 |
+
|
295 |
+
|
gligen/ldm/modules/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (13 kB). View file
|
|
gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc
ADDED
Binary file (18.3 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (188 Bytes). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (167 Bytes). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc
ADDED
Binary file (8.43 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
ADDED
Binary file (20.7 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc
ADDED
Binary file (1.94 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
ADDED
Binary file (13 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc
ADDED
Binary file (1.66 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc
ADDED
Binary file (10.1 kB). View file
|
|
gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
ADDED
Binary file (10.2 kB). View file
|
|