fasd / tddfa /bfm /bfm_onnx.py
ozyman's picture
added depth model
ddadf19
# coding: utf-8
__author__ = 'cleardusk'
import sys
sys.path.append('..')
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
from tddfa.utils.io import _load, _numpy_to_cuda, _numpy_to_tensor
make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)
def _to_ctype(arr):
if not arr.flags.c_contiguous:
return arr.copy(order='C')
return arr
def _load_tri(bfm_fp):
if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl':
tri = _load(make_abs_path('../configs/tri.pkl')) # this tri/face is re-built for bfm_noneck_v3
else:
tri = _load(bfm_fp).get('tri')
tri = _to_ctype(tri.T).astype(np.int32)
return tri
class BFMModel_ONNX(nn.Module):
"""BFM serves as a decoder"""
def __init__(self, bfm_fp, shape_dim=40, exp_dim=10):
super(BFMModel_ONNX, self).__init__()
_to_tensor = _numpy_to_tensor
# load bfm
bfm = _load(bfm_fp)
u = _to_tensor(bfm.get('u').astype(np.float32))
self.u = u.view(-1, 3).transpose(1, 0)
w_shp = _to_tensor(bfm.get('w_shp').astype(np.float32)[..., :shape_dim])
w_exp = _to_tensor(bfm.get('w_exp').astype(np.float32)[..., :exp_dim])
w = torch.cat((w_shp, w_exp), dim=1)
self.w = w.view(-1, 3, w.shape[-1]).contiguous().permute(1, 0, 2)
# self.u = _to_tensor(bfm.get('u').astype(np.float32)) # fix bug
# w_shp = _to_tensor(bfm.get('w_shp').astype(np.float32)[..., :shape_dim])
# w_exp = _to_tensor(bfm.get('w_exp').astype(np.float32)[..., :exp_dim])
# self.w = torch.cat((w_shp, w_exp), dim=1)
# self.keypoints = bfm.get('keypoints').astype(np.long) # fix bug
# self.u_base = self.u[self.keypoints].reshape(-1, 1)
# self.w_shp_base = self.w_shp[self.keypoints]
# self.w_exp_base = self.w_exp[self.keypoints]
def forward(self, *inps):
R, offset, alpha_shp, alpha_exp = inps
alpha = torch.cat((alpha_shp, alpha_exp))
# pts3d = R @ (self.u + self.w_shp.matmul(alpha_shp) + self.w_exp.matmul(alpha_exp)). \
# view(-1, 3).transpose(1, 0) + offset
# pts3d = R @ (self.u + self.w.matmul(alpha)).view(-1, 3).transpose(1, 0) + offset
pts3d = R @ (self.u + self.w.matmul(alpha).squeeze()) + offset
return pts3d
def convert_bfm_to_onnx(bfm_onnx_fp, shape_dim=40, exp_dim=10):
# print(shape_dim, exp_dim)
bfm_fp = bfm_onnx_fp.replace('.onnx', '.pkl')
bfm_decoder = BFMModel_ONNX(bfm_fp=bfm_fp, shape_dim=shape_dim, exp_dim=exp_dim)
bfm_decoder.eval()
# dummy_input = torch.randn(12 + shape_dim + exp_dim)
dummy_input = torch.randn(3, 3), torch.randn(3, 1), torch.randn(shape_dim, 1), torch.randn(exp_dim, 1)
R, offset, alpha_shp, alpha_exp = dummy_input
torch.onnx.export(
bfm_decoder,
(R, offset, alpha_shp, alpha_exp),
bfm_onnx_fp,
input_names=['R', 'offset', 'alpha_shp', 'alpha_exp'],
output_names=['output'],
dynamic_axes={
'alpha_shp': [0],
'alpha_exp': [0],
},
do_constant_folding=True
)
print(f'Convert {bfm_fp} to {bfm_onnx_fp} done.')
if __name__ == '__main__':
convert_bfm_to_onnx('../configs/bfm_noneck_v3.onnx')