Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from PIL import ImageDraw | |
def encode_scene(obj_list, H=320, W=320, src_bbox_format='xywh', tgt_bbox_format='xyxy'): | |
"""Encode scene into text and bounding boxes | |
Args: | |
obj_list: list of dicts | |
Each dict has keys: | |
'color': str | |
'material': str | |
'shape': str | |
or | |
'caption': str | |
and | |
'bbox': list of 4 floats (unnormalized) | |
[x0, y0, x1, y1] or [x0, y0, w, h] | |
""" | |
box_captions = [] | |
for obj in obj_list: | |
if 'caption' in obj: | |
box_caption = obj['caption'] | |
else: | |
box_caption = f"{obj['color']} {obj['material']} {obj['shape']}" | |
box_captions += [box_caption] | |
assert src_bbox_format in ['xywh', 'xyxy'], f"src_bbox_format must be 'xywh' or 'xyxy', not {src_bbox_format}" | |
assert tgt_bbox_format in ['xywh', 'xyxy'], f"tgt_bbox_format must be 'xywh' or 'xyxy', not {tgt_bbox_format}" | |
boxes_unnormalized = [] | |
boxes_normalized = [] | |
for obj in obj_list: | |
if src_bbox_format == 'xywh': | |
x0, y0, w, h = obj['bbox'] | |
x1 = x0 + w | |
y1 = y0 + h | |
elif src_bbox_format == 'xyxy': | |
x0, y0, x1, y1 = obj['bbox'] | |
w = x1 - x0 | |
h = y1 - y0 | |
assert x1 > x0, f"x1={x1} <= x0={x0}" | |
assert y1 > y0, f"y1={y1} <= y0={y0}" | |
assert x1 <= W, f"x1={x1} > W={W}" | |
assert y1 <= H, f"y1={y1} > H={H}" | |
if tgt_bbox_format == 'xywh': | |
bbox_unnormalized = [x0, y0, w, h] | |
bbox_normalized = [x0 / W, y0 / H, w / W, h / H] | |
elif tgt_bbox_format == 'xyxy': | |
bbox_unnormalized = [x0, y0, x1, y1] | |
bbox_normalized = [x0 / W, y0 / H, x1 / W, y1 / H] | |
boxes_unnormalized += [bbox_unnormalized] | |
boxes_normalized += [bbox_normalized] | |
assert len(box_captions) == len(boxes_normalized), f"len(box_captions)={len(box_captions)} != len(boxes_normalized)={len(boxes_normalized)}" | |
out = {} | |
out['box_captions'] = box_captions | |
out['boxes_normalized'] = boxes_normalized | |
out['boxes_unnormalized'] = boxes_unnormalized | |
return out | |
def encode_from_custom_annotation(custom_annotations, size=512): | |
# custom_annotations = [ | |
# {'x': 83, 'y': 335, 'width': 70, 'height': 69, 'label': 'blue metal cube'}, | |
# {'x': 162, 'y': 302, 'width': 110, 'height': 138, 'label': 'blue metal cube'}, | |
# {'x': 274, 'y': 250, 'width': 191, 'height': 234, 'label': 'blue metal cube'}, | |
# {'x': 14, 'y': 18, 'width': 155, 'height': 205, 'label': 'blue metal cube'}, | |
# {'x': 175, 'y': 79, 'width': 106, 'height': 119, 'label': 'blue metal cube'}, | |
# {'x': 288, 'y': 111, 'width': 69, 'height': 63, 'label': 'blue metal cube'} | |
# ] | |
H, W = size, size | |
objects = [] | |
for j in range(len(custom_annotations)): | |
xyxy = [ | |
custom_annotations[j]['x'], | |
custom_annotations[j]['y'], | |
custom_annotations[j]['x'] + custom_annotations[j]['width'], | |
custom_annotations[j]['y'] + custom_annotations[j]['height']] | |
objects.append({ | |
'caption': custom_annotations[j]['label'], | |
'bbox': xyxy, | |
}) | |
out = encode_scene(objects, H=H, W=W, | |
src_bbox_format='xyxy', tgt_bbox_format='xyxy') | |
return out | |
#### Below are for HF diffusers | |
def iterinpaint_sample_diffusers(pipe, datum, paste=True, verbose=False, guidance_scale=4.0, size=512, background_instruction='Add gray background'): | |
d = datum | |
d['unnormalized_boxes'] = d['boxes_unnormalized'] | |
n_total_boxes = len(d['unnormalized_boxes']) | |
context_imgs = [] | |
mask_imgs = [] | |
# masked_imgs = [] | |
generated_images = [] | |
prompts = [] | |
context_img = Image.new('RGB', (size, size)) | |
# context_draw = ImageDraw.Draw(context_img) | |
if verbose: | |
print('Initiailzed context image') | |
background_mask_img = Image.new('L', (size, size)) | |
background_mask_draw = ImageDraw.Draw(background_mask_img) | |
background_mask_draw.rectangle([(0, 0), background_mask_img.size], fill=255) | |
for i in range(n_total_boxes): | |
if verbose: | |
print('Iter: ', i+1, 'total: ', n_total_boxes) | |
target_caption = d['box_captions'][i] | |
if verbose: | |
print('Drawing ', target_caption) | |
mask_img = Image.new('L', context_img.size) | |
mask_draw = ImageDraw.Draw(mask_img) | |
mask_draw.rectangle([(0, 0), mask_img.size], fill=0) | |
box = d['unnormalized_boxes'][i] | |
if type(box) == list: | |
box = torch.tensor(box) | |
mask_draw.rectangle(box.long().tolist(), fill=255) | |
background_mask_draw.rectangle(box.long().tolist(), fill=0) | |
mask_imgs.append(mask_img.copy()) | |
prompt = f"Add {d['box_captions'][i]}" | |
if verbose: | |
print('prompt:', prompt) | |
prompts += [prompt] | |
context_imgs.append(context_img.copy()) | |
generated_image = pipe( | |
prompt, | |
context_img, | |
mask_img, | |
guidance_scale=guidance_scale).images[0] | |
if paste: | |
# context_img.paste(generated_image.crop(box.long().tolist()), box.long().tolist()) | |
src_box = box.long().tolist() | |
# x1 -> x1 + 1 | |
# y1 -> y1 + 1 | |
paste_box = box.long().tolist() | |
paste_box[0] -= 1 | |
paste_box[1] -= 1 | |
paste_box[2] += 1 | |
paste_box[3] += 1 | |
box_w = paste_box[2] - paste_box[0] | |
box_h = paste_box[3] - paste_box[1] | |
context_img.paste(generated_image.crop(src_box).resize((box_w, box_h)), paste_box) | |
generated_images.append(context_img.copy()) | |
else: | |
context_img = generated_image | |
generated_images.append(context_img.copy()) | |
if verbose: | |
print('Fill background') | |
mask_img = background_mask_img | |
mask_imgs.append(mask_img) | |
prompt = background_instruction | |
if verbose: | |
print('prompt:', prompt) | |
prompts += [prompt] | |
generated_image = pipe( | |
prompt, | |
context_img, | |
mask_img, | |
guidance_scale=guidance_scale).images[0] | |
generated_images.append(generated_image) | |
return { | |
'context_imgs': context_imgs, | |
'mask_imgs': mask_imgs, | |
'prompts': prompts, | |
'generated_images': generated_images, | |
} |