File size: 2,018 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from matplotlib import image
import nvdiffrast.torch as dr
import torch
def _warmup(glctx, device):
#windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
pos = torch.tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32, device=device)
tri = torch.tensor([[0, 1, 2]], dtype=torch.int32, device=device)
dr.rasterize(glctx, pos, tri, resolution=[256, 256])
class NormalsRenderer:
_glctx:dr.RasterizeGLContext = None
def __init__(
self,
mv: torch.Tensor, #C,4,4
proj: torch.Tensor, #C,4,4
image_size: tuple[int,int],
device: str
):
self._mvp = proj @ mv #C,4,4
self._image_size = image_size
# self._glctx = dr.RasterizeGLContext()
self._glctx = dr.RasterizeCudaContext(device=device)
_warmup(self._glctx, device)
def render(self,
vertices: torch.Tensor, #V,3 float
faces: torch.Tensor, #F,3 long
colors: torch.Tensor = None, #V,3 float
normals: torch.Tensor = None, #V,3 float
return_triangles: bool = False
) -> torch.Tensor: #C,H,W,4
V = vertices.shape[0]
faces = faces.type(torch.int32)
vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
vert_nrm = (normals+1)/2 if normals is not None else colors
nrm, _ = dr.interpolate(vert_nrm, rast_out, faces) #C,H,W,3
alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
nrm = torch.concat((nrm,alpha),dim=-1) #C,H,W,4
nrm = dr.antialias(nrm, rast_out, vertices_clip, faces) #C,H,W,4
if return_triangles:
return nrm, rast_out[..., -1]
return nrm #C,H,W,4
|