|
import os |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from skimage.transform import resize |
|
|
|
import matplotlib |
|
|
|
matplotlib.use('Agg') |
|
|
|
from .renderer import OpenDRenderer, PyRenderer |
|
|
|
|
|
def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=None): |
|
device_id = U_uv.get_device() |
|
batch_size = U_uv.size(0) |
|
K = U_uv.size(1) |
|
heatmap_size = U_uv.size(2) |
|
|
|
Index_UV_max = torch.argmax(Index_UV, dim=1) |
|
if AnnIndex is None: |
|
Index_UV_max = Index_UV_max.to(torch.int64) |
|
else: |
|
AnnIndex_max = torch.argmax(AnnIndex, dim=1) |
|
Index_UV_max = Index_UV_max * (AnnIndex_max > 0).to(torch.int64) |
|
|
|
outputs = [] |
|
|
|
for batch_id in range(batch_size): |
|
output = torch.zeros([3, U_uv.size(2), U_uv.size(3)], dtype=torch.float32).cuda(device_id) |
|
output[0] = Index_UV_max[batch_id].to(torch.float32) |
|
if ind_mapping is None: |
|
output[0] /= float(K - 1) |
|
else: |
|
for ind in range(len(ind_mapping)): |
|
output[0][output[0] == ind] = ind_mapping[ind] * (1. / 24.) |
|
|
|
for part_id in range(1, K): |
|
CurrentU = U_uv[batch_id, part_id] |
|
CurrentV = V_uv[batch_id, part_id] |
|
output[1, |
|
Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] |
|
output[2, |
|
Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] |
|
|
|
if uv_rois is None: |
|
outputs.append(output.unsqueeze(0)) |
|
else: |
|
roi_fg = uv_rois[batch_id][1:] |
|
w = roi_fg[2] - roi_fg[0] |
|
h = roi_fg[3] - roi_fg[1] |
|
|
|
aspect_ratio = float(w) / h |
|
|
|
if aspect_ratio < 1: |
|
new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)] |
|
output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') |
|
paddingleft = int(0.5 * (heatmap_size - new_size[1])) |
|
output = F.pad( |
|
output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0) |
|
) |
|
else: |
|
new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size] |
|
output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') |
|
paddingtop = int(0.5 * (heatmap_size - new_size[0])) |
|
output = F.pad( |
|
output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop) |
|
) |
|
|
|
outputs.append(output) |
|
|
|
return torch.cat(outputs, dim=0) |
|
|
|
|
|
def vis_smpl_iuv( |
|
image, |
|
cam_pred, |
|
vert_pred, |
|
face, |
|
pred_uv, |
|
vert_errors_batch, |
|
image_name, |
|
save_path, |
|
opt, |
|
ratio=1 |
|
): |
|
|
|
|
|
if not os.path.exists(save_path): |
|
os.makedirs(save_path) |
|
|
|
dr_render = PyRenderer() |
|
|
|
focal_length = 5000. |
|
orig_size = 224. |
|
|
|
if pred_uv is not None: |
|
iuv_img = iuv_map2img(*pred_uv) |
|
|
|
for draw_i in range(len(cam_pred)): |
|
err_val = '{:06d}_'.format(int(10 * vert_errors_batch[draw_i])) |
|
draw_name = err_val + image_name[draw_i] |
|
K = np.array( |
|
[[focal_length, 0., orig_size / 2.], [0., focal_length, orig_size / 2.], [0., 0., 1.]] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.save_obj: |
|
os.makedirs(os.path.join(save_path, 'mesh'), exist_ok=True) |
|
mesh_filename = os.path.join(save_path, 'mesh', draw_name[:-4] + '.obj') |
|
else: |
|
mesh_filename = None |
|
|
|
img_orig = np.moveaxis(image[draw_i], 0, -1) |
|
img_smpl, img_resized = dr_render( |
|
vert_pred[draw_i], |
|
img=img_orig, |
|
cam=cam_pred[draw_i], |
|
iwp_mode=True, |
|
scale_ratio=4., |
|
mesh_filename=mesh_filename, |
|
) |
|
|
|
ones_img = np.ones(img_smpl.shape[:2]) * 255 |
|
ones_img = ones_img[:, :, None] |
|
img_smpl_rgba = np.concatenate((img_smpl, ones_img), axis=2) |
|
img_resized_rgba = np.concatenate((img_resized, ones_img), axis=2) |
|
|
|
|
|
render_img = np.concatenate((img_resized_rgba, img_smpl_rgba), axis=1) |
|
render_img[render_img < 0] = 0 |
|
render_img[render_img > 255] = 255 |
|
matplotlib.image.imsave( |
|
os.path.join(save_path, draw_name[:-4] + '.png'), render_img.astype(np.uint8) |
|
) |
|
|
|
if pred_uv is not None: |
|
|
|
global_iuv = iuv_img[draw_i].cpu().numpy() |
|
global_iuv = np.transpose(global_iuv, (1, 2, 0)) |
|
global_iuv = resize(global_iuv, img_resized.shape[:2]) |
|
global_iuv[global_iuv > 1] = 1 |
|
global_iuv[global_iuv < 0] = 0 |
|
matplotlib.image.imsave( |
|
os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'), global_iuv |
|
) |
|
|