File size: 6,135 Bytes
1ba539f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.nn as nn
from .lbs import lbs, batch_rodrigues
import os.path as osp
import pickle
import numpy as np


def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')):
    if 'torch.tensor' not in str(type(array)):
        return torch.tensor(array, dtype=dtype).to(device)
    else:
        return array.to(device)


def to_np(array, dtype=np.float32):
    if 'scipy.sparse' in str(type(array)):
        array = array.todense()
    return np.array(array, dtype=dtype)


class SMPLlayer(nn.Module):
    def __init__(self,
                 model_path,
                 gender='neutral',
                 device=None,
                 regressor_path=None) -> None:
        super(SMPLlayer, self).__init__()
        dtype = torch.float32
        self.dtype = dtype
        self.device = device
        # create the SMPL model
        if osp.isdir(model_path):
            model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')
            smpl_path = osp.join(model_path, model_fn)
        else:
            smpl_path = model_path
        assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
            smpl_path)

        with open(smpl_path, 'rb') as smpl_file:
            data = pickle.load(smpl_file, encoding='latin1')
        self.faces = data['f']
        self.register_buffer(
            'faces_tensor',
            to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long))
        # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
        num_pose_basis = data['posedirs'].shape[-1]
        # 207 x 20670
        posedirs = data['posedirs']
        data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T

        for key in [
                'J_regressor', 'v_template', 'weights', 'posedirs', 'shapedirs'
        ]:
            val = to_tensor(to_np(data[key]), dtype=dtype)
            self.register_buffer(key, val)
        # indices of parents for each joints
        parents = to_tensor(to_np(data['kintree_table'][0])).long()
        parents[0] = -1
        self.register_buffer('parents', parents)
        # joints regressor
        if regressor_path is not None:
            X_regressor = to_tensor(np.load(regressor_path))
            X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0)

            j_J_regressor = torch.zeros(24,
                                        X_regressor.shape[0],
                                        device=device)
            for i in range(24):
                j_J_regressor[i, i] = 1
            j_v_template = X_regressor @ self.v_template
            #
            j_shapedirs = torch.einsum('vij,kv->kij',
                                       [self.shapedirs, X_regressor])
            # (25, 24)
            j_weights = X_regressor @ self.weights
            j_posedirs = torch.einsum(
                'ab, bde->ade',
                [X_regressor, torch.Tensor(posedirs)]).numpy()
            j_posedirs = np.reshape(j_posedirs, [-1, num_pose_basis]).T
            j_posedirs = to_tensor(j_posedirs)
            self.register_buffer('j_posedirs', j_posedirs)
            self.register_buffer('j_shapedirs', j_shapedirs)
            self.register_buffer('j_weights', j_weights)
            self.register_buffer('j_v_template', j_v_template)
            self.register_buffer('j_J_regressor', j_J_regressor)

    def forward(self,
                poses,
                shapes,
                Rh=None,
                Th=None,
                return_verts=True,
                return_tensor=True,
                scale=1,
                new_params=False,
                **kwargs):
        """ Forward pass for SMPL model

        Args:
            poses (n, 72)
            shapes (n, 10)
            Rh (n, 3): global orientation
            Th (n, 3): global translation
            return_verts (bool, optional): if True return (6890, 3). Defaults to False.
        """
        if 'torch' not in str(type(poses)):
            dtype, device = self.dtype, self.device
            poses = to_tensor(poses, dtype, device)
            shapes = to_tensor(shapes, dtype, device)
            Rh = to_tensor(Rh, dtype, device)
            Th = to_tensor(Th, dtype, device)
        bn = poses.shape[0]
        if Rh is None:
            Rh = torch.zeros(bn, 3, device=poses.device)
        rot = batch_rodrigues(Rh)
        transl = Th.unsqueeze(dim=1)
        if shapes.shape[0] < bn:
            shapes = shapes.expand(bn, -1)
        if return_verts:
            vertices, joints = lbs(shapes,
                                   poses,
                                   self.v_template,
                                   self.shapedirs,
                                   self.posedirs,
                                   self.J_regressor,
                                   self.parents,
                                   self.weights,
                                   pose2rot=True,
                                   new_params=new_params,
                                   dtype=self.dtype)
        else:
            vertices, joints = lbs(shapes,
                                   poses,
                                   self.j_v_template,
                                   self.j_shapedirs,
                                   self.j_posedirs,
                                   self.j_J_regressor,
                                   self.parents,
                                   self.j_weights,
                                   pose2rot=True,
                                   new_params=new_params,
                                   dtype=self.dtype)
            vertices = vertices[:, 24:, :]
        # transl = transl + joints[:, :1] * scale - torch.matmul(joints[:, :1],
        #                                                rot.permute(0, 2, 1)) * scale
        vertices = torch.matmul(vertices, rot.transpose(1, 2)) * scale + transl
        # vertices = vertices * scale + transl
        if not return_tensor:
            vertices = vertices.detach().cpu().numpy()
            transl = transl.detach().cpu().numpy()
        return vertices[0]