|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Base class for trainable models. |
|
""" |
|
|
|
from abc import ABCMeta, abstractmethod |
|
from copy import copy |
|
|
|
from omegaconf import OmegaConf |
|
from torch import nn |
|
|
|
|
|
class BaseModel(nn.Module, metaclass=ABCMeta): |
|
|
|
required_data_keys = [] |
|
strict_conf = True |
|
|
|
def __init__(self, conf): |
|
"""Perform some logic and call the _init method of the child model.""" |
|
super().__init__() |
|
self.conf = conf |
|
OmegaConf.set_readonly(conf, True) |
|
OmegaConf.set_struct(conf, True) |
|
self.required_data_keys = copy(self.required_data_keys) |
|
self._init(conf) |
|
|
|
def forward(self, data): |
|
"""Check the data and call the _forward method of the child model.""" |
|
|
|
def recursive_key_check(expected, given): |
|
for key in expected: |
|
assert key in given, f"Missing key {key} in data" |
|
if isinstance(expected, dict): |
|
recursive_key_check(expected[key], given[key]) |
|
|
|
recursive_key_check(self.required_data_keys, data) |
|
return self._forward(data) |
|
|
|
@abstractmethod |
|
def _init(self, conf): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _forward(self, data): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
def loss(self, pred, data): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
def metrics(self): |
|
return {} |
|
|