File size: 2,018 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from matplotlib import image
import nvdiffrast.torch as dr
import torch

def _warmup(glctx, device):
    #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
 
    pos = torch.tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32, device=device)
    tri = torch.tensor([[0, 1, 2]], dtype=torch.int32, device=device)
    dr.rasterize(glctx, pos, tri, resolution=[256, 256])

class NormalsRenderer:
    
    _glctx:dr.RasterizeGLContext = None
    
    def __init__(
            self,
            mv: torch.Tensor, #C,4,4
            proj: torch.Tensor, #C,4,4
            image_size: tuple[int,int],
            device: str
            ):
        self._mvp = proj @ mv #C,4,4
        self._image_size = image_size
        # self._glctx = dr.RasterizeGLContext()
        self._glctx = dr.RasterizeCudaContext(device=device)
        _warmup(self._glctx, device)

    def render(self,
            vertices: torch.Tensor, #V,3 float
            faces: torch.Tensor, #F,3 long
            colors: torch.Tensor = None, #V,3 float
            normals: torch.Tensor = None, #V,3 float
            return_triangles: bool = False  
            ) -> torch.Tensor: #C,H,W,4

        V = vertices.shape[0]
        faces = faces.type(torch.int32)
        vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
        vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
        rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
        vert_nrm = (normals+1)/2 if normals is not None else colors
        nrm, _ = dr.interpolate(vert_nrm, rast_out, faces) #C,H,W,3
        alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
        nrm = torch.concat((nrm,alpha),dim=-1) #C,H,W,4
        nrm = dr.antialias(nrm, rast_out, vertices_clip, faces) #C,H,W,4
        if return_triangles:
            return nrm, rast_out[..., -1]
        return nrm #C,H,W,4