PSHuman / lib /pymafx /utils /renderer.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
23 kB
import imp
import os
from pickle import NONE
# os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
import torch
import trimesh
import numpy as np
# import neural_renderer as nr
from skimage.transform import resize
from torchvision.utils import make_grid
import torch.nn.functional as F
from models.smpl import get_smpl_faces, get_model_faces, get_model_tpose
from utils.densepose_methods import DensePoseMethods
from core import constants, path_config
import json
from .geometry import convert_to_full_img_cam
from utils.imutils import crop
try:
import math
import pyrender
from pyrender.constants import RenderFlags
except:
pass
try:
from opendr.renderer import ColoredRenderer
from opendr.lighting import LambertianPointLight, SphericalHarmonics
from opendr.camera import ProjectPoints
except:
pass
from pytorch3d.structures.meshes import Meshes
# from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments
from pytorch3d.renderer import (
look_at_view_transform, FoVPerspectiveCameras, PerspectiveCameras, AmbientLights, PointLights,
RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, SoftPhongShader,
SoftSilhouetteShader, HardPhongShader, HardGouraudShader, HardFlatShader, TexturesVertex
)
import logging
logger = logging.getLogger(__name__)
class WeakPerspectiveCamera(pyrender.Camera):
def __init__(
self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None
):
super(WeakPerspectiveCamera, self).__init__(
znear=znear,
zfar=zfar,
name=name,
)
self.scale = scale
self.translation = translation
def get_projection_matrix(self, width=None, height=None):
P = np.eye(4)
P[0, 0] = self.scale[0]
P[1, 1] = self.scale[1]
P[0, 3] = self.translation[0] * self.scale[0]
P[1, 3] = -self.translation[1] * self.scale[1]
P[2, 2] = -1
return P
class PyRenderer:
def __init__(
self, resolution=(224, 224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1.
):
self.resolution = (resolution[0] * scale_ratio, resolution[1] * scale_ratio)
# self.scale_ratio = scale_ratio
self.faces = {
'smplx': get_model_faces('smplx'),
'smpl': get_model_faces('smpl'),
# 'mano': get_model_faces('mano'),
# 'flame': get_model_faces('flame'),
}
self.orig_img = orig_img
self.wireframe = wireframe
self.renderer = pyrender.OffscreenRenderer(
viewport_width=self.resolution[0], viewport_height=self.resolution[1], point_size=1.0
)
self.vis_ratio = vis_ratio
# set the scene
self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3))
light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=1)
yrot = np.radians(120) # angle of lights
light_pose = np.eye(4)
light_pose[:3, 3] = [0, -1, 1]
self.scene.add(light, pose=light_pose)
light_pose[:3, 3] = [0, 1, 1]
self.scene.add(light, pose=light_pose)
light_pose[:3, 3] = [1, 1, 2]
self.scene.add(light, pose=light_pose)
spot_l = pyrender.SpotLight(
color=np.ones(3), intensity=15.0, innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2
)
light_pose[:3, 3] = [1, 2, 2]
self.scene.add(spot_l, pose=light_pose)
light_pose[:3, 3] = [-1, 2, 2]
self.scene.add(spot_l, pose=light_pose)
# light_pose[:3, 3] = [-2, 2, 0]
# self.scene.add(spot_l, pose=light_pose)
# light_pose[:3, 3] = [-2, 2, 0]
# self.scene.add(spot_l, pose=light_pose)
self.colors_dict = {
'red': np.array([0.5, 0.2, 0.2]),
'pink': np.array([0.7, 0.5, 0.5]),
'neutral': np.array([0.7, 0.7, 0.6]),
# 'purple': np.array([0.5, 0.5, 0.7]),
'purple': np.array([0.55, 0.4, 0.9]),
'green': np.array([0.5, 0.55, 0.3]),
'sky': np.array([0.3, 0.5, 0.55]),
'white': np.array([1.0, 0.98, 0.94]),
}
def __call__(
self,
verts,
faces=None,
img=np.zeros((224, 224, 3)),
cam=np.array([1, 0, 0]),
focal_length=[5000, 5000],
camera_rotation=np.eye(3),
crop_info=None,
angle=None,
axis=None,
mesh_filename=None,
color_type=None,
color=[1.0, 1.0, 0.9],
iwp_mode=True,
crop_img=True,
mesh_type='smpl',
scale_ratio=1.,
rgba_mode=False
):
if faces is None:
faces = self.faces[mesh_type]
mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0])
mesh.apply_transform(Rx)
if mesh_filename is not None:
mesh.export(mesh_filename)
if angle and axis:
R = trimesh.transformations.rotation_matrix(math.radians(angle), axis)
mesh.apply_transform(R)
cam = cam.copy()
if iwp_mode:
resolution = np.array(img.shape[:2]) * scale_ratio
if len(cam) == 4:
sx, sy, tx, ty = cam
# sy = sx
camera_translation = np.array(
[tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
)
elif len(cam) == 3:
sx, tx, ty = cam
sy = sx
camera_translation = np.array(
[-tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
)
render_res = resolution
self.renderer.viewport_width = render_res[1]
self.renderer.viewport_height = render_res[0]
else:
if crop_info['opt_cam_t'] is None:
camera_translation = convert_to_full_img_cam(
pare_cam=cam[None],
bbox_height=crop_info['bbox_scale'] * 200.,
bbox_center=crop_info['bbox_center'],
img_w=crop_info['img_w'],
img_h=crop_info['img_h'],
focal_length=focal_length[0],
)
else:
camera_translation = crop_info['opt_cam_t']
if torch.is_tensor(camera_translation):
camera_translation = camera_translation[0].cpu().numpy()
camera_translation = camera_translation.copy()
camera_translation[0] *= -1
if 'img_h' in crop_info and 'img_w' in crop_info:
render_res = (int(crop_info['img_h'][0]), int(crop_info['img_w'][0]))
else:
render_res = img.shape[:2] if type(img) is not list else img[0].shape[:2]
self.renderer.viewport_width = render_res[1]
self.renderer.viewport_height = render_res[0]
camera_rotation = camera_rotation.T
camera = pyrender.IntrinsicsCamera(
fx=focal_length[0], fy=focal_length[1], cx=render_res[1] / 2., cy=render_res[0] / 2.
)
if color_type != None:
color = self.colors_dict[color_type]
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.2,
roughnessFactor=0.6,
alphaMode='OPAQUE',
baseColorFactor=(color[0], color[1], color[2], 1.0)
)
mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
mesh_node = self.scene.add(mesh, 'mesh')
camera_pose = np.eye(4)
camera_pose[:3, :3] = camera_rotation
camera_pose[:3, 3] = camera_rotation @ camera_translation
cam_node = self.scene.add(camera, pose=camera_pose)
if self.wireframe:
render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME | RenderFlags.SHADOWS_SPOT
else:
render_flags = RenderFlags.RGBA | RenderFlags.SHADOWS_SPOT
rgb, _ = self.renderer.render(self.scene, flags=render_flags)
if crop_info is not None and crop_img:
crop_res = img.shape[:2]
rgb, _, _ = crop(rgb, crop_info['bbox_center'][0], crop_info['bbox_scale'][0], crop_res)
valid_mask = (rgb[:, :, -1] > 0)[:, :, np.newaxis]
image_list = [img] if type(img) is not list else img
return_img = []
for item in image_list:
if scale_ratio != 1:
orig_size = item.shape[:2]
item = resize(
item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio),
anti_aliasing=True
)
item = (item * 255).astype(np.uint8)
output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + (
1 - valid_mask * self.vis_ratio
) * item
# output_img[valid_mask < 0.5] = item[valid_mask < 0.5]
# if scale_ratio != 1:
# output_img = resize(output_img, (orig_size[0], orig_size[1]), anti_aliasing=True)
if rgba_mode:
output_img_rgba = np.zeros((output_img.shape[0], output_img.shape[1], 4))
output_img_rgba[:, :, :3] = output_img
output_img_rgba[:, :, 3][valid_mask[:, :, 0]] = 255
output_img = output_img_rgba.astype(np.uint8)
image = output_img.astype(np.uint8)
return_img.append(image)
return_img.append(item)
if type(img) is not list:
# if scale_ratio == 1:
return_img = return_img[0]
self.scene.remove_node(mesh_node)
self.scene.remove_node(cam_node)
return return_img
class OpenDRenderer:
def __init__(self, resolution=(224, 224), ratio=1):
self.resolution = (resolution[0] * ratio, resolution[1] * ratio)
self.ratio = ratio
self.focal_length = 5000.
self.K = np.array(
[
[self.focal_length, 0., self.resolution[1] / 2.],
[0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
]
)
self.colors_dict = {
'red': np.array([0.5, 0.2, 0.2]),
'pink': np.array([0.7, 0.5, 0.5]),
'neutral': np.array([0.7, 0.7, 0.6]),
'purple': np.array([0.5, 0.5, 0.7]),
'green': np.array([0.5, 0.55, 0.3]),
'sky': np.array([0.3, 0.5, 0.55]),
'white': np.array([1.0, 0.98, 0.94]),
}
self.renderer = ColoredRenderer()
self.faces = get_smpl_faces()
def reset_res(self, resolution):
self.resolution = (resolution[0] * self.ratio, resolution[1] * self.ratio)
self.K = np.array(
[
[self.focal_length, 0., self.resolution[1] / 2.],
[0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
]
)
def __call__(
self,
verts,
faces=None,
color=None,
color_type='white',
R=None,
mesh_filename=None,
img=np.zeros((224, 224, 3)),
cam=np.array([1, 0, 0]),
rgba=False,
addlight=True
):
'''Render mesh using OpenDR
verts: shape - (V, 3)
faces: shape - (F, 3)
img: shape - (224, 224, 3), range - [0, 255] (np.uint8)
axis: rotate along with X/Y/Z axis (by angle)
R: rotation matrix (used to manipulate verts) shape - [3, 3]
Return:
rendered img: shape - (224, 224, 3), range - [0, 255] (np.uint8)
'''
## Create OpenDR renderer
rn = self.renderer
h, w = self.resolution
K = self.K
f = np.array([K[0, 0], K[1, 1]])
c = np.array([K[0, 2], K[1, 2]])
if faces is None:
faces = self.faces
if len(cam) == 4:
t = np.array([cam[2], cam[3], 2 * K[0, 0] / (w * cam[0] + 1e-9)])
elif len(cam) == 3:
t = np.array([cam[1], cam[2], 2 * K[0, 0] / (w * cam[0] + 1e-9)])
rn.camera = ProjectPoints(rt=np.array([0, 0, 0]), t=t, f=f, c=c, k=np.zeros(5))
rn.frustum = {'near': 1., 'far': 1000., 'width': w, 'height': h}
albedo = np.ones_like(verts) * .9
if color is not None:
color0 = np.array(color)
color1 = np.array(color)
color2 = np.array(color)
elif color_type == 'white':
color0 = np.array([1., 1., 1.])
color1 = np.array([1., 1., 1.])
color2 = np.array([0.7, 0.7, 0.7])
color = np.ones_like(verts) * self.colors_dict[color_type][None, :]
else:
color0 = self.colors_dict[color_type] * 1.2
color1 = self.colors_dict[color_type] * 1.2
color2 = self.colors_dict[color_type] * 1.2
color = np.ones_like(verts) * self.colors_dict[color_type][None, :]
# render_smpl = rn.r
if R is not None:
assert R.shape == (3, 3), "Shape of rotation matrix should be (3, 3)"
verts = np.dot(verts, R)
rn.set(v=verts, f=faces, vc=color, bgcolor=np.zeros(3))
if addlight:
yrot = np.radians(120) # angle of lights
# # 1. 1. 0.7
rn.vc = LambertianPointLight(
f=rn.f,
v=rn.v,
num_verts=len(rn.v),
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
vc=albedo,
light_color=color0
)
# Construct Left Light
rn.vc += LambertianPointLight(
f=rn.f,
v=rn.v,
num_verts=len(rn.v),
light_pos=rotateY(np.array([800, 10, 300]), yrot),
vc=albedo,
light_color=color1
)
# Construct Right Light
rn.vc += LambertianPointLight(
f=rn.f,
v=rn.v,
num_verts=len(rn.v),
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
vc=albedo,
light_color=color2
)
rendered_image = rn.r
visibility_image = rn.visibility_image
image_list = [img] if type(img) is not list else img
return_img = []
for item in image_list:
if self.ratio != 1:
img_resized = resize(
item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio),
anti_aliasing=True
)
else:
img_resized = item / 255.
try:
img_resized[visibility_image != (2**32 - 1)
] = rendered_image[visibility_image != (2**32 - 1)]
except:
logger.warning('Can not render mesh.')
img_resized = (img_resized * 255).astype(np.uint8)
res = img_resized
if rgba:
img_resized_rgba = np.zeros((img_resized.shape[0], img_resized.shape[1], 4))
img_resized_rgba[:, :, :3] = img_resized
img_resized_rgba[:, :, 3][visibility_image != (2**32 - 1)] = 255
res = img_resized_rgba.astype(np.uint8)
return_img.append(res)
if type(img) is not list:
return_img = return_img[0]
return return_img
# https://github.com/classner/up/blob/master/up_tools/camera.py
def rotateY(points, angle):
"""Rotate all points in a 2D array around the y axis."""
ry = np.array(
[[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0.,
np.cos(angle)]]
)
return np.dot(points, ry)
def rotateX(points, angle):
"""Rotate all points in a 2D array around the x axis."""
rx = np.array(
[[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle),
np.cos(angle)]]
)
return np.dot(points, rx)
def rotateZ(points, angle):
"""Rotate all points in a 2D array around the z axis."""
rz = np.array(
[[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]
)
return np.dot(points, rz)
class IUV_Renderer(object):
def __init__(
self,
focal_length=5000.,
orig_size=224,
output_size=56,
mode='iuv',
device=torch.device('cuda'),
mesh_type='smpl'
):
self.focal_length = focal_length
self.orig_size = orig_size
self.output_size = output_size
if mode in ['iuv']:
if mesh_type == 'smpl':
DP = DensePoseMethods()
vert_mapping = DP.All_vertices.astype('int64') - 1
self.vert_mapping = torch.from_numpy(vert_mapping)
faces = DP.FacesDensePose
faces = faces[None, :, :]
self.faces = torch.from_numpy(
faces.astype(np.int32)
) # [1, 13774, 3], torch.int32
num_part = float(np.max(DP.FaceIndices))
self.num_part = num_part
dp_vert_pid_fname = 'data/dp_vert_pid.npy'
if os.path.exists(dp_vert_pid_fname):
dp_vert_pid = list(np.load(dp_vert_pid_fname))
else:
print('creating data/dp_vert_pid.npy')
dp_vert_pid = []
for v in range(len(vert_mapping)):
for i, f in enumerate(DP.FacesDensePose):
if v in f:
dp_vert_pid.append(DP.FaceIndices[i])
break
np.save(dp_vert_pid_fname, np.array(dp_vert_pid))
textures_vts = np.array(
[
(dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i])
for i in range(len(vert_mapping))
]
)
self.textures_vts = torch.from_numpy(
textures_vts[None].astype(np.float32)
) # (1, 7829, 3)
elif mode == 'pncc':
self.vert_mapping = None
self.faces = torch.from_numpy(
get_model_faces(mesh_type)[None].astype(np.int32)
) # mano: torch.Size([1, 1538, 3])
textures_vts = get_model_tpose(mesh_type).unsqueeze(
0
) # mano: torch.Size([1, 778, 3])
texture_min = torch.min(textures_vts) - 0.001
texture_range = torch.max(textures_vts) - texture_min + 0.001
self.textures_vts = (textures_vts - texture_min) / texture_range
elif mode in ['seg']:
self.vert_mapping = None
body_model = 'smpl'
self.faces = torch.from_numpy(get_smpl_faces().astype(np.int32)[None])
with open(
os.path.join(
path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)
), 'rb'
) as json_file:
smpl_part_id = json.load(json_file)
v_id = []
for k in smpl_part_id.keys():
v_id.extend(smpl_part_id[k])
v_id = torch.tensor(v_id)
n_verts = len(torch.unique(v_id))
num_part = len(constants.SMPL_PART_ID.keys())
self.num_part = num_part
seg_vert_pid = np.zeros(n_verts)
for k in smpl_part_id.keys():
seg_vert_pid[smpl_part_id[k]] = constants.SMPL_PART_ID[k]
print('seg_vert_pid', seg_vert_pid.shape)
textures_vts = seg_vert_pid[:, None].repeat(3, axis=1) / num_part
print('textures_vts', textures_vts.shape)
# textures_vts = np.array(
# [(seg_vert_pid[i] / num_part,) * 3 for i in
# range(n_verts)])
self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32))
K = np.array(
[
[self.focal_length, 0., self.orig_size / 2.],
[0., self.focal_length, self.orig_size / 2.], [0., 0., 1.]
]
)
R = np.array([[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]])
t = np.array([0, 0, 5])
if self.orig_size != 224:
rander_scale = self.orig_size / float(224)
K[0, 0] *= rander_scale
K[1, 1] *= rander_scale
K[0, 2] *= rander_scale
K[1, 2] *= rander_scale
self.K = torch.FloatTensor(K[None, :, :])
self.R = torch.FloatTensor(R[None, :, :])
self.t = torch.FloatTensor(t[None, None, :])
camK = F.pad(self.K, (0, 1, 0, 1), "constant", 0)
camK[:, 2, 2] = 0
camK[:, 3, 2] = 1
camK[:, 2, 3] = 1
self.K = camK
self.device = device
lights = AmbientLights(device=self.device)
raster_settings = RasterizationSettings(
image_size=output_size,
blur_radius=0,
faces_per_pixel=1,
)
self.renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings),
shader=HardFlatShader(
device=self.device,
lights=lights,
blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0)
)
)
def camera_matrix(self, cam):
batch_size = cam.size(0)
K = self.K.repeat(batch_size, 1, 1)
R = self.R.repeat(batch_size, 1, 1)
t = torch.stack(
[-cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)],
dim=-1
)
if cam.is_cuda:
# device_id = cam.get_device()
K = K.to(cam.device)
R = R.to(cam.device)
t = t.to(cam.device)
return K, R, t
def verts2iuvimg(self, verts, cam, iwp_mode=True):
batch_size = verts.size(0)
K, R, t = self.camera_matrix(cam)
if self.vert_mapping is None:
vertices = verts
else:
vertices = verts[:, self.vert_mapping, :]
mesh = Meshes(vertices, self.faces.to(verts.device).expand(batch_size, -1, -1))
mesh.textures = TexturesVertex(
verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1)
)
cameras = PerspectiveCameras(
device=verts.device,
R=R,
T=t,
K=K,
in_ndc=False,
image_size=[(self.orig_size, self.orig_size)]
)
iuv_image = self.renderer(mesh, cameras=cameras)
iuv_image = iuv_image[..., :3].permute(0, 3, 1, 2)
return iuv_image