Spaces:
Configuration error
Configuration error
import torch | |
import os | |
from torch import nn | |
import numpy as np | |
import torch.nn.functional | |
from collections import OrderedDict | |
from termcolor import colored | |
def sigmoid(x): | |
y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4) | |
return y | |
def _neg_loss(pred, gt): | |
''' Modified focal loss. Exactly the same as CornerNet. | |
Runs faster and costs a little bit more memory | |
Arguments: | |
pred (batch x c x h x w) | |
gt_regr (batch x c x h x w) | |
''' | |
pos_inds = gt.eq(1).float() | |
neg_inds = gt.lt(1).float() | |
neg_weights = torch.pow(1 - gt, 4) | |
loss = 0 | |
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds | |
neg_loss = torch.log(1 - pred) * torch.pow(pred, | |
2) * neg_weights * neg_inds | |
num_pos = pos_inds.float().sum() | |
pos_loss = pos_loss.sum() | |
neg_loss = neg_loss.sum() | |
if num_pos == 0: | |
loss = loss - neg_loss | |
else: | |
loss = loss - (pos_loss + neg_loss) / num_pos | |
return loss | |
class FocalLoss(nn.Module): | |
'''nn.Module warpper for focal loss''' | |
def __init__(self): | |
super(FocalLoss, self).__init__() | |
self.neg_loss = _neg_loss | |
def forward(self, out, target): | |
return self.neg_loss(out, target) | |
def smooth_l1_loss(vertex_pred, | |
vertex_targets, | |
vertex_weights, | |
sigma=1.0, | |
normalize=True, | |
reduce=True): | |
""" | |
:param vertex_pred: [b, vn*2, h, w] | |
:param vertex_targets: [b, vn*2, h, w] | |
:param vertex_weights: [b, 1, h, w] | |
:param sigma: | |
:param normalize: | |
:param reduce: | |
:return: | |
""" | |
b, ver_dim, _, _ = vertex_pred.shape | |
sigma_2 = sigma**2 | |
vertex_diff = vertex_pred - vertex_targets | |
diff = vertex_weights * vertex_diff | |
abs_diff = torch.abs(diff) | |
smoothL1_sign = (abs_diff < 1. / sigma_2).detach().float() | |
in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \ | |
+ (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign) | |
if normalize: | |
in_loss = torch.sum(in_loss.view(b, -1), 1) / ( | |
ver_dim * torch.sum(vertex_weights.view(b, -1), 1) + 1e-3) | |
if reduce: | |
in_loss = torch.mean(in_loss) | |
return in_loss | |
class SmoothL1Loss(nn.Module): | |
def __init__(self): | |
super(SmoothL1Loss, self).__init__() | |
self.smooth_l1_loss = smooth_l1_loss | |
def forward(self, | |
preds, | |
targets, | |
weights, | |
sigma=1.0, | |
normalize=True, | |
reduce=True): | |
return self.smooth_l1_loss(preds, targets, weights, sigma, normalize, | |
reduce) | |
class AELoss(nn.Module): | |
def __init__(self): | |
super(AELoss, self).__init__() | |
def forward(self, ae, ind, ind_mask): | |
""" | |
ae: [b, 1, h, w] | |
ind: [b, max_objs, max_parts] | |
ind_mask: [b, max_objs, max_parts] | |
obj_mask: [b, max_objs] | |
""" | |
# first index | |
b, _, h, w = ae.shape | |
b, max_objs, max_parts = ind.shape | |
obj_mask = torch.sum(ind_mask, dim=2) != 0 | |
ae = ae.view(b, h * w, 1) | |
seed_ind = ind.view(b, max_objs * max_parts, 1) | |
tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts) | |
# compute the mean | |
tag_mean = tag * ind_mask | |
tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4) | |
# pull ae of the same object to their mean | |
pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask | |
obj_num = obj_mask.sum(dim=1).float() | |
pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum() | |
pull /= b | |
# push away the mean of different objects | |
push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)) | |
push_dist = 1 - push_dist | |
push_dist = nn.functional.relu(push_dist, inplace=True) | |
obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2 | |
push_dist = push_dist * obj_mask.float() | |
push = ((push_dist.sum(dim=(1, 2)) - obj_num) / | |
(obj_num * (obj_num - 1) + 1e-4)).sum() | |
push /= b | |
return pull, push | |
class PolyMatchingLoss(nn.Module): | |
def __init__(self, pnum): | |
super(PolyMatchingLoss, self).__init__() | |
self.pnum = pnum | |
batch_size = 1 | |
pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32) | |
for b in range(batch_size): | |
for i in range(pnum): | |
pidx = (np.arange(pnum) + i) % pnum | |
pidxall[b, i] = pidx | |
device = torch.device('cuda') | |
pidxall = torch.from_numpy( | |
np.reshape(pidxall, newshape=(batch_size, -1))).to(device) | |
self.feature_id = pidxall.unsqueeze_(2).long().expand( | |
pidxall.size(0), pidxall.size(1), 2).detach() | |
def forward(self, pred, gt, loss_type="L2"): | |
pnum = self.pnum | |
batch_size = pred.size()[0] | |
feature_id = self.feature_id.expand(batch_size, | |
self.feature_id.size(1), 2) | |
device = torch.device('cuda') | |
gt_expand = torch.gather(gt, 1, | |
feature_id).view(batch_size, pnum, pnum, 2) | |
pred_expand = pred.unsqueeze(1) | |
dis = pred_expand - gt_expand | |
if loss_type == "L2": | |
dis = (dis**2).sum(3).sqrt().sum(2) | |
elif loss_type == "L1": | |
dis = torch.abs(dis).sum(3).sum(2) | |
min_dis, min_id = torch.min(dis, dim=1, keepdim=True) | |
# print(min_id) | |
# min_id = torch.from_numpy(min_id.data.cpu().numpy()).to(device) | |
# min_gt_id_to_gather = min_id.unsqueeze_(2).unsqueeze_(3).long().\ | |
# expand(min_id.size(0), min_id.size(1), gt_expand.size(2), gt_expand.size(3)) | |
# gt_right_order = torch.gather(gt_expand, 1, min_gt_id_to_gather).view(batch_size, pnum, 2) | |
return torch.mean(min_dis) | |
class AttentionLoss(nn.Module): | |
def __init__(self, beta=4, gamma=0.5): | |
super(AttentionLoss, self).__init__() | |
self.beta = beta | |
self.gamma = gamma | |
def forward(self, pred, gt): | |
num_pos = torch.sum(gt) | |
num_neg = torch.sum(1 - gt) | |
alpha = num_neg / (num_pos + num_neg) | |
edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma)) | |
bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma)) | |
loss = 0 | |
loss = loss - alpha * edge_beta * torch.log(pred) * gt | |
loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt) | |
return torch.mean(loss) | |
def _gather_feat(feat, ind, mask=None): | |
dim = feat.size(2) | |
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) | |
feat = feat.gather(1, ind) | |
if mask is not None: | |
mask = mask.unsqueeze(2).expand_as(feat) | |
feat = feat[mask] | |
feat = feat.view(-1, dim) | |
return feat | |
def _tranpose_and_gather_feat(feat, ind): | |
feat = feat.permute(0, 2, 3, 1).contiguous() | |
feat = feat.view(feat.size(0), -1, feat.size(3)) | |
feat = _gather_feat(feat, ind) | |
return feat | |
class Ind2dRegL1Loss(nn.Module): | |
def __init__(self, type='l1'): | |
super(Ind2dRegL1Loss, self).__init__() | |
if type == 'l1': | |
self.loss = torch.nn.functional.l1_loss | |
elif type == 'smooth_l1': | |
self.loss = torch.nn.functional.smooth_l1_loss | |
def forward(self, output, target, ind, ind_mask): | |
"""ind: [b, max_objs, max_parts]""" | |
b, max_objs, max_parts = ind.shape | |
ind = ind.view(b, max_objs * max_parts) | |
pred = _tranpose_and_gather_feat(output, | |
ind).view(b, max_objs, max_parts, | |
output.size(1)) | |
mask = ind_mask.unsqueeze(3).expand_as(pred) | |
loss = self.loss(pred * mask, target * mask, reduction='sum') | |
loss = loss / (mask.sum() + 1e-4) | |
return loss | |
class IndL1Loss1d(nn.Module): | |
def __init__(self, type='l1'): | |
super(IndL1Loss1d, self).__init__() | |
if type == 'l1': | |
self.loss = torch.nn.functional.l1_loss | |
elif type == 'smooth_l1': | |
self.loss = torch.nn.functional.smooth_l1_loss | |
def forward(self, output, target, ind, weight): | |
"""ind: [b, n]""" | |
output = _tranpose_and_gather_feat(output, ind) | |
weight = weight.unsqueeze(2) | |
loss = self.loss(output * weight, target * weight, reduction='sum') | |
loss = loss / (weight.sum() * output.size(2) + 1e-4) | |
return loss | |
class GeoCrossEntropyLoss(nn.Module): | |
def __init__(self): | |
super(GeoCrossEntropyLoss, self).__init__() | |
def forward(self, output, target, poly): | |
output = torch.nn.functional.softmax(output, dim=1) | |
output = torch.log(torch.clamp(output, min=1e-4)) | |
poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2) | |
target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, | |
poly.size(3)) | |
target_poly = torch.gather(poly, 2, target) | |
sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True) | |
kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3)) | |
loss = -(output * kernel.transpose(2, 1)).sum(1).mean() | |
return loss | |
def load_model(net, | |
optim, | |
scheduler, | |
recorder, | |
model_dir, | |
resume=True, | |
epoch=-1): | |
if not resume: | |
os.system('rm -rf {}'.format(model_dir)) | |
if not os.path.exists(model_dir): | |
return 0 | |
pths = [ | |
int(pth.split('.')[0]) for pth in os.listdir(model_dir) | |
if pth != 'latest.pth' | |
] | |
if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir): | |
return 0 | |
if epoch == -1: | |
if 'latest.pth' in os.listdir(model_dir): | |
pth = 'latest' | |
else: | |
pth = max(pths) | |
else: | |
pth = epoch | |
print('load model: {}'.format(os.path.join(model_dir, | |
'{}.pth'.format(pth)))) | |
pretrained_model = torch.load( | |
os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu') | |
net.load_state_dict(pretrained_model['net']) | |
optim.load_state_dict(pretrained_model['optim']) | |
scheduler.load_state_dict(pretrained_model['scheduler']) | |
recorder.load_state_dict(pretrained_model['recorder']) | |
return pretrained_model['epoch'] + 1 | |
def save_model(net, optim, scheduler, recorder, model_dir, epoch, last=False): | |
os.system('mkdir -p {}'.format(model_dir)) | |
model = { | |
'net': net.state_dict(), | |
'optim': optim.state_dict(), | |
'scheduler': scheduler.state_dict(), | |
'recorder': recorder.state_dict(), | |
'epoch': epoch | |
} | |
if last: | |
torch.save(model, os.path.join(model_dir, 'latest.pth')) | |
else: | |
torch.save(model, os.path.join(model_dir, '{}.pth'.format(epoch))) | |
# remove previous pretrained model if the number of models is too big | |
pths = [ | |
int(pth.split('.')[0]) for pth in os.listdir(model_dir) | |
if pth != 'latest.pth' | |
] | |
if len(pths) <= 20: | |
return | |
os.system('rm {}'.format( | |
os.path.join(model_dir, '{}.pth'.format(min(pths))))) | |
def load_network(net, model_dir, resume=True, epoch=-1, strict=True): | |
if not resume: | |
return 0 | |
if not os.path.exists(model_dir): | |
print(colored('pretrained model does not exist', 'red')) | |
return 0 | |
if os.path.isdir(model_dir): | |
pths = [ | |
int(pth.split('.')[0]) for pth in os.listdir(model_dir) | |
if pth != 'latest.pth' | |
] | |
if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir): | |
return 0 | |
if epoch == -1: | |
if 'latest.pth' in os.listdir(model_dir): | |
pth = 'latest' | |
else: | |
pth = max(pths) | |
else: | |
pth = epoch | |
model_path = os.path.join(model_dir, '{}.pth'.format(pth)) | |
else: | |
model_path = model_dir | |
print('load model: {}'.format(model_path)) | |
pretrained_model = torch.load(model_path) | |
net.load_state_dict(pretrained_model['net'], strict=strict) | |
return pretrained_model['epoch'] + 1 | |
def remove_net_prefix(net, prefix): | |
net_ = OrderedDict() | |
for k in net.keys(): | |
if k.startswith(prefix): | |
net_[k[len(prefix):]] = net[k] | |
else: | |
net_[k] = net[k] | |
return net_ | |
def add_net_prefix(net, prefix): | |
net_ = OrderedDict() | |
for k in net.keys(): | |
net_[prefix + k] = net[k] | |
return net_ | |
def replace_net_prefix(net, orig_prefix, prefix): | |
net_ = OrderedDict() | |
for k in net.keys(): | |
if k.startswith(orig_prefix): | |
net_[prefix + k[len(orig_prefix):]] = net[k] | |
else: | |
net_[k] = net[k] | |
return net_ | |
def remove_net_layer(net, layers): | |
keys = list(net.keys()) | |
for k in keys: | |
for layer in layers: | |
if k.startswith(layer): | |
del net[k] | |
return net | |