Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from models.base_blocks import ResBlock, StyleConv, ToRGB | |
class ENet(nn.Module): | |
def __init__( | |
self, | |
num_style_feat=512, | |
lnet=None, | |
concat=False | |
): | |
super(ENet, self).__init__() | |
self.low_res = lnet | |
for param in self.low_res.parameters(): | |
param.requires_grad = False | |
channel_multiplier, narrow = 2, 1 | |
channels = { | |
'4': int(512 * narrow), | |
'8': int(512 * narrow), | |
'16': int(512 * narrow), | |
'32': int(512 * narrow), | |
'64': int(256 * channel_multiplier * narrow), | |
'128': int(128 * channel_multiplier * narrow), | |
'256': int(64 * channel_multiplier * narrow), | |
'512': int(32 * channel_multiplier * narrow), | |
'1024': int(16 * channel_multiplier * narrow) | |
} | |
self.log_size = 8 | |
first_out_size = 128 | |
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128 | |
# downsample | |
in_channels = channels[f'{first_out_size}'] | |
self.conv_body_down = nn.ModuleList() | |
for i in range(8, 2, -1): | |
out_channels = channels[f'{2**(i - 1)}'] | |
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) | |
in_channels = out_channels | |
self.num_style_feat = num_style_feat | |
linear_out_channel = num_style_feat | |
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) | |
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) | |
self.style_convs = nn.ModuleList() | |
self.to_rgbs = nn.ModuleList() | |
self.noises = nn.Module() | |
self.concat = concat | |
if concat: | |
in_channels = 3 + 32 # channels['64'] | |
else: | |
in_channels = 3 | |
for i in range(7, 9): # 128, 256 | |
out_channels = channels[f'{2**i}'] # | |
self.style_convs.append( | |
StyleConv( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
num_style_feat=num_style_feat, | |
demodulate=True, | |
sample_mode='upsample')) | |
self.style_convs.append( | |
StyleConv( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
num_style_feat=num_style_feat, | |
demodulate=True, | |
sample_mode=None)) | |
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) | |
in_channels = out_channels | |
def forward(self, audio_sequences, face_sequences, gt_sequences): | |
B = audio_sequences.size(0) | |
input_dim_size = len(face_sequences.size()) | |
inp, ref = torch.split(face_sequences,3,dim=1) | |
if input_dim_size > 4: | |
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) | |
inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0) | |
ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0) | |
gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0) | |
# get the global style | |
feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2) | |
for i in range(self.log_size - 2): | |
feat = self.conv_body_down[i](feat) | |
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) | |
# style code | |
style_code = self.final_linear(feat.reshape(feat.size(0), -1)) | |
style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat) | |
LNet_input = torch.cat([inp, gt_sequences], dim=1) | |
LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear') | |
if self.concat: | |
low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input) | |
low_res_img.detach() | |
low_res_feat.detach() | |
out = torch.cat([low_res_img, low_res_feat], dim=1) | |
else: | |
low_res_img = self.low_res(audio_sequences, LNet_input) | |
low_res_img.detach() | |
# 96 x 96 | |
out = low_res_img | |
p2d = (2,2,2,2) | |
out = F.pad(out, p2d, "reflect", 0) | |
skip = out | |
for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs): | |
out = conv1(out, style_code) # 96, 192, 384 | |
out = conv2(out, style_code) | |
skip = to_rgb(out, style_code, skip) | |
_outputs = skip | |
# remove padding | |
_outputs = _outputs[:,:,8:-8,8:-8] | |
if input_dim_size > 4: | |
_outputs = torch.split(_outputs, B, dim=0) | |
outputs = torch.stack(_outputs, dim=2) | |
low_res_img = F.interpolate(low_res_img, outputs.size()[3:]) | |
low_res_img = torch.split(low_res_img, B, dim=0) | |
low_res_img = torch.stack(low_res_img, dim=2) | |
else: | |
outputs = _outputs | |
return outputs, low_res_img |