haodongli's picture
init
916b126
from typing import List, Union
import torch
from PIL import Image
from transformers import (
CLIPProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers import StableDiffusionPipeline
from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
import os
import glob
import math
EXAMPLE_PROMPTS = [
"<obj> swimming in a pool",
"<obj> at a beach with a view of seashore",
"<obj> in times square",
"<obj> wearing sunglasses",
"<obj> in a construction outfit",
"<obj> playing with a ball",
"<obj> wearing headphones",
"<obj> oil painting ghibli inspired",
"<obj> working on the laptop",
"<obj> with mountains and sunset in background",
"Painting of <obj> at a beach by artist claude monet",
"<obj> digital painting 3d render geometric style",
"A screaming <obj>",
"A depressed <obj>",
"A sleeping <obj>",
"A sad <obj>",
"A joyous <obj>",
"A frowning <obj>",
"A sculpture of <obj>",
"<obj> near a pool",
"<obj> at a beach with a view of seashore",
"<obj> in a garden",
"<obj> in grand canyon",
"<obj> floating in ocean",
"<obj> and an armchair",
"A maple tree on the side of <obj>",
"<obj> and an orange sofa",
"<obj> with chocolate cake on it",
"<obj> with a vase of rose flowers on it",
"A digital illustration of <obj>",
"Georgia O'Keeffe style <obj> painting",
"A watercolor painting of <obj> on a beach",
]
def image_grid(_imgs, rows=None, cols=None):
if rows is None and cols is None:
rows = cols = math.ceil(len(_imgs) ** 0.5)
if rows is None:
rows = math.ceil(len(_imgs) / cols)
if cols is None:
cols = math.ceil(len(_imgs) / rows)
w, h = _imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(_imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def text_img_alignment(img_embeds, text_embeds, target_img_embeds):
# evaluation inspired from textual inversion paper
# https://arxiv.org/abs/2208.01618
# text alignment
assert img_embeds.shape[0] == text_embeds.shape[0]
text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / (
img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1)
)
# image alignment
img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True)
avg_target_img_embed = (
(target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True))
.mean(dim=0)
.unsqueeze(0)
.repeat(img_embeds.shape[0], 1)
)
img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1)
return {
"text_alignment_avg": text_img_sim.mean().item(),
"image_alignment_avg": img_img_sim.mean().item(),
"text_alignment_all": text_img_sim.tolist(),
"image_alignment_all": img_img_sim.tolist(),
}
def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"):
text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id)
tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id)
vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id)
processor = CLIPProcessor.from_pretrained(eval_clip_id)
return text_model, tokenizer, vis_model, processor
def evaluate_pipe(
pipe,
target_images: List[Image.Image],
class_token: str = "",
learnt_token: str = "",
guidance_scale: float = 5.0,
seed=0,
clip_model_sets=None,
eval_clip_id: str = "openai/clip-vit-large-patch14",
n_test: int = 10,
n_step: int = 50,
):
if clip_model_sets is not None:
text_model, tokenizer, vis_model, processor = clip_model_sets
else:
text_model, tokenizer, vis_model, processor = prepare_clip_model_sets(
eval_clip_id
)
images = []
img_embeds = []
text_embeds = []
for prompt in EXAMPLE_PROMPTS[:n_test]:
prompt = prompt.replace("<obj>", learnt_token)
torch.manual_seed(seed)
with torch.autocast("cuda"):
img = pipe(
prompt, num_inference_steps=n_step, guidance_scale=guidance_scale
).images[0]
images.append(img)
# image
inputs = processor(images=img, return_tensors="pt")
img_embed = vis_model(**inputs).image_embeds
img_embeds.append(img_embed)
prompt = prompt.replace(learnt_token, class_token)
# prompts
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
outputs = text_model(**inputs)
text_embed = outputs.text_embeds
text_embeds.append(text_embed)
# target images
inputs = processor(images=target_images, return_tensors="pt")
target_img_embeds = vis_model(**inputs).image_embeds
img_embeds = torch.cat(img_embeds, dim=0)
text_embeds = torch.cat(text_embeds, dim=0)
return text_img_alignment(img_embeds, text_embeds, target_img_embeds)
def visualize_progress(
path_alls: Union[str, List[str]],
prompt: str,
model_id: str = "runwayml/stable-diffusion-v1-5",
device="cuda:0",
patch_unet=True,
patch_text=True,
patch_ti=True,
unet_scale=1.0,
text_sclae=1.0,
num_inference_steps=50,
guidance_scale=5.0,
offset: int = 0,
limit: int = 10,
seed: int = 0,
):
imgs = []
if isinstance(path_alls, str):
alls = list(set(glob.glob(path_alls)))
alls.sort(key=os.path.getmtime)
else:
alls = path_alls
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)
print(f"Found {len(alls)} checkpoints")
for path in alls[offset:limit]:
print(path)
patch_pipe(
pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti
)
tune_lora_scale(pipe.unet, unet_scale)
tune_lora_scale(pipe.text_encoder, text_sclae)
torch.manual_seed(seed)
image = pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
imgs.append(image)
return imgs