|
from matplotlib import image |
|
import nvdiffrast.torch as dr |
|
import torch |
|
|
|
def _warmup(glctx, device): |
|
|
|
|
|
pos = torch.tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32, device=device) |
|
tri = torch.tensor([[0, 1, 2]], dtype=torch.int32, device=device) |
|
dr.rasterize(glctx, pos, tri, resolution=[256, 256]) |
|
|
|
class NormalsRenderer: |
|
|
|
_glctx:dr.RasterizeGLContext = None |
|
|
|
def __init__( |
|
self, |
|
mv: torch.Tensor, |
|
proj: torch.Tensor, |
|
image_size: tuple[int,int], |
|
device: str |
|
): |
|
self._mvp = proj @ mv |
|
self._image_size = image_size |
|
|
|
self._glctx = dr.RasterizeCudaContext(device=device) |
|
_warmup(self._glctx, device) |
|
|
|
def render(self, |
|
vertices: torch.Tensor, |
|
faces: torch.Tensor, |
|
colors: torch.Tensor = None, |
|
normals: torch.Tensor = None, |
|
return_triangles: bool = False |
|
) -> torch.Tensor: |
|
|
|
V = vertices.shape[0] |
|
faces = faces.type(torch.int32) |
|
vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) |
|
vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) |
|
rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) |
|
vert_nrm = (normals+1)/2 if normals is not None else colors |
|
nrm, _ = dr.interpolate(vert_nrm, rast_out, faces) |
|
alpha = torch.clamp(rast_out[..., -1:], max=1) |
|
nrm = torch.concat((nrm,alpha),dim=-1) |
|
nrm = dr.antialias(nrm, rast_out, vertices_clip, faces) |
|
if return_triangles: |
|
return nrm, rast_out[..., -1] |
|
return nrm |
|
|
|
|