|
""" |
|
Implementation of Yolo Loss Function similar to the one in Yolov3 paper, |
|
the difference from what I can tell is I use CrossEntropy for the classes |
|
instead of BinaryCrossEntropy. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from pytorch_lightning import LightningModule |
|
from utils import intersection_over_union |
|
|
|
|
|
class YoloLoss_basic(LightningModule): |
|
def __init__(self): |
|
super(YoloLoss_basic, self).__init__() |
|
self.mse = nn.MSELoss() |
|
self.bce = nn.BCEWithLogitsLoss() |
|
self.entropy = nn.CrossEntropyLoss() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
self.lambda_class = 1 |
|
self.lambda_noobj = 10 |
|
self.lambda_obj = 1 |
|
self.lambda_box = 10 |
|
|
|
def cal_loss(self, predictions, target, anchors): |
|
|
|
obj = target[..., 0] == 1 |
|
noobj = target[..., 0] == 0 |
|
|
|
|
|
|
|
|
|
|
|
no_object_loss = self.bce( |
|
(predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
anchors = anchors.reshape(1, 3, 1, 1, 2).to(device="cuda") |
|
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1) |
|
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach() |
|
object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj]) |
|
|
|
|
|
|
|
|
|
|
|
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) |
|
target[..., 3:5] = torch.log( |
|
(1e-16 + target[..., 3:5] / anchors) |
|
) |
|
box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj]) |
|
|
|
|
|
|
|
|
|
|
|
class_loss = self.entropy( |
|
(predictions[..., 5:][obj]), (target[..., 5][obj].long()), |
|
) |
|
|
|
return ( |
|
self.lambda_box * box_loss |
|
+ self.lambda_obj * object_loss |
|
+ self.lambda_noobj * no_object_loss |
|
+ self.lambda_class * class_loss |
|
) |
|
|
|
def forward(self, predictions, target, anchors): |
|
return self.cal_loss(predictions, target, anchors) |
|
|
|
|
|
class YoloLoss(LightningModule): |
|
def __init__(self): |
|
super(YoloLoss, self).__init__() |
|
self.yolo_basic = YoloLoss_basic() |
|
|
|
def forward(self, predictions, target, scaled_anchors): |
|
tot_loss = 0 |
|
for i in range(len(target)): |
|
tot_loss += self.yolo_basic(predictions[i], target[i], scaled_anchors[i]) |
|
return tot_loss |
|
|