Spaces:
Configuration error
Configuration error
import torch.nn as nn | |
from lib.config import cfg | |
import torch | |
from lib.networks.renderer import tpose_renderer | |
from lib.train import make_optimizer | |
class NetworkWrapper(nn.Module): | |
def __init__(self, net): | |
super(NetworkWrapper, self).__init__() | |
self.net = net | |
self.renderer = tpose_renderer.Renderer(self.net) | |
self.img2mse = lambda x, y : torch.mean((x - y) ** 2) | |
self.acc_crit = torch.nn.functional.smooth_l1_loss | |
def forward(self, batch): | |
ret = self.renderer.render(batch) | |
scalar_stats = {} | |
loss = 0 | |
mask = batch['mask_at_box'] | |
img_loss = self.img2mse(ret['rgb_map'][mask], batch['rgb'][mask]) | |
scalar_stats.update({'img_loss': img_loss}) | |
loss += img_loss | |
if 'rgb0' in ret: | |
img_loss0 = self.img2mse(ret['rgb0'], batch['rgb']) | |
scalar_stats.update({'img_loss0': img_loss0}) | |
loss += img_loss0 | |
scalar_stats.update({'loss': loss}) | |
image_stats = {} | |
return ret, loss, scalar_stats, image_stats | |