|
import gradio as gr |
|
import torch, torchvision |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from PIL import Image, ImageColor |
|
from diffusers import DDPMPipeline |
|
from diffusers import DDIMScheduler |
|
|
|
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
pipeline_name = 'Skier8402/ddpm-celebahq-finetuned-butterflies-16epochs' |
|
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) |
|
|
|
|
|
scheduler = DDIMScheduler.from_pretrained(pipeline_name) |
|
scheduler.set_timesteps(num_inference_steps=20) |
|
|
|
|
|
def color_loss(images, target_color=(0.1, 0.9, 0.5)): |
|
"""Given a target color (R, G, B) return a loss for how far away on average |
|
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5) """ |
|
target = torch.tensor(target_color).to(images.device) * 2 - 1 |
|
target = target[None, :, None, None] |
|
error = torch.abs(images - target).mean() |
|
return error |
|
|
|
|
|
def generate(color, guidance_loss_scale): |
|
target_color = ImageColor.getcolor(color, "RGB") |
|
target_color = [a/255 for a in target_color] |
|
x = torch.randn(1, 3, 256, 256).to(device) |
|
for i, t in enumerate(scheduler.timesteps): |
|
model_input = scheduler.scale_model_input(x, t) |
|
with torch.no_grad(): |
|
noise_pred = image_pipe.unet(model_input, t)["sample"] |
|
x = x.detach().requires_grad_() |
|
x0 = scheduler.step(noise_pred, t, x).pred_original_sample |
|
loss = color_loss(x0, target_color) * guidance_loss_scale |
|
cond_grad = -torch.autograd.grad(loss, x)[0] |
|
x = x.detach() + cond_grad |
|
x = scheduler.step(noise_pred, t, x).prev_sample |
|
grid = torchvision.utils.make_grid(x, nrow=4) |
|
im = grid.permute(1, 2, 0).cpu().clip(-1, 1)*0.5 + 0.5 |
|
im = Image.fromarray(np.array(im*255).astype(np.uint8)) |
|
im.save('test.jpeg') |
|
return im |
|
|
|
|
|
inputs = [ |
|
gr.ColorPicker(label="color", value='55FFAA'), |
|
gr.Slider(label="guidance_scale", minimum=0, maximum=30, value=3) |
|
] |
|
outputs = gr.Image(label="result") |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=inputs, |
|
outputs=outputs, |
|
examples=[ |
|
["#BB2266", 3],["#44CCAA", 5] |
|
], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(enable_queue=True) |
|
|