ghlee94's picture
Init
2a13495
import re
import torch.nn as nn
class BaseObject(nn.Module):
def __init__(self, name=None):
super().__init__()
self._name = name
@property
def __name__(self):
if self._name is None:
name = self.__class__.__name__
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
else:
return self._name
class Metric(BaseObject):
pass
class Loss(BaseObject):
def __add__(self, other):
if isinstance(other, Loss):
return SumOfLosses(self, other)
else:
raise ValueError("Loss should be inherited from `Loss` class")
def __radd__(self, other):
return self.__add__(other)
def __mul__(self, value):
if isinstance(value, (int, float)):
return MultipliedLoss(self, value)
else:
raise ValueError("Loss should be inherited from `BaseLoss` class")
def __rmul__(self, other):
return self.__mul__(other)
class SumOfLosses(Loss):
def __init__(self, l1, l2):
name = "{} + {}".format(l1.__name__, l2.__name__)
super().__init__(name=name)
self.l1 = l1
self.l2 = l2
def __call__(self, *inputs):
return self.l1.forward(*inputs) + self.l2.forward(*inputs)
class MultipliedLoss(Loss):
def __init__(self, loss, multiplier):
# resolve name
if len(loss.__name__.split("+")) > 1:
name = "{} * ({})".format(multiplier, loss.__name__)
else:
name = "{} * {}".format(multiplier, loss.__name__)
super().__init__(name=name)
self.loss = loss
self.multiplier = multiplier
def __call__(self, *inputs):
return self.multiplier * self.loss.forward(*inputs)