pengsida
initial commit
1ba539f
raw
history blame
1.07 kB
import torch.nn as nn
from lib.config import cfg
import torch
from lib.networks.renderer import tpose_renderer
from lib.train import make_optimizer
class NetworkWrapper(nn.Module):
def __init__(self, net):
super(NetworkWrapper, self).__init__()
self.net = net
self.renderer = tpose_renderer.Renderer(self.net)
self.img2mse = lambda x, y : torch.mean((x - y) ** 2)
self.acc_crit = torch.nn.functional.smooth_l1_loss
def forward(self, batch):
ret = self.renderer.render(batch)
scalar_stats = {}
loss = 0
mask = batch['mask_at_box']
img_loss = self.img2mse(ret['rgb_map'][mask], batch['rgb'][mask])
scalar_stats.update({'img_loss': img_loss})
loss += img_loss
if 'rgb0' in ret:
img_loss0 = self.img2mse(ret['rgb0'], batch['rgb'])
scalar_stats.update({'img_loss0': img_loss0})
loss += img_loss0
scalar_stats.update({'loss': loss})
image_stats = {}
return ret, loss, scalar_stats, image_stats