3v324v23's picture
init
2f85de4
raw
history blame
8.19 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.signal
from .blurpool import BlurPool
from .official_stylegan3_model_helper import SEL, SEL_unet_pro, MappingNetwork, FullyConnectedLayer, modulated_conv2d, SynthesisInput
from third_party.stylegan3_official_ops import filtered_lrelu
from third_party.stylegan3_official_ops import upfirdn2d
from third_party.stylegan3_official_ops import bias_act
class UNetBlock(nn.Module):
def __init__(self, w_dim, in_channel, latent_channel, out_channel, ks=3, layer_num=2):
super().__init__()
self.ks = ks
self.layer_num = layer_num
self.weight1 = nn.Parameter(torch.randn([latent_channel, in_channel, ks, ks]))
self.weight2 = nn.Parameter(torch.randn([out_channel, latent_channel, ks, ks]))
self.bias1 = nn.Parameter(torch.zeros([latent_channel]))
self.bias2 = nn.Parameter(torch.zeros([out_channel]))
self.affine1 = FullyConnectedLayer(w_dim, in_channel, bias_init=1)
self.affine2 = FullyConnectedLayer(w_dim, latent_channel, bias_init=1)
if self.layer_num == 3:
self.weight_mid = nn.Parameter(torch.randn([latent_channel, latent_channel, ks, ks]))
self.bias_mid = nn.Parameter(torch.zeros([latent_channel]))
self.affine_mid = FullyConnectedLayer(w_dim, latent_channel, bias_init=1)
def forward(self, x, *w):
s1 = self.affine1(w[0])
if self.layer_num == 3:
s_mid = self.affine_mid(w[1])
s2 = self.affine2(w[2])
else:
s2 = self.affine2(w[1])
x = modulated_conv2d(x, w=self.weight1, s=s1, padding=self.ks//2)
x = bias_act.bias_act(x, self.bias1.to(x.dtype), act='lrelu')
if self.layer_num == 3:
x = modulated_conv2d(x, w=self.weight_mid, s=s_mid, padding=self.ks//2)
x = bias_act.bias_act(x, self.bias_mid.to(x.dtype), act='lrelu')
x = modulated_conv2d(x, w=self.weight2, s=s2, padding=self.ks//2)
x = bias_act.bias_act(x, self.bias2.to(x.dtype), act='lrelu')
return x
class UNet(nn.Module):
def __init__(self, w_dim, in_dim=3, base_dim=64, ks=3, block_num=3, layer_num=2, filt_size=3, output_dim=3, label_nc=14, sel_type='normal', img_resolution=256, wo_transform = False,):
super().__init__()
self.block_num = block_num
self.layer_num = layer_num
self.sel_type = sel_type
if self.sel_type == 'normal':
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
for i in range(block_num):
self.register_buffer(f'down_filter_{i}', self.design_lowpass_filter(numtaps=12, cutoff=2**((block_num-i+1)/2), width=None, fs=img_resolution//(2**i)))
self.register_buffer(f'sel_down_filter_{i}', self.design_lowpass_filter(numtaps=6*2**i, cutoff=2**((block_num-i+2)/2), width=None, fs=img_resolution//(2**(i-1))))
self.input = SynthesisInput(w_dim=w_dim, channels=in_dim, size=img_resolution, sampling_rate=img_resolution, bound_len=0, bandwidth=4, wo_transform=wo_transform) # what is the bandwidth
encoder_list, sel_enc_list, sel_dec_list, decoder_list, bp_list = [], [], [], [], []
for i in range(block_num):
if i == 0:
encoder_list.append(UNetBlock(w_dim, in_dim, base_dim, base_dim, layer_num=layer_num))
else:
encoder_list.append(UNetBlock(w_dim, base_dim * 2 ** (i-1), base_dim * 2 ** i, base_dim * 2 ** i, layer_num=layer_num))
decoder_list.append(UNetBlock(w_dim,
base_dim * 2 ** (block_num-i),
base_dim * 2 ** (block_num-i-1),
base_dim * 2 ** (block_num-i-2) if i < block_num-1 else base_dim * 2 ** (block_num-i-1),
layer_num=layer_num
))
if self.sel_type == 'normal':
sel_enc_list.append(SEL(in_dim if i==0 else base_dim * 2 ** (i-1), label_nc))
sel_dec_list.append(SEL(base_dim * 2 ** (block_num-i-1), label_nc))
else:
sel_enc_list.append(SEL_unet_pro(in_dim if i==0 else base_dim * 2 ** (i-1), label_nc, down_filter=getattr(self, f'sel_down_filter_{i}')))
sel_dec_list.append(SEL_unet_pro(base_dim * 2 ** (block_num-i-1), label_nc, down_filter=getattr(self, f'sel_down_filter_{block_num-i-1}')))
self.encoders = nn.ModuleList(encoder_list)
self.decoders = nn.ModuleList(decoder_list)
self.enc_sels = nn.ModuleList(sel_enc_list)
self.dec_sels = nn.ModuleList(sel_dec_list)
self.torgb = UNetBlock(w_dim, base_dim, base_dim, output_dim)
@staticmethod
def design_lowpass_filter(numtaps, cutoff, fs, width=None):
if numtaps == 1:
return None
f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
return torch.as_tensor(f, dtype=torch.float32)
def forward(self, ws, heatmap, **kwargs):
ws = ws.unbind(1)
x = self.input(ws[0])
ws = ws[1:]
enc_x = []
for i in range(self.block_num):
# modulate with SEL
x = self.enc_sels[i] (x, heatmap)
if self.layer_num==2:
x = self.encoders[i] (x, ws[2*i], ws[2*i+1])
else:
x = self.encoders[i] (x, ws[3*i], ws[3*i+1], ws[3*i+2])
enc_x.append(x)
if self.sel_type == 'normal':
x = self.pool(x)
else:
x = upfirdn2d.upfirdn2d(x=x, f=getattr(self, f'down_filter_{i}'), down=2, flip_filter=False, padding=5)
ws = ws[self.layer_num*self.block_num: ]
for i in range(self.block_num):
x = F.interpolate(x, size=x.shape[-1] * 2, mode='bilinear', align_corners=False)
# modulate with SEL
x = self.dec_sels[i] (x, heatmap)
if self.layer_num==2:
x = self.decoders[i] (torch.cat([x, enc_x[-1-i]], 1), ws[2*i], ws[2*i+1])
else:
x = self.decoders[i] (torch.cat([x, enc_x[-1-i]], 1), ws[3*i], ws[3*i+1], ws[3*i+2])
ws = ws[self.layer_num*self.block_num: ]
x = self.torgb(x, ws[0], ws[1])
return x
class Generator(nn.Module):
def __init__(self, z_dim, c_dim, w_dim, img_resolution=256, img_channels=3,
in_dim=3, base_dim=64, ks=3, block_num=3, layer_num=2, filt_size=3, output_dim=3, label_nc=14, sel_type='normal', wo_transform=False, **kwargs):
super().__init__()
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=2*layer_num*block_num+3)
self.synthesis = UNet(w_dim=w_dim, in_dim=in_dim, base_dim=64, ks=3, block_num=block_num, layer_num=layer_num, filt_size=3, output_dim=img_channels, label_nc=label_nc, sel_type=sel_type, img_resolution=img_resolution, wo_transform=wo_transform)
def forward(self, z, c, heatmap, truncation_psi=1, truncation_cutoff=None, update_emas=False):
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
ret = self.synthesis(ws, heatmap=heatmap)
return ret
# class SELUNet(UNet):
# def forward(self, x, hm):
if __name__ == '__main__':
# g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4, img_resolution=256, img_channels=32, sel_type='abn')
# hm = torch.ones([10, 14, 256, 256])
# z = torch.zeros([10, 64])
# c = None
# opt = g(z, c, hm)
g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4,layer_num=3, img_resolution=512, img_channels=32, sel_type='abn')
hm = torch.ones([10, 14, 512, 512])
z = torch.zeros([10, 64])
c = None
opt = g(z, c, hm)
g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4,layer_num=3, img_resolution=256, img_channels=32, sel_type='abn')
hm = torch.ones([10, 14,256,256])
z = torch.zeros([10, 64])
c = None
opt = g(z, c, hm)