Spaces:
Runtime error
Runtime error
import os | |
os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html") | |
os.system("git clone https://github.com/openai/CLIP") | |
os.system("pip install -e ./CLIP") | |
os.system("pip install einops ninja scipy numpy Pillow tqdm") | |
import sys | |
sys.path.append('./CLIP') | |
import io | |
import os, time | |
import pickle | |
import shutil | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import requests | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
import clip | |
from tqdm.notebook import tqdm | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
from einops import rearrange | |
import gradio as gr | |
device = torch.device('cuda:0') | |
def fetch(url_or_path): | |
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): | |
r = requests.get(url_or_path) | |
r.raise_for_status() | |
fd = io.BytesIO() | |
fd.write(r.content) | |
fd.seek(0) | |
return fd | |
return open(url_or_path, 'rb') | |
def fetch_model(url_or_path): | |
basename = os.path.basename(url_or_path) | |
if os.path.exists(basename): | |
return basename | |
else: | |
os.system("wget -c '{url_or_path}'") | |
return basename | |
def norm1(prompt): | |
"Normalize to the unit sphere." | |
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt() | |
def spherical_dist_loss(x, y): | |
x = F.normalize(x, dim=-1) | |
y = F.normalize(y, dim=-1) | |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
class MakeCutouts(torch.nn.Module): | |
def __init__(self, cut_size, cutn, cut_pow=1.): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
self.cut_pow = cut_pow | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
max_size = min(sideX, sideY) | |
min_size = min(sideX, sideY, self.cut_size) | |
cutouts = [] | |
for _ in range(self.cutn): | |
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
offsetx = torch.randint(0, sideX - size + 1, ()) | |
offsety = torch.randint(0, sideY - size + 1, ()) | |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
return torch.cat(cutouts) | |
make_cutouts = MakeCutouts(224, 32, 0.5) | |
def embed_image(image): | |
n = image.shape[0] | |
cutouts = make_cutouts(image) | |
embeds = clip_model.embed_cutout(cutouts) | |
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n) | |
return embeds | |
def embed_url(url): | |
image = Image.open(fetch(url)).convert('RGB') | |
return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0) | |
class CLIP(object): | |
def __init__(self): | |
clip_model = "ViT-B/32" | |
self.model, _ = clip.load(clip_model) | |
self.model = self.model.requires_grad_(False) | |
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
def embed_text(self, prompt): | |
"Normalized clip text embedding." | |
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float()) | |
def embed_cutout(self, image): | |
"Normalized clip image embedding." | |
return norm1(self.model.encode_image(self.normalize(image))) | |
clip_model = CLIP() | |
# Load stylegan model | |
base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/" | |
model_name = "stylegan3-t-ffhqu-1024x1024.pkl" | |
#model_name = "stylegan3-r-metfacesu-1024x1024.pkl" | |
#model_name = "stylegan3-t-afhqv2-512x512.pkl" | |
network_url = base_url + model_name | |
os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl") | |
with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp: | |
G = pickle.load(fp)['G_ema'].to(device) | |
zs = torch.randn([10000, G.mapping.z_dim], device=device) | |
w_stds = G.mapping(zs, None).std(0) | |
def inference(text): | |
target = clip_model.embed_text(text) | |
steps = 20 | |
seed = 2 | |
tf = Compose([ | |
Resize(224), | |
lambda x: torch.clamp((x+1)/2,min=0,max=1), | |
]) | |
torch.manual_seed(seed) | |
timestring = time.strftime('%Y%m%d%H%M%S') | |
with torch.no_grad(): | |
qs = [] | |
losses = [] | |
for _ in range(8): | |
q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds | |
images = G.synthesis(q * w_stds + G.mapping.w_avg) | |
embeds = embed_image(images.add(1).div(2)) | |
loss = spherical_dist_loss(embeds, target).mean(0) | |
i = torch.argmin(loss) | |
qs.append(q[i]) | |
losses.append(loss[i]) | |
qs = torch.stack(qs) | |
losses = torch.stack(losses) | |
print(losses) | |
print(losses.shape, qs.shape) | |
i = torch.argmin(losses) | |
q = qs[i].unsqueeze(0) | |
q.requires_grad_() | |
q_ema = q | |
opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999)) | |
loop = tqdm(range(steps)) | |
for i in loop: | |
opt.zero_grad() | |
w = q * w_stds | |
image = G.synthesis(w + G.mapping.w_avg, noise_mode='const') | |
embed = embed_image(image.add(1).div(2)) | |
loss = spherical_dist_loss(embed, target).mean() | |
loss.backward() | |
opt.step() | |
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item()) | |
q_ema = q_ema * 0.9 + q * 0.1 | |
image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const') | |
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1)) | |
#os.makedirs(f'samples/{timestring}', exist_ok=True) | |
#pil_image.save(f'samples/{timestring}/{i:04}.jpg') | |
return pil_image | |
title = "StyleGAN3+CLIP" | |
description = "Gradio demo for StyleGAN3+CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." | |
article = "<p style='text-align: center'>colab by https://twitter.com/nshepperd1 <a href='https://colab.research.google.com/drive/1eYlenR1GHPZXt-YuvXabzO9wfh9CWY36' target='_blank'>Colab</a></p>" | |
examples = [['elon musk']] | |
gr.Interface( | |
inference, | |
"text", | |
gr.outputs.Image(type="pil", label="Output"), | |
title=title, | |
description=description, | |
article=article, | |
enable_queue=True, | |
examples=examples | |
).launch(debug=True) | |