|
from __future__ import annotations |
|
|
|
import math |
|
import random |
|
import sys |
|
from argparse import ArgumentParser |
|
|
|
from tqdm.auto import trange |
|
import einops |
|
import gradio as gr |
|
import k_diffusion as K |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from omegaconf import OmegaConf |
|
from PIL import Image, ImageOps, ImageFilter |
|
from torch import autocast |
|
import cv2 |
|
import imageio |
|
import spaces |
|
|
|
sys.path.append("./stable_diffusion") |
|
|
|
from stable_diffusion.ldm.util import instantiate_from_config |
|
|
|
class CFGDenoiser(nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.inner_model = model |
|
|
|
def forward(self, z_0, z_1, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): |
|
cfg_z_0 = einops.repeat(z_0, "1 ... -> n ...", n=3) |
|
cfg_z_1 = einops.repeat(z_1, "1 ... -> n ...", n=3) |
|
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) |
|
cfg_cond = { |
|
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], |
|
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], |
|
} |
|
output_0, output_1 = self.inner_model(cfg_z_0, cfg_z_1, cfg_sigma, cond=cfg_cond) |
|
out_cond_0, out_img_cond_0, out_uncond_0 = output_0.chunk(3) |
|
out_cond_1, _, _ = output_1.chunk(3) |
|
return out_uncond_0 + text_cfg_scale * (out_cond_0 - out_img_cond_0) + image_cfg_scale * (out_img_cond_0 - out_uncond_0), \ |
|
out_cond_1 |
|
|
|
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): |
|
print(f"Loading model from {ckpt}") |
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
if "global_step" in pl_sd: |
|
print(f"Global Step: {pl_sd['global_step']}") |
|
sd = pl_sd["state_dict"] |
|
if vae_ckpt is not None: |
|
print(f"Loading VAE from {vae_ckpt}") |
|
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] |
|
sd = { |
|
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v |
|
for k, v in sd.items() |
|
} |
|
model = instantiate_from_config(config.model) |
|
m, u = model.load_state_dict(sd, strict=True) |
|
if len(m) > 0 and verbose: |
|
print("missing keys:") |
|
print(m) |
|
if len(u) > 0 and verbose: |
|
print("unexpected keys:") |
|
print(u) |
|
return model |
|
|
|
def append_dims(x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
class CompVisDenoiser(K.external.CompVisDenoiser): |
|
def __init__(self, model, quantize=False, device='cpu'): |
|
super().__init__(model, quantize, device) |
|
|
|
def get_eps(self, *args, **kwargs): |
|
return self.inner_model.apply_model(*args, **kwargs) |
|
|
|
def forward(self, input_0, input_1, sigma, **kwargs): |
|
print("input_0.device:", input_0.device) |
|
print("input_1.device:", input_1.device) |
|
c_out, c_in = [append_dims(x, input_0.ndim) for x in self.get_scalings(sigma)] |
|
print("c_in.device:", c_in.device) |
|
print("c_out.device:", c_out.device) |
|
print("sigma.device:", sigma.device) |
|
|
|
eps_0, eps_1 = self.get_eps(input_0 * c_in, self.sigma_to_t(sigma.cpu()).cuda(), **kwargs) |
|
|
|
return input_0 + eps_0 * c_out, eps_1 |
|
|
|
def to_d(x, sigma, denoised): |
|
"""Converts a denoiser output to a Karras ODE derivative.""" |
|
return (x - denoised) / append_dims(sigma, x.ndim) |
|
|
|
def default_noise_sampler(x): |
|
return lambda sigma, sigma_next: torch.randn_like(x) |
|
|
|
def get_ancestral_step(sigma_from, sigma_to, eta=1.): |
|
"""Calculates the noise level (sigma_down) to step down to and the amount |
|
of noise to add (sigma_up) when doing an ancestral sampling step.""" |
|
if not eta: |
|
return sigma_to, 0. |
|
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) |
|
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 |
|
return sigma_down, sigma_up |
|
|
|
def decode_mask(mask, height = 256, width = 256): |
|
mask = nn.functional.interpolate(mask, size=(height, width), mode="bilinear", align_corners=False) |
|
mask = torch.where(mask > 0, 1, -1) |
|
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) |
|
mask = 255.0 * rearrange(mask, "1 c h w -> h w c") |
|
mask = torch.cat([mask, mask, mask], dim=-1) |
|
mask = mask.type(torch.uint8).cpu().numpy() |
|
return mask |
|
|
|
def sample_euler_ancestral(model, x_0, x_1, sigmas, height, width, extra_args=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
|
"""Ancestral sampling with Euler method steps.""" |
|
extra_args = {} if extra_args is None else extra_args |
|
noise_sampler = default_noise_sampler(x_0) if noise_sampler is None else noise_sampler |
|
s_in = x_0.new_ones([x_0.shape[0]]) |
|
|
|
mask_list = [] |
|
image_list = [] |
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
denoised_0, denoised_1 = model(x_0, x_1, sigmas[i] * s_in, **extra_args) |
|
image_list.append(denoised_0) |
|
|
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) |
|
d_0 = to_d(x_0, sigmas[i], denoised_0) |
|
|
|
|
|
dt = sigma_down - sigmas[i] |
|
x_0 = x_0 + d_0 * dt |
|
|
|
if sigmas[i + 1] > 0: |
|
x_0 = x_0 + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
|
|
|
x_1 = denoised_1 |
|
mask_list.append(decode_mask(x_1, height, width)) |
|
|
|
image_list = torch.cat(image_list, dim=0) |
|
|
|
return x_0, x_1, image_list, mask_list |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--resolution", default=512, type=int) |
|
parser.add_argument("--config", default="configs/generate_diffree.yaml", type=str) |
|
parser.add_argument("--ckpt", default="checkpoints/epoch=000041-step=000010999.ckpt", type=str) |
|
parser.add_argument("--vae-ckpt", default=None, type=str) |
|
args = parser.parse_args() |
|
|
|
config = OmegaConf.load(args.config) |
|
model = load_model_from_config(config, args.ckpt, args.vae_ckpt) |
|
model.eval().cuda() |
|
model_wrap = CompVisDenoiser(model) |
|
model_wrap_cfg = CFGDenoiser(model_wrap) |
|
null_token = model.get_learned_conditioning([""]) |
|
|
|
@spaces.GPU(duration=20) |
|
def generate( |
|
input_image: Image.Image, |
|
instruction: str, |
|
steps: int, |
|
randomize_seed: bool, |
|
seed: int, |
|
randomize_cfg: bool, |
|
text_cfg_scale: float, |
|
image_cfg_scale: float, |
|
weather_hide_video: bool, |
|
decode_image_batch: int |
|
): |
|
seed = random.randint(0, 100000) if randomize_seed else seed |
|
text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale |
|
image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale |
|
|
|
width, height = input_image.size |
|
factor = args.resolution / max(width, height) |
|
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) |
|
width = int((width * factor) // 64) * 64 |
|
height = int((height * factor) // 64) * 64 |
|
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) |
|
input_image_copy = input_image.convert("RGB") |
|
|
|
if instruction == "": |
|
return [input_image, seed] |
|
|
|
model.cuda() |
|
print("model.device:", model.device) |
|
with torch.no_grad(), autocast("cuda"), model.ema_scope(): |
|
cond = {} |
|
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)] |
|
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 |
|
input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) |
|
cond["c_concat"] = [model.encode_first_stage(input_image).mode().to(model.device)] |
|
|
|
uncond = {} |
|
uncond["c_crossattn"] = [null_token.to(model.device)] |
|
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] |
|
|
|
print("cond['c_crossattn'][0].device:", cond["c_crossattn"][0].device) |
|
print("cond['c_concat'][0].device:", cond["c_concat"][0].device) |
|
print("uncond['c_crossattn'][0].device:", uncond["c_crossattn"][0].device) |
|
print("uncond['c_concat'][0].device:", uncond["c_concat"][0].device) |
|
|
|
sigmas = model_wrap.get_sigmas(steps) |
|
|
|
extra_args = { |
|
"cond": cond, |
|
"uncond": uncond, |
|
"text_cfg_scale": text_cfg_scale, |
|
"image_cfg_scale": image_cfg_scale, |
|
} |
|
torch.manual_seed(seed) |
|
z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
|
z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
|
print("z_0.device:", z_0.device) |
|
print("z_1.device:", z_1.device) |
|
|
|
z_0, z_1, image_list, mask_list = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args) |
|
|
|
x_0 = model.decode_first_stage(z_0) |
|
|
|
if model.first_stage_downsample: |
|
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False) |
|
x_1 = torch.where(x_1 > 0, 1, -1) |
|
else: |
|
x_1 = model.decode_first_stage(z_1) |
|
|
|
x_0 = torch.clamp((x_0 + 1.0) / 2.0, min=0.0, max=1.0) |
|
x_1 = torch.clamp((x_1 + 1.0) / 2.0, min=0.0, max=1.0) |
|
x_0 = 255.0 * rearrange(x_0, "1 c h w -> h w c") |
|
x_1 = 255.0 * rearrange(x_1, "1 c h w -> h w c") |
|
x_1 = torch.cat([x_1, x_1, x_1], dim=-1) |
|
edited_image = Image.fromarray(x_0.type(torch.uint8).cpu().numpy()) |
|
edited_mask = Image.fromarray(x_1.type(torch.uint8).cpu().numpy()) |
|
|
|
|
|
image_video_path = None |
|
if not weather_hide_video: |
|
image_video = [] |
|
|
|
for i in range(0, len(image_list), decode_image_batch): |
|
if i + decode_image_batch < len(image_list): |
|
tmp_image_list = image_list[i:i+decode_image_batch] |
|
else: |
|
tmp_image_list = image_list[i:] |
|
tmp_image_list = model.decode_first_stage(tmp_image_list) |
|
tmp_image_list = torch.clamp((tmp_image_list + 1.0) / 2.0, min=0.0, max=1.0) |
|
tmp_image_list = 255.0 * rearrange(tmp_image_list, "b c h w -> b h w c") |
|
tmp_image_list = tmp_image_list.type(torch.uint8).cpu().numpy() |
|
|
|
for image in tmp_image_list: |
|
image_video.append(image) |
|
|
|
image_video_path = "image.mp4" |
|
fps = 30 |
|
with imageio.get_writer(image_video_path, fps=fps) as video: |
|
for image in image_video: |
|
video.append_data(image) |
|
|
|
|
|
edited_mask_copy = edited_mask.copy() |
|
kernel = np.ones((3, 3), np.uint8) |
|
edited_mask = cv2.dilate(np.array(edited_mask), kernel, iterations=3) |
|
edited_mask = Image.fromarray(edited_mask) |
|
|
|
|
|
m_img = edited_mask.filter(ImageFilter.GaussianBlur(radius=3)) |
|
m_img = np.asarray(m_img).astype('float') / 255.0 |
|
img_np = np.asarray(input_image_copy).astype('float') / 255.0 |
|
ours_np = np.asarray(edited_image).astype('float') / 255.0 |
|
|
|
mix_image_np = m_img * ours_np + (1 - m_img) * img_np |
|
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB') |
|
|
|
|
|
red = np.array(mix_image).astype('float') * 1 |
|
red[:, :, 0] = 180.0 |
|
red[:, :, 2] = 0 |
|
red[:, :, 1] = 0 |
|
mix_result_with_red_mask = np.array(mix_image) |
|
mix_result_with_red_mask = Image.fromarray( |
|
(mix_result_with_red_mask.astype('float') * (1 - m_img.astype('float') / 2.0) + |
|
m_img.astype('float') / 2.0 * red).astype('uint8')) |
|
|
|
|
|
|
|
mask_video_path = "mask.mp4" |
|
fps = 30 |
|
with imageio.get_writer(mask_video_path, fps=fps) as video: |
|
for image in mask_list: |
|
video.append_data(image) |
|
|
|
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask] |
|
|
|
def reset(): |
|
return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Hide Image Video", 10] |
|
|
|
def get_example(): |
|
return [ |
|
["test/dufu.png", "black and white suit", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/girl.jpeg", "reflective sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/road_sign.png", "stop sign", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/dufu.png", "blue medical mask", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/people_standing.png", "dark green pleated skirt", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/girl.jpeg", "shiny golden crown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/dufu.png", "sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/girl.jpeg", "diamond necklace", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/iron_man.jpg", "sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/girl.jpeg", "the queen's crown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
["test/girl.jpeg", "gorgeous yellow gown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
|
] |
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
with gr.Row(): |
|
gr.Markdown( |
|
"<div align='center'><font size='14'>Diffree: Text-Guided Shape Free Object Inpainting with Diffusion Model</font></div>" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=100): |
|
with gr.Row(): |
|
input_image = gr.Image(label="Input Image", type="pil", interactive=True) |
|
with gr.Row(): |
|
instruction = gr.Textbox(lines=1, label="Object description", interactive=True) |
|
with gr.Row(): |
|
steps = gr.Number(value=100, precision=0, label="Steps", interactive=True) |
|
randomize_seed = gr.Radio( |
|
["Fix Seed", "Randomize Seed"], |
|
value="Randomize Seed", |
|
type="index", |
|
label="Seed Selection", |
|
show_label=False, |
|
interactive=True, |
|
) |
|
seed = gr.Number(value=1372, precision=0, label="Seed", interactive=True) |
|
randomize_cfg = gr.Radio( |
|
["Fix CFG", "Randomize CFG"], |
|
value="Fix CFG", |
|
type="index", |
|
label="CFG Selection", |
|
show_label=False, |
|
interactive=True, |
|
) |
|
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True) |
|
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True) |
|
with gr.Row(): |
|
reset_button = gr.Button("Reset") |
|
generate_button = gr.Button("Generate") |
|
with gr.Column(scale=1, min_width=100): |
|
with gr.Column(): |
|
mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False) |
|
with gr.Column(): |
|
edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False) |
|
|
|
|
|
with gr.Accordion('More outputs', open=False): |
|
with gr.Row(): |
|
|
|
weather_hide_video = gr.Radio( |
|
["Show Image Video", "Hide Image Video"], |
|
value="Hide Image Video", |
|
type="index", |
|
label="Dynamic Process", |
|
show_label=False, |
|
interactive=True, |
|
) |
|
|
|
decode_image_batch = gr.Number(value=10, precision=0, label="Decode Image Batch (<steps)", interactive=True) |
|
with gr.Row(): |
|
image_video = gr.Video(label="Real-time Image Output") |
|
mask_video = gr.Video(label="Real-time Mask Output") |
|
with gr.Row(): |
|
original_image = gr.Image(label=f"Original Image", type="pil", interactive=False) |
|
edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False) |
|
mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=get_example(), |
|
fn=generate, |
|
inputs=[input_image, instruction, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale], |
|
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask], |
|
cache_examples=False, |
|
) |
|
|
|
generate_button.click( |
|
fn=generate, |
|
inputs=[ |
|
input_image, |
|
instruction, |
|
steps, |
|
randomize_seed, |
|
seed, |
|
randomize_cfg, |
|
text_cfg_scale, |
|
image_cfg_scale, |
|
weather_hide_video, |
|
decode_image_batch |
|
], |
|
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask], |
|
) |
|
reset_button.click( |
|
fn=reset, |
|
inputs=[], |
|
outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask, weather_hide_video, decode_image_batch], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch() |
|
|