Spaces:
Running
on
A10G
Running
on
A10G
File size: 8,382 Bytes
c25e2cc 3cdacdf 6255790 5e25b83 6255790 c25e2cc 6255790 e79152d 6255790 e79152d 6255790 5e25b83 9b96547 f55706c 27e096e 6255790 7f61c74 b12e6a1 71e0668 b12e6a1 7f61c74 7078734 05f89f0 277aca5 05f89f0 3be3aae 6f62fc3 5a65206 8134c37 48725b2 6255790 277aca5 6f62fc3 8134c37 6f62fc3 6255790 774cc5f 6255790 3489b04 774cc5f 88f076f 3be3aae 88f076f 5e25b83 88f076f 3489b04 5e25b83 6f62fc3 1a248f3 3489b04 5e25b83 3db28ba 6255790 88f076f db50056 660a4aa 3489b04 1ff3548 3489b04 7bb8383 77d316c d64d565 88f076f d64d565 77d316c f948a49 d58e1aa 77d316c 7bb8383 ffe201d 7bb8383 b885715 448a301 ffe201d 448a301 ffe201d d4c9ca7 6f62fc3 5a65206 ffe201d d4c9ca7 88f076f 448a301 7bb8383 fdf34ba 7bb8383 7078734 7bb8383 1df825b 6f62fc3 88f076f 6f62fc3 5a65206 277aca5 da366cb 277aca5 b12e6a1 7c70a52 b12e6a1 7f61c74 3489b04 a93910d b30a076 3489b04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import gradio as gr
import torch
import requests
from io import BytesIO
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
from utils import *
from inversion_utils import *
from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
from torch import autocast, inference_mode
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
# based on the code in https://github.com/inbarhub/DDPM_inversion
# returns wt, zs, wts:
# wt - inverted latent
# wts - intermediate inverted latents
# zs - noise maps
sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
# vae encode image
with autocast("cuda"), inference_mode():
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
# find Zs and wts - forward process
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
return wt, zs, wts
def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
# reverse process (via Zs and wT)
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
# vae decode image
with autocast("cuda"), inference_mode():
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
if x0_dec.dim()<4:
x0_dec = x0_dec[None,:,:,:]
img = image_grid(x0_dec)
return img
# load pipelines
sd_model_id = "runwayml/stable-diffusion-v1-5"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
cache_examples = True
def get_example():
case = [
[
'examples/source_a_man_wearing_a_brown_hoodie_in_a_crowded_street.jpeg',
'a man wearing a brown hoodie in a crowded street',
'a robot wearing a brown hoodie in a crowded street',
'painting',
'examples/ddpm_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png',
'examples/ddpm_sega_painting_of_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png'
]]
return case
def edit(input_image,
src_prompt ="",
tar_prompt="",
steps=100,
# src_cfg_scale,
skip=36,
tar_cfg_scale=15,
edit_concept="",
sega_edit_guidance=0,
# warm_up=1,
# neg_guidance=False,
flip=False,
left = 0,
right = 0,
top = 0,
bottom = 0):
# offsets=(0,0,0,0)
x0 = load_512(input_image, left,right, top, bottom, device)
# invert
# wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
if flip:
wt, zs, wts = torch.flip(wt, [2, 3]),torch.flip(zs, [2, 3]),torch.flip(wts, [2, 3])
latnets = wts[skip].expand(1, -1, -1, -1)
#pure DDPM output
pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
cfg_scale_tar=tar_cfg_scale, skip=skip)
if not edit_concepts or not sega_edit_guidance:
return pure_ddpm_out,pure_ddpm_out
# SEGA
edit_concepts = edit_concept.split(",")
neg_guidance =[]
for edit_concept in edit_concepts:
if edit_concept.startswith("-"):
neg_guidance.append(True)
else:
neg_guidance.append(False)
edit_concepts = [concept.strip("+|-") for concept in edit_concepts]
default_warm_up = [1]*len(edit_concepts)
editing_args = dict(
editing_prompt = edit_concepts,
reverse_editing_direction = neg_guidance,
edit_warmup_steps=default_warm_up,
edit_guidance_scale=[sega_edit_guidance],
edit_threshold=[.93],
edit_momentum_scale=0.5,
edit_mom_beta=0.6
)
sega_out = sem_pipe(prompt=tar_prompt,eta=1, latents=latnets, guidance_scale = tar_cfg_scale,
num_images_per_prompt=1,
num_inference_steps=steps,
use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
yield pure_ddpm_out,sega_out.images[0]
########
# demo #
########
intro = """
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
Edit Friendly DDPM X Semantic Guidance: Editing Real Images
</h1>
<p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
<a href="https://huggingface.co/spaces/LinoyTsaban/ddpm_sega?duplicate=true">
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
<p/>"""
with gr.Blocks() as demo:
gr.HTML(intro)
with gr.Row():
src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True)
tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True)
edit_concept = gr.Textbox(lines=1, label="SEGA Edit Concepts", interactive=True)
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True)
ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False)
sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
input_image.style(height=512, width=512)
ddpm_edited_image.style(height=512, width=512)
sega_edited_image.style(height=512, width=512)
with gr.Row():
with gr.Column(scale=1, min_width=100):
generate_button = gr.Button("Run")
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
#inversion
steps = gr.Number(value=100, precision=0, label="Num Diffusion Steps", interactive=True)
# src_cfg_scale = gr.Number(value=3.5, label=f"Source CFG", interactive=True)
# reconstruction
skip = gr.Slider(minimum=0, maximum=40, value=36, precision=0, label="Skip Steps", interactive=True)
tar_cfg_scale = gr.Slider(minimum=7, maximum=18,value=15, label=f"Guidance Scale", interactive=True)
flip = gr.Checkbox(label="Flip")
left = gr.Number(value=0, precision=0, label="Left Shift", interactive=True)
right = gr.Number(value=0, precision=0, label="Right Shift", interactive=True)
top = gr.Number(value=0, precision=0, label="Top Shift", interactive=True)
bottom = gr.Number(value=0, precision=0, label="Bottom Shift", interactive=True)
# edit
sega_edit_guidance = gr.Slider(value=10, label=f"SEGA Edit Guidance Scale", interactive=True)
# warm_up = gr.Number(value=1, label=f"SEGA Warm-up Steps", interactive=True)
# neg_guidance = gr.Checkbox(label="SEGA Negative Guidance")
# gr.Markdown(help_text)
generate_button.click(
fn=edit,
inputs=[input_image,
src_prompt,
tar_prompt,
steps,
# src_cfg_scale,
skip,
tar_cfg_scale,
edit_concept,
sega_edit_guidance,
# warm_up,
# neg_guidance,
flip,
left,
right,
top,
bottom
],
outputs=[ddpm_edited_image, sega_edited_image],
)
gr.Examples(
label='Examples',
examples=get_example(),
inputs=[input_image, src_prompt, tar_prompt, edit_concept, ddpm_edited_image, sega_edited_image],
outputs=[ddpm_edited_image, sega_edited_image])
demo.queue()
demo.launch(share=False)
|