import abc import os from typing import Sequence import matplotlib.pyplot as plt import numpy as np import torch import torch.optim.lr_scheduler from torch import nn def compute_plane_tv(t): batch_size, c, h, w = t.shape count_h = batch_size * c * (h - 1) * w count_w = batch_size * c * h * (w - 1) h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg def compute_plane_smoothness(t): batch_size, c, h, w = t.shape # Convolve with a second derivative filter, in the time dimension which is dimension 2 first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] # Take the L2 norm of the result return torch.square(second_difference).mean() class Regularizer(): def __init__(self, reg_type, initialization): self.reg_type = reg_type self.initialization = initialization self.weight = float(self.initialization) self.last_reg = None def step(self, global_step): pass def report(self, d): if self.last_reg is not None: d[self.reg_type].update(self.last_reg.item()) def regularize(self, *args, **kwargs) -> torch.Tensor: out = self._regularize(*args, **kwargs) * self.weight self.last_reg = out.detach() return out @abc.abstractmethod def _regularize(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError() def __str__(self): return f"Regularizer({self.reg_type}, weight={self.weight})" class PlaneTV(Regularizer): def __init__(self, initial_value, what: str = 'field'): if what not in {'field', 'proposal_network'}: raise ValueError(f'what must be one of "field" or "proposal_network" ' f'but {what} was passed.') name = f'planeTV-{what[:2]}' super().__init__(name, initial_value) self.what = what def step(self, global_step): pass def _regularize(self, model, **kwargs): multi_res_grids: Sequence[nn.ParameterList] if self.what == 'field': multi_res_grids = model.field.grids elif self.what == 'proposal_network': multi_res_grids = [p.grids for p in model.proposal_networks] else: raise NotImplementedError(self.what) total = 0 # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] for grids in multi_res_grids: if len(grids) == 3: spatial_grids = [0, 1, 2] else: spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal for grid_id in spatial_grids: total += compute_plane_tv(grids[grid_id]) for grid in grids: # grid: [1, c, h, w] total += compute_plane_tv(grid) return total class TimeSmoothness(Regularizer): def __init__(self, initial_value, what: str = 'field'): if what not in {'field', 'proposal_network'}: raise ValueError(f'what must be one of "field" or "proposal_network" ' f'but {what} was passed.') name = f'time-smooth-{what[:2]}' super().__init__(name, initial_value) self.what = what def _regularize(self, model, **kwargs) -> torch.Tensor: multi_res_grids: Sequence[nn.ParameterList] if self.what == 'field': multi_res_grids = model.field.grids elif self.what == 'proposal_network': multi_res_grids = [p.grids for p in model.proposal_networks] else: raise NotImplementedError(self.what) total = 0 # model.grids is 6 x [1, rank * F_dim, reso, reso] for grids in multi_res_grids: if len(grids) == 3: time_grids = [] else: time_grids = [2, 4, 5] for grid_id in time_grids: total += compute_plane_smoothness(grids[grid_id]) return torch.as_tensor(total) class L1ProposalNetwork(Regularizer): def __init__(self, initial_value): super().__init__('l1-proposal-network', initial_value) def _regularize(self, model, **kwargs) -> torch.Tensor: grids = [p.grids for p in model.proposal_networks] total = 0.0 for pn_grids in grids: for grid in pn_grids: total += torch.abs(grid).mean() return torch.as_tensor(total) class DepthTV(Regularizer): def __init__(self, initial_value): super().__init__('tv-depth', initial_value) def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: depth = model_out['depth'] tv = compute_plane_tv( depth.reshape(64, 64)[None, None, :, :] ) return tv class L1TimePlanes(Regularizer): def __init__(self, initial_value, what='field'): if what not in {'field', 'proposal_network'}: raise ValueError(f'what must be one of "field" or "proposal_network" ' f'but {what} was passed.') super().__init__(f'l1-time-{what[:2]}', initial_value) self.what = what def _regularize(self, model, **kwargs) -> torch.Tensor: # model.grids is 6 x [1, rank * F_dim, reso, reso] multi_res_grids: Sequence[nn.ParameterList] if self.what == 'field': multi_res_grids = model.field.grids elif self.what == 'proposal_network': multi_res_grids = [p.grids for p in model.proposal_networks] else: raise NotImplementedError(self.what) total = 0.0 for grids in multi_res_grids: if len(grids) == 3: continue else: # These are the spatiotemporal grids spatiotemporal_grids = [2, 4, 5] for grid_id in spatiotemporal_grids: total += torch.abs(1 - grids[grid_id]).mean() return torch.as_tensor(total)