File size: 4,320 Bytes
0145b71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import math
from collections import namedtuple
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.models.vgg as vgg
from op import fused_leaky_relu
FeatureOutput = namedtuple(
"FeatureOutput", ["relu1", "relu2", "relu3", "relu4", "relu5"])
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
class FeatureExtractor(nn.Module):
"""Reference:
https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/3
"""
def __init__(self):
super(FeatureExtractor, self).__init__()
self.vgg_layers = vgg.vgg19(pretrained=True).features
self.layer_name_mapping = {
'3': "relu1",
'8': "relu2",
'17': "relu3",
'26': "relu4",
'35': "relu5",
}
def forward(self, x):
output = {}
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output[self.layer_name_mapping[name]] = x
return FeatureOutput(**output)
class StyleEmbedder(nn.Module):
def __init__(self):
super(StyleEmbedder, self).__init__()
self.feature_extractor = FeatureExtractor()
self.feature_extractor.eval()
self.avg_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
def forward(self, img):
N = img.shape[0]
features = self.feature_extractor(self.avg_pool(img))
grams = []
for feature in features:
gram = gram_matrix(feature)
grams.append(gram.view(N, -1))
out = torch.cat(grams, dim=1)
return out
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)
class StyleEncoder(nn.Module):
def __init__(
self,
style_dim=512,
n_mlp=4,
):
super().__init__()
self.style_dim = style_dim
e_dim = 610304
self.embedder = StyleEmbedder()
layers = []
layers.append(EqualLinear(e_dim, style_dim, lr_mul=1, activation='fused_lrelu'))
for i in range(n_mlp - 2):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=1, activation='fused_lrelu'
)
)
layers.append(EqualLinear(style_dim, style_dim, lr_mul=1, activation=None))
self.embedder_mlp = nn.Sequential(*layers)
def forward(self, image):
z_embed = self.embedder_mlp(self.embedder(image)) # [N, 512]
return z_embed
class Projector(nn.Module):
def __init__(self, style_dim=512, n_mlp=4):
super().__init__()
layers = []
for i in range(n_mlp - 1):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=1, activation='fused_lrelu'
)
)
layers.append(EqualLinear(style_dim, style_dim, lr_mul=1, activation=None))
self.projector = nn.Sequential(*layers)
def forward(self, x):
return self.projector(x)
|