Spaces:
Runtime error
Runtime error
File size: 5,205 Bytes
128757a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn.functional as F
from torch import nn
from .deform_conv import ModulatedDeformConv
from .dyrelu import h_sigmoid, DYReLU
class Conv3x3Norm(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
deformable=False,
use_gn=False):
super(Conv3x3Norm, self).__init__()
if deformable:
self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
else:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
if use_gn:
self.bn = nn.GroupNorm(num_groups=16, num_channels=out_channels)
else:
self.bn = None
def forward(self, input, **kwargs):
x = self.conv(input, **kwargs)
if self.bn:
x = self.bn(x)
return x
class DyConv(nn.Module):
def __init__(self,
in_channels=256,
out_channels=256,
conv_func=Conv3x3Norm,
use_dyfuse=True,
use_dyrelu=False,
use_deform=False
):
super(DyConv, self).__init__()
self.DyConv = nn.ModuleList()
self.DyConv.append(conv_func(in_channels, out_channels, 1))
self.DyConv.append(conv_func(in_channels, out_channels, 1))
self.DyConv.append(conv_func(in_channels, out_channels, 2))
if use_dyfuse:
self.AttnConv = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, 1, kernel_size=1),
nn.ReLU(inplace=True))
self.h_sigmoid = h_sigmoid()
else:
self.AttnConv = None
if use_dyrelu:
self.relu = DYReLU(in_channels, out_channels)
else:
self.relu = nn.ReLU()
if use_deform:
self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1)
else:
self.offset = None
self.init_weights()
def init_weights(self):
for m in self.DyConv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
if self.AttnConv is not None:
for m in self.AttnConv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
next_x = []
for level, feature in enumerate(x):
conv_args = dict()
if self.offset is not None:
offset_mask = self.offset(feature)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, 18:, :, :].sigmoid()
conv_args = dict(offset=offset, mask=mask)
temp_fea = [self.DyConv[1](feature, **conv_args)]
if level > 0:
temp_fea.append(self.DyConv[2](x[level - 1], **conv_args))
if level < len(x) - 1:
temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args),
size=[feature.size(2), feature.size(3)]))
mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False)
if self.AttnConv is not None:
attn_fea = []
res_fea = []
for fea in temp_fea:
res_fea.append(fea)
attn_fea.append(self.AttnConv(fea))
res_fea = torch.stack(res_fea)
spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea))
mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False)
next_x.append(self.relu(mean_fea))
return next_x
class DyHead(nn.Module):
def __init__(self, cfg, in_channels):
super(DyHead, self).__init__()
self.cfg = cfg
channels = cfg.MODEL.DYHEAD.CHANNELS
use_gn = cfg.MODEL.DYHEAD.USE_GN
use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU
use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE
use_deform = cfg.MODEL.DYHEAD.USE_DFCONV
conv_func = lambda i,o,s : Conv3x3Norm(i,o,s,deformable=use_deform,use_gn=use_gn)
dyhead_tower = []
for i in range(cfg.MODEL.DYHEAD.NUM_CONVS):
dyhead_tower.append(
DyConv(
in_channels if i == 0 else channels,
channels,
conv_func=conv_func,
use_dyrelu=use_dyrelu,
use_dyfuse=use_dyfuse,
use_deform=use_deform
)
)
self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
def forward(self, x):
dyhead_tower = self.dyhead_tower(x)
return dyhead_tower |