pedrogengo's picture
Update main.py
ebe933f
import sys
sys.path.append('miniminiai/miniminiai')
import torchvision, torch
import fastcore.all as fc
import gradio as gr
from miniminiai import *
import numpy as np
from PIL import Image, ImageOps, ImageDraw
from torch import nn, tensor
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import models, transforms
class LengthDataset():
def __init__(self, length=1): self.length=length
def __len__(self): return self.length
def __getitem__(self, idx): return 0,0
def get_dummy_dls(length=100):
return DataLoaders(DataLoader(LengthDataset(length), batch_size=1), # Train
DataLoader(LengthDataset(1), batch_size=1)) # Valid (length 1)
class TensorModel(nn.Module):
def __init__(self, t):
super().__init__()
self.t = nn.Parameter(t.clone())
def forward(self, x=0): return self.t
class ImageOptCB(TrainCB):
def predict(self, learn): learn.preds = learn.model()
def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds)
def calc_features(imgs, target_layers=(18, 25)):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
x = normalize(imgs)
feats = []
for i, layer in enumerate(vgg16[:max(target_layers)+1]):
x = layer(x)
if i in target_layers:
feats.append(x.clone())
return feats
class ContentLossToTarget():
def __init__(self, target_im, target_layers=(18, 25)):
fc.store_attr()
with torch.no_grad():
self.target_features = calc_features(target_im, target_layers)
def __call__(self, input_im):
return sum((f1-f2).pow(2).mean() for f1, f2 in
zip(calc_features(input_im, self.target_layers), self.target_features))
def calc_grams(img, target_layers=(1, 6, 11, 18, 25)):
return fc.L(torch.einsum('chw, dhw -> cd', x, x) / (x.shape[-2]*x.shape[-1]) # 'bchw, bdhw -> bcd' if batched
for x in calc_features(img, target_layers))
class StyleLossToTarget():
def __init__(self, target_im, target_layers=(1, 6, 11, 18, 25), size=394):
fc.store_attr()
with torch.no_grad(): self.target_grams = calc_grams(target_im, target_layers)
def __call__(self, input_im):
return sum((f1-f2).pow(2).mean() for f1, f2 in
zip(calc_grams(input_im, self.target_layers), self.target_grams))
class OTStyleLossToTarget(nn.Module):
def __init__(self, target, size=128, style_layers = [1, 6, 11, 18, 25], scale_factor=2e-5):
super(OTStyleLossToTarget, self).__init__()
self.device = device
self.resize = transforms.Compose([transforms.Resize(size), transforms.CenterCrop(size)])
self.target = self.resize(target) # resize target image to size
self.style_layers = style_layers
self.scale_factor = scale_factor # Defaults tend to be very large, we scale to make them easier to work with
with torch.no_grad():
self.target_features = calc_features(self.target, self.style_layers)
def project_sort(self, x, proj):
return torch.einsum('bcn,cp->bpn', x, proj).sort()[0]
def ot_loss(self, source, target, proj_n=32):
ch, n = source.shape[-2:]
projs = F.normalize(torch.randn(ch, proj_n).to(self.device), dim=0)
source_proj = self.project_sort(source, projs)
target_proj = self.project_sort(target, projs)
target_interp = F.interpolate(target_proj, n, mode='nearest')
return (source_proj-target_interp).square().sum()
def forward(self, input):
input = self.resize(input) # set size (assumes square images)
input_features = calc_features(input, self.style_layers)
l = 0
# Run through all features and take l1 loss (mean error) between them
return sum(self.ot_loss(x, y) for x, y in zip(input_features, self.target_features)) * self.scale_factor
class VincentStyleLossToTarget(nn.Module):
def __init__(self, target, size=128, style_layers = [1, 6, 11, 18, 25], scale_factor=1e-5):
super(VincentStyleLossToTarget, self).__init__()
self.resize = transforms.Compose([transforms.Resize(size), transforms.CenterCrop(size)])
self.target = self.resize(target) # resize target image to size
self.style_layers = style_layers
self.scale_factor = scale_factor # Defaults tend to be very large, we scale to make them easier to work with
with torch.no_grad():
self.target_features = calc_features(self.target, self.style_layers)
def calc_2_moments(self, x):
c, w, h = x.shape
x = x.reshape(1, c, w*h) # b, c, n
mu = x.mean(dim=-1, keepdim=True) # b, c, 1
cov = torch.matmul(x-mu, torch.transpose(x-mu, -1, -2))
return mu, cov
def matrix_diag(self, diagonal):
N = diagonal.shape[-1]
shape = diagonal.shape[:-1] + (N, N)
device, dtype = diagonal.device, diagonal.dtype
result = torch.zeros(shape, dtype=dtype, device=device)
indices = torch.arange(result.numel(), device=device).reshape(shape)
indices = indices.diagonal(dim1=-2, dim2=-1)
result.view(-1)[indices] = diagonal
return result
def l2wass_dist(self, mean_stl, cov_stl, mean_synth, cov_synth):
# Calculate tr_cov and root_cov from mean_stl and cov_stl
eigvals,eigvects = torch.linalg.eigh(cov_stl) # eig returns complex tensors, I think eigh matches tf self_adjoint_eig
eigroot_mat = self.matrix_diag(torch.sqrt(eigvals.clip(0)))
root_cov_stl = torch.matmul(torch.matmul(eigvects, eigroot_mat),torch.transpose(eigvects, -1, -2))
tr_cov_stl = torch.sum(eigvals.clip(0), dim=1, keepdim=True)
tr_cov_synth = torch.sum(torch.linalg.eigvalsh(cov_synth).clip(0), dim=1, keepdim=True)
mean_diff_squared = torch.mean((mean_synth - mean_stl)**2)
cov_prod = torch.matmul(torch.matmul(root_cov_stl,cov_synth),root_cov_stl)
var_overlap = torch.sum(torch.sqrt(torch.linalg.eigvalsh(cov_prod).clip(0.1)), dim=1, keepdim=True) # .clip(0) meant errors getting eigvals
dist = mean_diff_squared+tr_cov_stl+tr_cov_synth-2*var_overlap
return dist
def forward(self, input):
input = self.resize(input) # set size (assumes square images, center crops otherwise)
input_features = calc_features(input, self.style_layers) # get features
l = 0
for x, y in zip(input_features, self.target_features):
mean_synth, cov_synth = self.calc_2_moments(x) # input mean and cov
mean_stl, cov_stl = self.calc_2_moments(y) # target mean and cov
l += self.l2wass_dist(mean_stl, cov_stl, mean_synth, cov_synth)
return l.mean() * self.scale_factor
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img.resize((w, h)), box=(i%cols*w, i//cols*h))
grid = ImageOps.expand(grid, border=20, fill=(255,255,255))
draw = ImageDraw.Draw(grid)
# # fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", )
# draw.text((0,0),"Sample Text",(0,0,0))
return grid
def style_image(content_image, style_image, style_losses):
data = []
content_image = content_image.resize((384, 384))
style_image = style_image.resize((384, 384))
output = [content_image]
content_image = torch.tensor(np.array(content_image).astype(np.float32) / 255.).permute(2, 0, 1)
style_image = torch.tensor(np.array(style_image).astype(np.float32) / 255.).permute(2, 0, 1)
content_loss = ContentLossToTarget(content_image.to(device))
sim = style_image.to(device)
for style_loss in style_losses:
style_loss = map_style_losses[style_loss](sim, size=384)
model = TensorModel(content_image)
def combined_loss(x): return style_loss(x) + content_loss(x)
learn = Learner(model, get_dummy_dls(150), combined_loss, lr=1e-2, cbs=[ImageOptCB(), DeviceCB()], opt_func=torch.optim.Adam)
learn.fit(1)
im = to_cpu(learn.preds.clip(0, 1))
output.append(Image.fromarray((im.permute(1, 2, 0).numpy()* 255).astype(np.uint8)))
return image_grid(output, 1, len(style_losses) + 1)
def run():
with gr.Blocks() as demo:
# gr.Markdown("Start typing below and then click **Run** to see the output.")
with gr.Row():
with gr.Column(scale=1):
content_im = gr.Image(shape=(318, 318), type='pil', label="Content image")
style_img = gr.Image(shape=(318, 318), type='pil', label="Style image")
style_losses = gr.CheckboxGroup(["Gram Matrix", "OT-Based", "Vincent's"], value=["Gram Matrix", "OT-Based", "Vincent's"], label="Style Loss")
btn = gr.Button("Generate")
with gr.Column(scale=1):
out = gr.Image(shape=(384, 384))
btn.click(fn=style_image, inputs=[content_im, style_img, style_losses], outputs=out)
demo.launch(server_name="0.0.0.0", server_port=7860)
map_style_losses = {
"Gram Matrix": StyleLossToTarget,
"OT-Based": OTStyleLossToTarget,
"Vincent's": VincentStyleLossToTarget
}
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).to(device)
vgg16.eval()
vgg16 = vgg16.features
run()