Spaces:
Runtime error
Runtime error
import gradio as gr | |
from git.repo.base import Repo | |
from os.path import exists as path_exists | |
if not (path_exists(f"stylegan_xl")): | |
Repo.clone_from("https://github.com/autonomousvision/stylegan_xl", "stylegan_xl") | |
import sys | |
sys.path.append('./CLIP') | |
sys.path.append('./stylegan_xl') | |
import io | |
import os, time, glob | |
import pickle | |
import shutil | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import requests | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
import clip | |
import unicodedata | |
import re | |
from tqdm import tqdm | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
from IPython.display import display | |
from einops import rearrange | |
import dnnlib | |
import legacy | |
import subprocess | |
torch.cuda.empty_cache() | |
device = torch.device('cuda:0') | |
print('Using device:', device, file=sys.stderr) | |
def fetch(url_or_path): | |
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): | |
r = requests.get(url_or_path) | |
r.raise_for_status() | |
fd = io.BytesIO() | |
fd.write(r.content) | |
fd.seek(0) | |
return fd | |
return open(url_or_path, 'rb') | |
def fetch_model(url_or_path,network_name): | |
print(network_name) | |
torch.hub.download_url_to_file(f'{url_or_path}',f'{network_name}') | |
print(os.listdir()) | |
def slugify(value, allow_unicode=False): | |
""" | |
Taken from https://github.com/django/django/blob/master/django/utils/text.py | |
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated | |
dashes to single dashes. Remove characters that aren't alphanumerics, | |
underscores, or hyphens. Convert to lowercase. Also strip leading and | |
trailing whitespace, dashes, and underscores. | |
""" | |
value = str(value) | |
if allow_unicode: | |
value = unicodedata.normalize('NFKC', value) | |
else: | |
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') | |
value = re.sub(r'[^\w\s-]', '', value.lower()) | |
return re.sub(r'[-\s]+', '-', value).strip('-_') | |
def norm1(prompt): | |
"Normalize to the unit sphere." | |
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt() | |
def spherical_dist_loss(x, y): | |
x = F.normalize(x, dim=-1) | |
y = F.normalize(y, dim=-1) | |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
def prompts_dist_loss(x, targets, loss): | |
if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance | |
return loss(x, targets[0]) | |
distances = [loss(x, target) for target in targets] | |
return torch.stack(distances, dim=-1).sum(dim=-1) | |
class MakeCutouts(torch.nn.Module): | |
def __init__(self, cut_size, cutn, cut_pow=1.): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
self.cut_pow = cut_pow | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
max_size = min(sideX, sideY) | |
min_size = min(sideX, sideY, self.cut_size) | |
cutouts = [] | |
for _ in range(self.cutn): | |
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
offsetx = torch.randint(0, sideX - size + 1, ()) | |
offsety = torch.randint(0, sideY - size + 1, ()) | |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
return torch.cat(cutouts) | |
make_cutouts = MakeCutouts(224, 32, 0.5) | |
def embed_image(image): | |
n = image.shape[0] | |
cutouts = make_cutouts(image) | |
embeds = clip_model.embed_cutout(cutouts) | |
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n) | |
return embeds | |
def embed_url(url): | |
image = Image.open(fetch(url)).convert('RGB') | |
return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0) | |
class CLIP(object): | |
def __init__(self): | |
clip_model = "ViT-B/32" | |
self.model, _ = clip.load(clip_model) | |
self.model = self.model.requires_grad_(False) | |
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
def embed_text(self, prompt): | |
"Normalized clip text embedding." | |
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float()) | |
def embed_cutout(self, image): | |
"Normalized clip image embedding." | |
return norm1(self.model.encode_image(self.normalize(image))) | |
clip_model = CLIP() | |
#@markdown #**Model selection** 🎭 | |
Models = ["imagenet256", "Pokemon", "FFHQ"] | |
#@markdown --- | |
network_url = { | |
"imagenet256":"https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl", | |
#"Imagenet512": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl", | |
#"Imagenet1024": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl", | |
"Pokemon": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl", | |
"FFHQ": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl" | |
} | |
for Model in Models: | |
network_name = network_url[Model].split("/")[-1] | |
if not (path_exists(network_name)): | |
fetch_model(network_url[Model],network_name) | |
def load_current_model(current_model="imagenet256.pkl"): | |
with dnnlib.util.open_url(current_model) as f: | |
G = legacy.load_network_pkl(f)['G_ema'].to(device) | |
zs = torch.randn([10000, G.mapping.z_dim], device=device) | |
cs = torch.zeros([10000, G.mapping.c_dim], device=device) | |
for i in range(cs.shape[0]): | |
cs[i,i//10]=1 | |
w_stds = G.mapping(zs, cs) | |
w_stds = w_stds.reshape(10, 1000, G.num_ws, -1) | |
w_stds=w_stds.std(0).mean(0)[0] | |
w_all_classes_avg = G.mapping.w_avg.mean(0) | |
return(G,w_stds,w_all_classes_avg) | |
G, w_stds, w_all_classes_avg = load_current_model() | |
previousModel = 'imagenet256' | |
def run(prompt,steps,model): | |
global G, w_stds, w_all_classes_avg, previousModel | |
if(model == 'imagenet256' and previousModel != 'imagenet256'): | |
G, w_stds, w_all_classes_avg = load_current_model('imagenet256.pkl') | |
#if(model == 'imagenet512' and previousModel != 'imagenet512'): | |
# G, w_stds, w_all_classes_avg = load_current_model('imagenet512.pkl') | |
#elif(model=='imagenet1024' and previousModel != 'imagenet1024'): | |
# G, w_stds, w_all_classes_avg = load_current_model('imagenet1024.pkl') | |
elif(model=='pokemon256' and previousModel != 'pokemon256'): | |
G, w_stds, w_all_classes_avg = load_current_model('pokemon256.pkl') | |
elif(model=='ffhq256' and previousModel != 'ffhq256'): | |
G, w_stds, w_all_classes_avg = load_current_model('ffhq256.pkl') | |
previousModel = model | |
texts = prompt | |
steps = steps | |
seed = -1 # @param {type:"number"} | |
# @markdown --- | |
if seed == -1: | |
seed = np.random.randint(0, 9e9) | |
print(f"Your random seed is: {seed}") | |
texts = [frase.strip() for frase in texts.split("|") if frase] | |
targets = [clip_model.embed_text(text) for text in texts] | |
tf = Compose( | |
[ | |
# Resize(224), | |
lambda x: torch.clamp((x + 1) / 2, min=0, max=1), | |
] | |
) | |
initial_batch = 2 # actually that will be multiplied by initial_image_steps | |
initial_image_steps = 8 | |
def get_image(timestring): | |
os.makedirs(f"samples/{timestring}", exist_ok=True) | |
torch.manual_seed(seed) | |
with torch.no_grad(): | |
qs = [] | |
losses = [] | |
for _ in range(initial_image_steps): | |
a = torch.randn([initial_batch, 512], device=device) * 0.4 + w_stds * 0.4 | |
q = (a - w_all_classes_avg) / w_stds | |
images = G.synthesis( | |
(q * w_stds + w_all_classes_avg).unsqueeze(1).repeat([1, G.num_ws, 1]) | |
) | |
embeds = embed_image(images.add(1).div(2)) | |
loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0) | |
i = torch.argmin(loss) | |
qs.append(q[i]) | |
losses.append(loss[i]) | |
qs = torch.stack(qs) | |
losses = torch.stack(losses) | |
i = torch.argmin(losses) | |
q = qs[i].unsqueeze(0).repeat([G.num_ws, 1]).requires_grad_() | |
# Sampling loop | |
q_ema = q | |
print(q.shape) | |
opt = torch.optim.AdamW([q], lr=0.05, betas=(0.0, 0.999), weight_decay=0.025) | |
loop = tqdm(range(steps)) | |
for i in loop: | |
opt.zero_grad() | |
w = q * w_stds | |
image = G.synthesis((q * w_stds + w_all_classes_avg)[None], noise_mode="const") | |
embed = embed_image(image.add(1).div(2)) | |
loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean() | |
loss.backward() | |
opt.step() | |
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item()) | |
q_ema = q_ema * 0.98 + q * 0.02 | |
image = G.synthesis( | |
(q_ema * w_stds + w_all_classes_avg)[None], noise_mode="const" | |
) | |
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1)) | |
pil_image.save(f"samples/{timestring}/{i:04}.jpg") | |
if (i+1) % steps == 0: | |
#/usr/bin/ | |
subprocess.call(['ffmpeg', '-r', '60', '-i', f'samples/{timestring}/%04d.jpg', '-vcodec', 'libx264', '-crf','18','-pix_fmt','yuv420p', f'{timestring}.mp4']) | |
shutil.rmtree(f"samples/{timestring}") | |
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1)) | |
return(pil_image, f'{timestring}.mp4') | |
try: | |
timestring = time.strftime("%Y%m%d%H%M%S") | |
image,video = get_image(timestring) | |
return([image,video]) | |
except KeyboardInterrupt: | |
pass | |
image = gr.outputs.Image(type="pil", label="Your imge") | |
video = gr.outputs.Video(type="mp4", label="Your video") | |
css = ".output-image{height: 528px !important} .output-video{height: 528px !important}" | |
iface = gr.Interface(fn=run, inputs=[ | |
gr.inputs.Textbox(label="Prompt",default="Hong Kong by Studio Ghibli"), | |
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=300,maximum=500,minimum=10,step=1), | |
#gr.inputs.Radio(label="Aspect Ratio", choices=["Square", "Horizontal", "Vertical"],default="Horizontal"), | |
gr.inputs.Dropdown(label="Model", choices=["imagenet256","Pokemon256", "ffhq256"], default="imagenet256") | |
#gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=256), | |
#gr.inputs.Slider(label="Images - How many images you wish to generate", default=2, step=1, minimum=1, maximum=4), | |
#gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=5.0, minimum=1.0, maximum=15.0), | |
#gr.inputs.Slider(label="ETA - between 0 and 1. Lower values can provide better quality, higher values can be more diverse",default=0.0,minimum=0.0, maximum=1.0,step=0.1), | |
], | |
outputs=[image,video], | |
css=css, | |
title="Generate images from text with StyleGAN XL + CLIP", | |
description="<div>By typing a prompt and pressing submit you generate images based on it. <a href='https://github.com/autonomousvision/stylegan_xl' target='_blank'>StyleGAN XL</a> is a general purpose StyleGAN, and it is CLIP Guidance notebook was created by <a href='https://github.com/CasualGANPapers/StyleGANXL-CLIP' target='_blank'>ryudrigo and ouhenio</a>, and optimised by <a href='https://twitter.com/rivershavewings' target='_blank'>Katherine Crowson</a> This Spaces Gradio UI to the model was assembled by <a style='color: rgb(99, 102, 241);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a>, keep up with the <a style='color: rgb(99, 102, 241);' href='https://multimodal.art/news' target='_blank'>latest multimodal ai art news here</a> and consider <a style='color: rgb(99, 102, 241);' href='https://www.patreon.com/multimodalart' target='_blank'>supporting us on Patreon</a></div>", | |
article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The models are meant to be used for research purposes, such as this one.</div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>") | |
iface.launch(enable_queue=True) |