import gc import math import sys from IPython import display import torch from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision import utils as tv_utils from torchvision.transforms import functional as TF import gradio as gr from git.repo.base import Repo from os.path import exists as path_exists if not (path_exists(f"v-diffusion-pytorch")): Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch") if not (path_exists(f"CLIP")): Repo.clone_from("https://github.com/openai/CLIP", "CLIP") sys.path.append('v-diffusion-pytorch') from huggingface_hub import hf_hub_download from CLIP import clip from diffusion import get_model, sampling, utils class MakeCutouts(nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow def forward(self, input): sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) cutouts = [] for _ in range(self.cutn): size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) offsetx = torch.randint(0, sideX - size + 1, ()) offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] cutout = F.adaptive_avg_pool2d(cutout, self.cut_size) cutouts.append(cutout) return torch.cat(cutouts) def spherical_dist_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth") model = get_model('cc12m_1_cfg')() _, side_y, side_x = model.shape model.load_state_dict(torch.load(cc12m_model, map_location='cpu')) model = model.half().cuda().eval().requires_grad_(False) clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0] clip_model.eval().requires_grad_(False) normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.) def run_all(prompt, steps, n_images, weight, clip_guided): import random seed = int(random.randint(0, 2147483647)) target_embed = clip_model.encode_text(clip.tokenize(prompt).to('cuda')).float()#.cuda() if(clip_guided): prompts = [prompt] target_embeds, weights = [], [] def parse_prompt(prompt): if prompt.startswith('http://') or prompt.startswith('https://'): vals = prompt.rsplit(':', 2) vals = [vals[0] + ':' + vals[1], *vals[2:]] else: vals = prompt.rsplit(':', 1) vals = vals + ['', '1'][len(vals):] return vals[0], float(vals[1]) for prompt in prompts: txt, weight = parse_prompt(prompt) target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float()) weights.append(weight) target_embeds = torch.cat(target_embeds) weights = torch.tensor(weights, device='cuda') if weights.sum().abs() < 1e-3: raise RuntimeError('The weights must not sum to 0.') weights /= weights.sum().abs() clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1) clip_embed = target_embed.repeat([n_images, 1]) def cfg_model_fn(x, t): """The CFG wrapper function.""" n = x.shape[0] x_in = x.repeat([2, 1, 1, 1]) t_in = t.repeat([2]) clip_embed_repeat = target_embed.repeat([n, 1]) clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat]) v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0) v = v_uncond + (v_cond - v_uncond) * weight return v def make_cond_model_fn(model, cond_fn): def cond_model_fn(x, t, **extra_args): with torch.enable_grad(): x = x.detach().requires_grad_() v = model(x, t, **extra_args) alphas, sigmas = utils.t_to_alpha_sigma(t) pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None] cond_grad = cond_fn(x, t, pred, **extra_args).detach() v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None]) return v return cond_model_fn def cond_fn(x, t, pred, clip_embed): if min(pred.shape[2:4]) < 256: pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False) clip_in = normalize(make_cutouts((pred + 1) / 2)) image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1]) losses = spherical_dist_loss(image_embeds, clip_embed[None]) loss = losses.mean(0).sum() * 500. grad = -torch.autograd.grad(loss, x)[0] return grad gc.collect() torch.cuda.empty_cache() torch.manual_seed(seed) x = torch.randn([n_images, 3, side_y, side_x], device='cuda') t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1] #step_list = utils.get_spliced_ddpm_cosine_schedule(t) if model.min_t == 0: step_list = utils.get_spliced_ddpm_cosine_schedule(t) else: step_list = utils.get_ddpm_schedule(t) if(not clip_guided): outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback) else: extra_args = {'clip_embed': clip_embed} cond_fn_ = cond_fn model_fn = make_cond_model_fn(model, cond_fn_) outs = sampling.plms_sample(model_fn, x, step_list, extra_args) images_out = [] for i, out in enumerate(outs): images_out.append(utils.to_pil_image(out)) return(images_out) ##################### START GRADIO HERE ############################ #image = gr.outputs.Image(type="pil", label="Your result") gallery = gr.Gallery(css={"height": "256px","width":"256px"}) iface = gr.Interface( fn=run_all, inputs=[ gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"), gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1), gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1), gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1), gr.inputs.Checkbox(label="CLIP Guided - improves coherence with prompt, makes it slower"), ], outputs=gallery, title="Generate images from text with V-Diffusion CC12M CFG", description="