|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
import trimesh |
|
import math |
|
from typing import NewType |
|
from pytorch3d.structures import Meshes |
|
from pytorch3d.renderer.mesh import rasterize_meshes |
|
|
|
Tensor = NewType('Tensor', torch.Tensor) |
|
|
|
|
|
def solid_angles(points: Tensor, |
|
triangles: Tensor, |
|
thresh: float = 1e-8) -> Tensor: |
|
''' Compute solid angle between the input points and triangles |
|
Follows the method described in: |
|
The Solid Angle of a Plane Triangle |
|
A. VAN OOSTEROM AND J. STRACKEE |
|
IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING, |
|
VOL. BME-30, NO. 2, FEBRUARY 1983 |
|
Parameters |
|
----------- |
|
points: BxQx3 |
|
Tensor of input query points |
|
triangles: BxFx3x3 |
|
Target triangles |
|
thresh: float |
|
float threshold |
|
Returns |
|
------- |
|
solid_angles: BxQxF |
|
A tensor containing the solid angle between all query points |
|
and input triangles |
|
''' |
|
|
|
centered_tris = triangles[:, None] - points[:, :, None, None] |
|
|
|
|
|
norms = torch.norm(centered_tris, dim=-1) |
|
|
|
|
|
cross_prod = torch.cross(centered_tris[:, :, :, 1], |
|
centered_tris[:, :, :, 2], |
|
dim=-1) |
|
|
|
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1) |
|
del cross_prod |
|
|
|
dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1) |
|
dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1) |
|
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1) |
|
del centered_tris |
|
|
|
denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + |
|
dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0]) |
|
del dot01, dot12, dot02, norms |
|
|
|
|
|
solid_angle = torch.atan2(numerator, denominator) |
|
del numerator, denominator |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return 2 * solid_angle |
|
|
|
|
|
def winding_numbers(points: Tensor, |
|
triangles: Tensor, |
|
thresh: float = 1e-8) -> Tensor: |
|
''' Uses winding_numbers to compute inside/outside |
|
Robust inside-outside segmentation using generalized winding numbers |
|
Alec Jacobson, |
|
Ladislav Kavan, |
|
Olga Sorkine-Hornung |
|
Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018 |
|
Gavin Barill |
|
NEIL G. Dickson |
|
Ryan Schmidt |
|
David I.W. Levin |
|
and Alec Jacobson |
|
Parameters |
|
----------- |
|
points: BxQx3 |
|
Tensor of input query points |
|
triangles: BxFx3x3 |
|
Target triangles |
|
thresh: float |
|
float threshold |
|
Returns |
|
------- |
|
winding_numbers: BxQ |
|
A tensor containing the Generalized winding numbers |
|
''' |
|
|
|
|
|
return 1 / (4 * math.pi) * solid_angles(points, triangles, |
|
thresh=thresh).sum(dim=-1) |
|
|
|
|
|
def batch_contains(verts, faces, points): |
|
|
|
B = verts.shape[0] |
|
N = points.shape[1] |
|
|
|
verts = verts.detach().cpu() |
|
faces = faces.detach().cpu() |
|
points = points.detach().cpu() |
|
contains = torch.zeros(B, N) |
|
|
|
for i in range(B): |
|
contains[i] = torch.as_tensor( |
|
trimesh.Trimesh(verts[i], faces[i]).contains(points[i])) |
|
|
|
return 2.0 * (contains - 0.5) |
|
|
|
|
|
def dict2obj(d): |
|
|
|
|
|
if not isinstance(d, dict): |
|
return d |
|
|
|
class C(object): |
|
pass |
|
|
|
o = C() |
|
for k in d: |
|
o.__dict__[k] = dict2obj(d[k]) |
|
return o |
|
|
|
|
|
def face_vertices(vertices, faces): |
|
""" |
|
:param vertices: [batch size, number of vertices, 3] |
|
:param faces: [batch size, number of faces, 3] |
|
:return: [batch size, number of faces, 3, 3] |
|
""" |
|
|
|
bs, nv = vertices.shape[:2] |
|
bs, nf = faces.shape[:2] |
|
device = vertices.device |
|
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * |
|
nv)[:, None, None] |
|
vertices = vertices.reshape((bs * nv, vertices.shape[-1])) |
|
|
|
return vertices[faces.long()] |
|
|
|
|
|
class Pytorch3dRasterizer(nn.Module): |
|
""" Borrowed from https://github.com/facebookresearch/pytorch3d |
|
Notice: |
|
x,y,z are in image space, normalized |
|
can only render squared image now |
|
""" |
|
|
|
def __init__(self, image_size=224): |
|
""" |
|
use fixed raster_settings for rendering faces |
|
""" |
|
super().__init__() |
|
raster_settings = { |
|
'image_size': image_size, |
|
'blur_radius': 0.0, |
|
'faces_per_pixel': 1, |
|
'bin_size': None, |
|
'max_faces_per_bin': None, |
|
'perspective_correct': True, |
|
'cull_backfaces': True, |
|
} |
|
raster_settings = dict2obj(raster_settings) |
|
self.raster_settings = raster_settings |
|
|
|
def forward(self, vertices, faces, attributes=None): |
|
fixed_vertices = vertices.clone() |
|
fixed_vertices[..., :2] = -fixed_vertices[..., :2] |
|
meshes_screen = Meshes(verts=fixed_vertices.float(), |
|
faces=faces.long()) |
|
raster_settings = self.raster_settings |
|
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( |
|
meshes_screen, |
|
image_size=raster_settings.image_size, |
|
blur_radius=raster_settings.blur_radius, |
|
faces_per_pixel=raster_settings.faces_per_pixel, |
|
bin_size=raster_settings.bin_size, |
|
max_faces_per_bin=raster_settings.max_faces_per_bin, |
|
perspective_correct=raster_settings.perspective_correct, |
|
) |
|
vismask = (pix_to_face > -1).float() |
|
D = attributes.shape[-1] |
|
attributes = attributes.clone() |
|
attributes = attributes.view(attributes.shape[0] * attributes.shape[1], |
|
3, attributes.shape[-1]) |
|
N, H, W, K, _ = bary_coords.shape |
|
mask = pix_to_face == -1 |
|
pix_to_face = pix_to_face.clone() |
|
pix_to_face[mask] = 0 |
|
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) |
|
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) |
|
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) |
|
pixel_vals[mask] = 0 |
|
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) |
|
pixel_vals = torch.cat( |
|
[pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) |
|
return pixel_vals |
|
|