""" This file only for testing mask regularzation. If it works, it will be merged with `layers.py`. """ import torch import torch.nn as nn import torch.nn.functional as F class AADLayer(nn.Module): def __init__(self, c_x, attr_c, c_id=256): super(AADLayer, self).__init__() self.attr_c = attr_c self.c_id = c_id self.c_x = c_x self.conv1 = nn.Conv2d( attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True ) self.conv2 = nn.Conv2d( attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True ) self.fc1 = nn.Linear(c_id, c_x) self.fc2 = nn.Linear(c_id, c_x) self.norm = nn.InstanceNorm2d(c_x, affine=False) self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True) def forward(self, h_in, z_attr, z_id): # h_in cxnxn # zid 256x1x1 # zattr cxnxn h = self.norm(h_in) gamma_attr = self.conv1(z_attr) beta_attr = self.conv2(z_attr) gamma_id = self.fc1(z_id) beta_id = self.fc2(z_id) A = gamma_attr * h + beta_attr gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) I = gamma_id * h + beta_id M = torch.sigmoid(self.conv_h(h)) out = (torch.ones_like(M).to(M.device) - M) * A + M * I return out, torch.mean(torch.ones_like(M).to(M.device) - M, dim=[1, 2, 3]) class AAD_ResBlk(nn.Module): def __init__(self, cin, cout, c_attr, c_id=256): super(AAD_ResBlk, self).__init__() self.cin = cin self.cout = cout self.AAD1 = AADLayer(cin, c_attr, c_id) self.conv1 = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False) self.relu1 = nn.ReLU(inplace=True) self.AAD2 = AADLayer(cin, c_attr, c_id) self.conv2 = nn.Conv2d( cin, cout, kernel_size=3, stride=1, padding=1, bias=False ) self.relu2 = nn.ReLU(inplace=True) if cin != cout: self.AAD3 = AADLayer(cin, c_attr, c_id) self.conv3 = nn.Conv2d( cin, cout, kernel_size=3, stride=1, padding=1, bias=False ) self.relu3 = nn.ReLU(inplace=True) def forward(self, h, z_attr, z_id): x, m1_ = self.AAD1(h, z_attr, z_id) x = self.relu1(x) x = self.conv1(x) x, m2_ = self.AAD2(x, z_attr, z_id) x = self.relu2(x) x = self.conv2(x) m = m1_ + m2_ if self.cin != self.cout: h, m3_ = self.AAD3(h, z_attr, z_id) h = self.relu3(h) h = self.conv3(h) m += m3_ x = x + h return x, m def weight_init(m): if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.001) m.bias.data.zero_() if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight.data) if isinstance(m, nn.ConvTranspose2d): nn.init.xavier_normal_(m.weight.data) def conv4x4(in_c, out_c, norm=nn.BatchNorm2d): return nn.Sequential( nn.Conv2d( in_channels=in_c, out_channels=out_c, kernel_size=4, stride=2, padding=1, bias=False, ), norm(out_c), nn.LeakyReLU(0.1, inplace=True), ) class deconv4x4(nn.Module): def __init__(self, in_c, out_c, norm=nn.BatchNorm2d): super(deconv4x4, self).__init__() self.deconv = nn.ConvTranspose2d( in_channels=in_c, out_channels=out_c, kernel_size=4, stride=2, padding=1, bias=False, ) self.bn = norm(out_c) self.lrelu = nn.LeakyReLU(0.1, inplace=True) def forward(self, input, skip): x = self.deconv(input) x = self.bn(x) x = self.lrelu(x) return torch.cat((x, skip), dim=1) class MLAttrEncoder(nn.Module): def __init__(self, finetune=False, downup=False): super(MLAttrEncoder, self).__init__() self.downup = downup if self.downup: self.conv00 = conv4x4(3, 16) self.conv01 = conv4x4(16, 32) self.deconv7 = deconv4x4(64, 16) self.conv1 = conv4x4(3, 32) self.conv2 = conv4x4(32, 64) self.conv3 = conv4x4(64, 128) self.conv4 = conv4x4(128, 256) self.conv5 = conv4x4(256, 512) self.conv6 = conv4x4(512, 1024) self.conv7 = conv4x4(1024, 1024) self.deconv1 = deconv4x4(1024, 1024) self.deconv2 = deconv4x4(2048, 512) self.deconv3 = deconv4x4(1024, 256) self.deconv4 = deconv4x4(512, 128) self.deconv5 = deconv4x4(256, 64) self.deconv6 = deconv4x4(128, 32) self.apply(weight_init) self.finetune = finetune if finetune: for name, param in self.named_parameters(): param.requires_grad = False if self.downup: self.conv00.requires_grad_(True) self.conv01.requires_grad_(True) self.deconv7.requires_grad_(True) def forward(self, Xt): if self.downup: feat0 = self.conv00(Xt) # (16,256,256) feat1 = self.conv01(feat0) # (32,128,128) else: feat0 = None feat1 = self.conv1(Xt) # 32x128x128 feat2 = self.conv2(feat1) # 64x64x64 feat3 = self.conv3(feat2) # 128x32x32 feat4 = self.conv4(feat3) # 256x16xx16 feat5 = self.conv5(feat4) # 512x8x8 feat6 = self.conv6(feat5) # 1024x4x4 if self.downup: z_attr1 = self.conv7(feat6) # 1024x2x2 z_attr2 = self.deconv1(z_attr1, feat6) z_attr3 = self.deconv2(z_attr2, feat5) z_attr4 = self.deconv3(z_attr3, feat4) z_attr5 = self.deconv4(z_attr4, feat3) z_attr6 = self.deconv5(z_attr5, feat2) z_attr7 = self.deconv6(z_attr6, feat1) # (128,64,64)+(32,128,128)->(64,128,128) z_attr8 = self.deconv7(z_attr7, feat0) # (64,128,128)+(16,256,256)->(32,256,256) z_attr9 = F.interpolate( z_attr8, scale_factor=2, mode="bilinear", align_corners=True ) # (32,512,512) return ( z_attr1, z_attr2, z_attr3, z_attr4, z_attr5, z_attr6, z_attr7, z_attr8, z_attr9 ) else: z_attr1 = self.conv7(feat6) # 1024x2x2 z_attr2 = self.deconv1(z_attr1, feat6) z_attr3 = self.deconv2(z_attr2, feat5) z_attr4 = self.deconv3(z_attr3, feat4) z_attr5 = self.deconv4(z_attr4, feat3) z_attr6 = self.deconv5(z_attr5, feat2) z_attr7 = self.deconv6(z_attr6, feat1) z_attr8 = F.interpolate( z_attr7, scale_factor=2, mode="bilinear", align_corners=True ) return ( z_attr1, z_attr2, z_attr3, z_attr4, z_attr5, z_attr6, z_attr7, z_attr8, ) class AADGenerator(nn.Module): def __init__(self, c_id=256, finetune=False, downup=False): super(AADGenerator, self).__init__() self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0) self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id) self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id) self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id) self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id) self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id) self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id) self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id) self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id) self.downup = downup if downup: self.AADBlk8_0 = AAD_ResBlk(64, 32, 32, c_id) self.AADBlk8_1 = AAD_ResBlk(32, 3, 32, c_id) self.apply(weight_init) if finetune: for name, param in self.named_parameters(): param.requires_grad = False self.AADBlk8_0.requires_grad_(True) self.AADBlk8_1.requires_grad_(True) def forward(self, z_attr, z_id): m = self.up1(z_id.reshape(z_id.shape[0], -1, 1, 1)) scale= z_attr[0].shape[2] // 2 # adaptive support for 512x512, 1024x1024 m = F.interpolate(m, scale_factor=scale, mode='bilinear', align_corners=True) m2, m2_ = self.AADBlk1(m, z_attr[0], z_id) m2 = F.interpolate( m2, scale_factor=2, mode="bilinear", align_corners=True, ) m3, m3_ = self.AADBlk2(m2, z_attr[1], z_id) m3 = F.interpolate( m3, scale_factor=2, mode="bilinear", align_corners=True, ) m4, m4_ = self.AADBlk3(m3, z_attr[2], z_id) m4 = F.interpolate( m4, scale_factor=2, mode="bilinear", align_corners=True, ) m5, m5_ = self.AADBlk4(m4, z_attr[3], z_id) m5 = F.interpolate( m5, scale_factor=2, mode="bilinear", align_corners=True, ) m6, m6_ = self.AADBlk5(m5, z_attr[4], z_id) m6 = F.interpolate( m6, scale_factor=2, mode="bilinear", align_corners=True, ) m7, m7_ = self.AADBlk6(m6, z_attr[5], z_id) m7 = F.interpolate( m7, scale_factor=2, mode="bilinear", align_corners=True, ) m8, m8_ = self.AADBlk7(m7, z_attr[6], z_id) m8 = F.interpolate( m8, scale_factor=2, mode="bilinear", align_corners=True, ) if self.downup: y0, m9_ = self.AADBlk8_0(m8, z_attr[7], z_id) y0 = F.interpolate(y0, scale_factor=2, mode='bilinear', align_corners=True) y1, m10_ = self.AADBlk8_1(y0, z_attr[8], z_id) y = torch.tanh(y1) else: y, m9_ = self.AADBlk8(m8, z_attr[7], z_id) y = torch.tanh(y) return y # , m # yuange class AEI_Net(nn.Module): def __init__(self, c_id=512, finetune=False, downup=False): super(AEI_Net, self).__init__() self.encoder = MLAttrEncoder(finetune=finetune, downup=downup) self.generator = AADGenerator(c_id, finetune=finetune, downup=downup) def forward(self, Xt, z_id): attr = self.encoder(Xt) Y = self.generator(attr, z_id) # yuange return Y, attr def get_attr(self, X): return self.encoder(X) def trainable_params(self): train_params = [] for param in self.parameters(): if param.requires_grad: train_params.append(param) return train_params if __name__ == "__main__": aie = AEI_Net(512).eval() x = aie(torch.randn(1, 3, 512, 512), torch.randn(1, 512)) # def numel(m: torch.nn.Module, only_trainable: bool = False): # """ # returns the total number of parameters used by `m` (only counting # shared parameters once); if `only_trainable` is True, then only # includes parameters with `requires_grad = True` # """ # parameters = list(m.parameters()) # if only_trainable: # parameters = [p for p in parameters if p.requires_grad] # unique = {p.data_ptr(): p for p in parameters}.values() # return sum(p.numel() for p in unique) # # # print(numel(aie, True)) # print(x[0].size()) # print(len(x[-1])) import thop img = torch.randn(1, 3, 256, 256) latent = torch.randn(1, 512) net = aie flops, params = thop.profile(net, inputs=(img, latent), verbose=False) print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))