|
import re |
|
import torch.nn as nn |
|
|
|
|
|
class BaseObject(nn.Module): |
|
def __init__(self, name=None): |
|
super().__init__() |
|
self._name = name |
|
|
|
@property |
|
def __name__(self): |
|
if self._name is None: |
|
name = self.__class__.__name__ |
|
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) |
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() |
|
else: |
|
return self._name |
|
|
|
|
|
class Metric(BaseObject): |
|
pass |
|
|
|
|
|
class Loss(BaseObject): |
|
def __add__(self, other): |
|
if isinstance(other, Loss): |
|
return SumOfLosses(self, other) |
|
else: |
|
raise ValueError("Loss should be inherited from `Loss` class") |
|
|
|
def __radd__(self, other): |
|
return self.__add__(other) |
|
|
|
def __mul__(self, value): |
|
if isinstance(value, (int, float)): |
|
return MultipliedLoss(self, value) |
|
else: |
|
raise ValueError("Loss should be inherited from `BaseLoss` class") |
|
|
|
def __rmul__(self, other): |
|
return self.__mul__(other) |
|
|
|
|
|
class SumOfLosses(Loss): |
|
def __init__(self, l1, l2): |
|
name = "{} + {}".format(l1.__name__, l2.__name__) |
|
super().__init__(name=name) |
|
self.l1 = l1 |
|
self.l2 = l2 |
|
|
|
def __call__(self, *inputs): |
|
return self.l1.forward(*inputs) + self.l2.forward(*inputs) |
|
|
|
|
|
class MultipliedLoss(Loss): |
|
def __init__(self, loss, multiplier): |
|
|
|
|
|
if len(loss.__name__.split("+")) > 1: |
|
name = "{} * ({})".format(multiplier, loss.__name__) |
|
else: |
|
name = "{} * {}".format(multiplier, loss.__name__) |
|
super().__init__(name=name) |
|
self.loss = loss |
|
self.multiplier = multiplier |
|
|
|
def __call__(self, *inputs): |
|
return self.multiplier * self.loss.forward(*inputs) |
|
|