Update videoretalking/third_part/GPEN/face_detect/layers/modules/multibox_loss.py
Browse files
@@ -1,125 +1,125 @@
1 |
import torch
2 |
import torch.nn as nn
3 |
import torch.nn.functional as F
4 |
from torch.autograd import Variable
5 |
from face_detect.utils.box_utils import match, log_sum_exp
6 |
from face_detect.data import cfg_mnet
7 |
GPU = cfg_mnet['gpu_train']
8 |
9 |
class MultiBoxLoss(nn.Module):
10 |
"""SSD Weighted Loss Function
11 |
Compute Targets:
12 |
1) Produce Confidence Target Indices by matching ground truth boxes
13 |
with (default) 'priorboxes' that have jaccard index > threshold parameter
14 |
(default threshold: 0.5).
15 |
2) Produce localization target by 'encoding' variance into offsets of ground
16 |
truth boxes and their matched 'priorboxes'.
17 |
3) Hard negative mining to filter the excessive number of negative examples
18 |
that comes with using a large number of default bounding boxes.
19 |
(default negative:positive ratio 3:1)
20 |
Objective Loss:
21 |
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
22 |
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
23 |
weighted by α which is set to 1 by cross val.
24 |
25 |
c: class confidences,
26 |
l: predicted boxes,
27 |
g: ground truth boxes
28 |
N: number of matched default boxes
29 |
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
30 |
31 |
32 |
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
33 |
super(MultiBoxLoss, self).__init__()
34 |
self.num_classes = num_classes
35 |
self.threshold = overlap_thresh
36 |
self.background_label = bkg_label
37 |
self.encode_target = encode_target
38 |
self.use_prior_for_matching = prior_for_matching
39 |
self.do_neg_mining = neg_mining
40 |
self.negpos_ratio = neg_pos
41 |
self.neg_overlap = neg_overlap
42 |
self.variance = [0.1, 0.2]
43 |
44 |
def forward(self, predictions, priors, targets):
45 |
"""Multibox Loss
46 |
47 |
predictions (tuple): A tuple containing loc preds, conf preds,
48 |
and prior boxes from SSD net.
49 |
conf shape: torch.size(batch_size,num_priors,num_classes)
50 |
loc shape: torch.size(batch_size,num_priors,4)
51 |
priors shape: torch.size(num_priors,4)
52 |
53 |
ground_truth (tensor): Ground truth boxes and labels for a batch,
54 |
shape: [batch_size,num_objs,5] (last idx is the label).
55 |
56 |
57 |
loc_data, conf_data, landm_data = predictions
58 |
priors = priors
59 |
num = loc_data.size(0)
60 |
num_priors = (priors.size(0))
61 |
62 |
# match priors (default boxes) and ground truth boxes
63 |
loc_t = torch.Tensor(num, num_priors, 4)
64 |
landm_t = torch.Tensor(num, num_priors, 10)
65 |
conf_t = torch.LongTensor(num, num_priors)
66 |
for idx in range(num):
67 |
truths = targets[idx][:, :4].data
68 |
labels = targets[idx][:, -1].data
69 |
landms = targets[idx][:, 4:14].data
70 |
defaults = priors.data
71 |
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
72 |
if GPU:
73 |
loc_t = loc_t.cuda()
74 |
conf_t = conf_t.cuda()
75 |
landm_t = landm_t.cuda()
76 |
77 |
zeros = torch.tensor(0).cuda()
78 |
# landm Loss (Smooth L1)
79 |
# Shape: [batch,num_priors,10]
80 |
pos1 = conf_t > zeros
81 |
num_pos_landm = pos1.long().sum(1, keepdim=True)
82 |
N1 = max(num_pos_landm.data.sum().float(), 1)
83 |
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
84 |
landm_p = landm_data[pos_idx1].view(-1, 10)
85 |
landm_t = landm_t[pos_idx1].view(-1, 10)
86 |
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
87 |
88 |
89 |
pos = conf_t != zeros
90 |
conf_t[pos] = 1
91 |
92 |
# Localization Loss (Smooth L1)
93 |
# Shape: [batch,num_priors,4]
94 |
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
95 |
loc_p = loc_data[pos_idx].view(-1, 4)
96 |
loc_t = loc_t[pos_idx].view(-1, 4)
97 |
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
98 |
99 |
# Compute max conf across batch for hard negative mining
100 |
batch_conf = conf_data.view(-1, self.num_classes)
101 |
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
102 |
103 |
# Hard Negative Mining
104 |
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
105 |
loss_c = loss_c.view(num, -1)
106 |
_, loss_idx = loss_c.sort(1, descending=True)
107 |
_, idx_rank = loss_idx.sort(1)
108 |
num_pos = pos.long().sum(1, keepdim=True)
109 |
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
110 |
neg = idx_rank < num_neg.expand_as(idx_rank)
111 |
112 |
# Confidence Loss Including Positive and Negative Examples
113 |
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
114 |
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
115 |
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
116 |
targets_weighted = conf_t[(pos+neg).gt(0)]
117 |
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
118 |
119 |
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
120 |
N = max(num_pos.data.sum().float(), 1)
121 |
loss_l /= N
122 |
loss_c /= N
123 |
loss_landm /= N1
124 |
125 |
return loss_l, loss_c, loss_landm
1 |
import torch
2 |
import torch.nn as nn
3 |
import torch.nn.functional as F
4 |
from torch.autograd import Variable
5 |
from videoretalking.third_part.GPEN.face_detect.utils.box_utils import match, log_sum_exp
6 |
from videoretalking.third_part.GPEN.face_detect.data import cfg_mnet
7 |
GPU = cfg_mnet['gpu_train']
8 |
9 |
class MultiBoxLoss(nn.Module):
10 |
"""SSD Weighted Loss Function
11 |
Compute Targets:
12 |
1) Produce Confidence Target Indices by matching ground truth boxes
13 |
with (default) 'priorboxes' that have jaccard index > threshold parameter
14 |
(default threshold: 0.5).
15 |
2) Produce localization target by 'encoding' variance into offsets of ground
16 |
truth boxes and their matched 'priorboxes'.
17 |
3) Hard negative mining to filter the excessive number of negative examples
18 |
that comes with using a large number of default bounding boxes.
19 |
(default negative:positive ratio 3:1)
20 |
Objective Loss:
21 |
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
22 |
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
23 |
weighted by α which is set to 1 by cross val.
24 |
25 |
c: class confidences,
26 |
l: predicted boxes,
27 |
g: ground truth boxes
28 |
N: number of matched default boxes
29 |
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
30 |
31 |
32 |
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
33 |
super(MultiBoxLoss, self).__init__()
34 |
self.num_classes = num_classes
35 |
self.threshold = overlap_thresh
36 |
self.background_label = bkg_label
37 |
self.encode_target = encode_target
38 |
self.use_prior_for_matching = prior_for_matching
39 |
self.do_neg_mining = neg_mining
40 |
self.negpos_ratio = neg_pos
41 |
self.neg_overlap = neg_overlap
42 |
self.variance = [0.1, 0.2]
43 |
44 |
def forward(self, predictions, priors, targets):
45 |
"""Multibox Loss
46 |
47 |
predictions (tuple): A tuple containing loc preds, conf preds,
48 |
and prior boxes from SSD net.
49 |
conf shape: torch.size(batch_size,num_priors,num_classes)
50 |
loc shape: torch.size(batch_size,num_priors,4)
51 |
priors shape: torch.size(num_priors,4)
52 |
53 |
ground_truth (tensor): Ground truth boxes and labels for a batch,
54 |
shape: [batch_size,num_objs,5] (last idx is the label).
55 |
56 |
57 |
loc_data, conf_data, landm_data = predictions
58 |
priors = priors
59 |
num = loc_data.size(0)
60 |
num_priors = (priors.size(0))
61 |
62 |
# match priors (default boxes) and ground truth boxes
63 |
loc_t = torch.Tensor(num, num_priors, 4)
64 |
landm_t = torch.Tensor(num, num_priors, 10)
65 |
conf_t = torch.LongTensor(num, num_priors)
66 |
for idx in range(num):
67 |
truths = targets[idx][:, :4].data
68 |
labels = targets[idx][:, -1].data
69 |
landms = targets[idx][:, 4:14].data
70 |
defaults = priors.data
71 |
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
72 |
if GPU:
73 |
loc_t = loc_t.cuda()
74 |
conf_t = conf_t.cuda()
75 |
landm_t = landm_t.cuda()
76 |
77 |
zeros = torch.tensor(0).cuda()
78 |
# landm Loss (Smooth L1)
79 |
# Shape: [batch,num_priors,10]
80 |
pos1 = conf_t > zeros
81 |
num_pos_landm = pos1.long().sum(1, keepdim=True)
82 |
N1 = max(num_pos_landm.data.sum().float(), 1)
83 |
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
84 |
landm_p = landm_data[pos_idx1].view(-1, 10)
85 |
landm_t = landm_t[pos_idx1].view(-1, 10)
86 |
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
87 |
88 |
89 |
pos = conf_t != zeros
90 |
conf_t[pos] = 1
91 |
92 |
# Localization Loss (Smooth L1)
93 |
# Shape: [batch,num_priors,4]
94 |
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
95 |
loc_p = loc_data[pos_idx].view(-1, 4)
96 |
loc_t = loc_t[pos_idx].view(-1, 4)
97 |
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
98 |
99 |
# Compute max conf across batch for hard negative mining
100 |
batch_conf = conf_data.view(-1, self.num_classes)
101 |
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
102 |
103 |
# Hard Negative Mining
104 |
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
105 |
loss_c = loss_c.view(num, -1)
106 |
_, loss_idx = loss_c.sort(1, descending=True)
107 |
_, idx_rank = loss_idx.sort(1)
108 |
num_pos = pos.long().sum(1, keepdim=True)
109 |
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
110 |
neg = idx_rank < num_neg.expand_as(idx_rank)
111 |
112 |
# Confidence Loss Including Positive and Negative Examples
113 |
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
114 |
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
115 |
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
116 |
targets_weighted = conf_t[(pos+neg).gt(0)]
117 |
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
118 |
119 |
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
120 |
N = max(num_pos.data.sum().float(), 1)
121 |
loss_l /= N
122 |
loss_c /= N
123 |
loss_landm /= N1
124 |
125 |
return loss_l, loss_c, loss_landm