import gradio as gr from io import BytesIO import requests import PIL from PIL import Image import numpy as np import os import uuid import torch from torch import autocast import cv2 from matplotlib import pyplot as plt from inpainting import StableDiffusionInpaintingPipeline from torchvision import transforms auth_token = os.environ.get("API_TOKEN") or True def download_image(url): response = requests.get(url) return PIL.Image.open(BytesIO(response.content)).convert("RGB") device = "cuda" if torch.cuda.is_available() else "cpu" pipe = StableDiffusionInpaintingPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token, ).to(device) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((512, 512)), ]) def predict(dict, prompt=""): with autocast("cuda"): init_image = dict["image"].convert("RGB").resize((512, 512)) mask = dict["mask"].convert("RGB").resize((512, 512)) images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"] return images[0] examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]] css = ''' .container {max-width: 1150px;margin: auto;padding-top: 1.5rem} #image_upload{min-height:400px} #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px} #mask_radio .gr-form{background:transparent; border: none} #word_mask{margin-top: .75em !important} #word_mask textarea:disabled{opacity: 0.3} .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5} .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white} .dark .footer {border-color: #303030} .dark .footer>p {background: #0b0f19} .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%} #image_upload .touch-none{display: flex} ''' image_blocks = gr.Blocks(css=css) with image_blocks as demo: gr.HTML(read_content("header.html")) with gr.Group(): with gr.Box(): image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400) with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): prompt = gr.Textbox(label = 'Your prompt (what you want to replace)') btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), full_width=False, ) ex = gr.Examples(fn=predict, inputs=[image, prompt], outputs=image, cache_examples=False) ex.dataset.headers = [""] btn.click(fn=predict, inputs=[image, prompt], outputs=image) gr.HTML( """