PSHuman / lib /pymafx /models /res_module.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
15.9 kB
# code brought in part from https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from lib.pymafx.core.cfgs import cfg
# from .transformers.tokenlearner import TokenLearner
import logging
logger = logging.getLogger(__name__)
BN_MOMENTUM = 0.1
def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes * groups,
out_planes * groups,
kernel_size=3,
stride=stride,
padding=1,
bias=bias,
groups=groups
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride, groups=groups)
self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, groups=groups)
self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
super().__init__()
self.conv1 = nn.Conv2d(
inplanes * groups, planes * groups, kernel_size=1, bias=False, groups=groups
)
self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
planes * groups,
planes * groups,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
groups=groups
)
self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(
planes * groups,
planes * self.expansion * groups,
kernel_size=1,
bias=False,
groups=groups
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
resnet_spec = {
18: (BasicBlock, [2, 2, 2, 2]),
34: (BasicBlock, [3, 4, 6, 3]),
50: (Bottleneck, [3, 4, 6, 3]),
101: (Bottleneck, [3, 4, 23, 3]),
152: (Bottleneck, [3, 8, 36, 3])
}
class IUV_predict_layer(nn.Module):
def __init__(self, feat_dim=256, final_cov_k=3, out_channels=25, with_uv=True, mode='iuv'):
super().__init__()
assert mode in ['iuv', 'seg', 'pncc']
self.mode = mode
if mode == 'seg':
self.predict_ann_index = nn.Conv2d(
in_channels=feat_dim,
out_channels=15,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.predict_uv_index = nn.Conv2d(
in_channels=feat_dim,
out_channels=25,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
elif mode == 'iuv':
self.predict_u = nn.Conv2d(
in_channels=feat_dim,
out_channels=25,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.predict_v = nn.Conv2d(
in_channels=feat_dim,
out_channels=25,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.predict_ann_index = nn.Conv2d(
in_channels=feat_dim,
out_channels=15,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.predict_uv_index = nn.Conv2d(
in_channels=feat_dim,
out_channels=25,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
elif mode in ['pncc']:
self.predict_pncc = nn.Conv2d(
in_channels=feat_dim,
out_channels=3,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.inplanes = feat_dim
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
return_dict = {}
if self.mode in ['iuv', 'seg']:
predict_uv_index = self.predict_uv_index(x)
predict_ann_index = self.predict_ann_index(x)
return_dict['predict_uv_index'] = predict_uv_index
return_dict['predict_ann_index'] = predict_ann_index
if self.mode == 'iuv':
predict_u = self.predict_u(x)
predict_v = self.predict_v(x)
return_dict['predict_u'] = predict_u
return_dict['predict_v'] = predict_v
else:
return_dict['predict_u'] = None
return_dict['predict_v'] = None
# return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
# return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
if self.mode == 'pncc':
predict_pncc = self.predict_pncc(x)
return_dict['predict_pncc'] = predict_pncc
return return_dict
class Seg_predict_layer(nn.Module):
def __init__(self, feat_dim=256, final_cov_k=3, out_channels=25):
super().__init__()
self.predict_seg_index = nn.Conv2d(
in_channels=feat_dim,
out_channels=out_channels,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.inplanes = feat_dim
def forward(self, x):
return_dict = {}
predict_seg_index = self.predict_seg_index(x)
return_dict['predict_seg_index'] = predict_seg_index
return return_dict
class Kps_predict_layer(nn.Module):
def __init__(self, feat_dim=256, final_cov_k=3, out_channels=3, add_module=None):
super().__init__()
if add_module is not None:
conv = nn.Conv2d(
in_channels=feat_dim,
out_channels=out_channels,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.predict_kps = nn.Sequential(
add_module,
# nn.BatchNorm2d(feat_dim, momentum=BN_MOMENTUM),
# conv,
)
else:
self.predict_kps = nn.Conv2d(
in_channels=feat_dim,
out_channels=out_channels,
kernel_size=final_cov_k,
stride=1,
padding=1 if final_cov_k == 3 else 0
)
self.inplanes = feat_dim
def forward(self, x):
return_dict = {}
predict_kps = self.predict_kps(x)
return_dict['predict_kps'] = predict_kps
return return_dict
class SmplResNet(nn.Module):
def __init__(
self,
resnet_nums,
in_channels=3,
num_classes=229,
last_stride=2,
n_extra_feat=0,
truncate=0,
**kwargs
):
super().__init__()
self.inplanes = 64
self.truncate = truncate
# extra = cfg.MODEL.EXTRA
# self.deconv_with_bias = extra.DECONV_WITH_BIAS
block, layers = resnet_spec[resnet_nums]
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) if truncate < 2 else None
self.layer4 = self._make_layer(
block, 512, layers[3], stride=last_stride
) if truncate < 1 else None
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
self.num_classes = num_classes
if num_classes > 0:
self.final_layer = nn.Linear(512 * block.expansion, num_classes)
nn.init.xavier_uniform_(self.final_layer.weight, gain=0.01)
self.n_extra_feat = n_extra_feat
if n_extra_feat > 0:
self.trans_conv = nn.Sequential(
nn.Conv2d(
n_extra_feat + 512 * block.expansion,
512 * block.expansion,
kernel_size=1,
bias=False
), nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM), nn.ReLU(True)
)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x, infeat=None):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2) if self.truncate < 2 else x2
x4 = self.layer4(x3) if self.truncate < 1 else x3
if infeat is not None:
x4 = self.trans_conv(torch.cat([infeat, x4], 1))
if self.num_classes > 0:
xp = self.avg_pooling(x4)
cls = self.final_layer(xp.view(xp.size(0), -1))
if not cfg.DANET.USE_MEAN_PARA:
# for non-negative scale
scale = F.relu(cls[:, 0]).unsqueeze(1)
cls = torch.cat((scale, cls[:, 1:]), dim=1)
else:
cls = None
return cls, {'x4': x4}
def init_weights(self, pretrained=''):
if os.path.isfile(pretrained):
logger.info('=> loading pretrained model {}'.format(pretrained))
# self.load_state_dict(pretrained_state_dict, strict=False)
checkpoint = torch.load(pretrained)
if isinstance(checkpoint, OrderedDict):
# state_dict = checkpoint
state_dict_old = self.state_dict()
for key in state_dict_old.keys():
if key in checkpoint.keys():
if state_dict_old[key].shape != checkpoint[key].shape:
del checkpoint[key]
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict_old = checkpoint['state_dict']
state_dict = OrderedDict()
# delete 'module.' because it is saved from DataParallel module
for key in state_dict_old.keys():
if key.startswith('module.'):
# state_dict[key[7:]] = state_dict[key]
# state_dict.pop(key)
state_dict[key[7:]] = state_dict_old[key]
else:
state_dict[key] = state_dict_old[key]
else:
raise RuntimeError('No state_dict found in checkpoint file {}'.format(pretrained))
self.load_state_dict(state_dict, strict=False)
else:
logger.error('=> imagenet pretrained model dose not exist')
logger.error('=> please download it first')
raise ValueError('imagenet pretrained model does not exist')
class LimbResLayers(nn.Module):
def __init__(self, resnet_nums, inplanes, outplanes=None, groups=1, **kwargs):
super().__init__()
self.inplanes = inplanes
block, layers = resnet_spec[resnet_nums]
self.outplanes = 256 if outplanes == None else outplanes
self.layer3 = self._make_layer(block, self.outplanes, layers[2], stride=2, groups=groups)
# self.outplanes = 512 if outplanes == None else outplanes
# self.layer4 = self._make_layer(block, self.outplanes, layers[3], stride=2, groups=groups)
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
# self.tklr = TokenLearner(S=n_token)
def _make_layer(self, block, planes, blocks, stride=1, groups=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes * groups,
planes * block.expansion * groups,
kernel_size=1,
stride=stride,
bias=False,
groups=groups
),
nn.BatchNorm2d(planes * block.expansion * groups, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, groups=groups))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=groups))
return nn.Sequential(*layers)
def forward(self, x):
x = self.layer3(x)
# x = self.layer4(x)
# x = self.avg_pooling(x)
# x_g = self.tklr(x.permute(0, 2, 3, 1).contiguous())
# x_g = x_g.reshape(x.shape[0], -1)
return x, None