|
|
|
import argparse, os |
|
|
|
|
|
import torch |
|
import requests |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from io import BytesIO |
|
from tqdm.auto import tqdm |
|
from matplotlib import pyplot as plt |
|
from torchvision import transforms as tfms |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
DDIMScheduler, |
|
DiffusionPipeline, |
|
StableDiffusionXLPipeline, |
|
) |
|
from diffusers.image_processor import VaeImageProcessor |
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from torchvision.utils import save_image |
|
import argparse |
|
import PIL.Image as Image |
|
from torchvision.utils import make_grid |
|
import numpy |
|
from diffusers.schedulers import DDIMScheduler |
|
import torch.nn.functional as F |
|
from models import attn_injection |
|
from omegaconf import OmegaConf |
|
from typing import List, Tuple |
|
|
|
import omegaconf |
|
import utils.exp_utils |
|
import json |
|
|
|
device = "cuda" |
|
|
|
|
|
def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device): |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
|
|
with torch.no_grad(): |
|
prompt_embeds = text_encoder( |
|
text_input_ids.to(device), |
|
output_hidden_states=True, |
|
) |
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
if prompt == "": |
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) |
|
return negative_prompt_embeds, negative_pooled_prompt_embeds |
|
return prompt_embeds, pooled_prompt_embeds |
|
|
|
|
|
def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str): |
|
device = model._execution_device |
|
( |
|
prompt_embeds, |
|
pooled_prompt_embeds, |
|
) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device) |
|
( |
|
prompt_embeds_2, |
|
pooled_prompt_embeds_2, |
|
) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device) |
|
prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1) |
|
text_encoder_projection_dim = model.text_encoder_2.config.projection_dim |
|
add_time_ids = model._get_add_time_ids( |
|
(1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim |
|
).to(device) |
|
|
|
add_time_ids = add_time_ids.repeat(len(prompt), 1) |
|
added_cond_kwargs = { |
|
"text_embeds": pooled_prompt_embeds_2, |
|
"time_ids": add_time_ids, |
|
} |
|
return added_cond_kwargs, prompt_embeds |
|
|
|
|
|
def _encode_text_sdxl_with_negative( |
|
model: StableDiffusionXLPipeline, prompt: List[str] |
|
): |
|
|
|
B = len(prompt) |
|
added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt) |
|
added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl( |
|
model, ["" for _ in range(B)] |
|
) |
|
prompt_embeds = torch.cat( |
|
( |
|
prompt_embeds_uncond, |
|
prompt_embeds, |
|
) |
|
) |
|
added_cond_kwargs = { |
|
"text_embeds": torch.cat( |
|
(added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"]) |
|
), |
|
"time_ids": torch.cat( |
|
(added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"]) |
|
), |
|
} |
|
return added_cond_kwargs, prompt_embeds |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def sample( |
|
pipe, |
|
prompt, |
|
start_step=0, |
|
start_latents=None, |
|
intermediate_latents=None, |
|
guidance_scale=3.5, |
|
num_inference_steps=30, |
|
num_images_per_prompt=1, |
|
do_classifier_free_guidance=True, |
|
negative_prompt="", |
|
device=device, |
|
): |
|
negative_prompt = [""] * len(prompt) |
|
|
|
if isinstance(pipe, StableDiffusionPipeline): |
|
text_embeddings = pipe._encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
) |
|
added_cond_kwargs = None |
|
elif isinstance(pipe, StableDiffusionXLPipeline): |
|
added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( |
|
pipe, prompt |
|
) |
|
|
|
|
|
pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
|
|
|
if start_latents is None: |
|
start_latents = torch.randn(1, 4, 64, 64, device=device) |
|
start_latents *= pipe.scheduler.init_noise_sigma |
|
|
|
latents = start_latents.clone() |
|
|
|
latents = latents.repeat(len(prompt), 1, 1, 1) |
|
|
|
for i in tqdm(range(start_step, num_inference_steps)): |
|
latents[0] = intermediate_latents[(-i + 1)] |
|
t = pipe.scheduler.timesteps[i] |
|
|
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
) |
|
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = pipe.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=text_embeddings, |
|
added_cond_kwargs=added_cond_kwargs, |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
images = pipe.decode_latents(latents) |
|
images = pipe.numpy_to_pil(images) |
|
|
|
return images |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def sample_disentangled( |
|
pipe, |
|
prompt, |
|
start_step=0, |
|
start_latents=None, |
|
intermediate_latents=None, |
|
guidance_scale=3.5, |
|
num_inference_steps=30, |
|
num_images_per_prompt=1, |
|
do_classifier_free_guidance=True, |
|
use_content_anchor=True, |
|
negative_prompt="", |
|
device=device, |
|
): |
|
negative_prompt = [""] * len(prompt) |
|
vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor) |
|
|
|
if isinstance(pipe, StableDiffusionPipeline): |
|
text_embeddings = pipe._encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
) |
|
added_cond_kwargs = None |
|
elif isinstance(pipe, StableDiffusionXLPipeline): |
|
added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( |
|
pipe, prompt |
|
) |
|
|
|
|
|
pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
|
|
|
latent_shape = ( |
|
(1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64) |
|
) |
|
generative_latent = torch.randn(latent_shape, device=device) |
|
generative_latent *= pipe.scheduler.init_noise_sigma |
|
|
|
latents = start_latents.clone() |
|
latents = latents.repeat(len(prompt), 1, 1, 1) |
|
|
|
latents[1] = generative_latent |
|
|
|
num_intermediate_latents = len(intermediate_latents) if intermediate_latents is not None else 0 |
|
|
|
for i in range(start_step, num_inference_steps): |
|
if use_content_anchor and intermediate_latents is not None: |
|
|
|
if -i >= -num_intermediate_latents: |
|
latents[0] = intermediate_latents[-i] |
|
else: |
|
|
|
|
|
latents[0] = intermediate_latents[0] |
|
|
|
t = pipe.scheduler.timesteps[i] |
|
|
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
) |
|
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = pipe.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=text_embeddings, |
|
added_cond_kwargs=added_cond_kwargs, |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
|
|
pipe.vae.to(dtype=torch.float32) |
|
latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype) |
|
latents = 1 / pipe.vae.config.scaling_factor * latents |
|
images = pipe.vae.decode(latents, return_dict=False)[0] |
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
|
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
images = pipe.numpy_to_pil(images) |
|
if isinstance(pipe, StableDiffusionXLPipeline): |
|
pipe.vae.to(dtype=torch.float16) |
|
|
|
return images |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def invert( |
|
pipe, |
|
start_latents, |
|
prompt, |
|
guidance_scale=3.5, |
|
num_inference_steps=50, |
|
num_images_per_prompt=1, |
|
do_classifier_free_guidance=True, |
|
negative_prompt="", |
|
device=device, |
|
): |
|
|
|
|
|
if isinstance(pipe, StableDiffusionPipeline): |
|
text_embeddings = pipe._encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
) |
|
added_cond_kwargs = None |
|
latents = start_latents.clone().detach() |
|
elif isinstance(pipe, StableDiffusionXLPipeline): |
|
added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( |
|
pipe, [prompt] |
|
) |
|
latents = start_latents.clone().detach().half() |
|
|
|
|
|
intermediate_latents = [] |
|
|
|
|
|
pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
|
|
|
timesteps = list(reversed(pipe.scheduler.timesteps)) |
|
|
|
for i in range(num_inference_steps): |
|
if i >= num_inference_steps - 1: |
|
continue |
|
|
|
t = timesteps[i] |
|
|
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
) |
|
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = pipe.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=text_embeddings, |
|
added_cond_kwargs=added_cond_kwargs, |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
current_t = max(0, t.item() - (1000 // num_inference_steps)) |
|
next_t = t |
|
alpha_t = pipe.scheduler.alphas_cumprod[current_t] |
|
alpha_t_next = pipe.scheduler.alphas_cumprod[next_t] |
|
|
|
|
|
latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * ( |
|
alpha_t_next.sqrt() / alpha_t.sqrt() |
|
) + (1 - alpha_t_next).sqrt() * noise_pred |
|
|
|
|
|
intermediate_latents.append(latents) |
|
|
|
return torch.cat(intermediate_latents) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def style_image_with_inversion( |
|
pipe, |
|
input_image, |
|
input_image_prompt, |
|
style_prompt, |
|
num_steps=100, |
|
start_step=30, |
|
guidance_scale=3.5, |
|
disentangle=False, |
|
share_attn=False, |
|
share_cross_attn=False, |
|
share_resnet_layers=[0, 1], |
|
share_attn_layers=[], |
|
c2s_layers=[0, 1], |
|
share_key=True, |
|
share_query=True, |
|
share_value=False, |
|
use_adain=True, |
|
use_content_anchor=True, |
|
output_dir: str = None, |
|
resnet_mode: str = None, |
|
return_intermediate=False, |
|
intermediate_latents=None, |
|
): |
|
with torch.no_grad(): |
|
pipe.vae.to(dtype=torch.float32) |
|
latent = pipe.vae.encode(input_image.to(device) * 2 - 1) |
|
|
|
l = pipe.vae.config.scaling_factor * latent.latent_dist.sample() |
|
if isinstance(pipe, StableDiffusionXLPipeline): |
|
pipe.vae.to(dtype=torch.float16) |
|
if intermediate_latents is None: |
|
inverted_latents = invert( |
|
pipe, l, input_image_prompt, num_inference_steps=num_steps |
|
) |
|
else: |
|
inverted_latents = intermediate_latents |
|
|
|
attn_injection.register_attention_processors( |
|
pipe, |
|
base_dir=output_dir, |
|
resnet_mode=resnet_mode, |
|
attn_mode="artist" if disentangle else "pnp", |
|
disentangle=disentangle, |
|
share_resblock=True, |
|
share_attn=share_attn, |
|
share_cross_attn=share_cross_attn, |
|
share_resnet_layers=share_resnet_layers, |
|
share_attn_layers=share_attn_layers, |
|
share_key=share_key, |
|
share_query=share_query, |
|
share_value=share_value, |
|
use_adain=use_adain, |
|
c2s_layers=c2s_layers, |
|
) |
|
|
|
if disentangle: |
|
final_im = sample_disentangled( |
|
pipe, |
|
style_prompt, |
|
start_latents=inverted_latents[-(start_step + 1)][None], |
|
intermediate_latents=inverted_latents, |
|
start_step=start_step, |
|
num_inference_steps=num_steps, |
|
guidance_scale=guidance_scale, |
|
use_content_anchor=use_content_anchor, |
|
) |
|
else: |
|
final_im = sample( |
|
pipe, |
|
style_prompt, |
|
start_latents=inverted_latents[-(start_step + 1)][None], |
|
intermediate_latents=inverted_latents, |
|
start_step=start_step, |
|
num_inference_steps=num_steps, |
|
guidance_scale=guidance_scale, |
|
) |
|
|
|
|
|
attn_injection.unset_attention_processors( |
|
pipe, |
|
unset_share_attn=True, |
|
unset_share_resblock=True, |
|
) |
|
if return_intermediate: |
|
return final_im, inverted_latents |
|
return final_im |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf") |
|
parser.add_argument( |
|
"--config", type=str, default="config.yaml", help="Path to the config file" |
|
) |
|
parser.add_argument( |
|
"--mode", |
|
type=str, |
|
default="dataset", |
|
choices=["dataset", "cli", "app"], |
|
help="Path to the config file", |
|
) |
|
parser.add_argument( |
|
"--image_dir", type=str, default="test.png", help="Path to the image" |
|
) |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
default="an impressionist painting", |
|
help="Stylization prompt", |
|
) |
|
|
|
args = parser.parse_args() |
|
config_dir = args.config |
|
mode = args.mode |
|
|
|
out_name = ["content_delegation", "style_delegation", "style_out"] |
|
|
|
if mode == "app": |
|
|
|
import gradio as gr |
|
import spaces |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base" |
|
).to(device) |
|
|
|
|
|
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
|
|
@spaces.GPU() |
|
def style_transfer_app( |
|
prompt, |
|
image, |
|
cfg_scale=7.5, |
|
num_content_layers=4, |
|
num_style_layers=9, |
|
seed=0, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
utils.exp_utils.seed_all(seed) |
|
image = utils.exp_utils.process_image(image, device, 512) |
|
|
|
tgt_prompt = prompt |
|
src_prompt = "" |
|
prompt_in = [ |
|
"", |
|
tgt_prompt, |
|
"", |
|
] |
|
|
|
share_resnet_layers = ( |
|
list(range(num_content_layers)) if num_content_layers != 0 else None |
|
) |
|
share_attn_layers = ( |
|
list(range(num_style_layers)) if num_style_layers != 0 else None |
|
) |
|
imgs = style_image_with_inversion( |
|
pipe, |
|
image, |
|
src_prompt, |
|
style_prompt=prompt_in, |
|
num_steps=50, |
|
start_step=0, |
|
guidance_scale=cfg_scale, |
|
disentangle=True, |
|
resnet_mode="hidden", |
|
share_attn=True, |
|
share_cross_attn=True, |
|
share_resnet_layers=share_resnet_layers, |
|
share_attn_layers=share_attn_layers, |
|
share_key=True, |
|
share_query=True, |
|
share_value=False, |
|
use_content_anchor=True, |
|
use_adain=True, |
|
output_dir="./", |
|
) |
|
|
|
return imgs[2] |
|
|
|
|
|
examples = [] |
|
annotation = json.load(open("data/example/annotation.json")) |
|
for entry in annotation: |
|
image = utils.exp_utils.get_processed_image( |
|
entry["image_path"], device, 512 |
|
) |
|
image = transforms.ToPILImage()(image[0]) |
|
|
|
examples.append([entry["target_prompt"], image, None, None, None]) |
|
|
|
text_input = gr.Textbox( |
|
value="An impressionist painting", |
|
label="Text Prompt", |
|
info="Describe the style you want to apply to the image, do not include the description of the image content itself", |
|
lines=2, |
|
placeholder="Enter a text prompt", |
|
) |
|
image_input = gr.Image( |
|
height="80%", |
|
width="80%", |
|
label="Content image (will be resized to 512x512)", |
|
interactive=True, |
|
) |
|
cfg_slider = gr.Slider( |
|
0, |
|
15, |
|
value=7.5, |
|
label="Classifier Free Guidance (CFG) Scale", |
|
info="higher values give more style, 7.5 should be good for most cases", |
|
) |
|
content_slider = gr.Slider( |
|
0, |
|
9, |
|
value=4, |
|
step=1, |
|
label="Number of content control layer", |
|
info="higher values make it more similar to original image. Default to control first 4 layers", |
|
) |
|
style_slider = gr.Slider( |
|
0, |
|
9, |
|
value=9, |
|
step=1, |
|
label="Number of style control layer", |
|
info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.", |
|
) |
|
seed_slider = gr.Slider( |
|
0, |
|
100, |
|
value=0, |
|
step=1, |
|
label="Seed", |
|
info="Random seed for the model", |
|
) |
|
app = gr.Interface( |
|
fn=style_transfer_app, |
|
inputs=[ |
|
text_input, |
|
image_input, |
|
cfg_slider, |
|
content_slider, |
|
style_slider, |
|
seed_slider, |
|
], |
|
outputs=["image"], |
|
title="Artist Interactive Demo", |
|
examples=examples, |
|
cache_examples=False |
|
) |
|
app.launch(show_api=False, show_error=True) |
|
|