PKaushik commited on
Commit
6d7be64
1 Parent(s): bf05275
Files changed (1) hide show
  1. yolov6/utils/figure_iou.py +114 -0
yolov6/utils/figure_iou.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import math
4
+ import torch
5
+
6
+
7
+ class IOUloss:
8
+ """ Calculate IoU loss.
9
+ """
10
+ def __init__(self, box_format='xywh', iou_type='ciou', reduction='none', eps=1e-7):
11
+ """ Setting of the class.
12
+ Args:
13
+ box_format: (string), must be one of 'xywh' or 'xyxy'.
14
+ iou_type: (string), can be one of 'ciou', 'diou', 'giou' or 'siou'
15
+ reduction: (string), specifies the reduction to apply to the output, must be one of 'none', 'mean','sum'.
16
+ eps: (float), a value to avoid divide by zero error.
17
+ """
18
+ self.box_format = box_format
19
+ self.iou_type = iou_type.lower()
20
+ self.reduction = reduction
21
+ self.eps = eps
22
+
23
+ def __call__(self, box1, box2):
24
+ """ calculate iou. box1 and box2 are torch tensor with shape [M, 4] and [Nm 4].
25
+ """
26
+ box2 = box2.T
27
+ if self.box_format == 'xyxy':
28
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
29
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
30
+ elif self.box_format == 'xywh':
31
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
32
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
33
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
34
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
35
+
36
+ # Intersection area
37
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
38
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
39
+
40
+ # Union Area
41
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps
42
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps
43
+ union = w1 * h1 + w2 * h2 - inter + self.eps
44
+ iou = inter / union
45
+
46
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex width
47
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
48
+ if self.iou_type == 'giou':
49
+ c_area = cw * ch + self.eps # convex area
50
+ iou = iou - (c_area - union) / c_area
51
+ elif self.iou_type in ['diou', 'ciou']:
52
+ c2 = cw ** 2 + ch ** 2 + self.eps # convex diagonal squared
53
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
54
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
55
+ if self.iou_type == 'diou':
56
+ iou = iou - rho2 / c2
57
+ elif self.iou_type == 'ciou':
58
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
59
+ with torch.no_grad():
60
+ alpha = v / (v - iou + (1 + self.eps))
61
+ iou = iou - (rho2 / c2 + v * alpha)
62
+ elif self.iou_type == 'siou':
63
+ # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
64
+ s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
65
+ s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
66
+ sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
67
+ sin_alpha_1 = torch.abs(s_cw) / sigma
68
+ sin_alpha_2 = torch.abs(s_ch) / sigma
69
+ threshold = pow(2, 0.5) / 2
70
+ sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
71
+ angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
72
+ rho_x = (s_cw / cw) ** 2
73
+ rho_y = (s_ch / ch) ** 2
74
+ gamma = angle_cost - 2
75
+ distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
76
+ omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
77
+ omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
78
+ shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
79
+ iou = iou - 0.5 * (distance_cost + shape_cost)
80
+ loss = 1.0 - iou
81
+
82
+ if self.reduction == 'sum':
83
+ loss = loss.sum()
84
+ elif self.reduction == 'mean':
85
+ loss = loss.mean()
86
+
87
+ return loss
88
+
89
+
90
+ def pairwise_bbox_iou(box1, box2, box_format='xywh'):
91
+ """Calculate iou.
92
+ This code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/boxes.py
93
+ """
94
+ if box_format == 'xyxy':
95
+ lt = torch.max(box1[:, None, :2], box2[:, :2])
96
+ rb = torch.min(box1[:, None, 2:], box2[:, 2:])
97
+ area_1 = torch.prod(box1[:, 2:] - box1[:, :2], 1)
98
+ area_2 = torch.prod(box2[:, 2:] - box2[:, :2], 1)
99
+
100
+ elif box_format == 'xywh':
101
+ lt = torch.max(
102
+ (box1[:, None, :2] - box1[:, None, 2:] / 2),
103
+ (box2[:, :2] - box2[:, 2:] / 2),
104
+ )
105
+ rb = torch.min(
106
+ (box1[:, None, :2] + box1[:, None, 2:] / 2),
107
+ (box2[:, :2] + box2[:, 2:] / 2),
108
+ )
109
+
110
+ area_1 = torch.prod(box1[:, 2:], 1)
111
+ area_2 = torch.prod(box2[:, 2:], 1)
112
+ valid = (lt < rb).type(lt.type()).prod(dim=2)
113
+ inter = torch.prod(rb - lt, 2) * valid
114
+ return inter / (area_1[:, None] + area_2 - inter)