import gradio as gr import numpy as np from PIL import Image, ImageDraw, ImageFont from collections import Counter import math from gradio import processing_utils from typing import Optional import warnings from datetime import datetime import torch from PIL import Image from diffusers import StableDiffusionInpaintPipeline from accelerate.utils import set_seed class Instance: def __init__(self, capacity = 2): self.model_type = 'base' self.loaded_model_list = {} self.counter = Counter() self.global_counter = Counter() self.capacity = capacity self.loaded_model = None def _log(self, model_type, batch_size, instruction, phrase_list): self.counter[model_type] += 1 self.global_counter[model_type] += 1 current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format( current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list )) def get_model(self): if self.loaded_model is None: self.loaded_model = self.load_model() return self.loaded_model def load_model(self, model_id='j-min/IterInpaint-CLEVR'): pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id) def dummy(images, **kwargs): return images, False pipe.safety_checker = dummy print("Disabled safety checker") print("Loaded model") # This command loads the individual model components on GPU on-demand. So, we don't # need to explicitly call pipe.to("cuda"). pipe.enable_model_cpu_offload() # xformers pipe.enable_xformers_memory_efficient_attention() return pipe instance = Instance() instance.load_model() from gen_utils import encode_from_custom_annotation, iterinpaint_sample_diffusers class ImageMask(gr.components.Image): """ Sets: source="canvas", tool="sketch" """ is_template = True def __init__(self, **kwargs): super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) def preprocess(self, x): if x is None: return x if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict: decode_image = processing_utils.decode_base64_to_image(x) width, height = decode_image.size mask = np.zeros((height, width, 4), dtype=np.uint8) mask[..., -1] = 255 mask = self.postprocess(mask) x = {'image': x, 'mask': mask} return super().preprocess(x) class Blocks(gr.Blocks): def __init__( self, theme: str = "default", analytics_enabled: Optional[bool] = None, mode: str = "blocks", title: str = "Gradio", css: Optional[str] = None, **kwargs, ): self.extra_configs = { 'thumbnail': kwargs.pop('thumbnail', ''), 'url': kwargs.pop('url', 'https://gradio.app/'), 'creator': kwargs.pop('creator', '@teamGradio'), } super(Blocks, self).__init__( theme, analytics_enabled, mode, title, css, **kwargs) warnings.filterwarnings("ignore") def get_config_file(self): config = super(Blocks, self).get_config_file() for k, v in self.extra_configs.items(): config[k] = v return config def draw_box(boxes=[], texts=[], img=None): if len(boxes) == 0 and img is None: return None if img is None: img = Image.new('RGB', (512, 512), (255, 255, 255)) colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] draw = ImageDraw.Draw(img) font = ImageFont.truetype("DejaVuSansMono.ttf", size=20) for bid, box in enumerate(boxes): draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) anno_text = texts[bid] draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255)) return img def get_concat(ims): if len(ims) == 1: n_col = 1 else: n_col = 2 n_row = math.ceil(len(ims) / 2) dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white") for i, im in enumerate(ims): row_id = i // n_col col_id = i % n_col dst.paste(im, (im.width * col_id, im.height * row_id)) return dst def inference(language_instruction, grounding_texts, boxes, guidance_scale): # custom_annotations = [ # {'x': 19, # 'y': 61, # 'width': 158, # 'height': 169, # 'label': 'blue metal cube'}, # {'x': 183, # 'y': 94, # 'width': 103, # 'height': 109, # 'label': 'brown rubber sphere'}, # ] # # boxes - normalized -> unnormalized # boxes = np.array(boxes) * 512 custom_annotations = [] for i in range(len(boxes)): box = boxes[i] custom_annotations.append({'x': box[0], 'y': box[1], 'width': box[2] - box[0], 'height': box[3] - box[1], 'label': grounding_texts[i]}) # # 1) convert xywh to xyxy # # 2) normalize coordinates scene = encode_from_custom_annotation(custom_annotations, size=512) print(scene['box_captions']) print(scene['boxes_normalized']) pipe = instance.get_model() out = iterinpaint_sample_diffusers( pipe, scene, paste=True, verbose=True, size=512, guidance_scale=guidance_scale) final_image = out['generated_images'][-1].copy() # Create Generation GIF prompts = out['prompts'] fps = 4 def create_gif_source_images(images, prompts): """Create source images for gif Each frame consists of a intermediate image with a prompt as title. Don't change size of the original images. """ step_images = [] font = ImageFont.truetype("DejaVuSansMono.ttf", size=20) for i, img in enumerate(images): draw = ImageDraw.Draw(img) draw.text((0, 0), prompts[i], (255, 255, 255), font=font) step_images.append(img) return step_images import imageio step_images = create_gif_source_images(out['generated_images'], prompts) print("Number of frames in GIF: {}".format(len(step_images))) # create temp path import tempfile import os gif_save_path = os.path.join(tempfile.gettempdir(), 'gen.gif') # create gif imageio.mimsave(gif_save_path, step_images, fps=fps) print('GIF saved to {}'.format(gif_save_path)) out_images = [ final_image, gif_save_path ] return out_images def generate(task, language_instruction, grounding_texts, sketch_pad, alpha_sample, guidance_scale, batch_size, fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image, state): if 'boxes' not in state: state['boxes'] = [] boxes = state['boxes'] grounding_texts = [x.strip() for x in grounding_texts.split(';')] # assert len(boxes) == len(grounding_texts) if len(boxes) != len(grounding_texts): if len(boxes) < len(grounding_texts): raise ValueError("""The number of boxes should be equal to the number of grounding objects. Number of boxes drawn: {}, number of grounding tokens: {}. Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts))) grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts)) # # normalize boxes # boxes = (np.asarray(boxes) / 512).tolist() print('input boxes: ', boxes) print('input grounding_texts: ', grounding_texts) print('input language instruction: ', language_instruction) if fix_seed: set_seed(rand_seed) print('seed set to: ', rand_seed) gen_image, gen_animation = inference( language_instruction, grounding_texts, boxes, guidance_scale=guidance_scale, ) # for idx, gen_image in enumerate(gen_images): # if task == 'Grounded Inpainting' and state.get('inpaint_hw', None): # hw = min(*state['original_image'].shape[:2]) # gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw) # gen_image = Image.fromarray(gen_image) # gen_images[idx] = gen_image # blank_samples = batch_size % 2 if batch_size > 1 else 0 # gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \ # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ # + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] # gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \ # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ gen_images = [ gr.Image.update(value=gen_image, visible=True), gr.Image.update(value=gen_animation, visible=True) ] return gen_images + [state] def binarize(x): return (x != 0).astype('uint8') * 255 def sized_center_crop(img, cropx, cropy): y, x = img.shape[:2] startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) return img[starty:starty+cropy, startx:startx+cropx] def sized_center_fill(img, fill, cropx, cropy): y, x = img.shape[:2] startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) img[starty:starty+cropy, startx:startx+cropx] = fill return img def sized_center_mask(img, cropx, cropy): y, x = img.shape[:2] startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) center_region = img[starty:starty+cropy, startx:startx+cropx].copy() img = (img * 0.2).astype('uint8') img[starty:starty+cropy, startx:startx+cropx] = center_region return img def center_crop(img, HW=None, tgt_size=(512, 512)): if HW is None: H, W = img.shape[:2] HW = min(H, W) img = sized_center_crop(img, HW, HW) img = Image.fromarray(img) img = img.resize(tgt_size) return np.array(img) def draw(task, input, grounding_texts, new_image_trigger, state): if type(input) == dict: image = input['image'] mask = input['mask'] else: mask = input if mask.ndim == 3: mask = mask[..., 0] image_scale = 1.0 # resize trigger if task == "Grounded Inpainting": mask_cond = mask.sum() == 0 # size_cond = mask.shape != (512, 512) if mask_cond and 'original_image' not in state: image = Image.fromarray(image) width, height = image.size scale = 600 / min(width, height) image = image.resize((int(width * scale), int(height * scale))) state['original_image'] = np.array(image).copy() image_scale = float(height / width) return [None, new_image_trigger + 1, image_scale, state] else: original_image = state['original_image'] H, W = original_image.shape[:2] image_scale = float(H / W) mask = binarize(mask) if mask.shape != (512, 512): # assert False, "should not receive any non- 512x512 masks." if 'original_image' in state and state['original_image'].shape[:2] == mask.shape: mask = center_crop(mask, state['inpaint_hw']) image = center_crop(state['original_image'], state['inpaint_hw']) else: mask = np.zeros((512, 512), dtype=np.uint8) # mask = center_crop(mask) mask = binarize(mask) if type(mask) != np.ndarray: mask = np.array(mask) if mask.sum() == 0 and task != "Grounded Inpainting": state = {} if task != 'Grounded Inpainting': image = None else: image = Image.fromarray(image) if 'boxes' not in state: state['boxes'] = [] if 'masks' not in state or len(state['masks']) == 0: state['masks'] = [] last_mask = np.zeros_like(mask) else: last_mask = state['masks'][-1] if type(mask) == np.ndarray and mask.size > 1: diff_mask = mask - last_mask else: diff_mask = np.zeros([]) if diff_mask.sum() > 0: x1x2 = np.where(diff_mask.max(0) != 0)[0] y1y2 = np.where(diff_mask.max(1) != 0)[0] y1, y2 = y1y2.min(), y1y2.max() x1, x2 = x1x2.min(), x1x2.max() if (x2 - x1 > 5) and (y2 - y1 > 5): state['masks'].append(mask.copy()) state['boxes'].append((x1, y1, x2, y2)) grounding_texts = [x.strip() for x in grounding_texts.split(';')] grounding_texts = [x for x in grounding_texts if len(x) > 0] if len(grounding_texts) < len(state['boxes']): grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))] box_image = draw_box(state['boxes'], grounding_texts, image) if box_image is not None and state.get('inpaint_hw', None): inpaint_hw = state['inpaint_hw'] box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw))) original_image = state['original_image'].copy() box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw) return [box_image, new_image_trigger, image_scale, state] def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False): if task != 'Grounded Inpainting': sketch_pad_trigger = sketch_pad_trigger + 1 blank_samples = batch_size % 2 if batch_size > 1 else 0 # out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \ # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ # + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] out_images = [gr.Image.update(value=None, visible=True) for i in range(1)] \ + [gr.Image.update(value=None, visible=True) for _ in range(1)] state = {} return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] css = """ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img { height: var(--height) !important; max-height: var(--height) !important; min-height: var(--height) !important; } #paper-info a { color:#008AD7; text-decoration: none; } #paper-info a:hover { cursor: pointer; text-decoration: none; } """ rescale_js = """ function(x) { const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app'); let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0; const image_width = root.querySelector('#img2img_image').clientWidth; const target_height = parseInt(image_width * image_scale); document.body.style.setProperty('--height', `${target_height}px`); root.querySelectorAll('button.justify-center.rounded')[0].style.display='none'; root.querySelectorAll('button.justify-center.rounded')[1].style.display='none'; return x; } """ with Blocks( # css=css, analytics_enabled=False, title="IterInpaint demo", ) as main: description = """
IterInpaint CLEVR Demo
[Project Page]
[Paper]
[GitHub]
(1) ⌨️ Enter the object names in Region Captions
(2) 🖱️ Draw their corresponding bounding boxes one by one using Sketch Pad -- the parsed boxes will be displayed automatically.
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.