Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: ps-license@tuebingen.mpg.de | |
import pytorch_lightning as pl | |
import torch.nn as nn | |
from .geometry import index, orthogonal, perspective | |
class BasePIFuNet(pl.LightningModule): | |
def __init__( | |
self, | |
projection_mode="orthogonal", | |
error_term=nn.MSELoss(), | |
): | |
""" | |
:param projection_mode: | |
Either orthogonal or perspective. | |
It will call the corresponding function for projection. | |
:param error_term: | |
nn Loss between the predicted [B, Res, N] and the label [B, Res, N] | |
""" | |
super(BasePIFuNet, self).__init__() | |
self.name = "base" | |
self.error_term = error_term | |
self.index = index | |
self.projection = orthogonal if projection_mode == "orthogonal" else perspective | |
def forward(self, points, images, calibs, transforms=None): | |
""" | |
:param points: [B, 3, N] world space coordinates of points | |
:param images: [B, C, H, W] input images | |
:param calibs: [B, 3, 4] calibration matrices for each image | |
:param transforms: Optional [B, 2, 3] image space coordinate transforms | |
:return: [B, Res, N] predictions for each point | |
""" | |
features = self.filter(images) | |
preds = self.query(features, points, calibs, transforms) | |
return preds | |
def filter(self, images): | |
""" | |
Filter the input images | |
store all intermediate features. | |
:param images: [B, C, H, W] input images | |
""" | |
return None | |
def query(self, features, points, calibs, transforms=None): | |
""" | |
Given 3D points, query the network predictions for each point. | |
Image features should be pre-computed before this call. | |
store all intermediate features. | |
query() function may behave differently during training/testing. | |
:param points: [B, 3, N] world space coordinates of points | |
:param calibs: [B, 3, 4] calibration matrices for each image | |
:param transforms: Optional [B, 2, 3] image space coordinate transforms | |
:param labels: Optional [B, Res, N] gt labeling | |
:return: [B, Res, N] predictions for each point | |
""" | |
return None | |
def get_error(self, preds, labels): | |
""" | |
Get the network loss from the last query | |
:return: loss term | |
""" | |
return self.error_term(preds, labels) | |