Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
from monoscene.modules import ( | |
Process, | |
ASPP, | |
) | |
class CPMegaVoxels(nn.Module): | |
def __init__(self, feature, size, n_relations=4, bn_momentum=0.0003): | |
super().__init__() | |
self.size = size | |
self.n_relations = n_relations | |
print("n_relations", self.n_relations) | |
self.flatten_size = size[0] * size[1] * size[2] | |
self.feature = feature | |
self.context_feature = feature * 2 | |
self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2) | |
padding = ((size[0] + 1) % 2, (size[1] + 1) % 2, (size[2] + 1) % 2) | |
self.mega_context = nn.Sequential( | |
nn.Conv3d( | |
feature, self.context_feature, stride=2, padding=padding, kernel_size=3 | |
), | |
) | |
self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2) | |
self.context_prior_logits = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.Conv3d( | |
self.feature, | |
self.flatten_context_size, | |
padding=0, | |
kernel_size=1, | |
), | |
) | |
for i in range(n_relations) | |
] | |
) | |
self.aspp = ASPP(feature, [1, 2, 3]) | |
self.resize = nn.Sequential( | |
nn.Conv3d( | |
self.context_feature * self.n_relations + feature, | |
feature, | |
kernel_size=1, | |
padding=0, | |
bias=False, | |
), | |
Process(feature, nn.BatchNorm3d, bn_momentum, dilations=[1]), | |
) | |
def forward(self, input): | |
ret = {} | |
bs = input.shape[0] | |
x_agg = self.aspp(input) | |
# get the mega context | |
x_mega_context_raw = self.mega_context(x_agg) | |
x_mega_context = x_mega_context_raw.reshape(bs, self.context_feature, -1) | |
x_mega_context = x_mega_context.permute(0, 2, 1) | |
# get context prior map | |
x_context_prior_logits = [] | |
x_context_rels = [] | |
for rel in range(self.n_relations): | |
# Compute the relation matrices | |
x_context_prior_logit = self.context_prior_logits[rel](x_agg) | |
x_context_prior_logit = x_context_prior_logit.reshape( | |
bs, self.flatten_context_size, self.flatten_size | |
) | |
x_context_prior_logits.append(x_context_prior_logit.unsqueeze(1)) | |
x_context_prior_logit = x_context_prior_logit.permute(0, 2, 1) | |
x_context_prior = torch.sigmoid(x_context_prior_logit) | |
# Multiply the relation matrices with the mega context to gather context features | |
x_context_rel = torch.bmm(x_context_prior, x_mega_context) # bs, N, f | |
x_context_rels.append(x_context_rel) | |
x_context = torch.cat(x_context_rels, dim=2) | |
x_context = x_context.permute(0, 2, 1) | |
x_context = x_context.reshape( | |
bs, x_context.shape[1], self.size[0], self.size[1], self.size[2] | |
) | |
x = torch.cat([input, x_context], dim=1) | |
x = self.resize(x) | |
x_context_prior_logits = torch.cat(x_context_prior_logits, dim=1) | |
ret["P_logits"] = x_context_prior_logits | |
ret["x"] = x | |
return ret | |