Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import models | |
from systems.utils import parse_optimizer, parse_scheduler, update_module_step | |
from utils.mixins import SaverMixin | |
from utils.misc import config_to_primitive, get_rank | |
class BaseSystem(pl.LightningModule, SaverMixin): | |
""" | |
Two ways to print to console: | |
1. self.print: correctly handle progress bar | |
2. rank_zero_info: use the logging module | |
""" | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.rank = get_rank() | |
self.prepare() | |
self.model = models.make(self.config.model.name, self.config.model) | |
def prepare(self): | |
pass | |
def forward(self, batch): | |
raise NotImplementedError | |
def C(self, value): | |
if isinstance(value, int) or isinstance(value, float): | |
pass | |
else: | |
value = config_to_primitive(value) | |
if not isinstance(value, list): | |
raise TypeError('Scalar specification only supports list, got', type(value)) | |
if len(value) == 3: | |
value = [0] + value | |
assert len(value) == 4 | |
start_step, start_value, end_value, end_step = value | |
if isinstance(end_step, int): | |
current_step = self.global_step | |
value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) | |
elif isinstance(end_step, float): | |
current_step = self.current_epoch | |
value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) | |
return value | |
def preprocess_data(self, batch, stage): | |
pass | |
""" | |
Implementing on_after_batch_transfer of DataModule does the same. | |
But on_after_batch_transfer does not support DP. | |
""" | |
def on_train_batch_start(self, batch, batch_idx, unused=0): | |
self.dataset = self.trainer.datamodule.train_dataloader().dataset | |
self.preprocess_data(batch, 'train') | |
update_module_step(self.model, self.current_epoch, self.global_step) | |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): | |
self.dataset = self.trainer.datamodule.val_dataloader().dataset | |
self.preprocess_data(batch, 'validation') | |
update_module_step(self.model, self.current_epoch, self.global_step) | |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx): | |
self.dataset = self.trainer.datamodule.test_dataloader().dataset | |
self.preprocess_data(batch, 'test') | |
update_module_step(self.model, self.current_epoch, self.global_step) | |
def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): | |
self.dataset = self.trainer.datamodule.predict_dataloader().dataset | |
self.preprocess_data(batch, 'predict') | |
update_module_step(self.model, self.current_epoch, self.global_step) | |
def training_step(self, batch, batch_idx): | |
raise NotImplementedError | |
""" | |
# aggregate outputs from different devices (DP) | |
def training_step_end(self, out): | |
pass | |
""" | |
""" | |
# aggregate outputs from different iterations | |
def training_epoch_end(self, out): | |
pass | |
""" | |
def validation_step(self, batch, batch_idx): | |
raise NotImplementedError | |
""" | |
# aggregate outputs from different devices when using DP | |
def validation_step_end(self, out): | |
pass | |
""" | |
def validation_epoch_end(self, out): | |
""" | |
Gather metrics from all devices, compute mean. | |
Purge repeated results using data index. | |
""" | |
raise NotImplementedError | |
def test_step(self, batch, batch_idx): | |
raise NotImplementedError | |
def test_epoch_end(self, out): | |
""" | |
Gather metrics from all devices, compute mean. | |
Purge repeated results using data index. | |
""" | |
raise NotImplementedError | |
def export(self): | |
raise NotImplementedError | |
def configure_optimizers(self): | |
optim = parse_optimizer(self.config.system.optimizer, self.model) | |
ret = { | |
'optimizer': optim, | |
} | |
if 'scheduler' in self.config.system: | |
ret.update({ | |
'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim), | |
}) | |
return ret | |