PSHuman / lib /net /BasePIFuNet.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
3.05 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch.nn as nn
import pytorch_lightning as pl
from .geometry import index, orthogonal, perspective
class BasePIFuNet(pl.LightningModule):
def __init__(
self,
projection_mode='orthogonal',
error_term=nn.MSELoss(),
):
"""
:param projection_mode:
Either orthogonal or perspective.
It will call the corresponding function for projection.
:param error_term:
nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
"""
super(BasePIFuNet, self).__init__()
self.name = 'base'
self.error_term = error_term
self.index = index
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
def forward(self, points, images, calibs, transforms=None):
'''
:param points: [B, 3, N] world space coordinates of points
:param images: [B, C, H, W] input images
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:return: [B, Res, N] predictions for each point
'''
features = self.filter(images)
preds = self.query(features, points, calibs, transforms)
return preds
def filter(self, images):
'''
Filter the input images
store all intermediate features.
:param images: [B, C, H, W] input images
'''
return None
def query(self, features, points, calibs, transforms=None):
'''
Given 3D points, query the network predictions for each point.
Image features should be pre-computed before this call.
store all intermediate features.
query() function may behave differently during training/testing.
:param points: [B, 3, N] world space coordinates of points
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:param labels: Optional [B, Res, N] gt labeling
:return: [B, Res, N] predictions for each point
'''
return None
def get_error(self, preds, labels):
'''
Get the network loss from the last query
:return: loss term
'''
return self.error_term(preds, labels)