|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
import copy |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from torch.nn import Dropout, Softmax, LayerNorm |
|
from torch.nn.modules.utils import _pair, _triple |
|
|
|
|
|
|
|
def get_activation(activation_type): |
|
activation_type = activation_type.lower() |
|
if hasattr(nn, activation_type): |
|
return getattr(nn, activation_type)() |
|
else: |
|
return nn.ReLU() |
|
|
|
def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): |
|
layers = [] |
|
layers.append(ConvBatchNorm(in_channels, out_channels, activation)) |
|
|
|
for _ in range(nb_Conv - 1): |
|
layers.append(ConvBatchNorm(out_channels, out_channels, activation)) |
|
return nn.Sequential(*layers) |
|
|
|
class ConvBatchNorm(nn.Module): |
|
"""(convolution => [BN] => ReLU)""" |
|
|
|
def __init__(self, in_channels, out_channels, activation='ReLU'): |
|
super(ConvBatchNorm, self).__init__() |
|
self.conv = nn.Conv3d(in_channels, out_channels, |
|
kernel_size=3, padding=1) |
|
self.norm = nn.BatchNorm3d(out_channels) |
|
self.activation = get_activation(activation) |
|
|
|
def forward(self, x): |
|
out = self.conv(x) |
|
out = self.norm(out) |
|
return self.activation(out) |
|
|
|
class DownBlock(nn.Module): |
|
"""Downscaling with maxpool convolution""" |
|
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): |
|
super(DownBlock, self).__init__() |
|
self.maxpool = nn.MaxPool3d(2) |
|
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) |
|
|
|
def forward(self, x): |
|
out = self.maxpool(x) |
|
return self.nConvs(out) |
|
|
|
class Flatten(nn.Module): |
|
def forward(self, x): |
|
return x.view(x.size(0), -1) |
|
|
|
class CCA(nn.Module): |
|
""" |
|
CCA Block |
|
""" |
|
def __init__(self, F_g, F_x): |
|
super().__init__() |
|
self.mlp_x = nn.Sequential( |
|
Flatten(), |
|
nn.Linear(F_x, F_x)) |
|
self.mlp_g = nn.Sequential( |
|
Flatten(), |
|
nn.Linear(F_g, F_x)) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, g, x): |
|
|
|
avg_pool_x = F.avg_pool3d( x, (x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4))) |
|
channel_att_x = self.mlp_x(avg_pool_x) |
|
avg_pool_g = F.avg_pool3d( g, (g.size(2), g.size(3), g.size(4)), stride=(g.size(2), g.size(3), g.size(4))) |
|
channel_att_g = self.mlp_g(avg_pool_g) |
|
channel_att_sum = (channel_att_x + channel_att_g)/2.0 |
|
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x) |
|
x_after_channel = x * scale |
|
out = self.relu(x_after_channel) |
|
return out |
|
|
|
class UpBlock_attention(nn.Module): |
|
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): |
|
super().__init__() |
|
self.up = nn.Upsample(scale_factor=2) |
|
self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2) |
|
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) |
|
|
|
def forward(self, x, skip_x): |
|
up = self.up(x) |
|
skip_x_att = self.coatt(g=up, x=skip_x) |
|
x = torch.cat([skip_x_att, up], dim=1) |
|
return self.nConvs(x) |
|
|
|
class UCTransNet(nn.Module): |
|
def __init__(self, in_channels, out_channels, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, feature_size, img_size, patch_sizes): |
|
super().__init__() |
|
self.inc = ConvBatchNorm(in_channels, feature_size) |
|
self.down1 = DownBlock(feature_size, feature_size*2, nb_Conv=2) |
|
self.down2 = DownBlock(feature_size*2, feature_size*4, nb_Conv=2) |
|
self.down3 = DownBlock(feature_size*4, feature_size*8, nb_Conv=2) |
|
self.down4 = DownBlock(feature_size*8, feature_size*8, nb_Conv=2) |
|
self.mtc = ChannelTransformer(img_size, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, |
|
channel_num=[feature_size, feature_size*2, feature_size*4, feature_size*8], |
|
patchSize=patch_sizes) |
|
self.up4 = UpBlock_attention(feature_size*16, feature_size*4, nb_Conv=2) |
|
self.up3 = UpBlock_attention(feature_size*8, feature_size*2, nb_Conv=2) |
|
self.up2 = UpBlock_attention(feature_size*4, feature_size, nb_Conv=2) |
|
self.up1 = UpBlock_attention(feature_size*2, feature_size, nb_Conv=2) |
|
self.outc = nn.Conv3d(feature_size, out_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1)) |
|
|
|
def forward(self, x): |
|
x = x.float() |
|
x1 = self.inc(x) |
|
x2 = self.down1(x1) |
|
x3 = self.down2(x2) |
|
x4 = self.down3(x3) |
|
x5 = self.down4(x4) |
|
x1,x2,x3,x4 = self.mtc(x1,x2,x3,x4) |
|
x = self.up4(x5, x4) |
|
x = self.up3(x, x3) |
|
x = self.up2(x, x2) |
|
x = self.up1(x, x1) |
|
|
|
logits = self.outc(x) |
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Channel_Embeddings(nn.Module): |
|
"""Construct the embeddings from patch, position embeddings. |
|
""" |
|
def __init__(self, patchsize, img_size, in_channels, reduce_scale): |
|
super().__init__() |
|
patch_size = _triple(patchsize) |
|
n_patches = (img_size[0] // reduce_scale // patch_size[0]) * (img_size[1] // reduce_scale // patch_size[1]) * (img_size[2] // reduce_scale // patch_size[2]) |
|
|
|
self.patch_embeddings = nn.Conv3d(in_channels=in_channels, |
|
out_channels=in_channels, |
|
kernel_size=patch_size, |
|
stride=patch_size) |
|
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels)) |
|
self.dropout = Dropout(0.1) |
|
|
|
def forward(self, x): |
|
if x is None: |
|
return None |
|
x = self.patch_embeddings(x) |
|
h, w, d = x.shape[-3:] |
|
x = x.flatten(2) |
|
x = x.transpose(-1, -2) |
|
embeddings = x + self.position_embeddings |
|
embeddings = self.dropout(embeddings) |
|
return embeddings, (h, w, d) |
|
|
|
class Reconstruct(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, scale_factor): |
|
super(Reconstruct, self).__init__() |
|
if kernel_size == 3: |
|
padding = 1 |
|
else: |
|
padding = 0 |
|
self.conv = nn.Conv3d(in_channels, out_channels,kernel_size=kernel_size, padding=padding) |
|
self.norm = nn.BatchNorm3d(out_channels) |
|
self.activation = nn.ReLU(inplace=True) |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, x, shp): |
|
if x is None: |
|
return None |
|
|
|
B, n_patch, hidden = x.size() |
|
h, w, d = shp |
|
x = x.permute(0, 2, 1) |
|
x = x.contiguous().view(B, hidden, h, w, d) |
|
x = nn.Upsample(scale_factor=self.scale_factor)(x) |
|
|
|
out = self.conv(x) |
|
out = self.norm(out) |
|
out = self.activation(out) |
|
return out |
|
|
|
class Attention_org(nn.Module): |
|
def __init__(self, KV_size, channel_num, num_heads, attention_dropout_rate): |
|
super(Attention_org, self).__init__() |
|
self.KV_size = KV_size |
|
self.channel_num = channel_num |
|
self.num_attention_heads = num_heads |
|
|
|
self.query1 = nn.ModuleList() |
|
self.query2 = nn.ModuleList() |
|
self.query3 = nn.ModuleList() |
|
self.query4 = nn.ModuleList() |
|
self.key = nn.ModuleList() |
|
self.value = nn.ModuleList() |
|
|
|
for _ in range(num_heads): |
|
query1 = nn.Linear(channel_num[0], channel_num[0], bias=False) |
|
query2 = nn.Linear(channel_num[1], channel_num[1], bias=False) |
|
query3 = nn.Linear(channel_num[2], channel_num[2], bias=False) |
|
query4 = nn.Linear(channel_num[3], channel_num[3], bias=False) |
|
key = nn.Linear( self.KV_size, self.KV_size, bias=False) |
|
value = nn.Linear(self.KV_size, self.KV_size, bias=False) |
|
self.query1.append(copy.deepcopy(query1)) |
|
self.query2.append(copy.deepcopy(query2)) |
|
self.query3.append(copy.deepcopy(query3)) |
|
self.query4.append(copy.deepcopy(query4)) |
|
self.key.append(copy.deepcopy(key)) |
|
self.value.append(copy.deepcopy(value)) |
|
self.psi = nn.InstanceNorm2d(self.num_attention_heads) |
|
self.softmax = Softmax(dim=3) |
|
self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False) |
|
self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False) |
|
self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False) |
|
self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False) |
|
self.attn_dropout = Dropout(attention_dropout_rate) |
|
self.proj_dropout = Dropout(attention_dropout_rate) |
|
|
|
|
|
|
|
def forward(self, emb1,emb2,emb3,emb4, emb_all): |
|
multi_head_Q1_list = [] |
|
multi_head_Q2_list = [] |
|
multi_head_Q3_list = [] |
|
multi_head_Q4_list = [] |
|
multi_head_K_list = [] |
|
multi_head_V_list = [] |
|
if emb1 is not None: |
|
for query1 in self.query1: |
|
Q1 = query1(emb1) |
|
multi_head_Q1_list.append(Q1) |
|
if emb2 is not None: |
|
for query2 in self.query2: |
|
Q2 = query2(emb2) |
|
multi_head_Q2_list.append(Q2) |
|
if emb3 is not None: |
|
for query3 in self.query3: |
|
Q3 = query3(emb3) |
|
multi_head_Q3_list.append(Q3) |
|
if emb4 is not None: |
|
for query4 in self.query4: |
|
Q4 = query4(emb4) |
|
multi_head_Q4_list.append(Q4) |
|
for key in self.key: |
|
K = key(emb_all) |
|
multi_head_K_list.append(K) |
|
for value in self.value: |
|
V = value(emb_all) |
|
multi_head_V_list.append(V) |
|
|
|
|
|
multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None |
|
multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None |
|
multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None |
|
multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None |
|
multi_head_K = torch.stack(multi_head_K_list, dim=1) |
|
multi_head_V = torch.stack(multi_head_V_list, dim=1) |
|
|
|
multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None |
|
multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None |
|
multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None |
|
multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None |
|
|
|
attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None |
|
attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None |
|
attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None |
|
attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None |
|
|
|
attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None |
|
attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None |
|
attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None |
|
attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None |
|
|
|
attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None |
|
attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None |
|
attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None |
|
attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None |
|
|
|
|
|
attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None |
|
attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None |
|
attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None |
|
attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None |
|
|
|
multi_head_V = multi_head_V.transpose(-1, -2) |
|
context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None |
|
context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None |
|
context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None |
|
context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None |
|
|
|
context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None |
|
context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None |
|
context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None |
|
context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None |
|
context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None |
|
context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None |
|
context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None |
|
context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None |
|
|
|
O1 = self.out1(context_layer1) if emb1 is not None else None |
|
O2 = self.out2(context_layer2) if emb2 is not None else None |
|
O3 = self.out3(context_layer3) if emb3 is not None else None |
|
O4 = self.out4(context_layer4) if emb4 is not None else None |
|
O1 = self.proj_dropout(O1) if emb1 is not None else None |
|
O2 = self.proj_dropout(O2) if emb2 is not None else None |
|
O3 = self.proj_dropout(O3) if emb3 is not None else None |
|
O4 = self.proj_dropout(O4) if emb4 is not None else None |
|
return O1,O2,O3,O4 |
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
def __init__(self, in_channel, mlp_channel, dropout_rate): |
|
super(Mlp, self).__init__() |
|
self.fc1 = nn.Linear(in_channel, mlp_channel) |
|
self.fc2 = nn.Linear(mlp_channel, in_channel) |
|
self.act_fn = nn.GELU() |
|
self.dropout = Dropout(dropout_rate) |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
nn.init.xavier_uniform_(self.fc1.weight) |
|
nn.init.xavier_uniform_(self.fc2.weight) |
|
nn.init.normal_(self.fc1.bias, std=1e-6) |
|
nn.init.normal_(self.fc2.bias, std=1e-6) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act_fn(x) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
class Block_ViT(nn.Module): |
|
def __init__(self, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate): |
|
super(Block_ViT, self).__init__() |
|
self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6) |
|
self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6) |
|
self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6) |
|
self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6) |
|
self.attn_norm = LayerNorm(KV_size,eps=1e-6) |
|
self.channel_attn = Attention_org(KV_size, channel_num, num_heads, attention_dropout_rate) |
|
|
|
self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6) |
|
self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6) |
|
self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6) |
|
self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6) |
|
self.ffn1 = Mlp(channel_num[0],channel_num[0]*4, mlp_dropout_rate) |
|
self.ffn2 = Mlp(channel_num[1],channel_num[1]*4, mlp_dropout_rate) |
|
self.ffn3 = Mlp(channel_num[2],channel_num[2]*4, mlp_dropout_rate) |
|
self.ffn4 = Mlp(channel_num[3],channel_num[3]*4, mlp_dropout_rate) |
|
|
|
|
|
def forward(self, emb1,emb2,emb3,emb4): |
|
embcat = [] |
|
org1 = emb1 |
|
org2 = emb2 |
|
org3 = emb3 |
|
org4 = emb4 |
|
for i in range(4): |
|
var_name = "emb"+str(i+1) |
|
tmp_var = locals()[var_name] |
|
if tmp_var is not None: |
|
embcat.append(tmp_var) |
|
|
|
emb_all = torch.cat(embcat,dim=2) |
|
cx1 = self.attn_norm1(emb1) if emb1 is not None else None |
|
cx2 = self.attn_norm2(emb2) if emb2 is not None else None |
|
cx3 = self.attn_norm3(emb3) if emb3 is not None else None |
|
cx4 = self.attn_norm4(emb4) if emb4 is not None else None |
|
emb_all = self.attn_norm(emb_all) |
|
cx1,cx2,cx3,cx4 = self.channel_attn(cx1,cx2,cx3,cx4,emb_all) |
|
cx1 = org1 + cx1 if emb1 is not None else None |
|
cx2 = org2 + cx2 if emb2 is not None else None |
|
cx3 = org3 + cx3 if emb3 is not None else None |
|
cx4 = org4 + cx4 if emb4 is not None else None |
|
|
|
org1 = cx1 |
|
org2 = cx2 |
|
org3 = cx3 |
|
org4 = cx4 |
|
x1 = self.ffn_norm1(cx1) if emb1 is not None else None |
|
x2 = self.ffn_norm2(cx2) if emb2 is not None else None |
|
x3 = self.ffn_norm3(cx3) if emb3 is not None else None |
|
x4 = self.ffn_norm4(cx4) if emb4 is not None else None |
|
x1 = self.ffn1(x1) if emb1 is not None else None |
|
x2 = self.ffn2(x2) if emb2 is not None else None |
|
x3 = self.ffn3(x3) if emb3 is not None else None |
|
x4 = self.ffn4(x4) if emb4 is not None else None |
|
x1 = x1 + org1 if emb1 is not None else None |
|
x2 = x2 + org2 if emb2 is not None else None |
|
x3 = x3 + org3 if emb3 is not None else None |
|
x4 = x4 + org4 if emb4 is not None else None |
|
|
|
return x1, x2, x3, x4 |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, num_layers, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate): |
|
super(Encoder, self).__init__() |
|
self.layer = nn.ModuleList() |
|
self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6) |
|
self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6) |
|
self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6) |
|
self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6) |
|
for _ in range(num_layers): |
|
layer = Block_ViT(KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate) |
|
self.layer.append(copy.deepcopy(layer)) |
|
|
|
def forward(self, emb1,emb2,emb3,emb4): |
|
for layer_block in self.layer: |
|
emb1,emb2,emb3,emb4 = layer_block(emb1,emb2,emb3,emb4) |
|
emb1 = self.encoder_norm1(emb1) if emb1 is not None else None |
|
emb2 = self.encoder_norm2(emb2) if emb2 is not None else None |
|
emb3 = self.encoder_norm3(emb3) if emb3 is not None else None |
|
emb4 = self.encoder_norm4(emb4) if emb4 is not None else None |
|
return emb1,emb2,emb3,emb4 |
|
|
|
|
|
class ChannelTransformer(nn.Module): |
|
def __init__(self, img_size, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]): |
|
super().__init__() |
|
|
|
self.patchSize_1 = patchSize[0] |
|
self.patchSize_2 = patchSize[1] |
|
self.patchSize_3 = patchSize[2] |
|
self.patchSize_4 = patchSize[3] |
|
self.embeddings_1 = Channel_Embeddings(self.patchSize_1, img_size=img_size, reduce_scale=1, in_channels=channel_num[0]) |
|
self.embeddings_2 = Channel_Embeddings(self.patchSize_2, img_size=img_size, reduce_scale=2, in_channels=channel_num[1]) |
|
self.embeddings_3 = Channel_Embeddings(self.patchSize_3, img_size=img_size, reduce_scale=4, in_channels=channel_num[2]) |
|
self.embeddings_4 = Channel_Embeddings(self.patchSize_4, img_size=img_size, reduce_scale=8, in_channels=channel_num[3]) |
|
self.encoder = Encoder(num_layers, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate) |
|
|
|
self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=_triple(self.patchSize_1)) |
|
self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=_triple(self.patchSize_2)) |
|
self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=_triple(self.patchSize_3)) |
|
self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=_triple(self.patchSize_4)) |
|
|
|
def forward(self, en1, en2, en3, en4): |
|
|
|
emb1, shp1 = self.embeddings_1(en1) |
|
emb2, shp2 = self.embeddings_2(en2) |
|
emb3, shp3 = self.embeddings_3(en3) |
|
emb4, shp4 = self.embeddings_4(en4) |
|
|
|
encoded1, encoded2, encoded3, encoded4 = self.encoder(emb1,emb2,emb3,emb4) |
|
x1 = self.reconstruct_1(encoded1, shp1) if en1 is not None else None |
|
x2 = self.reconstruct_2(encoded2, shp2) if en2 is not None else None |
|
x3 = self.reconstruct_3(encoded3, shp3) if en3 is not None else None |
|
x4 = self.reconstruct_4(encoded4, shp4) if en4 is not None else None |
|
|
|
x1 = x1 + en1 if en1 is not None else None |
|
x2 = x2 + en2 if en2 is not None else None |
|
x3 = x3 + en3 if en3 is not None else None |
|
x4 = x4 + en4 if en4 is not None else None |
|
|
|
return x1, x2, x3, x4 |