Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append('miniminiai/miniminiai') | |
import torchvision, torch | |
import fastcore.all as fc | |
import gradio as gr | |
from miniminiai import * | |
import numpy as np | |
from PIL import Image, ImageOps, ImageDraw | |
from torch import nn, tensor | |
from torch.utils.data import DataLoader | |
from torch.nn import functional as F | |
from torchvision import models, transforms | |
class LengthDataset(): | |
def __init__(self, length=1): self.length=length | |
def __len__(self): return self.length | |
def __getitem__(self, idx): return 0,0 | |
def get_dummy_dls(length=100): | |
return DataLoaders(DataLoader(LengthDataset(length), batch_size=1), # Train | |
DataLoader(LengthDataset(1), batch_size=1)) # Valid (length 1) | |
class TensorModel(nn.Module): | |
def __init__(self, t): | |
super().__init__() | |
self.t = nn.Parameter(t.clone()) | |
def forward(self, x=0): return self.t | |
class ImageOptCB(TrainCB): | |
def predict(self, learn): learn.preds = learn.model() | |
def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds) | |
def calc_features(imgs, target_layers=(18, 25)): | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
x = normalize(imgs) | |
feats = [] | |
for i, layer in enumerate(vgg16[:max(target_layers)+1]): | |
x = layer(x) | |
if i in target_layers: | |
feats.append(x.clone()) | |
return feats | |
class ContentLossToTarget(): | |
def __init__(self, target_im, target_layers=(18, 25)): | |
fc.store_attr() | |
with torch.no_grad(): | |
self.target_features = calc_features(target_im, target_layers) | |
def __call__(self, input_im): | |
return sum((f1-f2).pow(2).mean() for f1, f2 in | |
zip(calc_features(input_im, self.target_layers), self.target_features)) | |
def calc_grams(img, target_layers=(1, 6, 11, 18, 25)): | |
return fc.L(torch.einsum('chw, dhw -> cd', x, x) / (x.shape[-2]*x.shape[-1]) # 'bchw, bdhw -> bcd' if batched | |
for x in calc_features(img, target_layers)) | |
class StyleLossToTarget(): | |
def __init__(self, target_im, target_layers=(1, 6, 11, 18, 25), size=394): | |
fc.store_attr() | |
with torch.no_grad(): self.target_grams = calc_grams(target_im, target_layers) | |
def __call__(self, input_im): | |
return sum((f1-f2).pow(2).mean() for f1, f2 in | |
zip(calc_grams(input_im, self.target_layers), self.target_grams)) | |
class OTStyleLossToTarget(nn.Module): | |
def __init__(self, target, size=128, style_layers = [1, 6, 11, 18, 25], scale_factor=2e-5): | |
super(OTStyleLossToTarget, self).__init__() | |
self.device = device | |
self.resize = transforms.Compose([transforms.Resize(size), transforms.CenterCrop(size)]) | |
self.target = self.resize(target) # resize target image to size | |
self.style_layers = style_layers | |
self.scale_factor = scale_factor # Defaults tend to be very large, we scale to make them easier to work with | |
with torch.no_grad(): | |
self.target_features = calc_features(self.target, self.style_layers) | |
def project_sort(self, x, proj): | |
return torch.einsum('bcn,cp->bpn', x, proj).sort()[0] | |
def ot_loss(self, source, target, proj_n=32): | |
ch, n = source.shape[-2:] | |
projs = F.normalize(torch.randn(ch, proj_n).to(self.device), dim=0) | |
source_proj = self.project_sort(source, projs) | |
target_proj = self.project_sort(target, projs) | |
target_interp = F.interpolate(target_proj, n, mode='nearest') | |
return (source_proj-target_interp).square().sum() | |
def forward(self, input): | |
input = self.resize(input) # set size (assumes square images) | |
input_features = calc_features(input, self.style_layers) | |
l = 0 | |
# Run through all features and take l1 loss (mean error) between them | |
return sum(self.ot_loss(x, y) for x, y in zip(input_features, self.target_features)) * self.scale_factor | |
class VincentStyleLossToTarget(nn.Module): | |
def __init__(self, target, size=128, style_layers = [1, 6, 11, 18, 25], scale_factor=1e-5): | |
super(VincentStyleLossToTarget, self).__init__() | |
self.resize = transforms.Compose([transforms.Resize(size), transforms.CenterCrop(size)]) | |
self.target = self.resize(target) # resize target image to size | |
self.style_layers = style_layers | |
self.scale_factor = scale_factor # Defaults tend to be very large, we scale to make them easier to work with | |
with torch.no_grad(): | |
self.target_features = calc_features(self.target, self.style_layers) | |
def calc_2_moments(self, x): | |
c, w, h = x.shape | |
x = x.reshape(1, c, w*h) # b, c, n | |
mu = x.mean(dim=-1, keepdim=True) # b, c, 1 | |
cov = torch.matmul(x-mu, torch.transpose(x-mu, -1, -2)) | |
return mu, cov | |
def matrix_diag(self, diagonal): | |
N = diagonal.shape[-1] | |
shape = diagonal.shape[:-1] + (N, N) | |
device, dtype = diagonal.device, diagonal.dtype | |
result = torch.zeros(shape, dtype=dtype, device=device) | |
indices = torch.arange(result.numel(), device=device).reshape(shape) | |
indices = indices.diagonal(dim1=-2, dim2=-1) | |
result.view(-1)[indices] = diagonal | |
return result | |
def l2wass_dist(self, mean_stl, cov_stl, mean_synth, cov_synth): | |
# Calculate tr_cov and root_cov from mean_stl and cov_stl | |
eigvals,eigvects = torch.linalg.eigh(cov_stl) # eig returns complex tensors, I think eigh matches tf self_adjoint_eig | |
eigroot_mat = self.matrix_diag(torch.sqrt(eigvals.clip(0))) | |
root_cov_stl = torch.matmul(torch.matmul(eigvects, eigroot_mat),torch.transpose(eigvects, -1, -2)) | |
tr_cov_stl = torch.sum(eigvals.clip(0), dim=1, keepdim=True) | |
tr_cov_synth = torch.sum(torch.linalg.eigvalsh(cov_synth).clip(0), dim=1, keepdim=True) | |
mean_diff_squared = torch.mean((mean_synth - mean_stl)**2) | |
cov_prod = torch.matmul(torch.matmul(root_cov_stl,cov_synth),root_cov_stl) | |
var_overlap = torch.sum(torch.sqrt(torch.linalg.eigvalsh(cov_prod).clip(0.1)), dim=1, keepdim=True) # .clip(0) meant errors getting eigvals | |
dist = mean_diff_squared+tr_cov_stl+tr_cov_synth-2*var_overlap | |
return dist | |
def forward(self, input): | |
input = self.resize(input) # set size (assumes square images, center crops otherwise) | |
input_features = calc_features(input, self.style_layers) # get features | |
l = 0 | |
for x, y in zip(input_features, self.target_features): | |
mean_synth, cov_synth = self.calc_2_moments(x) # input mean and cov | |
mean_stl, cov_stl = self.calc_2_moments(y) # target mean and cov | |
l += self.l2wass_dist(mean_stl, cov_stl, mean_synth, cov_synth) | |
return l.mean() * self.scale_factor | |
def image_grid(imgs, rows, cols): | |
assert len(imgs) == rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img.resize((w, h)), box=(i%cols*w, i//cols*h)) | |
grid = ImageOps.expand(grid, border=20, fill=(255,255,255)) | |
draw = ImageDraw.Draw(grid) | |
# # fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", ) | |
# draw.text((0,0),"Sample Text",(0,0,0)) | |
return grid | |
def style_image(content_image, style_image, style_losses): | |
data = [] | |
content_image = content_image.resize((384, 384)) | |
style_image = style_image.resize((384, 384)) | |
output = [content_image] | |
content_image = torch.tensor(np.array(content_image).astype(np.float32) / 255.).permute(2, 0, 1) | |
style_image = torch.tensor(np.array(style_image).astype(np.float32) / 255.).permute(2, 0, 1) | |
content_loss = ContentLossToTarget(content_image.to(device)) | |
sim = style_image.to(device) | |
for style_loss in style_losses: | |
style_loss = map_style_losses[style_loss](sim, size=384) | |
model = TensorModel(content_image) | |
def combined_loss(x): return style_loss(x) + content_loss(x) | |
learn = Learner(model, get_dummy_dls(150), combined_loss, lr=1e-2, cbs=[ImageOptCB(), DeviceCB()], opt_func=torch.optim.Adam) | |
learn.fit(1) | |
im = to_cpu(learn.preds.clip(0, 1)) | |
output.append(Image.fromarray((im.permute(1, 2, 0).numpy()* 255).astype(np.uint8))) | |
return image_grid(output, 1, len(style_losses) + 1) | |
def run(): | |
with gr.Blocks() as demo: | |
# gr.Markdown("Start typing below and then click **Run** to see the output.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
content_im = gr.Image(shape=(318, 318), type='pil', label="Content image") | |
style_img = gr.Image(shape=(318, 318), type='pil', label="Style image") | |
style_losses = gr.CheckboxGroup(["Gram Matrix", "OT-Based", "Vincent's"], value=["Gram Matrix", "OT-Based", "Vincent's"], label="Style Loss") | |
btn = gr.Button("Generate") | |
with gr.Column(scale=1): | |
out = gr.Image(shape=(384, 384)) | |
btn.click(fn=style_image, inputs=[content_im, style_img, style_losses], outputs=out) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
map_style_losses = { | |
"Gram Matrix": StyleLossToTarget, | |
"OT-Based": OTStyleLossToTarget, | |
"Vincent's": VincentStyleLossToTarget | |
} | |
if __name__ == "__main__": | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).to(device) | |
vgg16.eval() | |
vgg16 = vgg16.features | |
run() |