File size: 2,756 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import numpy as np
import neural_renderer as nr
from core import path_config

from models import SMPL


class PartRenderer():
    """Renderer used to render segmentation masks and part segmentations.
    Internally it uses the Neural 3D Mesh Renderer
    """
    def __init__(self, focal_length=5000., render_res=224):
        # Parameters for rendering
        self.focal_length = focal_length
        self.render_res = render_res
        # We use Neural 3D mesh renderer for rendering masks and part segmentations
        self.neural_renderer = nr.Renderer(
            dist_coeffs=None,
            orig_size=self.render_res,
            image_size=render_res,
            light_intensity_ambient=1,
            light_intensity_directional=0,
            anti_aliasing=False
        )
        self.faces = torch.from_numpy(SMPL(path_config.SMPL_MODEL_DIR).faces.astype(np.int32)
                                     ).cuda()
        textures = np.load(path_config.VERTEX_TEXTURE_FILE)
        self.textures = torch.from_numpy(textures).cuda().float()
        self.cube_parts = torch.cuda.FloatTensor(np.load(path_config.CUBE_PARTS_FILE))

    def get_parts(self, parts, mask):
        """Process renderer part image to get body part indices."""
        bn, c, h, w = parts.shape
        mask = mask.view(-1, 1)
        parts_index = torch.floor(100 * parts.permute(0, 2, 3, 1).contiguous().view(-1, 3)).long()
        parts = self.cube_parts[parts_index[:, 0], parts_index[:, 1], parts_index[:, 2], None]
        parts *= mask
        parts = parts.view(bn, h, w).long()
        return parts

    def __call__(self, vertices, camera):
        """Wrapper function for rendering process."""
        # Estimate camera parameters given a fixed focal length
        cam_t = torch.stack(
            [
                camera[:, 1], camera[:, 2], 2 * self.focal_length /
                (self.render_res * camera[:, 0] + 1e-9)
            ],
            dim=-1
        )
        batch_size = vertices.shape[0]
        K = torch.eye(3, device=vertices.device)
        K[0, 0] = self.focal_length
        K[1, 1] = self.focal_length
        K[2, 2] = 1
        K[0, 2] = self.render_res / 2.
        K[1, 2] = self.render_res / 2.
        K = K[None, :, :].expand(batch_size, -1, -1)
        R = torch.eye(3, device=vertices.device)[None, :, :].expand(batch_size, -1, -1)
        faces = self.faces[None, :, :].expand(batch_size, -1, -1)
        parts, _, mask = self.neural_renderer(
            vertices,
            faces,
            textures=self.textures.expand(batch_size, -1, -1, -1, -1, -1),
            K=K,
            R=R,
            t=cam_t.unsqueeze(1)
        )
        parts = self.get_parts(parts, mask)
        return mask, parts