Spaces:
Runtime error
Runtime error
import math | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from huggingface_hub import from_pretrained_fastai | |
from fastai.vision.all import * | |
from torchvision.models import vgg19, vgg16 | |
pascal_source = '.' | |
EXAMPLES_PATH = Path('/content/examples') | |
repo_id = "hugginglearners/fastai-style-transfer" | |
def get_stl_fs(fs): return fs[:-1] | |
def style_loss(inp:Tensor, out_feat:Tensor): | |
"Calculate style loss, assumes we have `im_grams`" | |
# Get batch size | |
bs = inp[0].shape[0] | |
loss = [] | |
# For every item in our inputs | |
for y, f in zip(*map(get_stl_fs, [im_grams, inp])): | |
# Calculate MSE | |
loss.append(F.mse_loss(y.repeat(bs, 1, 1), gram(f))) | |
# Multiply their sum by 30000 | |
return 3e5 * sum(loss) | |
class FeatureLoss(Module): | |
"Combines two losses and features into a useable loss function" | |
def __init__(self, feats, style_loss, act_loss, hooks, feat_net): | |
store_attr() | |
self.hooks = hooks | |
self.feat_net = feat_net | |
self.reset_metrics() | |
def forward(self, pred, targ): | |
# First get the features of our prediction and target | |
pred_feat, targ_feat = self.feats(self.feat_net, self.hooks, pred), self.feats(self.feat_net, self.hooks, targ) | |
# Calculate style and activation loss | |
style_loss = self.style_loss(pred_feat, targ_feat) | |
act_loss = self.act_loss(pred_feat, targ_feat) | |
# Store the loss | |
self._add_loss(style_loss, act_loss) | |
# Return the sum | |
return style_loss + act_loss | |
def reset_metrics(self): | |
# Generates a blank metric | |
self.metrics = dict(style = [], content = []) | |
def _add_loss(self, style_loss, act_loss): | |
# Add to our metrics | |
self.metrics['style'].append(style_loss) | |
self.metrics['content'].append(act_loss) | |
def act_loss(inp:Tensor, targ:Tensor): | |
"Calculate the MSE loss of the activation layers" | |
return F.mse_loss(inp[-1], targ[-1]) | |
class ReflectionLayer(Module): | |
"A series of Reflection Padding followed by a ConvLayer" | |
def __init__(self, in_channels, out_channels, ks=3, stride=2): | |
reflection_padding = ks // 2 | |
self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
self.conv2d = nn.Conv2d(in_channels, out_channels, ks, stride) | |
def forward(self, x): | |
out = self.reflection_pad(x) | |
out = self.conv2d(out) | |
return out | |
class ResidualBlock(Module): | |
"Two reflection layers and an added activation function with residual" | |
def __init__(self, channels): | |
self.conv1 = ReflectionLayer(channels, channels, ks=3, stride=1) | |
self.in1 = nn.InstanceNorm2d(channels, affine=True) | |
self.conv2 = ReflectionLayer(channels, channels, ks=3, stride=1) | |
self.in2 = nn.InstanceNorm2d(channels, affine=True) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
residual = x | |
out = self.relu(self.in1(self.conv1(x))) | |
out = self.in2(self.conv2(out)) | |
out = out + residual | |
return out | |
class UpsampleConvLayer(Module): | |
"Upsample with a ReflectionLayer" | |
def __init__(self, in_channels, out_channels, ks=3, stride=1, upsample=None): | |
self.upsample = upsample | |
reflection_padding = ks // 2 | |
self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
self.conv2d = nn.Conv2d(in_channels, out_channels, ks, stride) | |
def forward(self, x): | |
x_in = x | |
if self.upsample: | |
x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample) | |
out = self.reflection_pad(x_in) | |
out = self.conv2d(out) | |
return out | |
class TransformerNet(Module): | |
"A simple network for style transfer" | |
def __init__(self): | |
# Initial convolution layers | |
self.conv1 = ReflectionLayer(3, 32, ks=9, stride=1) | |
self.in1 = nn.InstanceNorm2d(32, affine=True) | |
self.conv2 = ReflectionLayer(32, 64, ks=3, stride=2) | |
self.in2 = nn.InstanceNorm2d(64, affine=True) | |
self.conv3 = ReflectionLayer(64, 128, ks=3, stride=2) | |
self.in3 = nn.InstanceNorm2d(128, affine=True) | |
# Residual layers | |
self.res1 = ResidualBlock(128) | |
self.res2 = ResidualBlock(128) | |
self.res3 = ResidualBlock(128) | |
self.res4 = ResidualBlock(128) | |
self.res5 = ResidualBlock(128) | |
# Upsampling Layers | |
self.deconv1 = UpsampleConvLayer(128, 64, ks=3, stride=1, upsample=2) | |
self.in4 = nn.InstanceNorm2d(64, affine=True) | |
self.deconv2 = UpsampleConvLayer(64, 32, ks=3, stride=1, upsample=2) | |
self.in5 = nn.InstanceNorm2d(32, affine=True) | |
self.deconv3 = ReflectionLayer(32, 3, ks=9, stride=1) | |
# Non-linearities | |
self.relu = nn.ReLU() | |
def forward(self, X): | |
y = self.relu(self.in1(self.conv1(X))) | |
y = self.relu(self.in2(self.conv2(y))) | |
y = self.relu(self.in3(self.conv3(y))) | |
y = self.res1(y) | |
y = self.res2(y) | |
y = self.res3(y) | |
y = self.res4(y) | |
y = self.res5(y) | |
y = self.relu(self.in4(self.deconv1(y))) | |
y = self.relu(self.in5(self.deconv2(y))) | |
y = self.deconv3(y) | |
return y | |