Spaces:
Sleeping
Sleeping
import copy | |
import logging | |
from typing import Dict, Tuple | |
import pandas as pd | |
import pytorch_lightning as ptl | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DistributedSampler | |
from saicinpainting.evaluation import make_evaluator | |
from saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader | |
from saicinpainting.training.losses.adversarial import make_discrim_loss | |
from saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL | |
from saicinpainting.training.modules import make_generator, make_discriminator | |
from saicinpainting.training.visualizers import make_visualizer | |
from saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \ | |
get_has_ddp_rank | |
LOGGER = logging.getLogger(__name__) | |
def make_optimizer(parameters, kind='adamw', **kwargs): | |
if kind == 'adam': | |
optimizer_class = torch.optim.Adam | |
elif kind == 'adamw': | |
optimizer_class = torch.optim.AdamW | |
else: | |
raise ValueError(f'Unknown optimizer kind {kind}') | |
return optimizer_class(parameters, **kwargs) | |
def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999): | |
with torch.no_grad(): | |
res_params = dict(result.named_parameters()) | |
new_params = dict(new_iterate_model.named_parameters()) | |
for k in res_params.keys(): | |
res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay) | |
def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'): | |
batch_size, _, height, width = base_tensor.shape | |
cur_height, cur_width = height, width | |
result = [] | |
align_corners = False if scale_mode in ('bilinear', 'bicubic') else None | |
for _ in range(scales): | |
cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device) | |
cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners) | |
result.append(cur_sample_scaled) | |
cur_height //= 2 | |
cur_width //= 2 | |
return torch.cat(result, dim=1) | |
class BaseInpaintingTrainingModule(ptl.LightningModule): | |
def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100, | |
average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000, | |
average_generator_period=10, store_discr_outputs_for_vis=False, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
LOGGER.info('BaseInpaintingTrainingModule init called') | |
self.config = config | |
self.generator = make_generator(config, **self.config.generator) | |
self.use_ddp = use_ddp | |
if not get_has_ddp_rank(): | |
LOGGER.info(f'Generator\n{self.generator}') | |
if not predict_only: | |
self.save_hyperparameters(self.config) | |
self.discriminator = make_discriminator(**self.config.discriminator) | |
self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial) | |
self.visualizer = make_visualizer(**self.config.visualizer) | |
self.val_evaluator = make_evaluator(**self.config.evaluator) | |
self.test_evaluator = make_evaluator(**self.config.evaluator) | |
if not get_has_ddp_rank(): | |
LOGGER.info(f'Discriminator\n{self.discriminator}') | |
extra_val = self.config.data.get('extra_val', ()) | |
if extra_val: | |
self.extra_val_titles = list(extra_val) | |
self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator) | |
for k in extra_val}) | |
else: | |
self.extra_evaluators = {} | |
self.average_generator = average_generator | |
self.generator_avg_beta = generator_avg_beta | |
self.average_generator_start_step = average_generator_start_step | |
self.average_generator_period = average_generator_period | |
self.generator_average = None | |
self.last_generator_averaging_step = -1 | |
self.store_discr_outputs_for_vis = store_discr_outputs_for_vis | |
if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0: | |
self.loss_l1 = nn.L1Loss(reduction='none') | |
if self.config.losses.get("mse", {"weight": 0})['weight'] > 0: | |
self.loss_mse = nn.MSELoss(reduction='none') | |
if self.config.losses.perceptual.weight > 0: | |
self.loss_pl = PerceptualLoss() | |
if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0: | |
self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl) | |
else: | |
self.loss_resnet_pl = None | |
self.visualize_each_iters = visualize_each_iters | |
LOGGER.info('BaseInpaintingTrainingModule init done') | |
def configure_optimizers(self): | |
discriminator_params = list(self.discriminator.parameters()) | |
return [ | |
dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)), | |
dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)), | |
] | |
def train_dataloader(self): | |
kwargs = dict(self.config.data.train) | |
if self.use_ddp: | |
kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes, | |
rank=self.trainer.global_rank, | |
shuffle=True) | |
dataloader = make_default_train_dataloader(**self.config.data.train) | |
return dataloader | |
def val_dataloader(self): | |
res = [make_default_val_dataloader(**self.config.data.val)] | |
if self.config.data.visual_test is not None: | |
res = res + [make_default_val_dataloader(**self.config.data.visual_test)] | |
else: | |
res = res + res | |
extra_val = self.config.data.get('extra_val', ()) | |
if extra_val: | |
res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles] | |
return res | |
def training_step(self, batch, batch_idx, optimizer_idx=None): | |
self._is_training_step = True | |
return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx) | |
def validation_step(self, batch, batch_idx, dataloader_idx): | |
extra_val_key = None | |
if dataloader_idx == 0: | |
mode = 'val' | |
elif dataloader_idx == 1: | |
mode = 'test' | |
else: | |
mode = 'extra_val' | |
extra_val_key = self.extra_val_titles[dataloader_idx - 2] | |
self._is_training_step = False | |
return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key) | |
def training_step_end(self, batch_parts_outputs): | |
if self.training and self.average_generator \ | |
and self.global_step >= self.average_generator_start_step \ | |
and self.global_step >= self.last_generator_averaging_step + self.average_generator_period: | |
if self.generator_average is None: | |
self.generator_average = copy.deepcopy(self.generator) | |
else: | |
update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta) | |
self.last_generator_averaging_step = self.global_step | |
full_loss = (batch_parts_outputs['loss'].mean() | |
if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used | |
else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True)) | |
log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()} | |
self.log_dict(log_info, on_step=True, on_epoch=False) | |
return full_loss | |
def validation_epoch_end(self, outputs): | |
outputs = [step_out for out_group in outputs for step_out in out_group] | |
averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs) | |
self.log_dict({k: v.mean() for k, v in averaged_logs.items()}) | |
pd.set_option('display.max_columns', 500) | |
pd.set_option('display.width', 1000) | |
# standard validation | |
val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s] | |
val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states) | |
val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0) | |
val_evaluator_res_df.dropna(axis=1, how='all', inplace=True) | |
LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, ' | |
f'total {self.global_step} iterations:\n{val_evaluator_res_df}') | |
for k, v in flatten_dict(val_evaluator_res).items(): | |
self.log(f'val_{k}', v) | |
# standard visual test | |
test_evaluator_states = [s['test_evaluator_state'] for s in outputs | |
if 'test_evaluator_state' in s] | |
test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states) | |
test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0) | |
test_evaluator_res_df.dropna(axis=1, how='all', inplace=True) | |
LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, ' | |
f'total {self.global_step} iterations:\n{test_evaluator_res_df}') | |
for k, v in flatten_dict(test_evaluator_res).items(): | |
self.log(f'test_{k}', v) | |
# extra validations | |
if self.extra_evaluators: | |
for cur_eval_title, cur_evaluator in self.extra_evaluators.items(): | |
cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state' | |
cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s] | |
cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states) | |
cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0) | |
cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True) | |
LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, ' | |
f'total {self.global_step} iterations:\n{cur_evaluator_res_df}') | |
for k, v in flatten_dict(cur_evaluator_res).items(): | |
self.log(f'extra_val_{cur_eval_title}_{k}', v) | |
def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None): | |
if optimizer_idx == 0: # step for generator | |
set_requires_grad(self.generator, True) | |
set_requires_grad(self.discriminator, False) | |
elif optimizer_idx == 1: # step for discriminator | |
set_requires_grad(self.generator, False) | |
set_requires_grad(self.discriminator, True) | |
batch = self(batch) | |
total_loss = 0 | |
metrics = {} | |
if optimizer_idx is None or optimizer_idx == 0: # step for generator | |
total_loss, metrics = self.generator_loss(batch) | |
elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator | |
if self.config.losses.adversarial.weight > 0: | |
total_loss, metrics = self.discriminator_loss(batch) | |
if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'): | |
if self.config.losses.adversarial.weight > 0: | |
if self.store_discr_outputs_for_vis: | |
with torch.no_grad(): | |
self.store_discr_outputs(batch) | |
vis_suffix = f'_{mode}' | |
if mode == 'extra_val': | |
vis_suffix += f'_{extra_val_key}' | |
self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix) | |
metrics_prefix = f'{mode}_' | |
if mode == 'extra_val': | |
metrics_prefix += f'{extra_val_key}_' | |
result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix)) | |
if mode == 'val': | |
result['val_evaluator_state'] = self.val_evaluator.process_batch(batch) | |
elif mode == 'test': | |
result['test_evaluator_state'] = self.test_evaluator.process_batch(batch) | |
elif mode == 'extra_val': | |
result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch) | |
return result | |
def get_current_generator(self, no_average=False): | |
if not no_average and not self.training and self.average_generator and self.generator_average is not None: | |
return self.generator_average | |
return self.generator | |
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
"""Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" | |
raise NotImplementedError() | |
def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
raise NotImplementedError() | |
def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
raise NotImplementedError() | |
def store_discr_outputs(self, batch): | |
out_size = batch['image'].shape[2:] | |
discr_real_out, _ = self.discriminator(batch['image']) | |
discr_fake_out, _ = self.discriminator(batch['predicted_image']) | |
batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest') | |
batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest') | |
batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake'] | |
def get_ddp_rank(self): | |
return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None | |