|
import open_clip |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
|
|
from tqdm.auto import tqdm |
|
from PIL import Image, ImageColor |
|
from torchvision import transforms |
|
from diffusers import DDIMScheduler, DDPMPipeline |
|
|
|
|
|
device = ( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else "cuda" |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
|
|
pipeline_name = "alkzar90/sd-class-ukiyo-e-256" |
|
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) |
|
|
|
|
|
scheduler = DDIMScheduler.from_pretrained(pipeline_name) |
|
scheduler.set_timesteps(num_inference_steps=40) |
|
|
|
|
|
|
|
|
|
|
|
def color_loss(images, target_color=(0.1, 0.9, 0.5)): |
|
"""Given a target color (R, G, B) return a loss for how far away on average |
|
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)""" |
|
target = ( |
|
torch.tensor(target_color).to(images.device) * 2 - 1 |
|
) |
|
target = target[ |
|
None, :, None, None |
|
] |
|
error = torch.abs( |
|
images - target |
|
).mean() |
|
return error |
|
|
|
|
|
|
|
|
|
|
|
clip_model, _, preprocess = open_clip.create_model_and_transforms( |
|
"ViT-B-32", pretrained="openai" |
|
) |
|
clip_model.to(device) |
|
|
|
|
|
tfms = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop(224), |
|
transforms.RandomAffine( |
|
5 |
|
), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.Normalize( |
|
mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711), |
|
), |
|
] |
|
) |
|
|
|
|
|
|
|
def clip_loss(image, text_features): |
|
image_features = clip_model.encode_image( |
|
tfms(image) |
|
) |
|
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) |
|
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) |
|
dists = ( |
|
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) |
|
) |
|
return dists.mean() |
|
|
|
|
|
|
|
|
|
|
|
def generate(color, |
|
color_loss_scale, |
|
num_examples=4, |
|
seed=None, |
|
prompt=None, |
|
prompt_loss_scale=None, |
|
prompt_n_cuts=None, |
|
inference_steps=50, |
|
): |
|
scheduler.set_timesteps(num_inference_steps=inference_steps) |
|
|
|
if seed: |
|
torch.manual_seed(seed) |
|
|
|
if prompt: |
|
text = open_clip.tokenize([prompt]).to(device) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
text_features = clip_model.encode_text(text) |
|
|
|
target_color = ImageColor.getcolor(color, "RGB") |
|
target_color = [a / 255 for a in target_color] |
|
|
|
x = torch.randn(num_examples, 3, 256, 256).to(device) |
|
|
|
for i, t in tqdm(enumerate(scheduler.timesteps)): |
|
model_input = scheduler.scale_model_input(x, t) |
|
with torch.no_grad(): |
|
noise_pred = image_pipe.unet(model_input, t)["sample"] |
|
x = x.detach().requires_grad_() |
|
x0 = scheduler.step(noise_pred, t, x).pred_original_sample |
|
|
|
|
|
loss = color_loss(x0, target_color) * color_loss_scale |
|
cond_color_grad = -torch.autograd.grad(loss, x)[0] |
|
|
|
x_cond = x.detach() + cond_color_grad |
|
|
|
|
|
|
|
if prompt: |
|
cond_prompt_grad = 0 |
|
for cut in range(prompt_n_cuts): |
|
|
|
x = x.detach().requires_grad_() |
|
|
|
x0 = scheduler.step(noise_pred, t, x).pred_original_sample |
|
|
|
prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale |
|
|
|
cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts |
|
|
|
alpha_bar = scheduler.alphas_cumprod[i] |
|
x_cond = ( |
|
x_cond + cond_prompt_grad * alpha_bar.sqrt() |
|
) |
|
|
|
|
|
x = scheduler.step(noise_pred, t, x_cond).prev_sample |
|
grid = torchvision.utils.make_grid(x, nrow=4) |
|
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 |
|
im = Image.fromarray(np.array(im * 255).astype(np.uint8)) |
|
|
|
return im |
|
|
|
|
|
|
|
|
|
|
|
TITLE="Ukiyo-e postal generator service 🎴!" |
|
DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo" |
|
CSS = ".output-image, .input-image, .image-preview {height: 250px !important}" |
|
|
|
|
|
inputs = [ |
|
gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), |
|
gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7), |
|
gr.Slider(label="num_examples (# images generated)", minimum=2, maximum=12, value=2, step=4), |
|
gr.Number(label="seed (reproducibility and experimentation)", value=666), |
|
gr.Text(label="Text prompt (optional)", value=None), |
|
gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10), |
|
gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4), |
|
gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1), |
|
] |
|
|
|
outputs = gr.Image(label="result") |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=inputs, |
|
outputs=outputs, |
|
css=CSS, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title=TITLE, |
|
description=DESCRIPTION, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(enable_queue=True) |
|
|