Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from .deform_conv import ModulatedDeformConv | |
from .dyrelu import h_sigmoid, DYReLU | |
class Conv3x3Norm(torch.nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
stride, | |
deformable=False, | |
use_gn=False): | |
super(Conv3x3Norm, self).__init__() | |
if deformable: | |
self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) | |
else: | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) | |
if use_gn: | |
self.bn = nn.GroupNorm(num_groups=16, num_channels=out_channels) | |
else: | |
self.bn = None | |
def forward(self, input, **kwargs): | |
x = self.conv(input, **kwargs) | |
if self.bn: | |
x = self.bn(x) | |
return x | |
class DyConv(nn.Module): | |
def __init__(self, | |
in_channels=256, | |
out_channels=256, | |
conv_func=Conv3x3Norm, | |
use_dyfuse=True, | |
use_dyrelu=False, | |
use_deform=False | |
): | |
super(DyConv, self).__init__() | |
self.DyConv = nn.ModuleList() | |
self.DyConv.append(conv_func(in_channels, out_channels, 1)) | |
self.DyConv.append(conv_func(in_channels, out_channels, 1)) | |
self.DyConv.append(conv_func(in_channels, out_channels, 2)) | |
if use_dyfuse: | |
self.AttnConv = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(in_channels, 1, kernel_size=1), | |
nn.ReLU(inplace=True)) | |
self.h_sigmoid = h_sigmoid() | |
else: | |
self.AttnConv = None | |
if use_dyrelu: | |
self.relu = DYReLU(in_channels, out_channels) | |
else: | |
self.relu = nn.ReLU() | |
if use_deform: | |
self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1) | |
else: | |
self.offset = None | |
self.init_weights() | |
def init_weights(self): | |
for m in self.DyConv.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.normal_(m.weight.data, 0, 0.01) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
if self.AttnConv is not None: | |
for m in self.AttnConv.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.normal_(m.weight.data, 0, 0.01) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
def forward(self, x): | |
next_x = [] | |
for level, feature in enumerate(x): | |
conv_args = dict() | |
if self.offset is not None: | |
offset_mask = self.offset(feature) | |
offset = offset_mask[:, :18, :, :] | |
mask = offset_mask[:, 18:, :, :].sigmoid() | |
conv_args = dict(offset=offset, mask=mask) | |
temp_fea = [self.DyConv[1](feature, **conv_args)] | |
if level > 0: | |
temp_fea.append(self.DyConv[2](x[level - 1], **conv_args)) | |
if level < len(x) - 1: | |
temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args), | |
size=[feature.size(2), feature.size(3)])) | |
mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False) | |
if self.AttnConv is not None: | |
attn_fea = [] | |
res_fea = [] | |
for fea in temp_fea: | |
res_fea.append(fea) | |
attn_fea.append(self.AttnConv(fea)) | |
res_fea = torch.stack(res_fea) | |
spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea)) | |
mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False) | |
next_x.append(self.relu(mean_fea)) | |
return next_x | |
class DyHead(nn.Module): | |
def __init__(self, cfg, in_channels): | |
super(DyHead, self).__init__() | |
self.cfg = cfg | |
channels = cfg.MODEL.DYHEAD.CHANNELS | |
use_gn = cfg.MODEL.DYHEAD.USE_GN | |
use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU | |
use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE | |
use_deform = cfg.MODEL.DYHEAD.USE_DFCONV | |
conv_func = lambda i,o,s : Conv3x3Norm(i,o,s,deformable=use_deform,use_gn=use_gn) | |
dyhead_tower = [] | |
for i in range(cfg.MODEL.DYHEAD.NUM_CONVS): | |
dyhead_tower.append( | |
DyConv( | |
in_channels if i == 0 else channels, | |
channels, | |
conv_func=conv_func, | |
use_dyrelu=use_dyrelu, | |
use_dyfuse=use_dyfuse, | |
use_deform=use_deform | |
) | |
) | |
self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) | |
def forward(self, x): | |
dyhead_tower = self.dyhead_tower(x) | |
return dyhead_tower |