diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e549cb95182ed49dd7522f5ab61180b61a6cc81b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/663dcd6db19490de0b790da430bd5681.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..77e0c563b7e7254c3c5159a6eb857a053ee75971 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b8d0ab1bd2b159fbff44e5944841dfa8707cdee --- /dev/null +++ b/ORIGINAL_README.md @@ -0,0 +1,79 @@ +# PSHuman + +This is the official implementation of *PSHuman: Photorealistic Single-image 3D Human Reconstruction using Cross-Scale Multiview Diffusion*. + +### [Project Page](https://penghtyx.github.io/PSHuman/) | [Arxiv](https://arxiv.org/pdf/2409.10141) | [Weights](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views) + +https://github.com/user-attachments/assets/b62e3305-38a7-4b51-aed8-1fde967cca70 + +https://github.com/user-attachments/assets/76100d2e-4a1a-41ad-815c-816340ac6500 + + +Given a single image of a clothed person, **PSHuman** facilitates detailed geometry and realistic 3D human appearance across various poses within one minute. + +### 📝 Update +- __[2024.11.30]__: Release the SMPL-free [version](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views), which does not requires SMPL condition for multview generation and perfome well in general posed human. + + +### Installation +``` +conda create -n pshuman python=3.10 +conda activate pshuman + +# torch +pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 + +# other depedency +pip install -r requirement.txt +``` + +This project is also based on SMPLX. We borrowed the related models from [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU), and re-orginized them, which can be downloaded from [Onedrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/plibp_connect_ust_hk/EZQphP-2y5BGhEIe8jb03i4BIcqiJ2mUW2JmGC5s0VKOdw?e=qVzBBD). + + + +### Inference +1. Given a human image, we use [Clipdrop](https://github.com/xxlong0/Wonder3D?tab=readme-ov-file) or ```rembg``` to remove the background. For the latter, we provide a simple scrip. +``` +python utils/remove_bg.py --path $DATA_PATH$ +``` +Then, put the RGBA images in the ```$DATA_PATH$```. + +2. By running [inference.py](inference.py), the textured mesh and rendered video will be saved in ```out```. +``` +CUDA_VISIBLE_DEVICES=$GPU python inference.py --config configs/inference-768-6view.yaml \ + pretrained_model_name_or_path='pengHTYX/PSHuman_Unclip_768_6views' \ + validation_dataset.crop_size=740 \ + with_smpl=false \ + validation_dataset.root_dir=$DATA_PATH$ \ + seed=600 \ + num_views=7 \ + save_mode='rgb' + +``` +You can adjust the ```crop_size``` (720 or 740) and ```seed``` (42 or 600) to obtain best results for some cases. + +### Training +For the data preparing and preprocessing, please refer to our [paper](https://arxiv.org/pdf/2409.10141). Once the data is ready, we begin the training by running +``` +bash scripts/train_768.sh +``` +You should modified some parameters, such as ```data_common.root_dir``` and ```data_common.object_list```. + +### Related projects +We collect code from following projects. We thanks for the contributions from the open-source community! + +[ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU) recover human mesh from single human image. +[Era3D](https://github.com/pengHTYX/Era3D) and [Unique3D](https://github.com/AiuniAI/Unique3D) generate consistent multiview images with single color image. +[Continuous-Remeshing](https://github.com/Profactor/continuous-remeshing) for Inverse Rendering. + + +### Citation +If you find this codebase useful, please consider cite our work. +``` +@article{li2024pshuman, + title={PSHuman: Photorealistic Single-view Human Reconstruction using Cross-Scale Diffusion}, + author={Li, Peng and Zheng, Wangguandong and Liu, Yuan and Yu, Tao and Li, Yangguang and Qi, Xingqun and Li, Mengfei and Chi, Xiaowei and Xia, Siyu and Xue, Wei and others}, + journal={arXiv preprint arXiv:2409.10141}, + year={2024} +} +``` \ No newline at end of file diff --git a/assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 b/assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3b7021698b6c1da29f60359fed0795568b4f8fd4 Binary files /dev/null and b/assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 differ diff --git a/assets/result_clr_scale4_pexels-zdmit-6780091.mp4 b/assets/result_clr_scale4_pexels-zdmit-6780091.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..facd6c7bc240046be5019e0678f936cebb111854 Binary files /dev/null and b/assets/result_clr_scale4_pexels-zdmit-6780091.mp4 differ diff --git a/blender/blender_render_human_ortho.py b/blender/blender_render_human_ortho.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5579c6e1028ec1b3edcbe9f5d8c12bf60b60c6 --- /dev/null +++ b/blender/blender_render_human_ortho.py @@ -0,0 +1,837 @@ +"""Blender script to render images of 3D models. + +This script is used to render images of 3D models. It takes in a list of paths +to .glb files and renders images of each model. The images are from rotating the +object around the origin. The images are saved to the output directory. + +Example usage: + blender -b -P blender_script.py -- \ + --object_path my_object.glb \ + --output_dir ./views \ + --engine CYCLES \ + --scale 0.8 \ + --num_images 12 \ + --camera_dist 1.2 + +Here, input_model_paths.json is a json file containing a list of paths to .glb. +""" +import argparse +import json +import math +import os +import random +import sys +import time +import glob +import urllib.request +import uuid +from typing import Tuple +from mathutils import Vector, Matrix +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" +import cv2 +import numpy as np +from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple + +import bpy +from mathutils import Vector + +import OpenEXR +import Imath +from PIL import Image + +# import blenderproc as bproc + +bpy.app.debug_value=256 + +parser = argparse.ArgumentParser() +parser.add_argument( + "--object_path", + type=str, + required=True, + help="Path to the object file", +) +parser.add_argument("--smpl_path", type=str, required=True, help="Path to the object file") +parser.add_argument("--output_dir", type=str, default="/views_whole_sphere-test2") +parser.add_argument( + "--engine", type=str, default="BLENDER_EEVEE", choices=["CYCLES", "BLENDER_EEVEE"] +) +parser.add_argument("--scale", type=float, default=1.0) +parser.add_argument("--num_images", type=int, default=8) +parser.add_argument("--random_images", type=int, default=3) +parser.add_argument("--random_ortho", type=int, default=1) +parser.add_argument("--device", type=str, default="CUDA") +parser.add_argument("--resolution", type=int, default=512) + + +argv = sys.argv[sys.argv.index("--") + 1 :] +args = parser.parse_args(argv) + + + +print('===================', args.engine, '===================') + +context = bpy.context +scene = context.scene +render = scene.render + +cam = scene.objects["Camera"] +cam.data.type = 'ORTHO' +cam.data.ortho_scale = 1. +cam.data.lens = 35 +cam.data.sensor_height = 32 +cam.data.sensor_width = 32 + +cam_constraint = cam.constraints.new(type="TRACK_TO") +cam_constraint.track_axis = "TRACK_NEGATIVE_Z" +cam_constraint.up_axis = "UP_Y" + +# setup lighting +# bpy.ops.object.light_add(type="AREA") +# light2 = bpy.data.lights["Area"] +# light2.energy = 3000 +# bpy.data.objects["Area"].location[2] = 0.5 +# bpy.data.objects["Area"].scale[0] = 100 +# bpy.data.objects["Area"].scale[1] = 100 +# bpy.data.objects["Area"].scale[2] = 100 + +render.engine = args.engine +render.image_settings.file_format = "PNG" +render.image_settings.color_mode = "RGBA" +render.resolution_x = args.resolution +render.resolution_y = args.resolution +render.resolution_percentage = 100 +render.threads_mode = 'FIXED' # 使用固定线程数模式 +render.threads = 32 # 设置线程数 + +scene.cycles.device = "GPU" +scene.cycles.samples = 128 # 128 +scene.cycles.diffuse_bounces = 1 +scene.cycles.glossy_bounces = 1 +scene.cycles.transparent_max_bounces = 3 # 3 +scene.cycles.transmission_bounces = 3 # 3 +# scene.cycles.filter_width = 0.01 +bpy.context.scene.cycles.adaptive_threshold = 0 +scene.cycles.use_denoising = True +scene.render.film_transparent = True + +bpy.context.preferences.addons["cycles"].preferences.get_devices() +# Set the device_type +bpy.context.preferences.addons["cycles"].preferences.compute_device_type = 'CUDA' # or "OPENCL" +bpy.context.scene.cycles.tile_size = 8192 + + +# eevee = scene.eevee +# eevee.use_soft_shadows = True +# eevee.use_ssr = True +# eevee.use_ssr_refraction = True +# eevee.taa_render_samples = 64 +# eevee.use_gtao = True +# eevee.gtao_distance = 1 +# eevee.use_volumetric_shadows = True +# eevee.volumetric_tile_size = '2' +# eevee.gi_diffuse_bounces = 1 +# eevee.gi_cubemap_resolution = '128' +# eevee.gi_visibility_resolution = '16' +# eevee.gi_irradiance_smoothing = 0 + + +# for depth & normal +context.view_layer.use_pass_normal = True +context.view_layer.use_pass_z = True +context.scene.use_nodes = True + + +tree = bpy.context.scene.node_tree +nodes = bpy.context.scene.node_tree.nodes +links = bpy.context.scene.node_tree.links + +# Clear default nodes +for n in nodes: + nodes.remove(n) + +# # Create input render layer node. +render_layers = nodes.new('CompositorNodeRLayers') + +scale_normal = nodes.new(type="CompositorNodeMixRGB") +scale_normal.blend_type = 'MULTIPLY' +scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1) +links.new(render_layers.outputs['Normal'], scale_normal.inputs[1]) +bias_normal = nodes.new(type="CompositorNodeMixRGB") +bias_normal.blend_type = 'ADD' +bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0) +links.new(scale_normal.outputs[0], bias_normal.inputs[1]) +normal_file_output = nodes.new(type="CompositorNodeOutputFile") +normal_file_output.label = 'Normal Output' +links.new(bias_normal.outputs[0], normal_file_output.inputs[0]) + +normal_file_output.format.file_format = "OPEN_EXR" # default is "PNG" +normal_file_output.format.color_mode = "RGB" # default is "BW" + +depth_file_output = nodes.new(type="CompositorNodeOutputFile") +depth_file_output.label = 'Depth Output' +links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0]) +depth_file_output.format.file_format = "OPEN_EXR" # default is "PNG" +depth_file_output.format.color_mode = "RGB" # default is "BW" + +def prepare_depth_outputs(): + tree = bpy.context.scene.node_tree + links = tree.links + render_node = tree.nodes['Render Layers'] + depth_out_node = tree.nodes.new(type="CompositorNodeOutputFile") + depth_map_node = tree.nodes.new(type="CompositorNodeMapRange") + depth_out_node.base_path = '' + depth_out_node.format.file_format = 'OPEN_EXR' + depth_out_node.format.color_depth = '32' + + depth_map_node.inputs[1].default_value = 0.54 + depth_map_node.inputs[2].default_value = 1.96 + depth_map_node.inputs[3].default_value = 0 + depth_map_node.inputs[4].default_value = 1 + depth_map_node.use_clamp = True + links.new(render_node.outputs[2],depth_map_node.inputs[0]) + links.new(depth_map_node.outputs[0], depth_out_node.inputs[0]) + return depth_out_node, depth_map_node + +depth_file_output, depth_map_node = prepare_depth_outputs() + + +def exr_to_png(exr_path): + depth_path = exr_path.replace('.exr', '.png') + exr_image = OpenEXR.InputFile(exr_path) + dw = exr_image.header()['dataWindow'] + (width, height) = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1) + + def read_exr(s, width, height): + mat = np.fromstring(s, dtype=np.float32) + mat = mat.reshape(height, width) + return mat + + dmap, _, _ = [read_exr(s, width, height) for s in exr_image.channels('BGR', Imath.PixelType(Imath.PixelType.FLOAT))] + dmap = np.clip(np.asarray(dmap,np.float64),a_max=1.0, a_min=0.0) * 65535 + dmap = Image.fromarray(dmap.astype(np.uint16)) + dmap.save(depth_path) + exr_image.close() + # os.system('rm {}'.format(exr_path)) + +def extract_depth(directory): + fns = glob.glob(f'{directory}/*.exr') + for fn in fns: exr_to_png(fn) + os.system(f'rm {directory}/*.exr') + +def sample_point_on_sphere(radius: float) -> Tuple[float, float, float]: + theta = random.random() * 2 * math.pi + phi = math.acos(2 * random.random() - 1) + return ( + radius * math.sin(phi) * math.cos(theta), + radius * math.sin(phi) * math.sin(theta), + radius * math.cos(phi), + ) + +def sample_spherical(radius=3.0, maxz=3.0, minz=0.): + correct = False + while not correct: + vec = np.random.uniform(-1, 1, 3) + vec[2] = np.abs(vec[2]) + vec = vec / np.linalg.norm(vec, axis=0) * radius + if maxz > vec[2] > minz: + correct = True + return vec + +def sample_spherical(radius_min=1.5, radius_max=2.0, maxz=1.6, minz=-0.75): + correct = False + while not correct: + vec = np.random.uniform(-1, 1, 3) +# vec[2] = np.abs(vec[2]) + radius = np.random.uniform(radius_min, radius_max, 1) + vec = vec / np.linalg.norm(vec, axis=0) * radius[0] + if maxz > vec[2] > minz: + correct = True + return vec + +def randomize_camera(): + elevation = random.uniform(0., 90.) + azimuth = random.uniform(0., 360) + distance = random.uniform(0.8, 1.6) + return set_camera_location(elevation, azimuth, distance) + +def set_camera_location(elevation, azimuth, distance): + # from https://blender.stackexchange.com/questions/18530/ + x, y, z = sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2) + camera = bpy.data.objects["Camera"] + camera.location = x, y, z + + direction = - camera.location + rot_quat = direction.to_track_quat('-Z', 'Y') + camera.rotation_euler = rot_quat.to_euler() + return camera + +def set_camera_mvdream(azimuth, elevation, distance): + # theta, phi = np.deg2rad(azimuth), np.deg2rad(elevation) + azimuth, elevation = np.deg2rad(azimuth), np.deg2rad(elevation) + point = ( + distance * math.cos(azimuth) * math.cos(elevation), + distance * math.sin(azimuth) * math.cos(elevation), + distance * math.sin(elevation), + ) + camera = bpy.data.objects["Camera"] + camera.location = point + + direction = -camera.location + rot_quat = direction.to_track_quat('-Z', 'Y') + camera.rotation_euler = rot_quat.to_euler() + return camera + +def reset_scene() -> None: + """Resets the scene to a clean state. + + Returns: + None + """ + # delete everything that isn't part of a camera or a light + for obj in bpy.data.objects: + if obj.type not in {"CAMERA", "LIGHT"}: + bpy.data.objects.remove(obj, do_unlink=True) + + # delete all the materials + for material in bpy.data.materials: + bpy.data.materials.remove(material, do_unlink=True) + + # delete all the textures + for texture in bpy.data.textures: + bpy.data.textures.remove(texture, do_unlink=True) + + # delete all the images + for image in bpy.data.images: + bpy.data.images.remove(image, do_unlink=True) +def process_ply(obj): + # obj = bpy.context.selected_objects[0] + + # 创建一个新的材质 + material = bpy.data.materials.new(name="VertexColors") + material.use_nodes = True + obj.data.materials.append(material) + + # 获取材质的节点树 + nodes = material.node_tree.nodes + links = material.node_tree.links + + # 删除原有的'Principled BSDF'节点 + principled_bsdf_node = nodes.get("Principled BSDF") + if principled_bsdf_node: + nodes.remove(principled_bsdf_node) + + # 创建一个新的'Emission'节点 + emission_node = nodes.new(type="ShaderNodeEmission") + emission_node.location = 0, 0 + + # 创建一个'Attribute'节点 + attribute_node = nodes.new(type="ShaderNodeAttribute") + attribute_node.location = -300, 0 + attribute_node.attribute_name = "Col" # 顶点颜色属性名称 + + # 创建一个'Output'节点 + output_node = nodes.get("Material Output") + + # 连接节点 + links.new(attribute_node.outputs["Color"], emission_node.inputs["Color"]) + links.new(emission_node.outputs["Emission"], output_node.inputs["Surface"]) + +# # load the glb model +# def load_object(object_path: str) -> None: + +# if object_path.endswith(".glb"): +# bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False) +# elif object_path.endswith(".fbx"): +# bpy.ops.import_scene.fbx(filepath=object_path) +# elif object_path.endswith(".obj"): +# bpy.ops.import_scene.obj(filepath=object_path) +# elif object_path.endswith(".ply"): +# bpy.ops.import_mesh.ply(filepath=object_path) +# obj = bpy.context.selected_objects[0] +# obj.rotation_euler[0] = 1.5708 +# # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y') +# process_ply(obj) +# else: +# raise ValueError(f"Unsupported file type: {object_path}") + + + +def scene_bbox( + single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False +) -> Tuple[Vector, Vector]: + """Returns the bounding box of the scene. + + Taken from Shap-E rendering script + (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82) + + Args: + single_obj (Optional[bpy.types.Object], optional): If not None, only computes + the bounding box for the given object. Defaults to None. + ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults + to False. + + Raises: + RuntimeError: If there are no objects in the scene. + + Returns: + Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box. + """ + bbox_min = (math.inf,) * 3 + bbox_max = (-math.inf,) * 3 + found = False + for obj in get_scene_meshes() if single_obj is None else [single_obj]: + found = True + for coord in obj.bound_box: + coord = Vector(coord) + if not ignore_matrix: + coord = obj.matrix_world @ coord + bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) + bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) + + if not found: + raise RuntimeError("no objects in scene to compute bounding box for") + + return Vector(bbox_min), Vector(bbox_max) + + +def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]: + """Returns all root objects in the scene. + + Yields: + Generator[bpy.types.Object, None, None]: Generator of all root objects in the + scene. + """ + for obj in bpy.context.scene.objects.values(): + if not obj.parent: + yield obj + + +def get_scene_meshes() -> Generator[bpy.types.Object, None, None]: + """Returns all meshes in the scene. + + Yields: + Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene. + """ + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, (bpy.types.Mesh)): + yield obj + + +# Build intrinsic camera parameters from Blender camera data +# +# See notes on this in +# blender.stackexchange.com/questions/15102/what-is-blenders-camera-projection-matrix-model +def get_calibration_matrix_K_from_blender(camd): + f_in_mm = camd.lens + scene = bpy.context.scene + resolution_x_in_px = scene.render.resolution_x + resolution_y_in_px = scene.render.resolution_y + scale = scene.render.resolution_percentage / 100 + sensor_width_in_mm = camd.sensor_width + sensor_height_in_mm = camd.sensor_height + pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y + if (camd.sensor_fit == 'VERTICAL'): + # the sensor height is fixed (sensor fit is horizontal), + # the sensor width is effectively changed with the pixel aspect ratio + s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio + s_v = resolution_y_in_px * scale / sensor_height_in_mm + else: # 'HORIZONTAL' and 'AUTO' + # the sensor width is fixed (sensor fit is horizontal), + # the sensor height is effectively changed with the pixel aspect ratio + pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y + s_u = resolution_x_in_px * scale / sensor_width_in_mm + s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm + + # Parameters of intrinsic calibration matrix K + alpha_u = f_in_mm * s_u + alpha_v = f_in_mm * s_v + u_0 = resolution_x_in_px * scale / 2 + v_0 = resolution_y_in_px * scale / 2 + skew = 0 # only use rectangular pixels + + K = Matrix( + ((alpha_u, skew, u_0), + ( 0 , alpha_v, v_0), + ( 0 , 0, 1 ))) + return K + + +def get_calibration_matrix_K_from_blender_for_ortho(camd, ortho_scale): + scene = bpy.context.scene + resolution_x_in_px = scene.render.resolution_x + resolution_y_in_px = scene.render.resolution_y + scale = scene.render.resolution_percentage / 100 + pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y + + fx = resolution_x_in_px / ortho_scale + fy = resolution_y_in_px / ortho_scale / pixel_aspect_ratio + + cx = resolution_x_in_px / 2 + cy = resolution_y_in_px / 2 + + K = Matrix( + ((fx, 0, cx), + (0, fy, cy), + (0 , 0, 1))) + return K + + +def get_3x4_RT_matrix_from_blender(cam): + bpy.context.view_layer.update() + location, rotation = cam.matrix_world.decompose()[0:2] + R = np.asarray(rotation.to_matrix()) + t = np.asarray(location) + + cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) + R = R.T + t = -R @ t + R_world2cv = cam_rec @ R + t_world2cv = cam_rec @ t + + RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) + return RT + +def delete_invisible_objects() -> None: + """Deletes all invisible objects in the scene. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + for obj in scene.objects: + if obj.hide_viewport or obj.hide_render: + obj.hide_viewport = False + obj.hide_render = False + obj.hide_select = False + obj.select_set(True) + bpy.ops.object.delete() + + # Delete invisible collections + invisible_collections = [col for col in bpy.data.collections if col.hide_viewport] + for col in invisible_collections: + bpy.data.collections.remove(col) + + +def normalize_scene(): + """Normalizes the scene by scaling and translating it to fit in a unit cube centered + at the origin. + + Mostly taken from the Point-E / Shap-E rendering script + (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112), + but fix for multiple root objects: (see bug report here: + https://github.com/openai/shap-e/pull/60). + + Returns: + None + """ + if len(list(get_scene_root_objects())) > 1: + print('we have more than one root objects!!') + # create an empty object to be used as a parent for all root objects + parent_empty = bpy.data.objects.new("ParentEmpty", None) + bpy.context.scene.collection.objects.link(parent_empty) + + # parent all root objects to the empty object + for obj in get_scene_root_objects(): + if obj != parent_empty: + obj.parent = parent_empty + + bbox_min, bbox_max = scene_bbox() + dxyz = bbox_max - bbox_min + dist = np.sqrt(dxyz[0]**2+ dxyz[1]**2+dxyz[2]**2) + scale = 1 / dist + for obj in get_scene_root_objects(): + obj.scale = obj.scale * scale + + # Apply scale to matrix_world. + bpy.context.view_layer.update() + bbox_min, bbox_max = scene_bbox() + offset = -(bbox_min + bbox_max) / 2 + for obj in get_scene_root_objects(): + obj.matrix_world.translation += offset + bpy.ops.object.select_all(action="DESELECT") + + # unparent the camera + bpy.data.objects["Camera"].parent = None + return scale, offset + +def download_object(object_url: str) -> str: + """Download the object and return the path.""" + # uid = uuid.uuid4() + uid = object_url.split("/")[-1].split(".")[0] + tmp_local_path = os.path.join("tmp-objects", f"{uid}.glb" + ".tmp") + local_path = os.path.join("tmp-objects", f"{uid}.glb") + # wget the file and put it in local_path + os.makedirs(os.path.dirname(tmp_local_path), exist_ok=True) + urllib.request.urlretrieve(object_url, tmp_local_path) + os.rename(tmp_local_path, local_path) + # get the absolute path + local_path = os.path.abspath(local_path) + return local_path + + +def render_and_save(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False): + # print(view_id) + # render the image + render_path = os.path.join(args.output_dir, 'image', f"{view_id:03d}.png") + scene.render.filepath = render_path + + if not ortho: + cam.data.lens = len_val + + depth_map_node.inputs[1].default_value = distance - 1 + depth_map_node.inputs[2].default_value = distance + 1 + depth_file_output.base_path = os.path.join(args.output_dir, object_uid, 'depth') + + depth_file_output.file_slots[0].path = f"{view_id:03d}" + normal_file_output.file_slots[0].path = f"{view_id:03d}" + + if not os.path.exists(os.path.join(args.output_dir, 'normal', f"{view_id+1:03d}.png")): + bpy.ops.render.render(write_still=True) + + + if os.path.exists(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr")): + os.rename(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr"), + os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}.exr")) + + if os.path.exists(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr")): + normal = cv2.imread(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED) + normal_unit16 = (normal * 65535).astype(np.uint16) + cv2.imwrite(os.path.join(args.output_dir, 'normal', f"{view_id:03d}.png"), normal_unit16) + os.remove(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr")) + + # save camera KRT matrix + if ortho: + K = get_calibration_matrix_K_from_blender_for_ortho(cam.data, ortho_scale=cam.data.ortho_scale) + else: + K = get_calibration_matrix_K_from_blender(cam.data) + + RT = get_3x4_RT_matrix_from_blender(cam) + para_path = os.path.join(args.output_dir, 'camera', f"{view_id:03d}.npy") + # np.save(RT_path, RT) + paras = {} + paras['intrinsic'] = np.array(K, np.float32) + paras['extrinsic'] = np.array(RT, np.float32) + paras['fov'] = cam.data.angle + paras['azimuth'] = azimuth + paras['elevation'] = elevation + paras['distance'] = distance + paras['focal'] = cam.data.lens + paras['sensor_width'] = cam.data.sensor_width + paras['near'] = distance - 1 + paras['far'] = distance + 1 + paras['camera'] = 'persp' if not ortho else 'ortho' + np.save(para_path, paras) + +def render_and_save_smpl(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False): + + + if not ortho: + cam.data.lens = len_val + + render_path = os.path.join(args.output_dir, 'smpl_image', f"{view_id:03d}.png") + scene.render.filepath = render_path + + normal_file_output.file_slots[0].path = f"{view_id:03d}" + if not os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png")): + bpy.ops.render.render(write_still=True) + + if os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr")): + normal = cv2.imread(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED) + normal_unit16 = (normal * 65535).astype(np.uint16) + cv2.imwrite(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png"), normal_unit16) + os.remove(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr")) + + + +def scene_meshes(): + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, (bpy.types.Mesh)): + yield obj + +def load_object(object_path: str) -> None: + """Loads a glb model into the scene.""" + if object_path.endswith(".glb"): + bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False) + elif object_path.endswith(".fbx"): + bpy.ops.import_scene.fbx(filepath=object_path) + elif object_path.endswith(".obj"): + bpy.ops.import_scene.obj(filepath=object_path) + obj = bpy.context.selected_objects[0] + obj.rotation_euler[0] = 6.28319 + # obj.rotation_euler[2] = 1.5708 + elif object_path.endswith(".ply"): + bpy.ops.import_mesh.ply(filepath=object_path) + obj = bpy.context.selected_objects[0] + obj.rotation_euler[0] = 1.5708 + obj.rotation_euler[2] = 1.5708 + # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y') + process_ply(obj) + else: + raise ValueError(f"Unsupported file type: {object_path}") + +def save_images(object_file: str, smpl_file: str) -> None: + """Saves rendered images of the object in the scene.""" + object_uid = '' # os.path.basename(object_file).split(".")[0] +# # if we already render this object, we skip it + if os.path.exists(os.path.join(args.output_dir, 'meta.npy')): return + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, 'camera'), exist_ok=True) + + reset_scene() + load_object(object_file) + + lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT'] + for light in lights: + bpy.data.objects.remove(light, do_unlink=True) + +# bproc.init() + + world_tree = bpy.context.scene.world.node_tree + back_node = world_tree.nodes['Background'] + env_light = 0.5 + back_node.inputs['Color'].default_value = Vector([env_light, env_light, env_light, 1.0]) + back_node.inputs['Strength'].default_value = 1.0 + + #Make light just directional, disable shadows. + light_data = bpy.data.lights.new(name=f'Light', type='SUN') + light = bpy.data.objects.new(name=f'Light', object_data=light_data) + bpy.context.collection.objects.link(light) + light = bpy.data.lights['Light'] + light.use_shadow = False + # Possibly disable specular shading: + light.specular_factor = 1.0 + light.energy = 5.0 + + #Add another light source so stuff facing away from light is not completely dark + light_data = bpy.data.lights.new(name=f'Light2', type='SUN') + light = bpy.data.objects.new(name=f'Light2', object_data=light_data) + bpy.context.collection.objects.link(light) + light2 = bpy.data.lights['Light2'] + light2.use_shadow = False + light2.specular_factor = 1.0 + light2.energy = 3 #0.015 + bpy.data.objects['Light2'].rotation_euler = bpy.data.objects['Light2'].rotation_euler + bpy.data.objects['Light2'].rotation_euler[0] += 180 + + #Add another light source so stuff facing away from light is not completely dark + light_data = bpy.data.lights.new(name=f'Light3', type='SUN') + light = bpy.data.objects.new(name=f'Light3', object_data=light_data) + bpy.context.collection.objects.link(light) + light3 = bpy.data.lights['Light3'] + light3.use_shadow = False + light3.specular_factor = 1.0 + light3.energy = 3 #0.015 + bpy.data.objects['Light3'].rotation_euler = bpy.data.objects['Light3'].rotation_euler + bpy.data.objects['Light3'].rotation_euler[0] += 90 + + #Add another light source so stuff facing away from light is not completely dark + light_data = bpy.data.lights.new(name=f'Light4', type='SUN') + light = bpy.data.objects.new(name=f'Light4', object_data=light_data) + bpy.context.collection.objects.link(light) + light4 = bpy.data.lights['Light4'] + light4.use_shadow = False + light4.specular_factor = 1.0 + light4.energy = 3 #0.015 + bpy.data.objects['Light4'].rotation_euler = bpy.data.objects['Light4'].rotation_euler + bpy.data.objects['Light4'].rotation_euler[0] += -90 + + scale, offset = normalize_scene() + + + try: + # some objects' normals are affected by textures + mesh_objects = [obj for obj in scene_meshes()] + main_bsdf_name = 'BsdfPrincipled' + normal_name = 'Normal' + for obj in mesh_objects: + for mat in obj.data.materials: + for node in mat.node_tree.nodes: + if main_bsdf_name in node.bl_idname: + principled_bsdf = node + # remove links, we don't want add normal textures + if principled_bsdf.inputs[normal_name].links: + mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0]) + except: + print("don't know why") + # create an empty object to track + empty = bpy.data.objects.new("Empty", None) + scene.collection.objects.link(empty) + cam_constraint.target = empty + + subject_width = 1.0 + + normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'normal') + for i in range(args.num_images): + # change the camera to orthogonal + cam.data.type = 'ORTHO' + cam.data.ortho_scale = subject_width + distance = 1.5 + azimuth = i * 360 / args.num_images + bpy.context.view_layer.update() + set_camera_mvdream(azimuth, 0, distance) + render_and_save(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True) + extract_depth(os.path.join(args.output_dir, object_uid, 'depth')) +# #### smpl + reset_scene() + load_object(smpl_file) + + lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT'] + for light in lights: + bpy.data.objects.remove(light, do_unlink=True) + + scale, offset = normalize_scene() + + try: + # some objects' normals are affected by textures + mesh_objects = [obj for obj in scene_meshes()] + main_bsdf_name = 'BsdfPrincipled' + normal_name = 'Normal' + for obj in mesh_objects: + for mat in obj.data.materials: + for node in mat.node_tree.nodes: + if main_bsdf_name in node.bl_idname: + principled_bsdf = node + # remove links, we don't want add normal textures + if principled_bsdf.inputs[normal_name].links: + mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0]) + except: + print("don't know why") + # create an empty object to track + empty = bpy.data.objects.new("Empty", None) + scene.collection.objects.link(empty) + cam_constraint.target = empty + + subject_width = 1.0 + + normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'smpl_normal') + for i in range(args.num_images): + # change the camera to orthogonal + cam.data.type = 'ORTHO' + cam.data.ortho_scale = subject_width + distance = 1.5 + azimuth = i * 360 / args.num_images + bpy.context.view_layer.update() + set_camera_mvdream(azimuth, 0, distance) + render_and_save_smpl(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True) + + + np.save(os.path.join(args.output_dir, object_uid, 'meta.npy'), np.asarray([scale, offset[0], offset[1], offset[1]],np.float32)) + + +if __name__ == "__main__": + try: + start_i = time.time() + if args.object_path.startswith("http"): + local_path = download_object(args.object_path) + else: + local_path = args.object_path + save_images(local_path, args.smpl_path) + end_i = time.time() + print("Finished", local_path, "in", end_i - start_i, "seconds") + # delete the object if it was downloaded + if args.object_path.startswith("http"): + os.remove(local_path) + except Exception as e: + print("Failed to render", args.object_path) + print(e) diff --git a/blender/check_render.py b/blender/check_render.py new file mode 100644 index 0000000000000000000000000000000000000000..65dde4026b92c8b656ed5fd76451d7c2400ab672 --- /dev/null +++ b/blender/check_render.py @@ -0,0 +1,46 @@ +import os +from tqdm import tqdm +import json +from icecream import ic + + +def check_render(dataset, st=None, end=None): + total_lists = [] + with open(dataset+'.json', 'r') as f: + glb_list = json.load(f) + for x in glb_list: + total_lists.append(x.split('/')[-2] ) + + if st is not None: + end = min(end, len(total_lists)) + total_lists = total_lists[st:end] + glb_list = glb_list[st:end] + + save_dir = '/data/lipeng/human_8view_with_smplx/'+dataset + unrendered = set(total_lists) - set(os.listdir(save_dir)) + + num_finish = 0 + num_failed = len(unrendered) + failed_case = [] + for case in os.listdir(save_dir): + if not os.path.exists(os.path.join(save_dir, case, 'smpl_normal', '007.png')): + failed_case.append(case) + num_failed += 1 + else: + num_finish += 1 + ic(num_failed) + ic(num_finish) + + + need_render = [] + for full_path in glb_list: + for case in failed_case: + if case in full_path: + need_render.append(full_path) + + with open('need_render.json', 'w') as f: + json.dump(need_render, f, indent=4) + +if __name__ == '__main__': + dataset = 'THuman2.1' + check_render(dataset) \ No newline at end of file diff --git a/blender/count.py b/blender/count.py new file mode 100644 index 0000000000000000000000000000000000000000..2a104dce5a823b2029b23e8a5b667b3d5777e8ed --- /dev/null +++ b/blender/count.py @@ -0,0 +1,44 @@ +import os +import json +def find_files(directory, extensions): + results = [] + for foldername, subfolders, filenames in os.walk(directory): + for filename in filenames: + if filename.endswith(extensions): + file_path = os.path.abspath(os.path.join(foldername, filename)) + results.append(file_path) + return results + +def count_customhumans(root): + directory_path = ['CustomHumans/mesh'] + + extensions = ('.ply', '.obj') + + lists = [] + for dataset_path in directory_path: + dir = os.path.join(root, dataset_path) + file_paths = find_files(dir, extensions) + # import pdb;pdb.set_trace() + dataset_name = dataset_path.split('/')[0] + for file_path in file_paths: + lists.append(file_path.replace(root, "")) + with open(f'{dataset_name}.json', 'w') as f: + json.dump(lists, f, indent=4) + +def count_thuman21(root): + directory_path = ['THuman2.1/mesh'] + extensions = ('.ply', '.obj') + lists = [] + for dataset_path in directory_path: + dir = os.path.join(root, dataset_path) + file_paths = find_files(dir, extensions) + dataset_name = dataset_path.split('/')[0] + for file_path in file_paths: + lists.append(file_path.replace(root, "")) + with open(f'{dataset_name}.json', 'w') as f: + json.dump(lists, f, indent=4) + +if __name__ == '__main__': + root = '/data/lipeng/human_scan/' + # count_customhumans(root) + count_thuman21(root) \ No newline at end of file diff --git a/blender/distribute.py b/blender/distribute.py new file mode 100644 index 0000000000000000000000000000000000000000..54819ea275ef08b48467600f1a44b51e29039e0e --- /dev/null +++ b/blender/distribute.py @@ -0,0 +1,149 @@ +import glob +import json +import multiprocessing +import shutil +import subprocess +import time +from dataclasses import dataclass +from typing import Optional +import os + +import boto3 + + +from glob import glob + +import argparse + +parser = argparse.ArgumentParser(description='distributed rendering') + +parser.add_argument('--workers_per_gpu', type=int, default=10, + help='number of workers per gpu.') +parser.add_argument('--input_models_path', type=str, default='/data/lipeng/human_scan/', + help='Path to a json file containing a list of 3D object files.') +parser.add_argument('--num_gpus', type=int, default=-1, + help='number of gpus to use. -1 means all available gpus.') +parser.add_argument('--gpu_list',nargs='+', type=int, + help='the avalaible gpus') +parser.add_argument('--resolution', type=int, default=512, + help='') +parser.add_argument('--random_images', type=int, default=0) +parser.add_argument('--start_i', type=int, default=0, + help='the index of first object to be rendered.') +parser.add_argument('--end_i', type=int, default=-1, + help='the index of the last object to be rendered.') + +parser.add_argument('--data_dir', type=str, default='/data/lipeng/human_scan/', + help='Path to a json file containing a list of 3D object files.') + +parser.add_argument('--json_path', type=str, default='2K2K.json') + +parser.add_argument('--save_dir', type=str, default='/data/lipeng/human_8view', + help='Path to a json file containing a list of 3D object files.') + +parser.add_argument('--ortho_scale', type=float, default=1., + help='ortho rendering usage; how large the object is') + + +args = parser.parse_args() + +def parse_obj_list(xs): + cases = [] + # print(xs[:2]) + + for x in xs: + if 'THuman3.0' in x: + # print(apath) + splits = x.split('/') + x = os.path.join('THuman3.0', splits[-2]) + elif 'THuman2.1' in x: + splits = x.split('/') + x = os.path.join('THuman2.1', splits[-2]) + elif 'CustomHumans' in x: + splits = x.split('/') + x = os.path.join('CustomHumans', splits[-2]) + elif '1M' in x: + splits = x.split('/') + x = os.path.join('2K2K', splits[-2]) + elif 'realistic_8k_model' in x: + splits = x.split('/') + x = os.path.join('realistic_8k_model', splits[-1].split('.')[0]) + cases.append(f'{args.save_dir}/{x}') + return cases + + +with open(args.json_path, 'r') as f: + glb_list = json.load(f) + +# glb_list = ['THuman2.1/mesh/1618/1618.obj'] +# glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj'] +# glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj'] +# glb_list = ['1M/01968/01968.ply', '1M/00103/00103.ply'] +# glb_list = ['realistic_8k_model/01aab099a2fe4af7be120110a385105d.glb'] + +total_num_glbs = len(glb_list) + + + +def worker( + queue: multiprocessing.JoinableQueue, + count: multiprocessing.Value, + gpu: int, + s3: Optional[boto3.client], +) -> None: + print("Worker started") + while True: + case, save_p = queue.get() + src_path = os.path.join(args.data_dir, case) + smpl_path = src_path.replace('mesh', 'smplx', 1) + + command = ('blender -b -P blender_render_human_ortho.py' + f' -- --object_path {src_path}' + f' --smpl_path {smpl_path}' + f' --output_dir {save_p} --engine CYCLES' + f' --resolution {args.resolution}' + f' --random_images {args.random_images}' + ) + + print(command) + subprocess.run(command, shell=True) + + with count.get_lock(): + count.value += 1 + + queue.task_done() + + +if __name__ == "__main__": + # args = tyro.cli(Args) + + s3 = None + queue = multiprocessing.JoinableQueue() + count = multiprocessing.Value("i", 0) + + # Start worker processes on each of the GPUs + for gpu_i in range(args.num_gpus): + for worker_i in range(args.workers_per_gpu): + worker_i = gpu_i * args.workers_per_gpu + worker_i + process = multiprocessing.Process( + target=worker, args=(queue, count, args.gpu_list[gpu_i], s3) + ) + process.daemon = True + process.start() + + # Add items to the queue + + save_dirs = parse_obj_list(glb_list) + args.end_i = len(save_dirs) if args.end_i > len(save_dirs) or args.end_i==-1 else args.end_i + + for case_sub, save_dir in zip(glb_list[args.start_i:args.end_i], save_dirs[args.start_i:args.end_i]): + queue.put([case_sub, save_dir]) + + + + # Wait for all tasks to be completed + queue.join() + + # Add sentinels to the queue to stop the worker processes + for i in range(args.num_gpus * args.workers_per_gpu): + queue.put(None) diff --git a/blender/rename_smpl_files.py b/blender/rename_smpl_files.py new file mode 100644 index 0000000000000000000000000000000000000000..83cc56dbaa27ec74f8ce3f38227c537444cff5c6 --- /dev/null +++ b/blender/rename_smpl_files.py @@ -0,0 +1,25 @@ +import os +from tqdm import tqdm +from glob import glob + +def rename_customhumans(): + root = '/data/lipeng/human_scan/CustomHumans/smplx/' + file_paths = glob(os.path.join(root, '*/*_smpl.obj')) + for file_path in tqdm(file_paths): + new_path = file_path.replace('_smpl', '') + os.rename(file_path, new_path) + +def rename_thuman21(): + root = '/data/lipeng/human_scan/THuman2.1/smplx/' + file_paths = glob(os.path.join(root, '*/*.obj')) + for file_path in tqdm(file_paths): + obj_name = file_path.split('/')[-2] + folder_name = os.path.dirname(file_path) + new_path = os.path.join(folder_name, obj_name+'.obj') + # print(new_path) + # print(file_path) + os.rename(file_path, new_path) + +if __name__ == '__main__': + rename_thuman21() + rename_customhumans() \ No newline at end of file diff --git a/blender/render.sh b/blender/render.sh new file mode 100644 index 0000000000000000000000000000000000000000..8bbcf71148d9a4a4c2b89027f52112b45b02a970 --- /dev/null +++ b/blender/render.sh @@ -0,0 +1,4 @@ +#### install environment +# ~/pkgs/blender-3.6.4/3.6/python/bin/python3.10 -m pip install openexr opencv-python + +python render_human.py \ No newline at end of file diff --git a/blender/render_human.py b/blender/render_human.py new file mode 100644 index 0000000000000000000000000000000000000000..f4358312751647c153ffc42995d102ffb580281b --- /dev/null +++ b/blender/render_human.py @@ -0,0 +1,88 @@ +import os +import json +import math +from concurrent.futures import ProcessPoolExecutor +import threading +from tqdm import tqdm + +# from glcontext import egl +# egl.create_context() +# exit(0) + +LOCAL_RANK = 0 + +num_processes = 4 +NODE_RANK = int(os.getenv("SLURM_PROCID")) +WORLD_SIZE = 1 +NODE_NUM=1 +# NODE_RANK = int(os.getenv("SLURM_NODEID")) +IS_MAIN = False +if NODE_RANK == 0 and LOCAL_RANK == 0: + IS_MAIN = True + +GLOBAL_RANK = NODE_RANK * (WORLD_SIZE//NODE_NUM) + LOCAL_RANK + + +# json_path = "object_lists/Thuman2.0.json" +# json_path = "object_lists/THuman3.0.json" +json_path = "object_lists/CustomHumans.json" +data_dir = '/aifs4su/mmcode/lipeng' +save_dir = '/aifs4su/mmcode/lipeng/human_8view_new' +def parse_obj_list(x): + if 'THuman3.0' in x: + # print(apath) + splits = x.split('/') + x = os.path.join('THuman3.0', splits[-2]) + elif 'Thuman2.0' in x: + splits = x.split('/') + x = os.path.join('Thuman2.0', splits[-2]) + elif 'CustomHumans' in x: + splits = x.split('/') + x = os.path.join('CustomHumans', splits[-2]) + # print(splits[-2]) + elif '1M' in x: + splits = x.split('/') + x = os.path.join('2K2K', splits[-2]) + elif 'realistic_8k_model' in x: + splits = x.split('/') + x = os.path.join('realistic_8k_model', splits[-1].split('.')[0]) + return f'{save_dir}/{x}' + +with open(json_path, 'r') as f: + glb_list = json.load(f) + +# glb_list = ['Thuman2.0/0011/0011.obj'] +# glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj'] +# glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj'] +# glb_list = ['realistic_8k_model/1d41f2a72f994306b80e632f1cc8233f.glb'] + +total_num_glbs = len(glb_list) + +num_glbs_local = int(math.ceil(total_num_glbs / WORLD_SIZE)) +start_idx = GLOBAL_RANK * num_glbs_local +end_idx = start_idx + num_glbs_local +# print(start_idx, end_idx) +local_glbs = glb_list[start_idx:end_idx] +if IS_MAIN: + pbar = tqdm(total=len(local_glbs)) + lock = threading.Lock() + +def process_human(glb_path): + src_path = os.path.join(data_dir, glb_path) + save_path = parse_obj_list(glb_path) + # print(save_path) + command = ('blender -b -P blender_render_human_script.py' + f' -- --object_path {src_path}' + f' --output_dir {save_path} ') + # 1>/dev/null + # print(command) + os.system(command) + + if IS_MAIN: + with lock: + pbar.update(1) + +with ProcessPoolExecutor(max_workers=num_processes) as executor: + executor.map(process_human, local_glbs) + + diff --git a/blender/render_single.sh b/blender/render_single.sh new file mode 100644 index 0000000000000000000000000000000000000000..b33d6e2272fa3819f8fe50aefdaadd30f536d799 --- /dev/null +++ b/blender/render_single.sh @@ -0,0 +1,7 @@ +# debug single sample +blender -b -P blender_render_human_ortho.py \ + -- --object_path /data/lipeng/human_scan/THuman2.1/mesh/0011/0011.obj \ + --smpl_path /data/lipeng/human_scan/THuman2.1/smplx/0011/0011.obj \ + --output_dir debug --engine CYCLES \ + --resolution 768 \ + --random_images 0 diff --git a/blender/utils.py b/blender/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e059eb58cea5e0ea9595f1995d7706ddabadd130 --- /dev/null +++ b/blender/utils.py @@ -0,0 +1,128 @@ +import datetime +import pytz +import traceback +from torchvision.utils import make_grid +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import torch +import json +import os +from tqdm import tqdm +import cv2 +import imageio +def get_time_for_log(): + return datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime( + "%Y%m%d %H:%M:%S") + + +def get_trace_for_log(): + return str(traceback.format_exc()) + +def make_grid_(imgs, save_file, nrow=10, pad_value=1): + if isinstance(imgs, list): + if isinstance(imgs[0], Image.Image): + imgs = [torch.from_numpy(np.array(img)/255.) for img in imgs] + elif isinstance(imgs[0], np.ndarray): + imgs = [torch.from_numpy(img/255.) for img in imgs] + imgs = torch.stack(imgs, 0).permute(0, 3, 1, 2) + if isinstance(imgs, np.ndarray): + imgs = torch.from_numpy(imgs) + + img_grid = make_grid(imgs, nrow=nrow, padding=2, pad_value=pad_value) + img_grid = img_grid.permute(1, 2, 0).numpy() + img_grid = (img_grid * 255).astype(np.uint8) + img_grid = Image.fromarray(img_grid) + img_grid.save(save_file) + +def draw_caption(img, text, pos, size=100, color=(128, 128, 128)): + draw = ImageDraw.Draw(img) + # font = ImageFont.truetype(size= size) + font = ImageFont.load_default() + font = font.font_variant(size=size) + draw.text(pos, text, color, font=font) + return img + + +def txt2json(txt_file, json_file): + with open(txt_file, 'r') as f: + items = f.readlines() + items = [x.strip() for x in items] + + with open(json_file, 'w') as f: + json.dump(items.tolist(), f) + +def process_thuman_texture(): + path = '/aifs4su/mmcode/lipeng/Thuman2.0' + cases = os.listdir(path) + for case in tqdm(cases): + mtl = os.path.join(path, case, 'material0.mtl') + with open(mtl, 'r') as f: + lines = f.read() + lines = lines.replace('png', 'jpeg') + with open(mtl, 'w') as f: + f.write(lines) + + +#### for debug +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + + +def get_intrinsic_from_fov(fov, H, W, bs=-1): + focal_length = 0.5 * H / np.tan(0.5 * fov) + intrinsic = np.identity(3, dtype=np.float32) + intrinsic[0, 0] = focal_length + intrinsic[1, 1] = focal_length + intrinsic[0, 2] = W / 2.0 + intrinsic[1, 2] = H / 2.0 + + if bs > 0: + intrinsic = intrinsic[None].repeat(bs, axis=0) + + return torch.from_numpy(intrinsic) + +def read_data(data_dir, i): + """ + Return: + rgb: (H, W, 3) torch.float32 + depth: (H, W, 1) torch.float32 + mask: (H, W, 1) torch.float32 + c2w: (4, 4) torch.float32 + intrinsic: (3, 3) torch.float32 + """ + background_color = torch.tensor([0.0, 0.0, 0.0]) + + rgb_name = os.path.join(data_dir, f'render_%04d.webp' % i) + depth_name = os.path.join(data_dir, f'depth_%04d.exr' % i) + + img = torch.from_numpy( + np.asarray( + Image.fromarray(imageio.v2.imread(rgb_name)) + .convert("RGBA") + ) + / 255.0 + ).float() + mask = img[:, :, -1:] + rgb = img[:, :, :3] * mask + background_color[ + None, None, : + ] * (1 - mask) + + depth = torch.from_numpy( + cv2.imread(depth_name, cv2.IMREAD_UNCHANGED)[..., 0, None] + ) + mask[depth > 100.0] = 0.0 + depth[~(mask > 0.5)] = 0.0 # set invalid depth to 0 + + meta_path = os.path.join(data_dir, 'meta.json') + with open(meta_path, 'r') as f: + meta = json.load(f) + + c2w = torch.as_tensor( + meta['locations'][i]["transform_matrix"], + dtype=torch.float32, + ) + + H, W = rgb.shape[:2] + fovy = meta["camera_angle_x"] + intrinsic = get_intrinsic_from_fov(fovy, H=H, W=W) + + return rgb, depth, mask, c2w, intrinsic diff --git a/configs/inference-768-6view.yaml b/configs/inference-768-6view.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c82e9ea458a90bc2e4ecc7bbacd5ee48dfc2091c --- /dev/null +++ b/configs/inference-768-6view.yaml @@ -0,0 +1,72 @@ +pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-1-unclip' +revision: null + +num_views: 7 +with_smpl: false +validation_dataset: + prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view + root_dir: 'examples/shhq' + num_views: ${num_views} + bg_color: 'white' + img_wh: [768, 768] + num_validation_samples: 1000 + crop_size: 740 + margin_size: 50 + smpl_folder: 'smpl_image_pymaf' + + +save_dir: 'mv_results' +save_mode: 'rgba' # 'concat', 'rgba', 'rgb' +seed: 42 +validation_batch_size: 1 +dataloader_num_workers: 1 +local_rank: -1 + +pipe_kwargs: + num_views: ${num_views} + +validation_guidance_scales: 3.0 +pipe_validation_kwargs: + num_inference_steps: 40 + eta: 1.0 + +validation_grid_nrow: ${num_views} + +unet_from_pretrained_kwargs: + unclip: true + sdxl: false + num_views: ${num_views} + sample_size: 96 + zero_init_conv_in: false # modify + + projection_camera_embeddings_input_dim: 2 # 2 for elevation and 6 for focal_length + zero_init_camera_projection: false + num_regress_blocks: 3 + + cd_attention_last: false + cd_attention_mid: false + multiview_attention: true + sparse_mv_attention: true + selfattn_block: self_rowwise + mvcd_attention: true + +recon_opt: + res_path: out + save_glb: False + # camera setting + num_view: 6 + scale: 4 + mode: ortho + resolution: 1024 + cam_path: 'mvdiffusion/data/six_human_pose' + # optimization + iters: 700 + clr_iters: 200 + debug: false + snapshot_step: 50 + lr_clr: 2e-3 + gpu_id: 0 + + replace_hand: false + +enable_xformers_memory_efficient_attention: true \ No newline at end of file diff --git a/configs/remesh.yaml b/configs/remesh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..850d91bf9608ae7f242061a2ce35d8f60e340fda --- /dev/null +++ b/configs/remesh.yaml @@ -0,0 +1,18 @@ +res_path: out +save_glb: False +imgs_path: examples/debug +mv_path: ./ +# camera setting +num_view: 6 +scale: 4 +mode: ortho +resolution: 1024 +cam_path: 'mvdiffusion/data/six_human_pose' +# optimization +iters: 700 +clr_iters: 200 +debug: false +snapshot_step: 50 +lr_clr: 2e-3 +gpu_id: 0 +replace_hand: false \ No newline at end of file diff --git a/configs/train-768-6view-onlyscan_face.yaml b/configs/train-768-6view-onlyscan_face.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae0c503a374b83e0eba2ca383042e7efe11d6827 --- /dev/null +++ b/configs/train-768-6view-onlyscan_face.yaml @@ -0,0 +1,145 @@ +pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip +pretrained_unet_path: null +revision: null +with_smpl: false +data_common: + root_dir: /aifs4su/mmcode/lipeng/human_8view_new/ + predict_relative_views: [0, 1, 2, 4, 6, 7] + num_validation_samples: 8 + img_wh: [768, 768] + read_normal: true + read_color: true + read_depth: false + exten: .png + prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view + object_list: + - data_lists/human_only_scan.json + invalid_list: + - +train_dataset: + root_dir: ${data_common.root_dir} + azi_interval: 45.0 + random_views: 3 + predict_relative_views: ${data_common.predict_relative_views} + bg_color: three_choices + object_list: ${data_common.object_list} + invalid_list: ${data_common.invalid_list} + img_wh: ${data_common.img_wh} + validation: false + num_validation_samples: ${data_common.num_validation_samples} + read_normal: ${data_common.read_normal} + read_color: ${data_common.read_color} + read_depth: ${data_common.read_depth} + load_cache: false + exten: ${data_common.exten} + prompt_embeds_path: ${data_common.prompt_embeds_path} + side_views_rate: 0.3 + elevation_list: null +validation_dataset: + prompt_embeds_path: ${data_common.prompt_embeds_path} + root_dir: examples/debug + num_views: ${num_views} + bg_color: white + img_wh: ${data_common.img_wh} + num_validation_samples: 1000 + crop_size: 740 +validation_train_dataset: + root_dir: ${data_common.root_dir} + azi_interval: 45.0 + random_views: 3 + predict_relative_views: ${data_common.predict_relative_views} + bg_color: white + object_list: ${data_common.object_list} + invalid_list: ${data_common.invalid_list} + img_wh: ${data_common.img_wh} + validation: false + num_validation_samples: ${data_common.num_validation_samples} + read_normal: ${data_common.read_normal} + read_color: ${data_common.read_color} + read_depth: ${data_common.read_depth} + num_samples: ${data_common.num_validation_samples} + load_cache: false + exten: ${data_common.exten} + prompt_embeds_path: ${data_common.prompt_embeds_path} + elevation_list: null +output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5 +checkpoint_prefix: ../human_checkpoint_backup/ +seed: 42 +train_batch_size: 2 +validation_batch_size: 1 +validation_train_batch_size: 1 +max_train_steps: 30000 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +learning_rate: 0.0001 +scale_lr: false +lr_scheduler: piecewise_constant +step_rules: 1:2000,0.5 +lr_warmup_steps: 10 +snr_gamma: 5.0 +use_8bit_adam: false +allow_tf32: true +use_ema: true +dataloader_num_workers: 32 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1.0e-08 +max_grad_norm: 1.0 +prediction_type: null +logging_dir: logs +vis_dir: vis +mixed_precision: fp16 +report_to: wandb +local_rank: 0 +checkpointing_steps: 2500 +checkpoints_total_limit: 2 +resume_from_checkpoint: latest +enable_xformers_memory_efficient_attention: true +validation_steps: 2500 # +validation_sanity_check: true +tracker_project_name: PSHuman +trainable_modules: null + + +use_classifier_free_guidance: true +condition_drop_rate: 0.05 +scale_input_latents: true +regress_elevation: false +regress_focal_length: false +elevation_loss_weight: 1.0 +focal_loss_weight: 0.0 +pipe_kwargs: + num_views: ${num_views} +pipe_validation_kwargs: + eta: 1.0 + +unet_from_pretrained_kwargs: + unclip: true + num_views: ${num_views} + sample_size: 96 + zero_init_conv_in: true + regress_elevation: ${regress_elevation} + regress_focal_length: ${regress_focal_length} + num_regress_blocks: 2 + camera_embedding_type: e_de_da_sincos + projection_camera_embeddings_input_dim: 2 + zero_init_camera_projection: true # modified + init_mvattn_with_selfattn: false + cd_attention_last: false + cd_attention_mid: false + multiview_attention: true + sparse_mv_attention: true + selfattn_block: self_rowwise + mvcd_attention: true + addition_downsample: false + use_face_adapter: false + +validation_guidance_scales: +- 3.0 +validation_grid_nrow: ${num_views} +camera_embedding_lr_mult: 1.0 +plot_pose_acc: false +num_views: 7 +pred_type: joint +drop_type: drop_as_a_whole diff --git a/configs/train-768-6view-onlyscan_face_smplx.yaml b/configs/train-768-6view-onlyscan_face_smplx.yaml new file mode 100644 index 0000000000000000000000000000000000000000..924c91630adf0e74460ea60fb1ccfecb27a71bc4 --- /dev/null +++ b/configs/train-768-6view-onlyscan_face_smplx.yaml @@ -0,0 +1,154 @@ +pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip +pretrained_unet_path: null +revision: null +with_smpl: true +data_common: + root_dir: /aifs4su/mmcode/lipeng/human_8view_with_smplx/ + predict_relative_views: [0, 1, 2, 4, 6, 7] + num_validation_samples: 8 + img_wh: [768, 768] + read_normal: true + read_color: true + read_depth: false + exten: .png + prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view + object_list: + - data_lists/human_only_scan_with_smplx.json # modified + invalid_list: + - + with_smpl: ${with_smpl} + +train_dataset: + root_dir: ${data_common.root_dir} + azi_interval: 45.0 + random_views: 0 + predict_relative_views: ${data_common.predict_relative_views} + bg_color: three_choices + object_list: ${data_common.object_list} + invalid_list: ${data_common.invalid_list} + img_wh: ${data_common.img_wh} + validation: false + num_validation_samples: ${data_common.num_validation_samples} + read_normal: ${data_common.read_normal} + read_color: ${data_common.read_color} + read_depth: ${data_common.read_depth} + load_cache: false + exten: ${data_common.exten} + prompt_embeds_path: ${data_common.prompt_embeds_path} + side_views_rate: 0.3 + elevation_list: null + with_smpl: ${with_smpl} + +validation_dataset: + prompt_embeds_path: ${data_common.prompt_embeds_path} + root_dir: examples/debug + num_views: ${num_views} + bg_color: white + img_wh: ${data_common.img_wh} + num_validation_samples: 1000 + margin_size: 10 + # crop_size: 720 + +validation_train_dataset: + root_dir: ${data_common.root_dir} + azi_interval: 45.0 + random_views: 0 + predict_relative_views: ${data_common.predict_relative_views} + bg_color: white + object_list: ${data_common.object_list} + invalid_list: ${data_common.invalid_list} + img_wh: ${data_common.img_wh} + validation: false + num_validation_samples: ${data_common.num_validation_samples} + read_normal: ${data_common.read_normal} + read_color: ${data_common.read_color} + read_depth: ${data_common.read_depth} + num_samples: ${data_common.num_validation_samples} + load_cache: false + exten: ${data_common.exten} + prompt_embeds_path: ${data_common.prompt_embeds_path} + elevation_list: null + with_smpl: ${with_smpl} + +output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5-smplx +checkpoint_prefix: ../human_checkpoint_backup/ +seed: 42 +train_batch_size: 2 +validation_batch_size: 1 +validation_train_batch_size: 1 +max_train_steps: 30000 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +learning_rate: 0.0001 +scale_lr: false +lr_scheduler: piecewise_constant +step_rules: 1:2000,0.5 +lr_warmup_steps: 10 +snr_gamma: 5.0 +use_8bit_adam: false +allow_tf32: true +use_ema: true +dataloader_num_workers: 32 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1.0e-08 +max_grad_norm: 1.0 +prediction_type: null +logging_dir: logs +vis_dir: vis +mixed_precision: fp16 +report_to: wandb +local_rank: 0 +checkpointing_steps: 5000 +checkpoints_total_limit: 2 +resume_from_checkpoint: latest +enable_xformers_memory_efficient_attention: true +validation_steps: 2500 # +validation_sanity_check: true +tracker_project_name: PSHuman +trainable_modules: null + +use_classifier_free_guidance: true +condition_drop_rate: 0.05 +scale_input_latents: true +regress_elevation: false +regress_focal_length: false +elevation_loss_weight: 1.0 +focal_loss_weight: 0.0 +pipe_kwargs: + num_views: ${num_views} +pipe_validation_kwargs: + eta: 1.0 + +unet_from_pretrained_kwargs: + unclip: true + num_views: ${num_views} + sample_size: 96 + zero_init_conv_in: true + regress_elevation: ${regress_elevation} + regress_focal_length: ${regress_focal_length} + num_regress_blocks: 2 + camera_embedding_type: e_de_da_sincos + projection_camera_embeddings_input_dim: 2 + zero_init_camera_projection: true # modified + init_mvattn_with_selfattn: false + cd_attention_last: false + cd_attention_mid: false + multiview_attention: true + sparse_mv_attention: true + selfattn_block: self_rowwise + mvcd_attention: true + addition_downsample: false + use_face_adapter: false + in_channels: 12 + + +validation_guidance_scales: +- 3.0 +validation_grid_nrow: ${num_views} +camera_embedding_lr_mult: 1.0 +plot_pose_acc: false +num_views: 7 +pred_type: joint +drop_type: drop_as_a_whole diff --git a/core/opt.py b/core/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b9964ad155e201b33e33d411908a511578a7fe --- /dev/null +++ b/core/opt.py @@ -0,0 +1,197 @@ +from copy import deepcopy +import time +import torch +import torch_scatter +from core.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges + +@torch.no_grad() +def remesh( + vertices_etc:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + min_edgelen:torch.Tensor, #V + max_edgelen:torch.Tensor, #V + flip:bool, + max_vertices=1e6 + ): + + # dummies + vertices_etc,faces = prepend_dummies(vertices_etc,faces) + vertices = vertices_etc[:,:3] #V,3 + nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device) + min_edgelen = torch.concat((nan_tensor,min_edgelen)) + max_edgelen = torch.concat((nan_tensor,max_edgelen)) + + # collapse + edges,face_to_edge = calc_edges(faces) #E,2 F,3 + edge_length = calc_edge_length(vertices,edges) #E + face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3 + vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3 + face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5) + shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0 + priority = face_collapse.float() + shortness + vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority) + + # split + if vertices.shape[0] max_edgelen[edges].mean(dim=-1) + vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False) + + vertices_etc,faces = pack(vertices_etc,faces) + vertices = vertices_etc[:,:3] + + if flip: + edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3 + flip_edges(vertices,faces,edges,edge_to_face,with_border=False) + + return remove_dummies(vertices_etc,faces) + +def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int): + """lerp with adam's bias correction""" + c_prev = 1-weight**(step-1) + c = 1-weight**step + a_weight = weight*c_prev/c + b_weight = (1-weight)/c + a.mul_(a_weight).add_(b, alpha=b_weight) + + +class MeshOptimizer: + """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh().""" + + def __init__(self, + vertices:torch.Tensor, #V,3 + faces:torch.Tensor, #F,3 + lr=0.3, #learning rate + betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu + gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing) + nu_ref=0.3, #reference velocity for edge length controller + edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length + edge_len_tol=.5, #edge length tolerance for split and collapse + gain=.2, #gain value for edge length controller + laplacian_weight=.02, #for laplacian smoothing/regularization + ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0]) + grad_lim=10., #gradients are clipped to m1.abs()*grad_lim + remesh_interval=1, #larger intervals are faster but with worse mesh quality + local_edgelen=True, #set to False to use a global scalar reference edge length instead + remesh_milestones= [500], #list of steps at which to remesh + # total_steps=1000, #total number of steps + ): + self._vertices = vertices + self._faces = faces + self._lr = lr + self._betas = betas + self._gammas = gammas + self._nu_ref = nu_ref + self._edge_len_lims = edge_len_lims + self._edge_len_tol = edge_len_tol + self._gain = gain + self._laplacian_weight = laplacian_weight + self._ramp = ramp + self._grad_lim = grad_lim + # self._remesh_interval = remesh_interval + # self._remseh_milestones = [ for remesh_milestones] + self._local_edgelen = local_edgelen + self._step = 0 + self._start = time.time() + + V = self._vertices.shape[0] + # prepare continuous tensor for all vertex-based data + self._vertices_etc = torch.zeros([V,9],device=vertices.device) + self._split_vertices_etc() + self.vertices.copy_(vertices) #initialize vertices + self._vertices.requires_grad_() + self._ref_len.fill_(edge_len_lims[1]) + + @property + def vertices(self): + return self._vertices + + @property + def faces(self): + return self._faces + + def _split_vertices_etc(self): + self._vertices = self._vertices_etc[:,:3] + self._m2 = self._vertices_etc[:,3] + self._nu = self._vertices_etc[:,4] + self._m1 = self._vertices_etc[:,5:8] + self._ref_len = self._vertices_etc[:,8] + + with_gammas = any(g!=0 for g in self._gammas) + self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3] + + def zero_grad(self): + self._vertices.grad = None + + @torch.no_grad() + def step(self): + + eps = 1e-8 + + self._step += 1 + # spatial smoothing + edges,_ = calc_edges(self._faces) #E,2 + E = edges.shape[0] + edge_smooth = self._smooth[edges] #E,2,S + neighbor_smooth = torch.zeros_like(self._smooth) #V,S + torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth) + #apply optional smoothing of m1,m2,nu + if self._gammas[0]: + self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0]) + if self._gammas[1]: + self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1]) + if self._gammas[2]: + self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2]) + + #add laplace smoothing to gradients + laplace = self._vertices - neighbor_smooth[:,:3] + grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight) + + #gradient clipping + if self._step>1: + grad_lim = self._m1.abs().mul_(self._grad_lim) + grad.clamp_(min=-grad_lim,max=grad_lim) + + # moment updates + lerp_unbiased(self._m1, grad, self._betas[0], self._step) + lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step) + + velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3 + speed = velocity.norm(dim=-1) #V + + if self._betas[2]: + lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V + else: + self._nu.copy_(speed) #V + # update vertices + ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp) + self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr) + + # update target edge length + if self._step < 500: + self._remesh_interval = 4 + elif self._step < 800: + self._remesh_interval = 2 + else: + self._remesh_interval = 1 + + if self._step % self._remesh_interval == 0: + if self._local_edgelen: + len_change = (1 + (self._nu - self._nu_ref) * self._gain) + else: + len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain) + self._ref_len *= len_change + self._ref_len.clamp_(*self._edge_len_lims) + + def remesh(self, flip:bool=True)->tuple[torch.Tensor,torch.Tensor]: + min_edge_len = self._ref_len * (1 - self._edge_len_tol) + max_edge_len = self._ref_len * (1 + self._edge_len_tol) + + self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip) + + self._split_vertices_etc() + self._vertices.requires_grad_() + + return self._vertices, self._faces diff --git a/core/remesh.py b/core/remesh.py new file mode 100644 index 0000000000000000000000000000000000000000..96fe5d40066d1de5a6a081b80596d135f070cb4d --- /dev/null +++ b/core/remesh.py @@ -0,0 +1,359 @@ +import torch +import torch.nn.functional as tfunc +import torch_scatter + +def prepend_dummies( + vertices:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + )->tuple[torch.Tensor,torch.Tensor]: + """prepend dummy elements to vertices and faces to enable "masked" scatter operations""" + V,D = vertices.shape + vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0) + faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0) + return vertices,faces + +def remove_dummies( + vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced + faces:torch.Tensor, #F,3 long - first face all zeros + )->tuple[torch.Tensor,torch.Tensor]: + """remove dummy elements added with prepend_dummies()""" + return vertices[1:],faces[1:]-1 + + +def calc_edges( + faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros + with_edge_to_face: bool = False + ) -> tuple[torch.Tensor, ...]: + """ + returns tuple of + - edges E,2 long, 0 for unused, lower vertex index first + - face_to_edge F,3 long + - (optional) edge_to_face shape=E,[left,right],[face,side] + + o-<-----e1 e0,e1...edge, e0-o + """ + + F = faces.shape[0] + + # make full edges, lower vertex index first + face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2 + full_edges = face_edges.reshape(F*3,2) + sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 TODO min/max faster? + + # make unique edges + edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3) + E = edges.shape[0] + face_to_edge = full_to_unique.reshape(F,3) #F,3 + + if not with_edge_to_face: + return edges, face_to_edge + + is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3 + edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2 + scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2 + edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2 + edge_to_face[0] = 0 + return edges, face_to_edge, edge_to_face + +def calc_edge_length( + vertices:torch.Tensor, #V,3 first may be dummy + edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused + )->torch.Tensor: #E + + full_vertices = vertices[edges] #E,2,3 + a,b = full_vertices.unbind(dim=1) #E,3 + return torch.norm(a-b,p=2,dim=-1) + +def calc_face_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + normalize:bool=False, + )->torch.Tensor: #F,3 + """ + n + | + c0 corners ordered counterclockwise when + / \ looking onto surface (in neg normal direction) + c1---c2 + """ + full_vertices = vertices[faces] #F,C=3,3 + v0,v1,v2 = full_vertices.unbind(dim=1) #F,3 + face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3 + if normalize: + face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) #TODO inplace? + return face_normals #F,3 + +def calc_vertex_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + face_normals:torch.Tensor=None, #F,3, not normalized + )->torch.Tensor: #F,3 + + F = faces.shape[0] + + if face_normals is None: + face_normals = calc_face_normals(vertices,faces) + + vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3 + vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) + vertex_normals = vertex_normals.sum(dim=1) #V,3 + return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) + +def calc_face_ref_normals( + faces:torch.Tensor, #F,3 long, 0 for unused + vertex_normals:torch.Tensor, #V,3 first unused + normalize:bool=False, + )->torch.Tensor: #F,3 + """calculate reference normals for face flip detection""" + full_normals = vertex_normals[faces] #F,C=3,3 + ref_normals = full_normals.sum(dim=1) #F,3 + if normalize: + ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1) + return ref_normals + +def pack( + vertices:torch.Tensor, #V,3 first unused and nan + faces:torch.Tensor, #F,3 long, 0 for unused + )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused + """removes unused elements in vertices and faces""" + V = vertices.shape[0] + + # remove unused faces + used_faces = faces[:,0]!=0 + used_faces[0] = True + faces = faces[used_faces] #sync + + # remove unused vertices + used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device) + used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') #TODO int faster? + used_vertices = used_vertices.any(dim=1) + used_vertices[0] = True + vertices = vertices[used_vertices] #sync + + # update used faces + ind = torch.zeros(V,dtype=torch.long,device=vertices.device) + V1 = used_vertices.sum() + ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync + faces = ind[faces] + + return vertices,faces + +def split_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + splits, #E bool + pack_faces:bool=True, + )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + # c2 c2 c...corners = faces + # . . . . s...side_vert, 0 means no split + # . . .N2 . S...shrunk_face + # . . . . Ni...new_faces + # s2 s1 s2|c2...s1|c1 + # . . . . . + # . . . S . . + # . . . . N1 . + # c0...(s0=0)....c1 s0|c0...........c1 + # + # pseudo-code: + # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2] + # split = side_vert!=0 example:[False,True,True] + # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0] + # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0] + # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1] + + V = vertices.shape[0] + F = faces.shape[0] + S = splits.sum().item() #sync + + if S==0: + return vertices,faces + + edge_vert = torch.zeros_like(splits, dtype=torch.long) #E + edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync + side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split + split_edges = edges[splits] #S sync + + #vertices + split_vertices = vertices[split_edges].mean(dim=1) #S,3 + vertices = torch.concat((vertices,split_vertices),dim=0) + + #faces + side_split = side_vert!=0 #F,3 + shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split + new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3 + faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3 + if pack_faces: + mask = faces[:,0]!=0 + mask[0] = True + faces = faces[mask] #F',3 sync + + return vertices,faces + +def collapse_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + priorities:torch.Tensor, #E float + stable:bool=False, #only for unit testing + )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + V = vertices.shape[0] + + # check spacing + _,order = priorities.sort(stable=stable) #E + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + edge_rank = rank #E + for i in range(3): + torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank) + edge_rank,_ = vert_rank[edges].max(dim=-1) #E + candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2 + + # check connectivity + vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + vert_connections[candidates[:,0]] = 1 #start + edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start + vert_connections[candidates] = 0 #clear start and end + edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start + collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end + + # mean vertices + vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) #TODO dim? + + # update faces + dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V + dest[collapses[:,1]] = dest[collapses[:,0]] + faces = dest[faces] #F,3 TODO optimize? + c0,c1,c2 = faces.unbind(dim=-1) + collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2) + faces[collapsed] = 0 + + return vertices,faces + +def calc_face_collapses( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + edge_length:torch.Tensor, #E + face_normals:torch.Tensor, #F,3 + vertex_normals:torch.Tensor, #V,3 first unused + min_edge_length:torch.Tensor=None, #V + area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio + shortest_probability = 0.8 + )->torch.Tensor: #E edges to collapse + + E = edges.shape[0] + F = faces.shape[0] + + # face flips + ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3 + face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F + + # small faces + if min_edge_length is not None: + min_face_length = min_edge_length[faces].mean(dim=-1) #F + min_area = min_face_length**2 * area_ratio #F + face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F + face_collapses[0] = False + + # faces to edges + face_length = edge_length[face_to_edge] #F,3 + + if shortest_probability<1: + #select shortest edge with shortest_probability chance + randlim = round(2/(1-shortest_probability)) + rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face + sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3 + local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None]) + else: + local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face + + edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index + edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device) + edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) #TODO legal for bool? + + return edge_collapses.bool() + +def flip_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused + edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first + edge_to_face:torch.Tensor, #E,[left,right],[face,side] + with_border:bool=True, #handle border edges (D=4 instead of D=6) + with_normal_check:bool=True, #check face normal flips + stable:bool=False, #only for unit testing + ): + V = vertices.shape[0] + E = edges.shape[0] + device=vertices.device + vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long + vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add') + neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner + neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2 + edge_is_inside = neighbors.all(dim=-1) #E + + if with_border: + # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices + # need to use float for masks in order to use scatter(reduce='multiply') + vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float + src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float + vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply') + vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long + vertex_degree -= 2 * vertex_is_inside #V long + + neighbor_degrees = vertex_degree[neighbors] #E,LR=2 + edge_degrees = vertex_degree[edges] #E,2 + # + # loss = Sum_over_affected_vertices((new_degree-6)**2) + # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2) + # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2) + # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree)) + # + loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E + candidates = torch.logical_and(loss_change<0, edge_is_inside) #E + loss_change = loss_change[candidates] #E' + if loss_change.shape[0]==0: + return + + edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4 + _,order = loss_change.sort(descending=True, stable=stable) #E' + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4 + torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank) + vertex_rank,_ = vertex_rank.max(dim=-1) #V + neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E' + flip = rank==neighborhood_rank #E' + + if with_normal_check: + # cl-<-----e1 e0,e1...edge, e0-cr + v = vertices[edges_neighbors] #E",4,3 + v = v - v[:,0:1] #make relative to e0 + e1 = v[:,1] + cl = v[:,2] + cr = v[:,3] + n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors + flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face + flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face + + flip_edges_neighbors = edges_neighbors[flip] #E",4 + flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2 + flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3 + faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3)) diff --git a/econdataset.py b/econdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b79d141946b22d86074202ce1287e25137222de6 --- /dev/null +++ b/econdataset.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam +from lib.pixielib.utils.config import cfg as pixie_cfg +from lib.pixielib.pixie import PIXIE +import lib.smplx as smplx +# from lib.pare.pare.core.tester import PARETester +from lib.pymaf.utils.geometry import rot6d_to_rotmat, batch_rodrigues, rotation_matrix_to_angle_axis +from lib.pymaf.utils.imutils import process_image +from lib.common.imutils import econ_process_image +from lib.pymaf.core import path_config +from lib.pymaf.models import pymaf_net +from lib.common.config import cfg +from lib.common.render import Render +from lib.dataset.body_model import TetraSMPLModel +from lib.dataset.mesh_util import get_visibility +from utils.smpl_util import SMPLX +import os.path as osp +import os +import torch +import numpy as np +import random +from termcolor import colored +from PIL import ImageFile +from torchvision.models import detection + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class SMPLDataset(): + + def __init__(self, cfg, device): + + random.seed(1993) + + self.image_dir = cfg['image_dir'] + self.seg_dir = cfg['seg_dir'] + self.hps_type = cfg['hps_type'] + self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx' + self.smpl_gender = 'neutral' + self.colab = cfg['colab'] + + self.device = device + + keep_lst = [f"{self.image_dir}/{i}" for i in sorted(os.listdir(self.image_dir))] + img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp'] + keep_lst = [item for item in keep_lst if item.split(".")[-1] in img_fmts] + + self.subject_list = [item for item in keep_lst if item.split(".")[-1] in img_fmts] + + if self.colab: + self.subject_list = [self.subject_list[0]] + + # smpl related + self.smpl_data = SMPLX() + + # smpl-smplx correspondence + self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73] + self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [68 + 61, 72 + 68] + self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(model_path=self.smpl_data. + model_dir, + gender=smpl_gender, + model_type=smpl_type, + ext='npz') + + # Load SMPL model + self.smpl_model = self.get_smpl_model(self.smpl_type, self.smpl_gender).to(self.device) + self.faces = self.smpl_model.faces + + if self.hps_type == 'pymaf': + self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device) + self.hps.load_state_dict(torch.load(path_config.CHECKPOINT_FILE)['model'], strict=True) + self.hps.eval() + + elif self.hps_type == 'pare': + self.hps = PARETester(path_config.CFG, path_config.CKPT).model + elif self.hps_type == 'pixie': + self.hps = PIXIE(config=pixie_cfg, device=self.device) + self.smpl_model = self.hps.smplx + elif self.hps_type == 'hybrik': + smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl") + self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG, + smpl_path=smpl_path, + data_path=path_config.hybrik_data_dir) + self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'), + strict=False) + self.hps.to(self.device) + elif self.hps_type == 'bev': + try: + import bev + except: + print('Could not find bev, installing via pip install --upgrade simple-romp') + os.system('pip install simple-romp==1.0.3') + import bev + settings = bev.main.default_settings + # change the argparse settings of bev here if you prefer other settings. + settings.mode = 'image' + settings.GPU = int(str(self.device).split(':')[1]) + settings.show_largest = True + # settings.show = True # uncommit this to show the original BEV predictions + self.hps = bev.BEV(settings) + + self.detector=detection.maskrcnn_resnet50_fpn(pretrained=True) + self.detector.eval() + print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green")) + + self.render = Render(size=512, device=device) + + def __len__(self): + return len(self.subject_list) + + def compute_vis_cmap(self, smpl_verts, smpl_faces): + + (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1) + smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long()) + smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type) + + return { + 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device), + 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device), + 'smpl_verts': smpl_verts.unsqueeze(0) + } + + def compute_voxel_verts(self, body_pose, global_orient, betas, trans, scale): + + smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl") + tetra_path = osp.join(self.smpl_data.tedra_dir, 'tetra_neutral_adult_smpl.npz') + smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult') + + pose = torch.cat([global_orient[0], body_pose[0]], dim=0) + smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0]) + + verts = np.concatenate([smpl_model.verts, smpl_model.verts_added], + axis=0) * scale.item() + trans.detach().cpu().numpy() + faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir, 'tetrahedrons_neutral_adult.txt'), + dtype=np.int32) - 1 + + pad_v_num = int(8000 - verts.shape[0]) + pad_f_num = int(25100 - faces.shape[0]) + + verts = np.pad(verts, + ((0, pad_v_num), + (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) * 0.5 + faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant', + constant_values=0.0).astype(np.int32) + + verts[:, 2] *= -1.0 + + voxel_dict = { + 'voxel_verts': torch.from_numpy(verts).to(self.device).unsqueeze(0).float(), + 'voxel_faces': torch.from_numpy(faces).to(self.device).unsqueeze(0).long(), + 'pad_v_num': torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(), + 'pad_f_num': torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long() + } + + return voxel_dict + + def __getitem__(self, index): + + img_path = self.subject_list[index] + img_name = img_path.split("/")[-1].rsplit(".", 1)[0] + print(img_name) + # smplx_param_path=f'./data/thuman2/smplx/{img_name[:-2]}.pkl' + # smplx_param = np.load(smplx_param_path, allow_pickle=True) + + if self.seg_dir is None: + img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image( + img_path, self.hps_type, 512, self.device) + + data_dict = { + 'name': img_name, + 'image': img_icon.to(self.device).unsqueeze(0), + 'ori_image': img_ori, + 'mask': img_mask, + 'uncrop_param': uncrop_param + } + + else: + img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image( + img_path, + self.hps_type, + 512, + self.device, + seg_path=os.path.join(self.seg_dir, f'{img_name}.json')) + data_dict = { + 'name': img_name, + 'image': img_icon.to(self.device).unsqueeze(0), + 'ori_image': img_ori, + 'mask': img_mask, + 'uncrop_param': uncrop_param, + 'segmentations': segmentations + } + + arr_dict=econ_process_image(img_path,self.hps_type,True,512,self.detector) + data_dict['hands_visibility']=arr_dict['hands_visibility'] + + with torch.no_grad(): + # import ipdb; ipdb.set_trace() + preds_dict = self.hps.forward(img_hps) + + data_dict['smpl_faces'] = torch.Tensor(self.faces.astype(np.int64)).long().unsqueeze(0).to( + self.device) + + if self.hps_type == 'pymaf': + output = preds_dict['smpl_out'][-1] + scale, tranX, tranY = output['theta'][0, :3] + data_dict['betas'] = output['pred_shape'] + data_dict['body_pose'] = output['rotmat'][:, 1:] + data_dict['global_orient'] = output['rotmat'][:, 0:1] + data_dict['smpl_verts'] = output['verts'] # 不确定尺度是否一样 + data_dict["type"] = "smpl" + + elif self.hps_type == 'pare': + data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:] + data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1] + data_dict['betas'] = preds_dict['pred_shape'] + data_dict['smpl_verts'] = preds_dict['smpl_vertices'] + scale, tranX, tranY = preds_dict['pred_cam'][0, :3] + data_dict["type"] = "smpl" + + elif self.hps_type == 'pixie': + data_dict.update(preds_dict) + data_dict['body_pose'] = preds_dict['body_pose'] + data_dict['global_orient'] = preds_dict['global_pose'] + data_dict['betas'] = preds_dict['shape'] + data_dict['smpl_verts'] = preds_dict['vertices'] + scale, tranX, tranY = preds_dict['cam'][0, :3] + data_dict["type"] = "smplx" + + elif self.hps_type == 'hybrik': + data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:] + data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]] + data_dict['betas'] = preds_dict['pred_shape'] + data_dict['smpl_verts'] = preds_dict['pred_vertices'] + scale, tranX, tranY = preds_dict['pred_camera'][0, :3] + scale = scale * 2 + data_dict["type"] = "smpl" + + elif self.hps_type == 'bev': + data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[[0], :10].to( + self.device).float() + pred_thetas = batch_rodrigues( + torch.from_numpy(preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float() + data_dict['body_pose'] = pred_thetas[1:][None].to(self.device) + data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device) + data_dict['smpl_verts'] = torch.from_numpy(preds_dict['verts'][[0]]).to( + self.device).float() + tranX = preds_dict['cam_trans'][0, 0] + tranY = preds_dict['cam'][0, 1] + 0.28 + scale = preds_dict['cam'][0, 0] * 1.1 + data_dict["type"] = "smpl" + + data_dict['scale'] = scale + data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).unsqueeze(0).to(self.device).float() + + # data_dict info (key-shape): + # scale, tranX, tranY - tensor.float + # betas - [1,10] / [1, 200] + # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3] + # global_orient - [1, 1, 3, 3] + # smpl_verts - [1, 6890, 3] / [1, 10475, 3] + + # from rot_mat to rot_6d for better optimization + N_body = data_dict["body_pose"].shape[1] + data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body, -1) + data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1, -1) + + return data_dict + + def render_normal(self, verts, faces): + + # render optimized mesh (normal, T_normal, image [-1,1]) + self.render.load_meshes(verts, faces) + return self.render.get_rgb_image() + + def render_depth(self, verts, faces): + + # render optimized mesh (normal, T_normal, image [-1,1]) + self.render.load_meshes(verts, faces) + return self.render.get_depth_map(cam_ids=[0, 2]) + + def visualize_alignment(self, data): + + import vedo + import trimesh + + if self.hps_type != 'pixie': + smpl_out = self.smpl_model(betas=data['betas'], + body_pose=data['body_pose'], + global_orient=data['global_orient'], + pose2rot=False) + smpl_verts = ((smpl_out.vertices + data['trans']) * + data['scale']).detach().cpu().numpy()[0] + else: + smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'], + expression_params=data['exp'], + body_pose=data['body_pose'], + global_pose=data['global_orient'], + jaw_pose=data['jaw_pose'], + left_hand_pose=data['left_hand_pose'], + right_hand_pose=data['right_hand_pose']) + + smpl_verts = ((smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0] + + smpl_verts *= np.array([1.0, -1.0, -1.0]) + faces = data['smpl_faces'][0].detach().cpu().numpy() + + image_P = data['image'] + image_F, image_B = self.render_normal(smpl_verts, faces) + + # create plot + vp = vedo.Plotter(title="", size=(1500, 1500)) + vis_list = [] + + image_F = (0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) + image_B = (0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) + image_P = (0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) + + vis_list.append( + vedo.Picture(image_P * 0.5 + image_F * 0.5).scale(2.0 / image_P.shape[0]).pos( + -1.0, -1.0, 1.0)) + vis_list.append(vedo.Picture(image_F).scale(2.0 / image_F.shape[0]).pos(-1.0, -1.0, -0.5)) + vis_list.append(vedo.Picture(image_B).scale(2.0 / image_B.shape[0]).pos(-1.0, -1.0, -1.0)) + + # create a mesh + mesh = trimesh.Trimesh(smpl_verts, faces, process=False) + mesh.visual.vertex_colors = [200, 200, 0] + vis_list.append(mesh) + + vp.show(*vis_list, bg="white", axes=1, interactive=True) + + +if __name__ == '__main__': + + cfg.merge_from_file("./configs/icon-filter.yaml") + cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml') + + cfg_show_list = ['test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False] + + cfg.merge_from_list(cfg_show_list) + cfg.freeze() + + + device = torch.device('cuda:0') + + dataset = SMPLDataset( + { + 'image_dir': "./examples", + 'has_det': True, # w/ or w/o detection + 'hps_type': 'bev' # pymaf/pare/pixie/hybrik/bev + }, + device) + + for i in range(len(dataset)): + dataset.visualize_alignment(dataset[i]) diff --git a/examples/02986d0998ce01aa0aa67a99fbd1e09a.png b/examples/02986d0998ce01aa0aa67a99fbd1e09a.png new file mode 100644 index 0000000000000000000000000000000000000000..cb5027a1512ce86c9277bb751161745870632846 Binary files /dev/null and b/examples/02986d0998ce01aa0aa67a99fbd1e09a.png differ diff --git a/examples/16171.png b/examples/16171.png new file mode 100644 index 0000000000000000000000000000000000000000..d425579ac3d6df5c5a7831aa04d89258d02a5ac8 Binary files /dev/null and b/examples/16171.png differ diff --git a/examples/26d2e846349647ff04c536816e0e8ca1.png b/examples/26d2e846349647ff04c536816e0e8ca1.png new file mode 100644 index 0000000000000000000000000000000000000000..a0fd76c3b667d526a2efff899a6f2a66c9ace221 Binary files /dev/null and b/examples/26d2e846349647ff04c536816e0e8ca1.png differ diff --git a/examples/30755.png b/examples/30755.png new file mode 100644 index 0000000000000000000000000000000000000000..d989d2d19d1bb82ab8d83cf018cc81d8f9d00fe2 Binary files /dev/null and b/examples/30755.png differ diff --git a/examples/3930.png b/examples/3930.png new file mode 100644 index 0000000000000000000000000000000000000000..cf82a094858438f93c73ed1864deebb95ec5e624 Binary files /dev/null and b/examples/3930.png differ diff --git a/examples/4656716-3016170581.png b/examples/4656716-3016170581.png new file mode 100644 index 0000000000000000000000000000000000000000..d61dd293a11e94cf194a81df01c9bd8fa44a55ff Binary files /dev/null and b/examples/4656716-3016170581.png differ diff --git a/examples/663dcd6db19490de0b790da430bd5681.png b/examples/663dcd6db19490de0b790da430bd5681.png new file mode 100644 index 0000000000000000000000000000000000000000..64576a2de9e19a04457590055595dedf1a82ec97 --- /dev/null +++ b/examples/663dcd6db19490de0b790da430bd5681.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b499922b6df6d6874fea68c571ff3271f68aa6bc40420396f4898e5c58d74dc8 +size 1000047 diff --git a/examples/7332.png b/examples/7332.png new file mode 100644 index 0000000000000000000000000000000000000000..09e5943fdc2e0f884ed9dd3b19eea6fcc3962f0e Binary files /dev/null and b/examples/7332.png differ diff --git a/examples/85891251f52a2399e660a63c2a7fdf40.png b/examples/85891251f52a2399e660a63c2a7fdf40.png new file mode 100644 index 0000000000000000000000000000000000000000..c1b20ab06ea3ca3e640093c6c2349d77e0c5f0f0 Binary files /dev/null and b/examples/85891251f52a2399e660a63c2a7fdf40.png differ diff --git a/examples/a689a48d23d6b8d58d67ff5146c6e088.png b/examples/a689a48d23d6b8d58d67ff5146c6e088.png new file mode 100644 index 0000000000000000000000000000000000000000..807b81ba0cdaa8e4e4c0fa1bea4979a5fbbaf3fb Binary files /dev/null and b/examples/a689a48d23d6b8d58d67ff5146c6e088.png differ diff --git a/examples/b0d178743c7e3e09700aaee8d2b1ec47.png b/examples/b0d178743c7e3e09700aaee8d2b1ec47.png new file mode 100644 index 0000000000000000000000000000000000000000..eab8122926ced37d350b1131a5868a7bf79c9723 Binary files /dev/null and b/examples/b0d178743c7e3e09700aaee8d2b1ec47.png differ diff --git a/examples/case5.png b/examples/case5.png new file mode 100644 index 0000000000000000000000000000000000000000..c12614992d55d6edae86323c6bc39472d7ab167f Binary files /dev/null and b/examples/case5.png differ diff --git a/examples/d40776a1e1582179d97907d36f84d776.png b/examples/d40776a1e1582179d97907d36f84d776.png new file mode 100644 index 0000000000000000000000000000000000000000..4b13198ed5dea16c5fc40a7825fc613be4d5217c Binary files /dev/null and b/examples/d40776a1e1582179d97907d36f84d776.png differ diff --git a/examples/durant.png b/examples/durant.png new file mode 100644 index 0000000000000000000000000000000000000000..5c5fb2acd75c90fb1c611b43f9829128b1391cab Binary files /dev/null and b/examples/durant.png differ diff --git a/examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png b/examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png new file mode 100644 index 0000000000000000000000000000000000000000..7c5b95813598325e74197fe8e8538840ccac840a Binary files /dev/null and b/examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png differ diff --git a/examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png b/examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png new file mode 100644 index 0000000000000000000000000000000000000000..1fe35ebf6c234820bdc5a1c9c882b0ccd7c82ff5 Binary files /dev/null and b/examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png differ diff --git a/examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png b/examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png new file mode 100644 index 0000000000000000000000000000000000000000..e427b4d226abd655ea9fa747e90d96ef598c3041 Binary files /dev/null and b/examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png differ diff --git a/examples/pexels-barbara-olsen-7869640.png b/examples/pexels-barbara-olsen-7869640.png new file mode 100644 index 0000000000000000000000000000000000000000..7bb87aba4b7c8b5a459acd40094b85cf03c4218c Binary files /dev/null and b/examples/pexels-barbara-olsen-7869640.png differ diff --git a/examples/pexels-julia-m-cameron-4145040.png b/examples/pexels-julia-m-cameron-4145040.png new file mode 100644 index 0000000000000000000000000000000000000000..ca708ea668cfe011089fc111b71538bd39b1b52f Binary files /dev/null and b/examples/pexels-julia-m-cameron-4145040.png differ diff --git a/examples/pexels-marta-wave-6437749.png b/examples/pexels-marta-wave-6437749.png new file mode 100644 index 0000000000000000000000000000000000000000..60cf71c5fabdc04896bb5b709f44543da7398edf Binary files /dev/null and b/examples/pexels-marta-wave-6437749.png differ diff --git a/examples/pexels-photo-6311555-removebg.png b/examples/pexels-photo-6311555-removebg.png new file mode 100644 index 0000000000000000000000000000000000000000..4e6cdcb2bcf606df05919de32b03ef72489353d5 Binary files /dev/null and b/examples/pexels-photo-6311555-removebg.png differ diff --git a/examples/pexels-zdmit-6780091.png b/examples/pexels-zdmit-6780091.png new file mode 100644 index 0000000000000000000000000000000000000000..c41077185288414f081765881a3017ec6b9965e7 Binary files /dev/null and b/examples/pexels-zdmit-6780091.png differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b10a36d7254e6b1870de472a629931298f0bf8 --- /dev/null +++ b/inference.py @@ -0,0 +1,221 @@ +import argparse +import os +from typing import Dict, Optional, Tuple, List +from omegaconf import OmegaConf +from PIL import Image +from dataclasses import dataclass +from collections import defaultdict +import torch +import torch.utils.checkpoint +from torchvision.utils import make_grid, save_image +from accelerate.utils import set_seed +from tqdm.auto import tqdm +import torch.nn.functional as F +from einops import rearrange +from rembg import remove, new_session +import pdb +from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline +from econdataset import SMPLDataset +from reconstruct import ReMesh +providers = [ + ('CUDAExecutionProvider', { + 'device_id': 0, + 'arena_extend_strategy': 'kSameAsRequested', + 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, + 'cudnn_conv_algo_search': 'HEURISTIC', + }) +] +session = new_session(providers=providers) + +weight_dtype = torch.float16 +def tensor_to_numpy(tensor): + return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + + +@dataclass +class TestConfig: + pretrained_model_name_or_path: str + revision: Optional[str] + validation_dataset: Dict + save_dir: str + seed: Optional[int] + validation_batch_size: int + dataloader_num_workers: int + # save_single_views: bool + save_mode: str + local_rank: int + + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_guidance_scales: float + validation_grid_nrow: int + + num_views: int + enable_xformers_memory_efficient_attention: bool + with_smpl: Optional[bool] + + recon_opt: Dict + + +def convert_to_numpy(tensor): + return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + +def convert_to_pil(tensor): + return Image.fromarray(convert_to_numpy(tensor)) + +def save_image(tensor, fp): + ndarr = convert_to_numpy(tensor) + # pdb.set_trace() + save_image_numpy(ndarr, fp) + return ndarr + +def save_image_numpy(ndarr, fp): + im = Image.fromarray(ndarr) + im.save(fp) + +def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir): + pipeline.set_progress_bar_config(disable=True) + + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) + + images_cond, pred_cat = [], defaultdict(list) + for case_id, batch in tqdm(enumerate(dataloader)): + images_cond.append(batch['imgs_in'][:, 0]) + + imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) + num_views = imgs_in.shape[1] + imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) + if cfg.with_smpl: + smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0) + smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W") + else: + smpl_in = None + + normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings'] + prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) + prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") + + with torch.autocast("cuda"): + # B*Nv images + guidance_scale = cfg.validation_guidance_scales + unet_out = pipeline( + imgs_in, None, prompt_embeds=prompt_embeddings, + dino_feature=None, smpl_in=smpl_in, + generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, + **cfg.pipe_validation_kwargs + ) + + out = unet_out.images + bsz = out.shape[0] // 2 + + normals_pred = out[:bsz] + images_pred = out[bsz:] + if cfg.save_mode == 'concat': ## save concatenated color and normal--------------------- + pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) # b, 3, h, w + cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}") + os.makedirs(cur_dir, exist_ok=True) + for i in range(bsz//num_views): + scene = batch['filename'][i].split('.')[0] + + img_in_ = images_cond[-1][i].to(out.device) + vis_ = [img_in_] + for j in range(num_views): + idx = i*num_views + j + normal = normals_pred[idx] + color = images_pred[idx] + + vis_.append(color) + vis_.append(normal) + + out_filename = f"{cur_dir}/{scene}.png" + vis_ = torch.stack(vis_, dim=0) + vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1)) + save_image(vis_, out_filename) + elif cfg.save_mode == 'rgb': + for i in range(bsz//num_views): + scene = batch['filename'][i].split('.')[0] + + img_in_ = images_cond[-1][i].to(out.device) + normals, colors = [], [] + for j in range(num_views): + idx = i*num_views + j + normal = normals_pred[idx] + if j == 0: + color = imgs_in[0].to(out.device) + else: + color = images_pred[idx] + if j in [3, 4]: + normal = torch.flip(normal, dims=[2]) + color = torch.flip(color, dims=[2]) + + colors.append(color) + if j == 6: + normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0) + normals.append(normal) + + ## save color and normal--------------------- + # normal_filename = f"normals_{view}_masked.png" + # rgb_filename = f"color_{view}_masked.png" + # save_image(normal, os.path.join(scene_dir, normal_filename)) + # save_image(color, os.path.join(scene_dir, rgb_filename)) + normals[0][:, :256, 256:512] = normals[-1] + + colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]] + normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]] + pose = econdata.__getitem__(case_id) + carving.optimize_case(scene, pose, colors, normals) + torch.cuda.empty_cache() + + + +def load_pshuman_pipeline(cfg): + pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype) + pipeline.unet.enable_xformers_memory_efficient_attention() + if torch.cuda.is_available(): + pipeline.to('cuda') + return pipeline + +def main( + cfg: TestConfig +): + + # If passed along, set the training seed now. + if cfg.seed is not None: + set_seed(cfg.seed) + pipeline = load_pshuman_pipeline(cfg) + + + if cfg.with_smpl: + from mvdiffusion.data.testdata_with_smpl import SingleImageDataset + else: + from mvdiffusion.data.single_image_dataset import SingleImageDataset + + # Get the dataset + validation_dataset = SingleImageDataset( + **cfg.validation_dataset + ) + validation_dataloader = torch.utils.data.DataLoader( + validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers + ) + dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} + econdata = SMPLDataset(dataset_param, device='cuda') + + carving = ReMesh(cfg.recon_opt, econ_dataset=econdata) + run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + args, extras = parser.parse_known_args() + from utils.misc import load_config + + # parse YAML config to OmegaConf + cfg = load_config(args.config, cli_args=extras) + schema = OmegaConf.structured(TestConfig) + cfg = OmegaConf.merge(schema, cfg) + main(cfg) diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/common/__init__.py b/lib/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/common/cloth_extraction.py b/lib/common/cloth_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..7661a8c7adb3e02dfdf251aafbaed1ef98fa00a9 --- /dev/null +++ b/lib/common/cloth_extraction.py @@ -0,0 +1,182 @@ +import numpy as np +import json +import os +import itertools +import trimesh +from matplotlib.path import Path +from collections import Counter +from sklearn.neighbors import KNeighborsClassifier + + +def load_segmentation(path, shape): + """ + Get a segmentation mask for a given image + Arguments: + path: path to the segmentation json file + shape: shape of the output mask + Returns: + Returns a segmentation mask + """ + with open(path) as json_file: + dict = json.load(json_file) + segmentations = [] + for key, val in dict.items(): + if not key.startswith('item'): + continue + + # Each item can have multiple polygons. Combine them to one + # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation'])) + # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int) + + coordinates = [] + for segmentation_coord in val['segmentation']: + # The format before is [x1,y1, x2, y2, ....] + x = segmentation_coord[::2] + y = segmentation_coord[1::2] + xy = np.vstack((x, y)).T + coordinates.append(xy) + + segmentations.append({ + 'type': val['category_name'], + 'type_id': val['category_id'], + 'coordinates': coordinates + }) + + return segmentations + + +def smpl_to_recon_labels(recon, smpl, k=1): + """ + Get the bodypart labels for the recon object by using the labels from the corresponding smpl object + Arguments: + recon: trimesh object (fully clothed model) + shape: trimesh object (smpl model) + k: number of nearest neighbours to use + Returns: + Returns a dictionary containing the bodypart and the corresponding indices + """ + smpl_vert_segmentation = json.load( + open( + os.path.join(os.path.dirname(__file__), + 'smpl_vert_segmentation.json'))) + n = smpl.vertices.shape[0] + y = np.array([None] * n) + for key, val in smpl_vert_segmentation.items(): + y[val] = key + + classifier = KNeighborsClassifier(n_neighbors=1) + classifier.fit(smpl.vertices, y) + + y_pred = classifier.predict(recon.vertices) + + recon_labels = {} + for key in smpl_vert_segmentation.keys(): + recon_labels[key] = list( + np.argwhere(y_pred == key).flatten().astype(int)) + + return recon_labels + + +def extract_cloth(recon, segmentation, K, R, t, smpl=None): + """ + Extract a portion of a mesh using 2d segmentation coordinates + Arguments: + recon: fully clothed mesh + seg_coord: segmentation coordinates in 2D (NDC) + K: intrinsic matrix of the projection + R: rotation matrix of the projection + t: translation vector of the projection + Returns: + Returns a submesh using the segmentation coordinates + """ + seg_coord = segmentation['coord_normalized'] + mesh = trimesh.Trimesh(recon.vertices, recon.faces) + extrinsic = np.zeros((3, 4)) + extrinsic[:3, :3] = R + extrinsic[:, 3] = t + P = K[:3, :3] @ extrinsic + + P_inv = np.linalg.pinv(P) + + # Each segmentation can contain multiple polygons + # We need to check them separately + points_so_far = [] + faces = recon.faces + for polygon in seg_coord: + n = len(polygon) + coords_h = np.hstack((polygon, np.ones((n, 1)))) + # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates + XYZ = P_inv @ coords_h[:, :, None] + XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1])) + XYZ = XYZ[:, :3] / XYZ[:, 3, None] + + p = Path(XYZ[:, :2]) + + grid = p.contains_points(recon.vertices[:, :2]) + indeces = np.argwhere(grid == True) + points_so_far += list(indeces.flatten()) + + if smpl is not None: + num_verts = recon.vertices.shape[0] + recon_labels = smpl_to_recon_labels(recon, smpl) + body_parts_to_remove = [ + 'rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head', + 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', + 'rightHand' + ] + type = segmentation['type_id'] + + # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso) + # https://github.com/switchablenorms/DeepFashion2 + # Short sleeve clothes + if type == 1 or type == 3 or type == 10: + body_parts_to_remove += ['leftForeArm', 'rightForeArm'] + # No sleeves at all or lower body clothes + elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9: + body_parts_to_remove += [ + 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm' + ] + # Shorts + elif type == 7: + body_parts_to_remove += [ + 'leftLeg', 'rightLeg', 'leftForeArm', 'rightForeArm', + 'leftArm', 'rightArm' + ] + + verts_to_remove = list( + itertools.chain.from_iterable( + [recon_labels[part] for part in body_parts_to_remove])) + + label_mask = np.zeros(num_verts, dtype=bool) + label_mask[verts_to_remove] = True + + seg_mask = np.zeros(num_verts, dtype=bool) + seg_mask[points_so_far] = True + + # Remove points that belong to other bodyparts + # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed + extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask)) + + combine_mask = np.zeros(num_verts, dtype=bool) + combine_mask[points_so_far] = True + combine_mask[extra_verts_to_remove] = False + + all_indices = np.argwhere(combine_mask == True).flatten() + + i_x = np.where(np.in1d(faces[:, 0], all_indices))[0] + i_y = np.where(np.in1d(faces[:, 1], all_indices))[0] + i_z = np.where(np.in1d(faces[:, 2], all_indices))[0] + + faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z))) + mask = np.zeros(len(recon.faces), dtype=bool) + if len(faces_to_keep) > 0: + mask[faces_to_keep] = True + + mesh.update_faces(mask) + mesh.remove_unreferenced_vertices() + + # mesh.rezero() + + return mesh + + return None diff --git a/lib/common/config.py b/lib/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e60dfa2c7e56a890c08da0327cdb69ef62dc0b5e --- /dev/null +++ b/lib/common/config.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from yacs.config import CfgNode as CN +import os + +_C = CN(new_allowed=True) + +# needed by trainer +_C.name = 'default' +_C.gpus = [0] +_C.test_gpus = [1] +_C.root = "./data/" +_C.ckpt_dir = './data/ckpt/' +_C.resume_path = '' +_C.normal_path = '' +_C.corr_path = '' +_C.results_path = './data/results/' +_C.projection_mode = 'orthogonal' +_C.num_views = 1 +_C.sdf = False +_C.sdf_clip = 5.0 + +_C.lr_G = 1e-3 +_C.lr_C = 1e-3 +_C.lr_N = 2e-4 +_C.weight_decay = 0.0 +_C.momentum = 0.0 +_C.optim = 'Adam' +_C.schedule = [5, 10, 15] +_C.gamma = 0.1 + +_C.overfit = False +_C.resume = False +_C.test_mode = False +_C.test_uv = False +_C.draw_geo_thres = 0.60 +_C.num_sanity_val_steps = 2 +_C.fast_dev = 0 +_C.get_fit = False +_C.agora = False +_C.optim_cloth = False +_C.optim_body = False +_C.mcube_res = 256 +_C.clean_mesh = True +_C.remesh = False + +_C.batch_size = 4 +_C.num_threads = 8 + +_C.num_epoch = 10 +_C.freq_plot = 0.01 +_C.freq_show_train = 0.1 +_C.freq_show_val = 0.2 +_C.freq_eval = 0.5 +_C.accu_grad_batch = 4 + +_C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt'] + +_C.net = CN() +_C.net.gtype = 'HGPIFuNet' +_C.net.ctype = 'resnet18' +_C.net.classifierIMF = 'MultiSegClassifier' +_C.net.netIMF = 'resnet18' +_C.net.norm = 'group' +_C.net.norm_mlp = 'group' +_C.net.norm_color = 'group' +_C.net.hg_down = 'conv128' #'ave_pool' +_C.net.num_views = 1 + +# kernel_size, stride, dilation, padding + +_C.net.conv1 = [7, 2, 1, 3] +_C.net.conv3x3 = [3, 1, 1, 1] + +_C.net.num_stack = 4 +_C.net.num_hourglass = 2 +_C.net.hourglass_dim = 256 +_C.net.voxel_dim = 32 +_C.net.resnet_dim = 120 +_C.net.mlp_dim = [320, 1024, 512, 256, 128, 1] +_C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3] +_C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3] +_C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500] +_C.net.res_layers = [2, 3, 4] +_C.net.filter_dim = 256 +_C.net.smpl_dim = 3 + +_C.net.cly_dim = 3 +_C.net.soft_dim = 64 +_C.net.z_size = 200.0 +_C.net.N_freqs = 10 +_C.net.geo_w = 0.1 +_C.net.norm_w = 0.1 +_C.net.dc_w = 0.1 +_C.net.C_cat_to_G = False + +_C.net.skip_hourglass = True +_C.net.use_tanh = False +_C.net.soft_onehot = True +_C.net.no_residual = False +_C.net.use_attention = False + +_C.net.prior_type = "sdf" +_C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis'] +_C.net.use_filter = True +_C.net.use_cc = False +_C.net.use_PE = False +_C.net.use_IGR = False +_C.net.in_geo = () +_C.net.in_nml = () + +_C.dataset = CN() +_C.dataset.root = '' +_C.dataset.set_splits = [0.95, 0.04] +_C.dataset.types = [ + "3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy" +] +_C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37] +_C.dataset.rp_type = "pifu900" +_C.dataset.th_type = 'train' +_C.dataset.input_size = 512 +_C.dataset.rotation_num = 3 +_C.dataset.num_sample_ray=128 # volume rendering +_C.dataset.num_precomp = 10 # Number of segmentation classifiers +_C.dataset.num_multiseg = 500 # Number of categories per classifier +_C.dataset.num_knn = 10 # for loss/error +_C.dataset.num_knn_dis = 20 # for accuracy +_C.dataset.num_verts_max = 20000 +_C.dataset.zray_type = False +_C.dataset.online_smpl = False +_C.dataset.noise_type = ['z-trans', 'pose', 'beta'] +_C.dataset.noise_scale = [0.0, 0.0, 0.0] +_C.dataset.num_sample_geo = 10000 +_C.dataset.num_sample_color = 0 +_C.dataset.num_sample_seg = 0 +_C.dataset.num_sample_knn = 10000 + +_C.dataset.sigma_geo = 5.0 +_C.dataset.sigma_color = 0.10 +_C.dataset.sigma_seg = 0.10 +_C.dataset.thickness_threshold = 20.0 +_C.dataset.ray_sample_num = 2 +_C.dataset.semantic_p = False +_C.dataset.remove_outlier = False + +_C.dataset.train_bsize = 1.0 +_C.dataset.val_bsize = 1.0 +_C.dataset.test_bsize = 1.0 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() + + +# Alternatively, provide a way to import the defaults as +# a global singleton: +cfg = _C # users can `from config import cfg` + +# cfg = get_cfg_defaults() +# cfg.merge_from_file('./configs/example.yaml') + +# # Now override from a list (opts could come from the command line) +# opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2'] +# cfg.merge_from_list(opts) + + +def update_cfg(cfg_file): + # cfg = get_cfg_defaults() + _C.merge_from_file(cfg_file) + # return cfg.clone() + return _C + + +def parse_args(args): + cfg_file = args.cfg_file + if args.cfg_file is not None: + cfg = update_cfg(args.cfg_file) + else: + cfg = get_cfg_defaults() + + # if args.misc is not None: + # cfg.merge_from_list(args.misc) + + return cfg + + +def parse_args_extend(args): + if args.resume: + if not os.path.exists(args.log_dir): + raise ValueError( + 'Experiment are set to resume mode, but log directory does not exist.' + ) + + # load log's cfg + cfg_file = os.path.join(args.log_dir, 'cfg.yaml') + cfg = update_cfg(cfg_file) + + if args.misc is not None: + cfg.merge_from_list(args.misc) + else: + parse_args(args) diff --git a/lib/common/imutils.py b/lib/common/imutils.py new file mode 100644 index 0000000000000000000000000000000000000000..c44b4bd30fce1b030be1df4b445b53ce001d62a7 --- /dev/null +++ b/lib/common/imutils.py @@ -0,0 +1,364 @@ +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +import cv2 +import mediapipe as mp +import torch +import numpy as np +import torch.nn.functional as F +from PIL import Image +from lib.pymafx.core import constants +from rembg import remove +# from rembg.session_factory import new_session +from torchvision import transforms +from kornia.geometry.transform import get_affine_matrix2d, warp_affine + + +def transform_to_tensor(res, mean=None, std=None, is_tensor=False): + all_ops = [] + if res is not None: + all_ops.append(transforms.Resize(size=res)) + if not is_tensor: + all_ops.append(transforms.ToTensor()) + if mean is not None and std is not None: + all_ops.append(transforms.Normalize(mean=mean, std=std)) + return transforms.Compose(all_ops) + + +def get_affine_matrix_wh(w1, h1, w2, h2): + + transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0) + center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0) + scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0) + M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.])) + + return M + + +def get_affine_matrix_box(boxes, w2, h2): + + # boxes [left, top, right, bottom] + width = boxes[:, 2] - boxes[:, 0] #(N,) + height = boxes[:, 3] - boxes[:, 1] #(N,) + center = torch.tensor( + [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0] + ).T #(N,2) + scale = torch.min(torch.tensor([w2 / width, h2 / height]), + dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2) + transl = torch.cat([w2 / 2.0 - center[:, 0:1], h2 / 2.0 - center[:, 1:2]], dim=1) #(N,2) + M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.,]*transl.shape[0])) + + return M + + +def load_img(img_file): + + if img_file.endswith("exr"): + img = cv2.imread(img_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + else : + img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED) + + # considering non 8-bit image + if img.dtype != np.uint8 : + img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if not img_file.endswith("png"): + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + + return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2] + + +def get_keypoints(image): + def collect_xyv(x, body=True): + lmk = x.landmark + all_lmks = [] + for i in range(len(lmk)): + visibility = lmk[i].visibility if body else 1.0 + all_lmks.append(torch.Tensor([lmk[i].x, lmk[i].y, lmk[i].z, visibility])) + return torch.stack(all_lmks).view(-1, 4) + + mp_holistic = mp.solutions.holistic + + with mp_holistic.Holistic( + static_image_mode=True, + model_complexity=2, + ) as holistic: + results = holistic.process(image) + + fake_kps = torch.zeros(33, 4) + + result = {} + result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps + result["lhand"] = collect_xyv( + results.left_hand_landmarks, False + ) if results.left_hand_landmarks else fake_kps + result["rhand"] = collect_xyv( + results.right_hand_landmarks, False + ) if results.right_hand_landmarks else fake_kps + result["face"] = collect_xyv( + results.face_landmarks, False + ) if results.face_landmarks else fake_kps + + return result + + +def get_pymafx(image, landmarks): + + # image [3,512,512] + + item = { + 'img_body': + F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0] + } + + for part in ['lhand', 'rhand', 'face']: + kp2d = landmarks[part] + kp2d_valid = kp2d[kp2d[:, 3] > 0.] + if len(kp2d_valid) > 0: + bbox = [ + min(kp2d_valid[:, 0]), + min(kp2d_valid[:, 1]), + max(kp2d_valid[:, 0]), + max(kp2d_valid[:, 1]) + ] + center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.] + scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + + # handle invalid part keypoints + if len(kp2d_valid) < 1 or scale_part < 0.01: + center_part = [0, 0] + scale_part = 0.5 + kp2d[:, 3] = 0 + + center_part = torch.tensor(center_part).float() + + theta_part = torch.zeros(1, 2, 3) + theta_part[:, 0, 0] = scale_part + theta_part[:, 1, 1] = scale_part + theta_part[:, :, -1] = center_part + + grid = F.affine_grid(theta_part, torch.Size([1, 3, 224, 224]), align_corners=False) + img_part = F.grid_sample(image.unsqueeze(0), grid, align_corners=False).squeeze(0).float() + + item[f'img_{part}'] = img_part + + theta_i_inv = torch.zeros_like(theta_part) + theta_i_inv[:, 0, 0] = 1. / theta_part[:, 0, 0] + theta_i_inv[:, 1, 1] = 1. / theta_part[:, 1, 1] + theta_i_inv[:, :, -1] = -theta_part[:, :, -1] / theta_part[:, 0, 0].unsqueeze(-1) + item[f'{part}_theta_inv'] = theta_i_inv[0] + + return item + + +def remove_floats(mask): + + # 1. find all the contours + # 2. fillPoly "True" for the largest one + # 3. fillPoly "False" for its childrens + + new_mask = np.zeros(mask.shape) + cnts, hier = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cnt_index = sorted(range(len(cnts)), key=lambda k: cv2.contourArea(cnts[k]), reverse=True) + body_cnt = cnts[cnt_index[0]] + childs_cnt_idx = np.where(np.array(hier)[0, :, -1] == cnt_index[0])[0] + childs_cnt = [cnts[idx] for idx in childs_cnt_idx] + cv2.fillPoly(new_mask, [body_cnt], 1) + cv2.fillPoly(new_mask, childs_cnt, 0) + + return new_mask + + +def econ_process_image(img_file, hps_type, single, input_res, detector): + + img_raw, (in_height, in_width) = load_img(img_file) + tgt_res = input_res * 2 + M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res) + img_square = warp_affine( + img_raw, + M_square[:, :2], (tgt_res, ) * 2, + mode='bilinear', + padding_mode='zeros', + align_corners=True + ) + + # detection for bbox + predictions = detector(img_square / 255.)[0] + + if single: + top_score = predictions["scores"][predictions["labels"] == 1].max() + human_ids = torch.where(predictions["scores"] == top_score)[0] + else: + human_ids = torch.logical_and(predictions["labels"] == 1, + predictions["scores"] > 0.9).nonzero().squeeze(1) + + boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy() + masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy() + + M_crop = get_affine_matrix_box(boxes, input_res, input_res) + + img_icon_lst = [] + img_crop_lst = [] + img_hps_lst = [] + img_mask_lst = [] + landmark_lst = [] + hands_visibility_lst = [] + img_pymafx_lst = [] + + uncrop_param = { + "ori_shape": [in_height, in_width], + "box_shape": [input_res, input_res], + "square_shape": [tgt_res, tgt_res], + "M_square": M_square, + "M_crop": M_crop + } + + for idx in range(len(boxes)): + + # mask out the pixels of others + if len(masks) > 1: + mask_detection = (masks[np.arange(len(masks)) != idx]).max(axis=0) + else: + mask_detection = masks[0] * 0. + + img_square_rgba = torch.cat( + [img_square.squeeze(0).permute(1, 2, 0), + torch.tensor(mask_detection < 0.4) * 255], + dim=2 + ) + + img_crop = warp_affine( + img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2), + M_crop[idx:idx + 1, :2], (input_res, ) * 2, + mode='bilinear', + padding_mode='zeros', + align_corners=True + ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) + + # get accurate person segmentation mask + img_rembg = remove(img_crop) #post_process_mask=True) + img_mask = remove_floats(img_rembg[:, :, [3]]) + + mean_icon = std_icon = (0.5, 0.5, 0.5) + img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8) + img_icon = transform_to_tensor(512, mean_icon, std_icon)( + Image.fromarray(img_np) + ) * torch.tensor(img_mask).permute(2, 0, 1) + img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN, + constants.IMG_NORM_STD)(Image.fromarray(img_np)) + + landmarks = get_keypoints(img_np) + + # get hands visibility + hands_visibility = [True, True] + if landmarks['lhand'][:, -1].mean() == 0.: + hands_visibility[0] = False + if landmarks['rhand'][:, -1].mean() == 0.: + hands_visibility[1] = False + hands_visibility_lst.append(hands_visibility) + + if hps_type == 'pymafx': + img_pymafx_lst.append( + get_pymafx( + transform_to_tensor(512, constants.IMG_NORM_MEAN, + constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks + ) + ) + + img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0) + img_icon_lst.append(img_icon) + img_hps_lst.append(img_hps) + img_mask_lst.append(torch.tensor(img_mask[..., 0])) + landmark_lst.append(landmarks['body']) + + # required image tensors / arrays + + # img_icon (tensor): (-1, 1), [3,512,512] + # img_hps (tensor): (-2.11, 2.44), [3,224,224] + + # img_np (array): (0, 255), [512,512,3] + # img_rembg (array): (0, 255), [512,512,4] + # img_mask (array): (0, 1), [512,512,1] + # img_crop (array): (0, 255), [512,512,4] + + return_dict = { + "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res] + "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res] + "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res] + "img_raw": img_raw, #[1, 3, H, W] + "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res] + "uncrop_param": uncrop_param, + "landmark": torch.stack(landmark_lst), #[N, 33, 4] + "hands_visibility": hands_visibility_lst, + } + + img_pymafx = {} + + if len(img_pymafx_lst) > 0: + for idx in range(len(img_pymafx_lst)): + for key in img_pymafx_lst[idx].keys(): + if key not in img_pymafx.keys(): + img_pymafx[key] = [img_pymafx_lst[idx][key]] + else: + img_pymafx[key] += [img_pymafx_lst[idx][key]] + + for key in img_pymafx.keys(): + img_pymafx[key] = torch.stack(img_pymafx[key]).float() + + return_dict.update({"img_pymafx": img_pymafx}) + + return return_dict + + +def blend_rgb_norm(norms, data): + + # norms [N, 3, res, res] + masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1) + norm_mask = F.interpolate( + torch.cat([norms, masks], dim=1).detach(), + size=data["uncrop_param"]["box_shape"], + mode="bilinear", + align_corners=False + ) + final = data["img_raw"].type_as(norm_mask) + + for idx in range(len(norms)): + + norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0 + mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1) + + norm_ori = unwrap(norm_pred, data["uncrop_param"], idx) + mask_ori = unwrap(mask_pred, data["uncrop_param"], idx) + + final = final * (1.0 - mask_ori) + norm_ori * mask_ori + + return final.detach().cpu() + + +def unwrap(image, uncrop_param, idx): + + device = image.device + + img_square = warp_affine( + image, + torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device), + uncrop_param["square_shape"], + mode='bilinear', + padding_mode='zeros', + align_corners=True + ) + + img_ori = warp_affine( + img_square, + torch.inverse(uncrop_param["M_square"])[:, :2].to(device), + uncrop_param["ori_shape"], + mode='bilinear', + padding_mode='zeros', + align_corners=True + ) + + return img_ori diff --git a/lib/common/render.py b/lib/common/render.py new file mode 100644 index 0000000000000000000000000000000000000000..a5fde83cd307827732972f92fd637da97d6d3cfd --- /dev/null +++ b/lib/common/render.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from pytorch3d.renderer import ( + BlendParams, + blending, + look_at_view_transform, + FoVOrthographicCameras, + PointLights, + RasterizationSettings, + PointsRasterizationSettings, + PointsRenderer, + AlphaCompositor, + PointsRasterizer, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + SoftSilhouetteShader, + TexturesVertex, +) +from pytorch3d.renderer.mesh import TexturesVertex +from pytorch3d.structures import Meshes +from lib.dataset.mesh_util import get_visibility, get_visibility_color + +import lib.common.render_utils as util +import torch +import numpy as np +from PIL import Image +from tqdm import tqdm +import os +import cv2 +import math +from termcolor import colored + + +def image2vid(images, vid_path): + + w, h = images[0].size + videodims = (w, h) + fourcc = cv2.VideoWriter_fourcc(*'XVID') + video = cv2.VideoWriter(vid_path, fourcc, len(images) / 5.0, videodims) + for image in images: + video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) + video.release() + + +def query_color(verts, faces, image, device, predicted_color): + """query colors from points and image + + Args: + verts ([B, 3]): [query verts] + faces ([M, 3]): [query faces] + image ([B, 3, H, W]): [full image] + + Returns: + [np.float]: [return colors] + """ + + verts = verts.float().to(device) + faces = faces.long().to(device) + predicted_color=predicted_color.to(device) + (xy, z) = verts.split([2, 1], dim=1) + visibility = get_visibility_color(xy, z, faces[:, [0, 2, 1]]).flatten() + uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2] + uv = uv * torch.tensor([1.0, -1.0]).type_as(uv) + colors = (torch.nn.functional.grid_sample( + image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) + + 1.0) * 0.5 * 255.0 + colors[visibility == 0.0]=(predicted_color* 255.0)[visibility == 0.0] + + return colors.detach().cpu() + + +class cleanShader(torch.nn.Module): + + def __init__(self, device="cpu", cameras=None, blend_params=None): + super().__init__() + self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams( + ) + + def forward(self, fragments, meshes, **kwargs): + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of TexturedSoftPhongShader" + + raise ValueError(msg) + + # get renderer output + blend_params = kwargs.get("blend_params", self.blend_params) + texels = meshes.sample_textures(fragments) + images = blending.softmax_rgb_blend(texels, + fragments, + blend_params, + znear=-256, + zfar=256) + + return images + + +class Render: + + def __init__(self, size=512, device=torch.device("cuda:0")): + self.device = device + self.size = size + + # camera setting + self.dis = 100.0 + self.scale = 100.0 + self.mesh_y_center = 0.0 + + self.reload_cam() + + self.type = "color" + + self.mesh = None + self.deform_mesh = None + self.pcd = None + self.renderer = None + self.meshRas = None + + self.uv_rasterizer = util.Pytorch3dRasterizer(self.size) + + def reload_cam(self): + + self.cam_pos = [ + (0, self.mesh_y_center, self.dis), + (self.dis, self.mesh_y_center, 0), + (0, self.mesh_y_center, -self.dis), + (-self.dis, self.mesh_y_center, 0), + (0,self.mesh_y_center+self.dis,0), + (0,self.mesh_y_center-self.dis,0), + ] + + def get_camera(self, cam_id): + + if cam_id == 4: + R, T = look_at_view_transform( + eye=[self.cam_pos[cam_id]], + at=((0, self.mesh_y_center, 0), ), + up=((0, 0, 1), ), + ) + elif cam_id == 5: + R, T = look_at_view_transform( + eye=[self.cam_pos[cam_id]], + at=((0, self.mesh_y_center, 0), ), + up=((0, 0, 1), ), + ) + + else: + R, T = look_at_view_transform( + eye=[self.cam_pos[cam_id]], + at=((0, self.mesh_y_center, 0), ), + up=((0, 1, 0), ), + ) + + camera = FoVOrthographicCameras( + device=self.device, + R=R, + T=T, + znear=100.0, + zfar=-100.0, + max_y=100.0, + min_y=-100.0, + max_x=100.0, + min_x=-100.0, + scale_xyz=(self.scale * np.ones(3), ), + ) + + return camera + + def init_renderer(self, camera, type="clean_mesh", bg="gray"): + + if "mesh" in type: + + # rasterizer + self.raster_settings_mesh = RasterizationSettings( + image_size=self.size, + blur_radius=np.log(1.0 / 1e-4) * 1e-7, + faces_per_pixel=30, + ) + self.meshRas = MeshRasterizer( + cameras=camera, raster_settings=self.raster_settings_mesh) + + if bg == "black": + blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0)) + elif bg == "white": + blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0)) + elif bg == "gray": + blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5)) + + if type == "ori_mesh": + + lights = PointLights( + device=self.device, + ambient_color=((0.8, 0.8, 0.8), ), + diffuse_color=((0.2, 0.2, 0.2), ), + specular_color=((0.0, 0.0, 0.0), ), + location=[[0.0, 200.0, 0.0]], + ) + + self.renderer = MeshRenderer( + rasterizer=self.meshRas, + shader=SoftPhongShader( + device=self.device, + cameras=camera, + lights=None, + blend_params=blendparam, + ), + ) + + if type == "silhouette": + self.raster_settings_silhouette = RasterizationSettings( + image_size=self.size, + blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5, + faces_per_pixel=50, + cull_backfaces=True, + ) + + self.silhouetteRas = MeshRasterizer( + cameras=camera, + raster_settings=self.raster_settings_silhouette) + self.renderer = MeshRenderer(rasterizer=self.silhouetteRas, + shader=SoftSilhouetteShader()) + + if type == "pointcloud": + self.raster_settings_pcd = PointsRasterizationSettings( + image_size=self.size, radius=0.006, points_per_pixel=10) + + self.pcdRas = PointsRasterizer( + cameras=camera, raster_settings=self.raster_settings_pcd) + self.renderer = PointsRenderer( + rasterizer=self.pcdRas, + compositor=AlphaCompositor(background_color=(0, 0, 0)), + ) + + if type == "clean_mesh": + + self.renderer = MeshRenderer( + rasterizer=self.meshRas, + shader=cleanShader(device=self.device, + cameras=camera, + blend_params=blendparam), + ) + + def VF2Mesh(self, verts, faces, vertex_texture = None): + + if not torch.is_tensor(verts): + verts = torch.tensor(verts) + if not torch.is_tensor(faces): + faces = torch.tensor(faces) + + if verts.ndimension() == 2: + verts = verts.unsqueeze(0).float() + if faces.ndimension() == 2: + faces = faces.unsqueeze(0).long() + + verts = verts.to(self.device) + faces = faces.to(self.device) + if vertex_texture is not None: + vertex_texture = vertex_texture.to(self.device) + + mesh = Meshes(verts, faces).to(self.device) + + if vertex_texture is None: + mesh.textures = TexturesVertex( + verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5)#modify + else: + mesh.textures = TexturesVertex( + verts_features = vertex_texture.unsqueeze(0))#modify + return mesh + + def load_meshes(self, verts, faces,offset=None, vertex_texture = None): + """load mesh into the pytorch3d renderer + + Args: + verts ([N,3]): verts + faces ([N,3]): faces + offset ([N,3]): offset + """ + if offset is not None: + verts = verts + offset + + if isinstance(verts, list): + self.meshes = [] + for V, F in zip(verts, faces): + if vertex_texture is None: + self.meshes.append(self.VF2Mesh(V, F)) + else: + self.meshes.append(self.VF2Mesh(V, F, vertex_texture)) + else: + if vertex_texture is None: + self.meshes = [self.VF2Mesh(verts, faces)] + else: + self.meshes = [self.VF2Mesh(verts, faces, vertex_texture)] + + def get_depth_map(self, cam_ids=[0, 2]): + + depth_maps = [] + for cam_id in cam_ids: + self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray") + fragments = self.meshRas(self.meshes[0]) + depth_map = fragments.zbuf[..., 0].squeeze(0) + if cam_id == 2: + depth_map = torch.fliplr(depth_map) + depth_maps.append(depth_map) + + return depth_maps + + def get_rgb_image(self, cam_ids=[0, 2], bg='gray'): + + images = [] + for cam_id in range(len(self.cam_pos)): + if cam_id in cam_ids: + self.init_renderer(self.get_camera(cam_id), "clean_mesh", bg) + if len(cam_ids) == 4: + rendered_img = (self.renderer( + self.meshes[0])[0:1, :, :, :3].permute(0, 3, 1, 2) - + 0.5) * 2.0 + else: + rendered_img = (self.renderer( + self.meshes[0])[0:1, :, :, :3].permute(0, 3, 1, 2) - + 0.5) * 2.0 + if cam_id == 2 and len(cam_ids) == 2: + rendered_img = torch.flip(rendered_img, dims=[3]) + images.append(rendered_img) + + return images + + def get_rendered_video(self, images, save_path): + + self.cam_pos = [] + for angle in range(360): + self.cam_pos.append(( + 100.0 * math.cos(np.pi / 180 * angle), + self.mesh_y_center, + 100.0 * math.sin(np.pi / 180 * angle), + )) + + old_shape = np.array(images[0].shape[:2]) + new_shape = np.around( + (self.size / old_shape[0]) * old_shape).astype(np.int) + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video = cv2.VideoWriter(save_path, fourcc, 10, + (self.size * len(self.meshes) + + new_shape[1] * len(images), self.size)) + + pbar = tqdm(range(len(self.cam_pos))) + pbar.set_description( + colored(f"exporting video {os.path.basename(save_path)}...", + "blue")) + for cam_id in pbar: + self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray") + + img_lst = [ + np.array(Image.fromarray(img).resize(new_shape[::-1])).astype( + np.uint8)[:, :, [2, 1, 0]] for img in images + ] + + for mesh in self.meshes: + rendered_img = ((self.renderer(mesh)[0, :, :, :3] * + 255.0).detach().cpu().numpy().astype( + np.uint8)) + + img_lst.append(rendered_img) + final_img = np.concatenate(img_lst, axis=1) + video.write(final_img) + + video.release() + self.reload_cam() + + def get_silhouette_image(self, cam_ids=[0, 2]): + + images = [] + for cam_id in range(len(self.cam_pos)): + if cam_id in cam_ids: + self.init_renderer(self.get_camera(cam_id), "silhouette") + rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3] + if cam_id == 2 and len(cam_ids) == 2: + rendered_img = torch.flip(rendered_img, dims=[2]) + images.append(rendered_img) + + return images diff --git a/lib/common/render_utils.py b/lib/common/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1137d12fefdf11a1c5b2f436b9abcad2e6aa51f8 --- /dev/null +++ b/lib/common/render_utils.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch +from torch import nn +import trimesh +import math +from typing import NewType +from pytorch3d.structures import Meshes +from pytorch3d.renderer.mesh import rasterize_meshes + +Tensor = NewType('Tensor', torch.Tensor) + + +def solid_angles(points: Tensor, + triangles: Tensor, + thresh: float = 1e-8) -> Tensor: + ''' Compute solid angle between the input points and triangles + Follows the method described in: + The Solid Angle of a Plane Triangle + A. VAN OOSTEROM AND J. STRACKEE + IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING, + VOL. BME-30, NO. 2, FEBRUARY 1983 + Parameters + ----------- + points: BxQx3 + Tensor of input query points + triangles: BxFx3x3 + Target triangles + thresh: float + float threshold + Returns + ------- + solid_angles: BxQxF + A tensor containing the solid angle between all query points + and input triangles + ''' + # Center the triangles on the query points. Size should be BxQxFx3x3 + centered_tris = triangles[:, None] - points[:, :, None, None] + + # BxQxFx3 + norms = torch.norm(centered_tris, dim=-1) + + # Should be BxQxFx3 + cross_prod = torch.cross(centered_tris[:, :, :, 1], + centered_tris[:, :, :, 2], + dim=-1) + # Should be BxQxF + numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1) + del cross_prod + + dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1) + dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1) + dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1) + del centered_tris + + denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + + dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0]) + del dot01, dot12, dot02, norms + + # Should be BxQ + solid_angle = torch.atan2(numerator, denominator) + del numerator, denominator + + torch.cuda.empty_cache() + + return 2 * solid_angle + + +def winding_numbers(points: Tensor, + triangles: Tensor, + thresh: float = 1e-8) -> Tensor: + ''' Uses winding_numbers to compute inside/outside + Robust inside-outside segmentation using generalized winding numbers + Alec Jacobson, + Ladislav Kavan, + Olga Sorkine-Hornung + Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018 + Gavin Barill + NEIL G. Dickson + Ryan Schmidt + David I.W. Levin + and Alec Jacobson + Parameters + ----------- + points: BxQx3 + Tensor of input query points + triangles: BxFx3x3 + Target triangles + thresh: float + float threshold + Returns + ------- + winding_numbers: BxQ + A tensor containing the Generalized winding numbers + ''' + # The generalized winding number is the sum of solid angles of the point + # with respect to all triangles. + return 1 / (4 * math.pi) * solid_angles(points, triangles, + thresh=thresh).sum(dim=-1) + + +def batch_contains(verts, faces, points): + + B = verts.shape[0] + N = points.shape[1] + + verts = verts.detach().cpu() + faces = faces.detach().cpu() + points = points.detach().cpu() + contains = torch.zeros(B, N) + + for i in range(B): + contains[i] = torch.as_tensor( + trimesh.Trimesh(verts[i], faces[i]).contains(points[i])) + + return 2.0 * (contains - 0.5) + + +def dict2obj(d): + # if isinstance(d, list): + # d = [dict2obj(x) for x in d] + if not isinstance(d, dict): + return d + + class C(object): + pass + + o = C() + for k in d: + o.__dict__[k] = dict2obj(d[k]) + return o + + +def face_vertices(vertices, faces): + """ + :param vertices: [batch size, number of vertices, 3] + :param faces: [batch size, number of faces, 3] + :return: [batch size, number of faces, 3, 3] + """ + + bs, nv = vertices.shape[:2] + bs, nf = faces.shape[:2] + device = vertices.device + faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * + nv)[:, None, None] + vertices = vertices.reshape((bs * nv, vertices.shape[-1])) + + return vertices[faces.long()] + + +class Pytorch3dRasterizer(nn.Module): + """ Borrowed from https://github.com/facebookresearch/pytorch3d + Notice: + x,y,z are in image space, normalized + can only render squared image now + """ + + def __init__(self, image_size=224): + """ + use fixed raster_settings for rendering faces + """ + super().__init__() + raster_settings = { + 'image_size': image_size, + 'blur_radius': 0.0, + 'faces_per_pixel': 1, + 'bin_size': None, + 'max_faces_per_bin': None, + 'perspective_correct': True, + 'cull_backfaces': True, + } + raster_settings = dict2obj(raster_settings) + self.raster_settings = raster_settings + + def forward(self, vertices, faces, attributes=None): + fixed_vertices = vertices.clone() + fixed_vertices[..., :2] = -fixed_vertices[..., :2] + meshes_screen = Meshes(verts=fixed_vertices.float(), + faces=faces.long()) + raster_settings = self.raster_settings + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + ) + vismask = (pix_to_face > -1).float() + D = attributes.shape[-1] + attributes = attributes.clone() + attributes = attributes.view(attributes.shape[0] * attributes.shape[1], + 3, attributes.shape[-1]) + N, H, W, K, _ = bary_coords.shape + mask = pix_to_face == -1 + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) + pixel_vals = torch.cat( + [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + return pixel_vals diff --git a/lib/common/seg3d_lossless.py b/lib/common/seg3d_lossless.py new file mode 100644 index 0000000000000000000000000000000000000000..4787720f533fbb3a66d5712ffe898a52d1ce08b1 --- /dev/null +++ b/lib/common/seg3d_lossless.py @@ -0,0 +1,603 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from .seg3d_utils import ( + create_grid3D, + plot_mask3D, + SmoothConv3D, +) + +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import mcubes +from kaolin.ops.conversions import voxelgrids_to_trianglemeshes +import logging + +logging.getLogger("lightning").setLevel(logging.ERROR) + + +class Seg3dLossless(nn.Module): + + def __init__(self, + query_func, + b_min, + b_max, + resolutions, + channels=1, + balance_value=0.5, + align_corners=False, + visualize=False, + debug=False, + use_cuda_impl=False, + faster=False, + use_shadow=False, + **kwargs): + """ + align_corners: same with how you process gt. (grid_sample / interpolate) + """ + super().__init__() + self.query_func = query_func + self.register_buffer( + 'b_min', + torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3] + self.register_buffer( + 'b_max', + torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3] + + # ti.init(arch=ti.cuda) + # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1) + + if type(resolutions[0]) is int: + resolutions = torch.tensor([(res, res, res) + for res in resolutions]) + else: + resolutions = torch.tensor(resolutions) + self.register_buffer('resolutions', resolutions) + self.batchsize = self.b_min.size(0) + assert self.batchsize == 1 + self.balance_value = balance_value + self.channels = channels + assert self.channels == 1 + self.align_corners = align_corners + self.visualize = visualize + self.debug = debug + self.use_cuda_impl = use_cuda_impl + self.faster = faster + self.use_shadow = use_shadow + + for resolution in resolutions: + assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \ + f"resolution {resolution} need to be odd becuase of align_corner." + + # init first resolution + init_coords = create_grid3D(0, + resolutions[-1] - 1, + steps=resolutions[0]) # [N, 3] + init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, + 1) # [bz, N, 3] + self.register_buffer('init_coords', init_coords) + + # some useful tensors + calculated = torch.zeros( + (self.resolutions[-1][2], self.resolutions[-1][1], + self.resolutions[-1][0]), + dtype=torch.bool) + self.register_buffer('calculated', calculated) + + gird8_offsets = torch.stack( + torch.meshgrid([ + torch.tensor([-1, 0, 1]), + torch.tensor([-1, 0, 1]), + torch.tensor([-1, 0, 1]) + ])).int().view(3, -1).t() # [27, 3] + self.register_buffer('gird8_offsets', gird8_offsets) + + # smooth convs + self.smooth_conv3x3 = SmoothConv3D(in_channels=1, + out_channels=1, + kernel_size=3) + self.smooth_conv5x5 = SmoothConv3D(in_channels=1, + out_channels=1, + kernel_size=5) + self.smooth_conv7x7 = SmoothConv3D(in_channels=1, + out_channels=1, + kernel_size=7) + self.smooth_conv9x9 = SmoothConv3D(in_channels=1, + out_channels=1, + kernel_size=9) + + def batch_eval(self, coords, **kwargs): + """ + coords: in the coordinates of last resolution + **kwargs: for query_func + """ + coords = coords.detach() + # normalize coords to fit in [b_min, b_max] + if self.align_corners: + coords2D = coords.float() / (self.resolutions[-1] - 1) + else: + step = 1.0 / self.resolutions[-1].float() + coords2D = coords.float() / self.resolutions[-1] + step / 2 + coords2D = coords2D * (self.b_max - self.b_min) + self.b_min + # query function + occupancys = self.query_func(**kwargs, points=coords2D) + if type(occupancys) is list: + occupancys = torch.stack(occupancys) # [bz, C, N] + assert len(occupancys.size()) == 3, \ + "query_func should return a occupancy with shape of [bz, C, N]" + return occupancys + + def forward(self, **kwargs): + if self.faster: + return self._forward_faster(**kwargs) + else: + return self._forward(**kwargs) + + def _forward_faster(self, **kwargs): + """ + In faster mode, we make following changes to exchange accuracy for speed: + 1. no conflict checking: 4.88 fps -> 6.56 fps + 2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution + 3. last step no examine + """ + final_W = self.resolutions[-1][0] + final_H = self.resolutions[-1][1] + final_D = self.resolutions[-1][2] + + for resolution in self.resolutions: + W, H, D = resolution + stride = (self.resolutions[-1] - 1) / (resolution - 1) + + # first step + if torch.equal(resolution, self.resolutions[0]): + coords = self.init_coords.clone() # torch.long + occupancys = self.batch_eval(coords, **kwargs) + occupancys = occupancys.view(self.batchsize, self.channels, D, + H, W) + if (occupancys > 0.5).sum() == 0: + # return F.interpolate( + # occupancys, size=(final_D, final_H, final_W), + # mode="linear", align_corners=True) + return None + + if self.visualize: + self.plot(occupancys, coords, final_D, final_H, final_W) + + with torch.no_grad(): + coords_accum = coords / stride + + # last step + elif torch.equal(resolution, self.resolutions[-1]): + + with torch.no_grad(): + # here true is correct! + valid = F.interpolate( + (occupancys > self.balance_value).float(), + size=(D, H, W), + mode="trilinear", + align_corners=True) + + # here true is correct! + occupancys = F.interpolate(occupancys.float(), + size=(D, H, W), + mode="trilinear", + align_corners=True) + + # is_boundary = (valid > 0.0) & (valid < 1.0) + is_boundary = valid == 0.5 + + # next steps + else: + coords_accum *= 2 + + with torch.no_grad(): + # here true is correct! + valid = F.interpolate( + (occupancys > self.balance_value).float(), + size=(D, H, W), + mode="trilinear", + align_corners=True) + + # here true is correct! + occupancys = F.interpolate(occupancys.float(), + size=(D, H, W), + mode="trilinear", + align_corners=True) + + is_boundary = (valid > 0.0) & (valid < 1.0) + + with torch.no_grad(): + if torch.equal(resolution, self.resolutions[1]): + is_boundary = (self.smooth_conv9x9(is_boundary.float()) + > 0)[0, 0] + elif torch.equal(resolution, self.resolutions[2]): + is_boundary = (self.smooth_conv7x7(is_boundary.float()) + > 0)[0, 0] + else: + is_boundary = (self.smooth_conv3x3(is_boundary.float()) + > 0)[0, 0] + + coords_accum = coords_accum.long() + is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1], + coords_accum[0, :, 0]] = False + point_coords = is_boundary.permute( + 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0) + point_indices = (point_coords[:, :, 2] * H * W + + point_coords[:, :, 1] * W + + point_coords[:, :, 0]) + + R, C, D, H, W = occupancys.shape + + # inferred value + coords = point_coords * stride + + if coords.size(1) == 0: + continue + occupancys_topk = self.batch_eval(coords, **kwargs) + + # put mask point predictions to the right places on the upsampled grid. + R, C, D, H, W = occupancys.shape + point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) + occupancys = (occupancys.reshape(R, C, D * H * W).scatter_( + 2, point_indices, occupancys_topk).view(R, C, D, H, W)) + + with torch.no_grad(): + voxels = coords / stride + coords_accum = torch.cat([voxels, coords_accum], + dim=1).unique(dim=1) + + return occupancys[0, 0] + + def _forward(self, **kwargs): + """ + output occupancy field would be: + (bz, C, res, res) + """ + final_W = self.resolutions[-1][0] + final_H = self.resolutions[-1][1] + final_D = self.resolutions[-1][2] + + calculated = self.calculated.clone() + + for resolution in self.resolutions: + W, H, D = resolution + stride = (self.resolutions[-1] - 1) / (resolution - 1) + + if self.visualize: + this_stage_coords = [] + + # first step + if torch.equal(resolution, self.resolutions[0]): + coords = self.init_coords.clone() # torch.long + occupancys = self.batch_eval(coords, **kwargs) + occupancys = occupancys.view(self.batchsize, self.channels, D, + H, W) + + if self.visualize: + self.plot(occupancys, coords, final_D, final_H, final_W) + + with torch.no_grad(): + coords_accum = coords / stride + calculated[coords[0, :, 2], coords[0, :, 1], + coords[0, :, 0]] = True + + # next steps + else: + coords_accum *= 2 + + with torch.no_grad(): + # here true is correct! + valid = F.interpolate( + (occupancys > self.balance_value).float(), + size=(D, H, W), + mode="trilinear", + align_corners=True) + + # here true is correct! + occupancys = F.interpolate(occupancys.float(), + size=(D, H, W), + mode="trilinear", + align_corners=True) + + is_boundary = (valid > 0.0) & (valid < 1.0) + + with torch.no_grad(): + # TODO + if self.use_shadow and torch.equal(resolution, + self.resolutions[-1]): + # larger z means smaller depth here + depth_res = resolution[2].item() + depth_index = torch.linspace(0, + depth_res - 1, + steps=depth_res).type_as( + occupancys.device) + depth_index_max = torch.max( + (occupancys > self.balance_value) * + (depth_index + 1), + dim=-1, + keepdim=True)[0] - 1 + shadow = depth_index < depth_index_max + is_boundary[shadow] = False + is_boundary = is_boundary[0, 0] + else: + is_boundary = (self.smooth_conv3x3(is_boundary.float()) + > 0)[0, 0] + # is_boundary = is_boundary[0, 0] + + is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1], + coords_accum[0, :, 0]] = False + point_coords = is_boundary.permute( + 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0) + point_indices = (point_coords[:, :, 2] * H * W + + point_coords[:, :, 1] * W + + point_coords[:, :, 0]) + + R, C, D, H, W = occupancys.shape + # interpolated value + occupancys_interp = torch.gather( + occupancys.reshape(R, C, D * H * W), 2, + point_indices.unsqueeze(1)) + + # inferred value + coords = point_coords * stride + + if coords.size(1) == 0: + continue + occupancys_topk = self.batch_eval(coords, **kwargs) + if self.visualize: + this_stage_coords.append(coords) + + # put mask point predictions to the right places on the upsampled grid. + R, C, D, H, W = occupancys.shape + point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) + occupancys = (occupancys.reshape(R, C, D * H * W).scatter_( + 2, point_indices, occupancys_topk).view(R, C, D, H, W)) + + with torch.no_grad(): + # conflicts + conflicts = ((occupancys_interp - self.balance_value) * + (occupancys_topk - self.balance_value) < 0)[0, + 0] + + if self.visualize: + self.plot(occupancys, coords, final_D, final_H, + final_W) + + voxels = coords / stride + coords_accum = torch.cat([voxels, coords_accum], + dim=1).unique(dim=1) + calculated[coords[0, :, 2], coords[0, :, 1], + coords[0, :, 0]] = True + + while conflicts.sum() > 0: + if self.use_shadow and torch.equal(resolution, + self.resolutions[-1]): + break + + with torch.no_grad(): + conflicts_coords = coords[0, conflicts, :] + + if self.debug: + self.plot(occupancys, + conflicts_coords.unsqueeze(0), + final_D, + final_H, + final_W, + title='conflicts') + + conflicts_boundary = (conflicts_coords.int() + + self.gird8_offsets.unsqueeze(1) * + stride.int()).reshape( + -1, 3).long().unique(dim=0) + conflicts_boundary[:, 0] = ( + conflicts_boundary[:, 0].clamp( + 0, + calculated.size(2) - 1)) + conflicts_boundary[:, 1] = ( + conflicts_boundary[:, 1].clamp( + 0, + calculated.size(1) - 1)) + conflicts_boundary[:, 2] = ( + conflicts_boundary[:, 2].clamp( + 0, + calculated.size(0) - 1)) + + coords = conflicts_boundary[calculated[ + conflicts_boundary[:, 2], conflicts_boundary[:, 1], + conflicts_boundary[:, 0]] == False] + + if self.debug: + self.plot(occupancys, + coords.unsqueeze(0), + final_D, + final_H, + final_W, + title='coords') + + coords = coords.unsqueeze(0) + point_coords = coords / stride + point_indices = (point_coords[:, :, 2] * H * W + + point_coords[:, :, 1] * W + + point_coords[:, :, 0]) + + R, C, D, H, W = occupancys.shape + # interpolated value + occupancys_interp = torch.gather( + occupancys.reshape(R, C, D * H * W), 2, + point_indices.unsqueeze(1)) + + # inferred value + coords = point_coords * stride + + if coords.size(1) == 0: + break + occupancys_topk = self.batch_eval(coords, **kwargs) + if self.visualize: + this_stage_coords.append(coords) + + with torch.no_grad(): + # conflicts + conflicts = ((occupancys_interp - self.balance_value) * + (occupancys_topk - self.balance_value) < + 0)[0, 0] + + # put mask point predictions to the right places on the upsampled grid. + point_indices = point_indices.unsqueeze(1).expand( + -1, C, -1) + occupancys = (occupancys.reshape(R, C, D * H * W).scatter_( + 2, point_indices, occupancys_topk).view(R, C, D, H, W)) + + with torch.no_grad(): + voxels = coords / stride + coords_accum = torch.cat([voxels, coords_accum], + dim=1).unique(dim=1) + calculated[coords[0, :, 2], coords[0, :, 1], + coords[0, :, 0]] = True + + if self.visualize: + this_stage_coords = torch.cat(this_stage_coords, dim=1) + self.plot(occupancys, this_stage_coords, final_D, final_H, + final_W) + + return occupancys[0, 0] + + def plot(self, + occupancys, + coords, + final_D, + final_H, + final_W, + title='', + **kwargs): + final = F.interpolate(occupancys.float(), + size=(final_D, final_H, final_W), + mode="trilinear", + align_corners=True) # here true is correct! + x = coords[0, :, 0].to("cpu") + y = coords[0, :, 1].to("cpu") + z = coords[0, :, 2].to("cpu") + + plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs) + + def find_vertices(self, sdf, direction="front"): + ''' + - direction: "front" | "back" | "left" | "right" + ''' + resolution = sdf.size(2) + if direction == "front": + pass + elif direction == "left": + sdf = sdf.permute(2, 1, 0) + elif direction == "back": + inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long() + sdf = sdf[inv_idx, :, :] + elif direction == "right": + inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long() + sdf = sdf[:, :, inv_idx] + sdf = sdf.permute(2, 1, 0) + + inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long() + sdf = sdf[inv_idx, :, :] + sdf_all = sdf.permute(2, 1, 0) + + # shadow + grad_v = (sdf_all > 0.5) * torch.linspace( + resolution, 1, steps=resolution).to(sdf.device) + grad_c = torch.ones_like(sdf_all) * torch.linspace( + 0, resolution - 1, steps=resolution).to(sdf.device) + max_v, max_c = grad_v.max(dim=2) + shadow = grad_c > max_c.view(resolution, resolution, 1) + keep = (sdf_all > 0.5) & (~shadow) + + p1 = keep.nonzero(as_tuple=False).t() # [3, N] + p2 = p1.clone() # z + p2[2, :] = (p2[2, :] - 2).clamp(0, resolution) + p3 = p1.clone() # y + p3[1, :] = (p3[1, :] - 2).clamp(0, resolution) + p4 = p1.clone() # x + p4[0, :] = (p4[0, :] - 2).clamp(0, resolution) + + v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]] + v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]] + v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]] + v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]] + + X = p1[0, :].long() # [N,] + Y = p1[1, :].long() # [N,] + Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \ + p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,] + Z = Z.clamp(0, resolution) + + # normal + norm_z = v2 - v1 + norm_y = v3 - v1 + norm_x = v4 - v1 + # print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0]) + + norm = torch.stack([norm_x, norm_y, norm_z], dim=1) + norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True) + + return X, Y, Z, norm + + def render_normal(self, resolution, X, Y, Z, norm): + image = torch.ones((1, 3, resolution, resolution), + dtype=torch.float32).to(norm.device) + color = (norm + 1) / 2.0 + color = color.clamp(0, 1) + image[0, :, Y, X] = color.t() + return image + + def display(self, sdf): + + # render + X, Y, Z, norm = self.find_vertices(sdf, direction="front") + image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm) + X, Y, Z, norm = self.find_vertices(sdf, direction="left") + image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm) + X, Y, Z, norm = self.find_vertices(sdf, direction="right") + image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm) + X, Y, Z, norm = self.find_vertices(sdf, direction="back") + image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm) + + image = torch.cat([image1, image2, image3, image4], axis=3) + image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0 + + return np.uint8(image) + + def export_mesh(self, occupancys): + + final = occupancys[1:, 1:, 1:].contiguous() + + if final.shape[0] > 256: + # for voxelgrid larger than 256^3, the required GPU memory will be > 9GB + # thus we use CPU marching_cube to avoid "CUDA out of memory" + occu_arr = final.detach().cpu().numpy() # non-smooth surface + # occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface + vertices, triangles = mcubes.marching_cubes( + occu_arr, self.balance_value) + verts = torch.as_tensor(vertices[:, [2, 1, 0]]) + faces = torch.as_tensor(triangles.astype(np.longlong), + dtype=torch.long)[:, [0, 2, 1]] + else: + torch.cuda.empty_cache() + vertices, triangles = voxelgrids_to_trianglemeshes( + final.unsqueeze(0)) + verts = vertices[0][:, [2, 1, 0]].cpu() + faces = triangles[0][:, [0, 2, 1]].cpu() + + return verts, faces diff --git a/lib/common/seg3d_utils.py b/lib/common/seg3d_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48d78ac0278ba97185ca12aec3b014ab5fef401c --- /dev/null +++ b/lib/common/seg3d_utils.py @@ -0,0 +1,393 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +def plot_mask2D(mask, + title="", + point_coords=None, + figsize=10, + point_marker_size=5): + ''' + Simple plotting tool to show intermediate mask predictions and points + where PointRend is applied. + + Args: + mask (Tensor): mask prediction of shape HxW + title (str): title for the plot + point_coords ((Tensor, Tensor)): x and y point coordinates + figsize (int): size of the figure to plot + point_marker_size (int): marker size for points + ''' + + H, W = mask.shape + plt.figure(figsize=(figsize, figsize)) + if title: + title += ", " + plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30) + plt.ylabel(H, fontsize=30) + plt.xlabel(W, fontsize=30) + plt.xticks([], []) + plt.yticks([], []) + plt.imshow(mask.detach(), + interpolation="nearest", + cmap=plt.get_cmap('gray')) + if point_coords is not None: + plt.scatter(x=point_coords[0], + y=point_coords[1], + color="red", + s=point_marker_size, + clip_on=True) + plt.xlim(-0.5, W - 0.5) + plt.ylim(H - 0.5, -0.5) + plt.show() + + +def plot_mask3D(mask=None, + title="", + point_coords=None, + figsize=1500, + point_marker_size=8, + interactive=True): + ''' + Simple plotting tool to show intermediate mask predictions and points + where PointRend is applied. + + Args: + mask (Tensor): mask prediction of shape DxHxW + title (str): title for the plot + point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates + figsize (int): size of the figure to plot + point_marker_size (int): marker size for points + ''' + import trimesh + import vtkplotter + from skimage import measure + + vp = vtkplotter.Plotter(title=title, size=(figsize, figsize)) + vis_list = [] + + if mask is not None: + mask = mask.detach().to("cpu").numpy() + mask = mask.transpose(2, 1, 0) + + # marching cube to find surface + verts, faces, normals, values = measure.marching_cubes_lewiner( + mask, 0.5, gradient_direction='ascent') + + # create a mesh + mesh = trimesh.Trimesh(verts, faces) + mesh.visual.face_colors = [200, 200, 250, 100] + vis_list.append(mesh) + + if point_coords is not None: + point_coords = torch.stack(point_coords, 1).to("cpu").numpy() + + # import numpy as np + # select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112) + # select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272) + # select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112) + # select = np.logical_and(np.logical_and(select_x, select_y), select_z) + # point_coords = point_coords[select, :] + + pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red') + vis_list.append(pc) + + vp.show(*vis_list, + bg="white", + axes=1, + interactive=interactive, + azimuth=30, + elevation=30) + + +def create_grid3D(min, max, steps): + if type(min) is int: + min = (min, min, min) # (x, y, z) + if type(max) is int: + max = (max, max, max) # (x, y) + if type(steps) is int: + steps = (steps, steps, steps) # (x, y, z) + arrangeX = torch.linspace(min[0], max[0], steps[0]).long() + arrangeY = torch.linspace(min[1], max[1], steps[1]).long() + arrangeZ = torch.linspace(min[2], max[2], steps[2]).long() + gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX]) + coords = torch.stack([gridW, girdH, + gridD]) # [2, steps[0], steps[1], steps[2]] + coords = coords.view(3, -1).t() # [N, 3] + return coords + + +def create_grid2D(min, max, steps): + if type(min) is int: + min = (min, min) # (x, y) + if type(max) is int: + max = (max, max) # (x, y) + if type(steps) is int: + steps = (steps, steps) # (x, y) + arrangeX = torch.linspace(min[0], max[0], steps[0]).long() + arrangeY = torch.linspace(min[1], max[1], steps[1]).long() + girdH, gridW = torch.meshgrid([arrangeY, arrangeX]) + coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]] + coords = coords.view(2, -1).t() # [N, 2] + return coords + + +class SmoothConv2D(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__() + assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" + self.padding = (kernel_size - 1) // 2 + + weight = torch.ones( + (in_channels, out_channels, kernel_size, kernel_size), + dtype=torch.float32) / (kernel_size**2) + self.register_buffer('weight', weight) + + def forward(self, input): + return F.conv2d(input, self.weight, padding=self.padding) + + +class SmoothConv3D(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__() + assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" + self.padding = (kernel_size - 1) // 2 + + weight = torch.ones( + (in_channels, out_channels, kernel_size, kernel_size, kernel_size), + dtype=torch.float32) / (kernel_size**3) + self.register_buffer('weight', weight) + + def forward(self, input): + return F.conv3d(input, self.weight, padding=self.padding) + + +def build_smooth_conv3D(in_channels=1, + out_channels=1, + kernel_size=3, + padding=1): + smooth_conv = torch.nn.Conv3d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding) + smooth_conv.weight.data = torch.ones( + (in_channels, out_channels, kernel_size, kernel_size, kernel_size), + dtype=torch.float32) / (kernel_size**3) + smooth_conv.bias.data = torch.zeros(out_channels) + return smooth_conv + + +def build_smooth_conv2D(in_channels=1, + out_channels=1, + kernel_size=3, + padding=1): + smooth_conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding) + smooth_conv.weight.data = torch.ones( + (in_channels, out_channels, kernel_size, kernel_size), + dtype=torch.float32) / (kernel_size**2) + smooth_conv.bias.data = torch.zeros(out_channels) + return smooth_conv + + +def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, + **kwargs): + """ + Find `num_points` most uncertain points from `uncertainty_map` grid. + Args: + uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty + values for a set of points on a regular H x W x D grid. + num_points (int): The number of points P to select. + Returns: + point_indices (Tensor): A tensor of shape (N, P) that contains indices from + [0, H x W x D) of the most uncertain points. + point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized + coordinates of the most uncertain points from the H x W x D grid. + """ + R, _, D, H, W = uncertainty_map.shape + # h_step = 1.0 / float(H) + # w_step = 1.0 / float(W) + # d_step = 1.0 / float(D) + + num_points = min(D * H * W, num_points) + point_scores, point_indices = torch.topk(uncertainty_map.view( + R, D * H * W), + k=num_points, + dim=1) + point_coords = torch.zeros(R, + num_points, + 3, + dtype=torch.float, + device=uncertainty_map.device) + # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step + # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step + # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step + point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x + point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y + point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z + print(f"resolution {D} x {H} x {W}", point_scores.min(), + point_scores.max()) + return point_indices, point_coords + + +def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, + clip_min): + """ + Find `num_points` most uncertain points from `uncertainty_map` grid. + Args: + uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty + values for a set of points on a regular H x W x D grid. + num_points (int): The number of points P to select. + Returns: + point_indices (Tensor): A tensor of shape (N, P) that contains indices from + [0, H x W x D) of the most uncertain points. + point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized + coordinates of the most uncertain points from the H x W x D grid. + """ + R, _, D, H, W = uncertainty_map.shape + # h_step = 1.0 / float(H) + # w_step = 1.0 / float(W) + # d_step = 1.0 / float(D) + + assert R == 1, "batchsize > 1 is not implemented!" + uncertainty_map = uncertainty_map.view(D * H * W) + indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) + num_points = min(num_points, indices.size(0)) + point_scores, point_indices = torch.topk(uncertainty_map[indices], + k=num_points, + dim=0) + point_indices = indices[point_indices].unsqueeze(0) + + point_coords = torch.zeros(R, + num_points, + 3, + dtype=torch.float, + device=uncertainty_map.device) + # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step + # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step + # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step + point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x + point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y + point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z + # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) + return point_indices, point_coords + + +def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, + **kwargs): + """ + Find `num_points` most uncertain points from `uncertainty_map` grid. + Args: + uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty + values for a set of points on a regular H x W grid. + num_points (int): The number of points P to select. + Returns: + point_indices (Tensor): A tensor of shape (N, P) that contains indices from + [0, H x W) of the most uncertain points. + point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized + coordinates of the most uncertain points from the H x W grid. + """ + R, _, H, W = uncertainty_map.shape + # h_step = 1.0 / float(H) + # w_step = 1.0 / float(W) + + num_points = min(H * W, num_points) + point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), + k=num_points, + dim=1) + point_coords = torch.zeros(R, + num_points, + 2, + dtype=torch.long, + device=uncertainty_map.device) + # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step + # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step + point_coords[:, :, 0] = (point_indices % W).to(torch.long) + point_coords[:, :, 1] = (point_indices // W).to(torch.long) + # print (point_scores.min(), point_scores.max()) + return point_indices, point_coords + + +def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, + clip_min): + """ + Find `num_points` most uncertain points from `uncertainty_map` grid. + Args: + uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty + values for a set of points on a regular H x W grid. + num_points (int): The number of points P to select. + Returns: + point_indices (Tensor): A tensor of shape (N, P) that contains indices from + [0, H x W) of the most uncertain points. + point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized + coordinates of the most uncertain points from the H x W grid. + """ + R, _, H, W = uncertainty_map.shape + # h_step = 1.0 / float(H) + # w_step = 1.0 / float(W) + + assert R == 1, "batchsize > 1 is not implemented!" + uncertainty_map = uncertainty_map.view(H * W) + indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) + num_points = min(num_points, indices.size(0)) + point_scores, point_indices = torch.topk(uncertainty_map[indices], + k=num_points, + dim=0) + point_indices = indices[point_indices].unsqueeze(0) + + point_coords = torch.zeros(R, + num_points, + 2, + dtype=torch.long, + device=uncertainty_map.device) + # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step + # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step + point_coords[:, :, 0] = (point_indices % W).to(torch.long) + point_coords[:, :, 1] = (point_indices // W).to(torch.long) + # print (point_scores.min(), point_scores.max()) + return point_indices, point_coords + + +def calculate_uncertainty(logits, classes=None, balance_value=0.5): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + classes (list): A list of length R that contains either predicted of ground truth class + for eash predicted mask. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + if logits.shape[1] == 1: + gt_class_logits = logits + else: + gt_class_logits = logits[ + torch.arange(logits.shape[0], device=logits.device), + classes].unsqueeze(1) + return -torch.abs(gt_class_logits - balance_value) diff --git a/lib/common/smpl_vert_segmentation.json b/lib/common/smpl_vert_segmentation.json new file mode 100644 index 0000000000000000000000000000000000000000..b3244cce450e13f1095a1c3af676f4c8fdea5633 --- /dev/null +++ b/lib/common/smpl_vert_segmentation.json @@ -0,0 +1,7440 @@ +{ + "rightHand": [ + 5442, + 5443, + 5444, + 5445, + 5446, + 5447, + 5448, + 5449, + 5450, + 5451, + 5452, + 5453, + 5454, + 5455, + 5456, + 5457, + 5458, + 5459, + 5460, + 5461, + 5462, + 5463, + 5464, + 5465, + 5466, + 5467, + 5468, + 5469, + 5470, + 5471, + 5472, + 5473, + 5474, + 5475, + 5476, + 5477, + 5478, + 5479, + 5480, + 5481, + 5482, + 5483, + 5484, + 5485, + 5486, + 5487, + 5492, + 5493, + 5494, + 5495, + 5496, + 5497, + 5502, + 5503, + 5504, + 5505, + 5506, + 5507, + 5508, + 5509, + 5510, + 5511, + 5512, + 5513, + 5514, + 5515, + 5516, + 5517, + 5518, + 5519, + 5520, + 5521, + 5522, + 5523, + 5524, + 5525, + 5526, + 5527, + 5530, + 5531, + 5532, + 5533, + 5534, + 5535, + 5536, + 5537, + 5538, + 5539, + 5540, + 5541, + 5542, + 5543, + 5544, + 5545, + 5546, + 5547, + 5548, + 5549, + 5550, + 5551, + 5552, + 5553, + 5554, + 5555, + 5556, + 5557, + 5558, + 5559, + 5560, + 5561, + 5562, + 5569, + 5571, + 5574, + 5575, + 5576, + 5577, + 5578, + 5579, + 5580, + 5581, + 5582, + 5583, + 5588, + 5589, + 5592, + 5593, + 5594, + 5595, + 5596, + 5597, + 5598, + 5599, + 5600, + 5601, + 5602, + 5603, + 5604, + 5605, + 5610, + 5611, + 5612, + 5613, + 5614, + 5621, + 5622, + 5625, + 5631, + 5632, + 5633, + 5634, + 5635, + 5636, + 5637, + 5638, + 5639, + 5640, + 5641, + 5643, + 5644, + 5645, + 5646, + 5649, + 5650, + 5652, + 5653, + 5654, + 5655, + 5656, + 5657, + 5658, + 5659, + 5660, + 5661, + 5662, + 5663, + 5664, + 5667, + 5670, + 5671, + 5672, + 5673, + 5674, + 5675, + 5682, + 5683, + 5684, + 5685, + 5686, + 5687, + 5688, + 5689, + 5690, + 5692, + 5695, + 5697, + 5698, + 5699, + 5700, + 5701, + 5707, + 5708, + 5709, + 5710, + 5711, + 5712, + 5713, + 5714, + 5715, + 5716, + 5717, + 5718, + 5719, + 5720, + 5721, + 5723, + 5724, + 5725, + 5726, + 5727, + 5728, + 5729, + 5730, + 5731, + 5732, + 5735, + 5736, + 5737, + 5738, + 5739, + 5740, + 5745, + 5746, + 5748, + 5749, + 5750, + 5751, + 5752, + 6056, + 6057, + 6066, + 6067, + 6158, + 6159, + 6160, + 6161, + 6162, + 6163, + 6164, + 6165, + 6166, + 6167, + 6168, + 6169, + 6170, + 6171, + 6172, + 6173, + 6174, + 6175, + 6176, + 6177, + 6178, + 6179, + 6180, + 6181, + 6182, + 6183, + 6184, + 6185, + 6186, + 6187, + 6188, + 6189, + 6190, + 6191, + 6192, + 6193, + 6194, + 6195, + 6196, + 6197, + 6198, + 6199, + 6200, + 6201, + 6202, + 6203, + 6204, + 6205, + 6206, + 6207, + 6208, + 6209, + 6210, + 6211, + 6212, + 6213, + 6214, + 6215, + 6216, + 6217, + 6218, + 6219, + 6220, + 6221, + 6222, + 6223, + 6224, + 6225, + 6226, + 6227, + 6228, + 6229, + 6230, + 6231, + 6232, + 6233, + 6234, + 6235, + 6236, + 6237, + 6238, + 6239 + ], + "rightUpLeg": [ + 4320, + 4321, + 4323, + 4324, + 4333, + 4334, + 4335, + 4336, + 4337, + 4338, + 4339, + 4340, + 4356, + 4357, + 4358, + 4359, + 4360, + 4361, + 4362, + 4363, + 4364, + 4365, + 4366, + 4367, + 4383, + 4384, + 4385, + 4386, + 4387, + 4388, + 4389, + 4390, + 4391, + 4392, + 4393, + 4394, + 4395, + 4396, + 4397, + 4398, + 4399, + 4400, + 4401, + 4419, + 4420, + 4421, + 4422, + 4430, + 4431, + 4432, + 4433, + 4434, + 4435, + 4436, + 4437, + 4438, + 4439, + 4440, + 4441, + 4442, + 4443, + 4444, + 4445, + 4446, + 4447, + 4448, + 4449, + 4450, + 4451, + 4452, + 4453, + 4454, + 4455, + 4456, + 4457, + 4458, + 4459, + 4460, + 4461, + 4462, + 4463, + 4464, + 4465, + 4466, + 4467, + 4468, + 4469, + 4470, + 4471, + 4472, + 4473, + 4474, + 4475, + 4476, + 4477, + 4478, + 4479, + 4480, + 4481, + 4482, + 4483, + 4484, + 4485, + 4486, + 4487, + 4488, + 4489, + 4490, + 4491, + 4492, + 4493, + 4494, + 4495, + 4496, + 4497, + 4498, + 4499, + 4500, + 4501, + 4502, + 4503, + 4504, + 4505, + 4506, + 4507, + 4508, + 4509, + 4510, + 4511, + 4512, + 4513, + 4514, + 4515, + 4516, + 4517, + 4518, + 4519, + 4520, + 4521, + 4522, + 4523, + 4524, + 4525, + 4526, + 4527, + 4528, + 4529, + 4530, + 4531, + 4532, + 4623, + 4624, + 4625, + 4626, + 4627, + 4628, + 4629, + 4630, + 4631, + 4632, + 4633, + 4634, + 4645, + 4646, + 4647, + 4648, + 4649, + 4650, + 4651, + 4652, + 4653, + 4654, + 4655, + 4656, + 4657, + 4658, + 4659, + 4660, + 4670, + 4671, + 4672, + 4673, + 4704, + 4705, + 4706, + 4707, + 4708, + 4709, + 4710, + 4711, + 4712, + 4713, + 4745, + 4746, + 4757, + 4758, + 4759, + 4760, + 4801, + 4802, + 4829, + 4834, + 4835, + 4836, + 4837, + 4838, + 4839, + 4840, + 4841, + 4924, + 4925, + 4926, + 4928, + 4929, + 4930, + 4931, + 4932, + 4933, + 4934, + 4935, + 4936, + 4948, + 4949, + 4950, + 4951, + 4952, + 4970, + 4971, + 4972, + 4973, + 4983, + 4984, + 4985, + 4986, + 4987, + 4988, + 4989, + 4990, + 4991, + 4992, + 4993, + 5004, + 5005, + 6546, + 6547, + 6548, + 6549, + 6552, + 6553, + 6554, + 6555, + 6556, + 6873, + 6877 + ], + "leftArm": [ + 626, + 627, + 628, + 629, + 634, + 635, + 680, + 681, + 716, + 717, + 718, + 719, + 769, + 770, + 771, + 772, + 773, + 774, + 775, + 776, + 777, + 778, + 779, + 780, + 784, + 785, + 786, + 787, + 788, + 789, + 790, + 791, + 792, + 793, + 1231, + 1232, + 1233, + 1234, + 1258, + 1259, + 1260, + 1261, + 1271, + 1281, + 1282, + 1310, + 1311, + 1314, + 1315, + 1340, + 1341, + 1342, + 1343, + 1355, + 1356, + 1357, + 1358, + 1376, + 1377, + 1378, + 1379, + 1380, + 1381, + 1382, + 1383, + 1384, + 1385, + 1386, + 1387, + 1388, + 1389, + 1390, + 1391, + 1392, + 1393, + 1394, + 1395, + 1396, + 1397, + 1398, + 1399, + 1400, + 1402, + 1403, + 1405, + 1406, + 1407, + 1408, + 1409, + 1410, + 1411, + 1412, + 1413, + 1414, + 1415, + 1416, + 1428, + 1429, + 1430, + 1431, + 1432, + 1433, + 1438, + 1439, + 1440, + 1441, + 1442, + 1443, + 1444, + 1445, + 1502, + 1505, + 1506, + 1507, + 1508, + 1509, + 1510, + 1538, + 1541, + 1542, + 1543, + 1545, + 1619, + 1620, + 1621, + 1622, + 1631, + 1632, + 1633, + 1634, + 1635, + 1636, + 1637, + 1638, + 1639, + 1640, + 1641, + 1642, + 1645, + 1646, + 1647, + 1648, + 1649, + 1650, + 1651, + 1652, + 1653, + 1654, + 1655, + 1656, + 1658, + 1659, + 1661, + 1662, + 1664, + 1666, + 1667, + 1668, + 1669, + 1670, + 1671, + 1672, + 1673, + 1674, + 1675, + 1676, + 1677, + 1678, + 1679, + 1680, + 1681, + 1682, + 1683, + 1684, + 1696, + 1697, + 1698, + 1703, + 1704, + 1705, + 1706, + 1707, + 1708, + 1709, + 1710, + 1711, + 1712, + 1713, + 1714, + 1715, + 1716, + 1717, + 1718, + 1719, + 1720, + 1725, + 1731, + 1732, + 1733, + 1734, + 1735, + 1737, + 1739, + 1740, + 1745, + 1746, + 1747, + 1748, + 1749, + 1751, + 1761, + 1830, + 1831, + 1844, + 1845, + 1846, + 1850, + 1851, + 1854, + 1855, + 1858, + 1860, + 1865, + 1866, + 1867, + 1869, + 1870, + 1871, + 1874, + 1875, + 1876, + 1877, + 1878, + 1882, + 1883, + 1888, + 1889, + 1892, + 1900, + 1901, + 1902, + 1903, + 1904, + 1909, + 2819, + 2820, + 2821, + 2822, + 2895, + 2896, + 2897, + 2898, + 2899, + 2900, + 2901, + 2902, + 2903, + 2945, + 2946, + 2974, + 2975, + 2976, + 2977, + 2978, + 2979, + 2980, + 2981, + 2982, + 2983, + 2984, + 2985, + 2986, + 2987, + 2988, + 2989, + 2990, + 2991, + 2992, + 2993, + 2994, + 2995, + 2996, + 3002, + 3013 + ], + "leftLeg": [ + 995, + 998, + 999, + 1002, + 1004, + 1005, + 1008, + 1010, + 1012, + 1015, + 1016, + 1018, + 1019, + 1043, + 1044, + 1047, + 1048, + 1049, + 1050, + 1051, + 1052, + 1053, + 1054, + 1055, + 1056, + 1057, + 1058, + 1059, + 1060, + 1061, + 1062, + 1063, + 1064, + 1065, + 1066, + 1067, + 1068, + 1069, + 1070, + 1071, + 1072, + 1073, + 1074, + 1075, + 1076, + 1077, + 1078, + 1079, + 1080, + 1081, + 1082, + 1083, + 1084, + 1085, + 1086, + 1087, + 1088, + 1089, + 1090, + 1091, + 1092, + 1093, + 1094, + 1095, + 1096, + 1097, + 1098, + 1099, + 1100, + 1101, + 1102, + 1103, + 1104, + 1105, + 1106, + 1107, + 1108, + 1109, + 1110, + 1111, + 1112, + 1113, + 1114, + 1115, + 1116, + 1117, + 1118, + 1119, + 1120, + 1121, + 1122, + 1123, + 1124, + 1125, + 1126, + 1127, + 1128, + 1129, + 1130, + 1131, + 1132, + 1133, + 1134, + 1135, + 1136, + 1148, + 1149, + 1150, + 1151, + 1152, + 1153, + 1154, + 1155, + 1156, + 1157, + 1158, + 1175, + 1176, + 1177, + 1178, + 1179, + 1180, + 1181, + 1182, + 1183, + 1369, + 1370, + 1371, + 1372, + 1373, + 1374, + 1375, + 1464, + 1465, + 1466, + 1467, + 1468, + 1469, + 1470, + 1471, + 1472, + 1473, + 1474, + 1522, + 1523, + 1524, + 1525, + 1526, + 1527, + 1528, + 1529, + 1530, + 1531, + 1532, + 3174, + 3175, + 3176, + 3177, + 3178, + 3179, + 3180, + 3181, + 3182, + 3183, + 3184, + 3185, + 3186, + 3187, + 3188, + 3189, + 3190, + 3191, + 3192, + 3193, + 3194, + 3195, + 3196, + 3197, + 3198, + 3199, + 3200, + 3201, + 3202, + 3203, + 3204, + 3205, + 3206, + 3207, + 3208, + 3209, + 3210, + 3319, + 3320, + 3321, + 3322, + 3323, + 3324, + 3325, + 3326, + 3327, + 3328, + 3329, + 3330, + 3331, + 3332, + 3333, + 3334, + 3335, + 3432, + 3433, + 3434, + 3435, + 3436, + 3469, + 3472, + 3473, + 3474 + ], + "leftToeBase": [ + 3211, + 3212, + 3213, + 3214, + 3215, + 3216, + 3217, + 3218, + 3219, + 3220, + 3221, + 3222, + 3223, + 3224, + 3225, + 3226, + 3227, + 3228, + 3229, + 3230, + 3231, + 3232, + 3233, + 3234, + 3235, + 3236, + 3237, + 3238, + 3239, + 3240, + 3241, + 3242, + 3243, + 3244, + 3245, + 3246, + 3247, + 3248, + 3249, + 3250, + 3251, + 3252, + 3253, + 3254, + 3255, + 3256, + 3257, + 3258, + 3259, + 3260, + 3261, + 3262, + 3263, + 3264, + 3265, + 3266, + 3267, + 3268, + 3269, + 3270, + 3271, + 3272, + 3273, + 3274, + 3275, + 3276, + 3277, + 3278, + 3279, + 3280, + 3281, + 3282, + 3283, + 3284, + 3285, + 3286, + 3287, + 3288, + 3289, + 3290, + 3291, + 3292, + 3293, + 3294, + 3295, + 3296, + 3297, + 3298, + 3299, + 3300, + 3301, + 3302, + 3303, + 3304, + 3305, + 3306, + 3307, + 3308, + 3309, + 3310, + 3311, + 3312, + 3313, + 3314, + 3315, + 3316, + 3317, + 3318, + 3336, + 3337, + 3340, + 3342, + 3344, + 3346, + 3348, + 3350, + 3352, + 3354, + 3357, + 3358, + 3360, + 3362 + ], + "leftFoot": [ + 3327, + 3328, + 3329, + 3330, + 3331, + 3332, + 3333, + 3334, + 3335, + 3336, + 3337, + 3338, + 3339, + 3340, + 3341, + 3342, + 3343, + 3344, + 3345, + 3346, + 3347, + 3348, + 3349, + 3350, + 3351, + 3352, + 3353, + 3354, + 3355, + 3356, + 3357, + 3358, + 3359, + 3360, + 3361, + 3362, + 3363, + 3364, + 3365, + 3366, + 3367, + 3368, + 3369, + 3370, + 3371, + 3372, + 3373, + 3374, + 3375, + 3376, + 3377, + 3378, + 3379, + 3380, + 3381, + 3382, + 3383, + 3384, + 3385, + 3386, + 3387, + 3388, + 3389, + 3390, + 3391, + 3392, + 3393, + 3394, + 3395, + 3396, + 3397, + 3398, + 3399, + 3400, + 3401, + 3402, + 3403, + 3404, + 3405, + 3406, + 3407, + 3408, + 3409, + 3410, + 3411, + 3412, + 3413, + 3414, + 3415, + 3416, + 3417, + 3418, + 3419, + 3420, + 3421, + 3422, + 3423, + 3424, + 3425, + 3426, + 3427, + 3428, + 3429, + 3430, + 3431, + 3432, + 3433, + 3434, + 3435, + 3436, + 3437, + 3438, + 3439, + 3440, + 3441, + 3442, + 3443, + 3444, + 3445, + 3446, + 3447, + 3448, + 3449, + 3450, + 3451, + 3452, + 3453, + 3454, + 3455, + 3456, + 3457, + 3458, + 3459, + 3460, + 3461, + 3462, + 3463, + 3464, + 3465, + 3466, + 3467, + 3468, + 3469 + ], + "spine1": [ + 598, + 599, + 600, + 601, + 610, + 611, + 612, + 613, + 614, + 615, + 616, + 617, + 618, + 619, + 620, + 621, + 642, + 645, + 646, + 647, + 652, + 653, + 658, + 659, + 660, + 661, + 668, + 669, + 670, + 671, + 684, + 685, + 686, + 687, + 688, + 689, + 690, + 691, + 692, + 722, + 723, + 724, + 725, + 736, + 750, + 751, + 761, + 764, + 766, + 767, + 794, + 795, + 891, + 892, + 893, + 894, + 925, + 926, + 927, + 928, + 929, + 940, + 941, + 942, + 943, + 1190, + 1191, + 1192, + 1193, + 1194, + 1195, + 1196, + 1197, + 1200, + 1201, + 1202, + 1212, + 1236, + 1252, + 1253, + 1254, + 1255, + 1268, + 1269, + 1270, + 1329, + 1330, + 1348, + 1349, + 1351, + 1420, + 1421, + 1423, + 1424, + 1425, + 1426, + 1436, + 1437, + 1756, + 1757, + 1758, + 2839, + 2840, + 2841, + 2842, + 2843, + 2844, + 2845, + 2846, + 2847, + 2848, + 2849, + 2850, + 2851, + 2870, + 2871, + 2883, + 2906, + 2908, + 3014, + 3017, + 3025, + 3030, + 3033, + 3034, + 3037, + 3039, + 3040, + 3041, + 3042, + 3043, + 3044, + 3076, + 3077, + 3079, + 3480, + 3505, + 3511, + 4086, + 4087, + 4088, + 4089, + 4098, + 4099, + 4100, + 4101, + 4102, + 4103, + 4104, + 4105, + 4106, + 4107, + 4108, + 4109, + 4130, + 4131, + 4134, + 4135, + 4140, + 4141, + 4146, + 4147, + 4148, + 4149, + 4156, + 4157, + 4158, + 4159, + 4172, + 4173, + 4174, + 4175, + 4176, + 4177, + 4178, + 4179, + 4180, + 4210, + 4211, + 4212, + 4213, + 4225, + 4239, + 4240, + 4249, + 4250, + 4255, + 4256, + 4282, + 4283, + 4377, + 4378, + 4379, + 4380, + 4411, + 4412, + 4413, + 4414, + 4415, + 4426, + 4427, + 4428, + 4429, + 4676, + 4677, + 4678, + 4679, + 4680, + 4681, + 4682, + 4683, + 4686, + 4687, + 4688, + 4695, + 4719, + 4735, + 4736, + 4737, + 4740, + 4751, + 4752, + 4753, + 4824, + 4825, + 4828, + 4893, + 4894, + 4895, + 4897, + 4898, + 4899, + 4908, + 4909, + 5223, + 5224, + 5225, + 6300, + 6301, + 6302, + 6303, + 6304, + 6305, + 6306, + 6307, + 6308, + 6309, + 6310, + 6311, + 6312, + 6331, + 6332, + 6342, + 6366, + 6367, + 6475, + 6477, + 6478, + 6481, + 6482, + 6485, + 6487, + 6488, + 6489, + 6490, + 6491, + 6878 + ], + "spine2": [ + 570, + 571, + 572, + 573, + 584, + 585, + 586, + 587, + 588, + 589, + 590, + 591, + 592, + 593, + 594, + 595, + 596, + 597, + 602, + 603, + 604, + 605, + 606, + 607, + 608, + 609, + 622, + 623, + 624, + 625, + 638, + 639, + 640, + 641, + 643, + 644, + 648, + 649, + 650, + 651, + 666, + 667, + 672, + 673, + 674, + 675, + 680, + 681, + 682, + 683, + 693, + 694, + 695, + 696, + 697, + 698, + 699, + 700, + 701, + 702, + 703, + 704, + 713, + 714, + 715, + 716, + 717, + 726, + 727, + 728, + 729, + 730, + 731, + 732, + 733, + 735, + 737, + 738, + 739, + 740, + 741, + 742, + 743, + 744, + 745, + 746, + 747, + 748, + 749, + 752, + 753, + 754, + 755, + 756, + 757, + 758, + 759, + 760, + 762, + 763, + 803, + 804, + 805, + 806, + 811, + 812, + 813, + 814, + 817, + 818, + 819, + 820, + 821, + 824, + 825, + 826, + 827, + 828, + 895, + 896, + 930, + 931, + 1198, + 1199, + 1213, + 1214, + 1215, + 1216, + 1217, + 1218, + 1219, + 1220, + 1235, + 1237, + 1256, + 1257, + 1271, + 1272, + 1273, + 1279, + 1280, + 1283, + 1284, + 1285, + 1286, + 1287, + 1288, + 1289, + 1290, + 1291, + 1292, + 1293, + 1294, + 1295, + 1296, + 1297, + 1298, + 1299, + 1300, + 1301, + 1302, + 1303, + 1304, + 1305, + 1306, + 1307, + 1308, + 1309, + 1312, + 1313, + 1319, + 1320, + 1346, + 1347, + 1350, + 1352, + 1401, + 1417, + 1418, + 1419, + 1422, + 1427, + 1434, + 1435, + 1503, + 1504, + 1536, + 1537, + 1544, + 1545, + 1753, + 1754, + 1755, + 1759, + 1760, + 1761, + 1762, + 1763, + 1808, + 1809, + 1810, + 1811, + 1816, + 1817, + 1818, + 1819, + 1820, + 1834, + 1835, + 1836, + 1837, + 1838, + 1839, + 1868, + 1879, + 1880, + 2812, + 2813, + 2852, + 2853, + 2854, + 2855, + 2856, + 2857, + 2858, + 2859, + 2860, + 2861, + 2862, + 2863, + 2864, + 2865, + 2866, + 2867, + 2868, + 2869, + 2872, + 2875, + 2876, + 2877, + 2878, + 2881, + 2882, + 2884, + 2885, + 2886, + 2904, + 2905, + 2907, + 2931, + 2932, + 2933, + 2934, + 2935, + 2936, + 2937, + 2941, + 2950, + 2951, + 2952, + 2953, + 2954, + 2955, + 2956, + 2957, + 2958, + 2959, + 2960, + 2961, + 2962, + 2963, + 2964, + 2965, + 2966, + 2967, + 2968, + 2969, + 2970, + 2971, + 2972, + 2973, + 2997, + 2998, + 3006, + 3007, + 3012, + 3015, + 3026, + 3027, + 3028, + 3029, + 3031, + 3032, + 3035, + 3036, + 3038, + 3059, + 3060, + 3061, + 3062, + 3063, + 3064, + 3065, + 3066, + 3067, + 3073, + 3074, + 3075, + 3078, + 3168, + 3169, + 3171, + 3470, + 3471, + 3482, + 3483, + 3495, + 3496, + 3497, + 3498, + 3506, + 3508, + 4058, + 4059, + 4060, + 4061, + 4072, + 4073, + 4074, + 4075, + 4076, + 4077, + 4078, + 4079, + 4080, + 4081, + 4082, + 4083, + 4084, + 4085, + 4090, + 4091, + 4092, + 4093, + 4094, + 4095, + 4096, + 4097, + 4110, + 4111, + 4112, + 4113, + 4126, + 4127, + 4128, + 4129, + 4132, + 4133, + 4136, + 4137, + 4138, + 4139, + 4154, + 4155, + 4160, + 4161, + 4162, + 4163, + 4168, + 4169, + 4170, + 4171, + 4181, + 4182, + 4183, + 4184, + 4185, + 4186, + 4187, + 4188, + 4189, + 4190, + 4191, + 4192, + 4201, + 4202, + 4203, + 4204, + 4207, + 4214, + 4215, + 4216, + 4217, + 4218, + 4219, + 4220, + 4221, + 4223, + 4224, + 4226, + 4227, + 4228, + 4229, + 4230, + 4231, + 4232, + 4233, + 4234, + 4235, + 4236, + 4237, + 4238, + 4241, + 4242, + 4243, + 4244, + 4245, + 4246, + 4247, + 4248, + 4251, + 4252, + 4291, + 4292, + 4293, + 4294, + 4299, + 4300, + 4301, + 4302, + 4305, + 4306, + 4307, + 4308, + 4309, + 4312, + 4313, + 4314, + 4315, + 4381, + 4382, + 4416, + 4417, + 4684, + 4685, + 4696, + 4697, + 4698, + 4699, + 4700, + 4701, + 4702, + 4703, + 4718, + 4720, + 4738, + 4739, + 4754, + 4755, + 4756, + 4761, + 4762, + 4765, + 4766, + 4767, + 4768, + 4769, + 4770, + 4771, + 4772, + 4773, + 4774, + 4775, + 4776, + 4777, + 4778, + 4779, + 4780, + 4781, + 4782, + 4783, + 4784, + 4785, + 4786, + 4787, + 4788, + 4789, + 4792, + 4793, + 4799, + 4800, + 4822, + 4823, + 4826, + 4827, + 4874, + 4890, + 4891, + 4892, + 4896, + 4900, + 4907, + 4910, + 4975, + 4976, + 5007, + 5008, + 5013, + 5014, + 5222, + 5226, + 5227, + 5228, + 5229, + 5230, + 5269, + 5270, + 5271, + 5272, + 5277, + 5278, + 5279, + 5280, + 5281, + 5295, + 5296, + 5297, + 5298, + 5299, + 5300, + 5329, + 5340, + 5341, + 6273, + 6274, + 6313, + 6314, + 6315, + 6316, + 6317, + 6318, + 6319, + 6320, + 6321, + 6322, + 6323, + 6324, + 6325, + 6326, + 6327, + 6328, + 6329, + 6330, + 6333, + 6336, + 6337, + 6340, + 6341, + 6343, + 6344, + 6345, + 6363, + 6364, + 6365, + 6390, + 6391, + 6392, + 6393, + 6394, + 6395, + 6396, + 6398, + 6409, + 6410, + 6411, + 6412, + 6413, + 6414, + 6415, + 6416, + 6417, + 6418, + 6419, + 6420, + 6421, + 6422, + 6423, + 6424, + 6425, + 6426, + 6427, + 6428, + 6429, + 6430, + 6431, + 6432, + 6456, + 6457, + 6465, + 6466, + 6476, + 6479, + 6480, + 6483, + 6484, + 6486, + 6496, + 6497, + 6498, + 6499, + 6500, + 6501, + 6502, + 6503, + 6879 + ], + "leftShoulder": [ + 591, + 604, + 605, + 606, + 609, + 634, + 635, + 636, + 637, + 674, + 706, + 707, + 708, + 709, + 710, + 711, + 712, + 713, + 715, + 717, + 730, + 733, + 734, + 735, + 781, + 782, + 783, + 1238, + 1239, + 1240, + 1241, + 1242, + 1243, + 1244, + 1245, + 1290, + 1291, + 1294, + 1316, + 1317, + 1318, + 1401, + 1402, + 1403, + 1404, + 1509, + 1535, + 1545, + 1808, + 1810, + 1811, + 1812, + 1813, + 1814, + 1815, + 1818, + 1819, + 1821, + 1822, + 1823, + 1824, + 1825, + 1826, + 1827, + 1828, + 1829, + 1830, + 1831, + 1832, + 1833, + 1837, + 1840, + 1841, + 1842, + 1843, + 1844, + 1845, + 1846, + 1847, + 1848, + 1849, + 1850, + 1851, + 1852, + 1853, + 1854, + 1855, + 1856, + 1857, + 1858, + 1859, + 1861, + 1862, + 1863, + 1864, + 1872, + 1873, + 1880, + 1881, + 1884, + 1885, + 1886, + 1887, + 1890, + 1891, + 1893, + 1894, + 1895, + 1896, + 1897, + 1898, + 1899, + 2879, + 2880, + 2881, + 2886, + 2887, + 2888, + 2889, + 2890, + 2891, + 2892, + 2893, + 2894, + 2903, + 2938, + 2939, + 2940, + 2941, + 2942, + 2943, + 2944, + 2945, + 2946, + 2947, + 2948, + 2949, + 2965, + 2967, + 2969, + 2999, + 3000, + 3001, + 3002, + 3003, + 3004, + 3005, + 3008, + 3009, + 3010, + 3011 + ], + "rightShoulder": [ + 4077, + 4091, + 4092, + 4094, + 4095, + 4122, + 4123, + 4124, + 4125, + 4162, + 4194, + 4195, + 4196, + 4197, + 4198, + 4199, + 4200, + 4201, + 4203, + 4207, + 4218, + 4219, + 4222, + 4223, + 4269, + 4270, + 4271, + 4721, + 4722, + 4723, + 4724, + 4725, + 4726, + 4727, + 4728, + 4773, + 4774, + 4778, + 4796, + 4797, + 4798, + 4874, + 4875, + 4876, + 4877, + 4982, + 5006, + 5014, + 5269, + 5271, + 5272, + 5273, + 5274, + 5275, + 5276, + 5279, + 5281, + 5282, + 5283, + 5284, + 5285, + 5286, + 5287, + 5288, + 5289, + 5290, + 5291, + 5292, + 5293, + 5294, + 5298, + 5301, + 5302, + 5303, + 5304, + 5305, + 5306, + 5307, + 5308, + 5309, + 5310, + 5311, + 5312, + 5313, + 5314, + 5315, + 5316, + 5317, + 5318, + 5319, + 5320, + 5322, + 5323, + 5324, + 5325, + 5333, + 5334, + 5341, + 5342, + 5345, + 5346, + 5347, + 5348, + 5351, + 5352, + 5354, + 5355, + 5356, + 5357, + 5358, + 5359, + 5360, + 6338, + 6339, + 6340, + 6345, + 6346, + 6347, + 6348, + 6349, + 6350, + 6351, + 6352, + 6353, + 6362, + 6397, + 6398, + 6399, + 6400, + 6401, + 6402, + 6403, + 6404, + 6405, + 6406, + 6407, + 6408, + 6424, + 6425, + 6428, + 6458, + 6459, + 6460, + 6461, + 6462, + 6463, + 6464, + 6467, + 6468, + 6469, + 6470 + ], + "rightFoot": [ + 6727, + 6728, + 6729, + 6730, + 6731, + 6732, + 6733, + 6734, + 6735, + 6736, + 6737, + 6738, + 6739, + 6740, + 6741, + 6742, + 6743, + 6744, + 6745, + 6746, + 6747, + 6748, + 6749, + 6750, + 6751, + 6752, + 6753, + 6754, + 6755, + 6756, + 6757, + 6758, + 6759, + 6760, + 6761, + 6762, + 6763, + 6764, + 6765, + 6766, + 6767, + 6768, + 6769, + 6770, + 6771, + 6772, + 6773, + 6774, + 6775, + 6776, + 6777, + 6778, + 6779, + 6780, + 6781, + 6782, + 6783, + 6784, + 6785, + 6786, + 6787, + 6788, + 6789, + 6790, + 6791, + 6792, + 6793, + 6794, + 6795, + 6796, + 6797, + 6798, + 6799, + 6800, + 6801, + 6802, + 6803, + 6804, + 6805, + 6806, + 6807, + 6808, + 6809, + 6810, + 6811, + 6812, + 6813, + 6814, + 6815, + 6816, + 6817, + 6818, + 6819, + 6820, + 6821, + 6822, + 6823, + 6824, + 6825, + 6826, + 6827, + 6828, + 6829, + 6830, + 6831, + 6832, + 6833, + 6834, + 6835, + 6836, + 6837, + 6838, + 6839, + 6840, + 6841, + 6842, + 6843, + 6844, + 6845, + 6846, + 6847, + 6848, + 6849, + 6850, + 6851, + 6852, + 6853, + 6854, + 6855, + 6856, + 6857, + 6858, + 6859, + 6860, + 6861, + 6862, + 6863, + 6864, + 6865, + 6866, + 6867, + 6868, + 6869 + ], + "head": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 220, + 221, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 286, + 287, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 303, + 304, + 306, + 307, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 329, + 330, + 331, + 332, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 383, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 439, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 449, + 450, + 454, + 455, + 456, + 457, + 458, + 459, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 472, + 473, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 482, + 483, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 492, + 493, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 503, + 504, + 505, + 506, + 507, + 508, + 509, + 510, + 511, + 512, + 513, + 514, + 515, + 516, + 517, + 518, + 519, + 520, + 521, + 522, + 523, + 524, + 525, + 526, + 527, + 528, + 529, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + 538, + 539, + 540, + 541, + 542, + 543, + 544, + 545, + 546, + 547, + 548, + 549, + 550, + 551, + 552, + 553, + 554, + 555, + 556, + 557, + 558, + 559, + 560, + 561, + 562, + 563, + 564, + 565, + 566, + 567, + 568, + 569, + 574, + 575, + 576, + 577, + 578, + 579, + 580, + 581, + 582, + 583, + 1764, + 1765, + 1766, + 1770, + 1771, + 1772, + 1773, + 1774, + 1775, + 1776, + 1777, + 1778, + 1905, + 1906, + 1907, + 1908, + 2779, + 2780, + 2781, + 2782, + 2783, + 2784, + 2785, + 2786, + 2787, + 2788, + 2789, + 2790, + 2791, + 2792, + 2793, + 2794, + 2795, + 2796, + 2797, + 2798, + 2799, + 2800, + 2801, + 2802, + 2803, + 2804, + 2805, + 2806, + 2807, + 2808, + 2809, + 2810, + 2811, + 2814, + 2815, + 2816, + 2817, + 2818, + 3045, + 3046, + 3047, + 3048, + 3051, + 3052, + 3053, + 3054, + 3055, + 3056, + 3058, + 3069, + 3070, + 3071, + 3072, + 3161, + 3162, + 3163, + 3165, + 3166, + 3167, + 3485, + 3486, + 3487, + 3488, + 3489, + 3490, + 3491, + 3492, + 3493, + 3494, + 3499, + 3512, + 3513, + 3514, + 3515, + 3516, + 3517, + 3518, + 3519, + 3520, + 3521, + 3522, + 3523, + 3524, + 3525, + 3526, + 3527, + 3528, + 3529, + 3530, + 3531, + 3532, + 3533, + 3534, + 3535, + 3536, + 3537, + 3538, + 3539, + 3540, + 3541, + 3542, + 3543, + 3544, + 3545, + 3546, + 3547, + 3548, + 3549, + 3550, + 3551, + 3552, + 3553, + 3554, + 3555, + 3556, + 3557, + 3558, + 3559, + 3560, + 3561, + 3562, + 3563, + 3564, + 3565, + 3566, + 3567, + 3568, + 3569, + 3570, + 3571, + 3572, + 3573, + 3574, + 3575, + 3576, + 3577, + 3578, + 3579, + 3580, + 3581, + 3582, + 3583, + 3584, + 3585, + 3586, + 3587, + 3588, + 3589, + 3590, + 3591, + 3592, + 3593, + 3594, + 3595, + 3596, + 3597, + 3598, + 3599, + 3600, + 3601, + 3602, + 3603, + 3604, + 3605, + 3606, + 3607, + 3608, + 3609, + 3610, + 3611, + 3612, + 3613, + 3614, + 3615, + 3616, + 3617, + 3618, + 3619, + 3620, + 3621, + 3622, + 3623, + 3624, + 3625, + 3626, + 3627, + 3628, + 3629, + 3630, + 3631, + 3632, + 3633, + 3634, + 3635, + 3636, + 3637, + 3638, + 3639, + 3640, + 3641, + 3642, + 3643, + 3644, + 3645, + 3646, + 3647, + 3648, + 3649, + 3650, + 3651, + 3652, + 3653, + 3654, + 3655, + 3656, + 3657, + 3658, + 3659, + 3660, + 3661, + 3666, + 3667, + 3668, + 3669, + 3670, + 3671, + 3672, + 3673, + 3674, + 3675, + 3676, + 3677, + 3678, + 3679, + 3680, + 3681, + 3682, + 3683, + 3684, + 3685, + 3688, + 3689, + 3690, + 3691, + 3692, + 3693, + 3694, + 3695, + 3696, + 3697, + 3698, + 3699, + 3700, + 3701, + 3702, + 3703, + 3704, + 3705, + 3706, + 3707, + 3708, + 3709, + 3710, + 3711, + 3712, + 3713, + 3714, + 3715, + 3716, + 3717, + 3732, + 3733, + 3737, + 3738, + 3739, + 3740, + 3741, + 3742, + 3743, + 3744, + 3745, + 3746, + 3747, + 3748, + 3749, + 3750, + 3751, + 3752, + 3753, + 3754, + 3755, + 3756, + 3757, + 3758, + 3759, + 3760, + 3761, + 3762, + 3763, + 3764, + 3765, + 3766, + 3767, + 3770, + 3771, + 3772, + 3773, + 3774, + 3775, + 3776, + 3777, + 3778, + 3779, + 3780, + 3781, + 3782, + 3783, + 3784, + 3785, + 3786, + 3787, + 3788, + 3789, + 3790, + 3791, + 3792, + 3793, + 3794, + 3795, + 3798, + 3799, + 3800, + 3801, + 3802, + 3803, + 3804, + 3805, + 3806, + 3807, + 3815, + 3816, + 3819, + 3820, + 3821, + 3822, + 3823, + 3824, + 3825, + 3826, + 3827, + 3828, + 3829, + 3830, + 3831, + 3832, + 3833, + 3834, + 3835, + 3836, + 3837, + 3838, + 3841, + 3842, + 3843, + 3844, + 3845, + 3846, + 3847, + 3848, + 3849, + 3850, + 3851, + 3852, + 3853, + 3854, + 3855, + 3856, + 3857, + 3858, + 3859, + 3860, + 3861, + 3862, + 3863, + 3864, + 3865, + 3866, + 3867, + 3868, + 3869, + 3870, + 3871, + 3872, + 3873, + 3874, + 3875, + 3876, + 3877, + 3878, + 3879, + 3880, + 3881, + 3882, + 3883, + 3884, + 3885, + 3886, + 3887, + 3888, + 3889, + 3890, + 3891, + 3892, + 3893, + 3894, + 3895, + 3896, + 3897, + 3898, + 3899, + 3900, + 3901, + 3902, + 3903, + 3904, + 3905, + 3906, + 3907, + 3908, + 3909, + 3910, + 3911, + 3912, + 3913, + 3914, + 3915, + 3916, + 3917, + 3922, + 3923, + 3924, + 3925, + 3926, + 3927, + 3928, + 3929, + 3930, + 3931, + 3932, + 3933, + 3936, + 3937, + 3938, + 3939, + 3940, + 3941, + 3945, + 3946, + 3947, + 3948, + 3949, + 3950, + 3951, + 3952, + 3953, + 3954, + 3955, + 3956, + 3957, + 3958, + 3959, + 3960, + 3961, + 3962, + 3963, + 3964, + 3965, + 3966, + 3967, + 3968, + 3969, + 3970, + 3971, + 3972, + 3973, + 3974, + 3975, + 3976, + 3977, + 3978, + 3979, + 3980, + 3981, + 3982, + 3983, + 3984, + 3985, + 3986, + 3987, + 3988, + 3989, + 3990, + 3991, + 3992, + 3993, + 3994, + 3995, + 3996, + 3997, + 3998, + 3999, + 4000, + 4001, + 4002, + 4003, + 4004, + 4005, + 4006, + 4007, + 4008, + 4009, + 4010, + 4011, + 4012, + 4013, + 4014, + 4015, + 4016, + 4017, + 4018, + 4019, + 4020, + 4021, + 4022, + 4023, + 4024, + 4025, + 4026, + 4027, + 4028, + 4029, + 4030, + 4031, + 4032, + 4033, + 4034, + 4035, + 4036, + 4037, + 4038, + 4039, + 4040, + 4041, + 4042, + 4043, + 4044, + 4045, + 4046, + 4047, + 4048, + 4049, + 4050, + 4051, + 4052, + 4053, + 4054, + 4055, + 4056, + 4057, + 4062, + 4063, + 4064, + 4065, + 4066, + 4067, + 4068, + 4069, + 4070, + 4071, + 5231, + 5232, + 5233, + 5235, + 5236, + 5237, + 5238, + 5239, + 5240, + 5241, + 5242, + 5243, + 5366, + 5367, + 5368, + 5369, + 6240, + 6241, + 6242, + 6243, + 6244, + 6245, + 6246, + 6247, + 6248, + 6249, + 6250, + 6251, + 6252, + 6253, + 6254, + 6255, + 6256, + 6257, + 6258, + 6259, + 6260, + 6261, + 6262, + 6263, + 6264, + 6265, + 6266, + 6267, + 6268, + 6269, + 6270, + 6271, + 6272, + 6275, + 6276, + 6277, + 6278, + 6279, + 6492, + 6493, + 6494, + 6495, + 6880, + 6881, + 6882, + 6883, + 6884, + 6885, + 6886, + 6887, + 6888, + 6889 + ], + "rightArm": [ + 4114, + 4115, + 4116, + 4117, + 4122, + 4125, + 4168, + 4171, + 4204, + 4205, + 4206, + 4207, + 4257, + 4258, + 4259, + 4260, + 4261, + 4262, + 4263, + 4264, + 4265, + 4266, + 4267, + 4268, + 4272, + 4273, + 4274, + 4275, + 4276, + 4277, + 4278, + 4279, + 4280, + 4281, + 4714, + 4715, + 4716, + 4717, + 4741, + 4742, + 4743, + 4744, + 4756, + 4763, + 4764, + 4790, + 4791, + 4794, + 4795, + 4816, + 4817, + 4818, + 4819, + 4830, + 4831, + 4832, + 4833, + 4849, + 4850, + 4851, + 4852, + 4853, + 4854, + 4855, + 4856, + 4857, + 4858, + 4859, + 4860, + 4861, + 4862, + 4863, + 4864, + 4865, + 4866, + 4867, + 4868, + 4869, + 4870, + 4871, + 4872, + 4873, + 4876, + 4877, + 4878, + 4879, + 4880, + 4881, + 4882, + 4883, + 4884, + 4885, + 4886, + 4887, + 4888, + 4889, + 4901, + 4902, + 4903, + 4904, + 4905, + 4906, + 4911, + 4912, + 4913, + 4914, + 4915, + 4916, + 4917, + 4918, + 4974, + 4977, + 4978, + 4979, + 4980, + 4981, + 4982, + 5009, + 5010, + 5011, + 5012, + 5014, + 5088, + 5089, + 5090, + 5091, + 5100, + 5101, + 5102, + 5103, + 5104, + 5105, + 5106, + 5107, + 5108, + 5109, + 5110, + 5111, + 5114, + 5115, + 5116, + 5117, + 5118, + 5119, + 5120, + 5121, + 5122, + 5123, + 5124, + 5125, + 5128, + 5129, + 5130, + 5131, + 5134, + 5135, + 5136, + 5137, + 5138, + 5139, + 5140, + 5141, + 5142, + 5143, + 5144, + 5145, + 5146, + 5147, + 5148, + 5149, + 5150, + 5151, + 5152, + 5153, + 5165, + 5166, + 5167, + 5172, + 5173, + 5174, + 5175, + 5176, + 5177, + 5178, + 5179, + 5180, + 5181, + 5182, + 5183, + 5184, + 5185, + 5186, + 5187, + 5188, + 5189, + 5194, + 5200, + 5201, + 5202, + 5203, + 5204, + 5206, + 5208, + 5209, + 5214, + 5215, + 5216, + 5217, + 5218, + 5220, + 5229, + 5292, + 5293, + 5303, + 5306, + 5309, + 5311, + 5314, + 5315, + 5318, + 5319, + 5321, + 5326, + 5327, + 5328, + 5330, + 5331, + 5332, + 5335, + 5336, + 5337, + 5338, + 5339, + 5343, + 5344, + 5349, + 5350, + 5353, + 5361, + 5362, + 5363, + 5364, + 5365, + 5370, + 6280, + 6281, + 6282, + 6283, + 6354, + 6355, + 6356, + 6357, + 6358, + 6359, + 6360, + 6361, + 6362, + 6404, + 6405, + 6433, + 6434, + 6435, + 6436, + 6437, + 6438, + 6439, + 6440, + 6441, + 6442, + 6443, + 6444, + 6445, + 6446, + 6447, + 6448, + 6449, + 6450, + 6451, + 6452, + 6453, + 6454, + 6455, + 6461, + 6471 + ], + "leftHandIndex1": [ + 2027, + 2028, + 2029, + 2030, + 2037, + 2038, + 2039, + 2040, + 2057, + 2067, + 2068, + 2123, + 2124, + 2125, + 2126, + 2127, + 2128, + 2129, + 2130, + 2132, + 2145, + 2146, + 2152, + 2153, + 2154, + 2156, + 2157, + 2158, + 2159, + 2160, + 2161, + 2162, + 2163, + 2164, + 2165, + 2166, + 2167, + 2168, + 2169, + 2177, + 2178, + 2179, + 2181, + 2186, + 2187, + 2190, + 2191, + 2204, + 2205, + 2215, + 2216, + 2217, + 2218, + 2219, + 2220, + 2232, + 2233, + 2245, + 2246, + 2247, + 2258, + 2259, + 2261, + 2262, + 2263, + 2269, + 2270, + 2272, + 2273, + 2274, + 2276, + 2277, + 2280, + 2281, + 2282, + 2283, + 2291, + 2292, + 2293, + 2294, + 2295, + 2296, + 2297, + 2298, + 2299, + 2300, + 2301, + 2302, + 2303, + 2304, + 2305, + 2306, + 2307, + 2308, + 2309, + 2310, + 2311, + 2312, + 2313, + 2314, + 2315, + 2316, + 2317, + 2318, + 2319, + 2320, + 2321, + 2322, + 2323, + 2324, + 2325, + 2326, + 2327, + 2328, + 2329, + 2330, + 2331, + 2332, + 2333, + 2334, + 2335, + 2336, + 2337, + 2338, + 2339, + 2340, + 2341, + 2342, + 2343, + 2344, + 2345, + 2346, + 2347, + 2348, + 2349, + 2350, + 2351, + 2352, + 2353, + 2354, + 2355, + 2356, + 2357, + 2358, + 2359, + 2360, + 2361, + 2362, + 2363, + 2364, + 2365, + 2366, + 2367, + 2368, + 2369, + 2370, + 2371, + 2372, + 2373, + 2374, + 2375, + 2376, + 2377, + 2378, + 2379, + 2380, + 2381, + 2382, + 2383, + 2384, + 2385, + 2386, + 2387, + 2388, + 2389, + 2390, + 2391, + 2392, + 2393, + 2394, + 2395, + 2396, + 2397, + 2398, + 2399, + 2400, + 2401, + 2402, + 2403, + 2404, + 2405, + 2406, + 2407, + 2408, + 2409, + 2410, + 2411, + 2412, + 2413, + 2414, + 2415, + 2416, + 2417, + 2418, + 2419, + 2420, + 2421, + 2422, + 2423, + 2424, + 2425, + 2426, + 2427, + 2428, + 2429, + 2430, + 2431, + 2432, + 2433, + 2434, + 2435, + 2436, + 2437, + 2438, + 2439, + 2440, + 2441, + 2442, + 2443, + 2444, + 2445, + 2446, + 2447, + 2448, + 2449, + 2450, + 2451, + 2452, + 2453, + 2454, + 2455, + 2456, + 2457, + 2458, + 2459, + 2460, + 2461, + 2462, + 2463, + 2464, + 2465, + 2466, + 2467, + 2468, + 2469, + 2470, + 2471, + 2472, + 2473, + 2474, + 2475, + 2476, + 2477, + 2478, + 2479, + 2480, + 2481, + 2482, + 2483, + 2484, + 2485, + 2486, + 2487, + 2488, + 2489, + 2490, + 2491, + 2492, + 2493, + 2494, + 2495, + 2496, + 2497, + 2498, + 2499, + 2500, + 2501, + 2502, + 2503, + 2504, + 2505, + 2506, + 2507, + 2508, + 2509, + 2510, + 2511, + 2512, + 2513, + 2514, + 2515, + 2516, + 2517, + 2518, + 2519, + 2520, + 2521, + 2522, + 2523, + 2524, + 2525, + 2526, + 2527, + 2528, + 2529, + 2530, + 2531, + 2532, + 2533, + 2534, + 2535, + 2536, + 2537, + 2538, + 2539, + 2540, + 2541, + 2542, + 2543, + 2544, + 2545, + 2546, + 2547, + 2548, + 2549, + 2550, + 2551, + 2552, + 2553, + 2554, + 2555, + 2556, + 2557, + 2558, + 2559, + 2560, + 2561, + 2562, + 2563, + 2564, + 2565, + 2566, + 2567, + 2568, + 2569, + 2570, + 2571, + 2572, + 2573, + 2574, + 2575, + 2576, + 2577, + 2578, + 2579, + 2580, + 2581, + 2582, + 2583, + 2584, + 2585, + 2586, + 2587, + 2588, + 2589, + 2590, + 2591, + 2592, + 2593, + 2594, + 2596, + 2597, + 2599, + 2600, + 2601, + 2602, + 2603, + 2604, + 2606, + 2607, + 2609, + 2610, + 2611, + 2612, + 2613, + 2614, + 2615, + 2616, + 2617, + 2618, + 2619, + 2620, + 2621, + 2622, + 2623, + 2624, + 2625, + 2626, + 2627, + 2628, + 2629, + 2630, + 2631, + 2632, + 2633, + 2634, + 2635, + 2636, + 2637, + 2638, + 2639, + 2640, + 2641, + 2642, + 2643, + 2644, + 2645, + 2646, + 2647, + 2648, + 2649, + 2650, + 2651, + 2652, + 2653, + 2654, + 2655, + 2656, + 2657, + 2658, + 2659, + 2660, + 2661, + 2662, + 2663, + 2664, + 2665, + 2666, + 2667, + 2668, + 2669, + 2670, + 2671, + 2672, + 2673, + 2674, + 2675, + 2676, + 2677, + 2678, + 2679, + 2680, + 2681, + 2682, + 2683, + 2684, + 2685, + 2686, + 2687, + 2688, + 2689, + 2690, + 2691, + 2692, + 2693, + 2694, + 2695, + 2696 + ], + "rightLeg": [ + 4481, + 4482, + 4485, + 4486, + 4491, + 4492, + 4493, + 4495, + 4498, + 4500, + 4501, + 4505, + 4506, + 4529, + 4532, + 4533, + 4534, + 4535, + 4536, + 4537, + 4538, + 4539, + 4540, + 4541, + 4542, + 4543, + 4544, + 4545, + 4546, + 4547, + 4548, + 4549, + 4550, + 4551, + 4552, + 4553, + 4554, + 4555, + 4556, + 4557, + 4558, + 4559, + 4560, + 4561, + 4562, + 4563, + 4564, + 4565, + 4566, + 4567, + 4568, + 4569, + 4570, + 4571, + 4572, + 4573, + 4574, + 4575, + 4576, + 4577, + 4578, + 4579, + 4580, + 4581, + 4582, + 4583, + 4584, + 4585, + 4586, + 4587, + 4588, + 4589, + 4590, + 4591, + 4592, + 4593, + 4594, + 4595, + 4596, + 4597, + 4598, + 4599, + 4600, + 4601, + 4602, + 4603, + 4604, + 4605, + 4606, + 4607, + 4608, + 4609, + 4610, + 4611, + 4612, + 4613, + 4614, + 4615, + 4616, + 4617, + 4618, + 4619, + 4620, + 4621, + 4622, + 4634, + 4635, + 4636, + 4637, + 4638, + 4639, + 4640, + 4641, + 4642, + 4643, + 4644, + 4661, + 4662, + 4663, + 4664, + 4665, + 4666, + 4667, + 4668, + 4669, + 4842, + 4843, + 4844, + 4845, + 4846, + 4847, + 4848, + 4937, + 4938, + 4939, + 4940, + 4941, + 4942, + 4943, + 4944, + 4945, + 4946, + 4947, + 4993, + 4994, + 4995, + 4996, + 4997, + 4998, + 4999, + 5000, + 5001, + 5002, + 5003, + 6574, + 6575, + 6576, + 6577, + 6578, + 6579, + 6580, + 6581, + 6582, + 6583, + 6584, + 6585, + 6586, + 6587, + 6588, + 6589, + 6590, + 6591, + 6592, + 6593, + 6594, + 6595, + 6596, + 6597, + 6598, + 6599, + 6600, + 6601, + 6602, + 6603, + 6604, + 6605, + 6606, + 6607, + 6608, + 6609, + 6610, + 6719, + 6720, + 6721, + 6722, + 6723, + 6724, + 6725, + 6726, + 6727, + 6728, + 6729, + 6730, + 6731, + 6732, + 6733, + 6734, + 6735, + 6832, + 6833, + 6834, + 6835, + 6836, + 6869, + 6870, + 6871, + 6872 + ], + "rightHandIndex1": [ + 5488, + 5489, + 5490, + 5491, + 5498, + 5499, + 5500, + 5501, + 5518, + 5528, + 5529, + 5584, + 5585, + 5586, + 5587, + 5588, + 5589, + 5590, + 5591, + 5592, + 5606, + 5607, + 5613, + 5615, + 5616, + 5617, + 5618, + 5619, + 5620, + 5621, + 5622, + 5623, + 5624, + 5625, + 5626, + 5627, + 5628, + 5629, + 5630, + 5638, + 5639, + 5640, + 5642, + 5647, + 5648, + 5650, + 5651, + 5665, + 5666, + 5676, + 5677, + 5678, + 5679, + 5680, + 5681, + 5693, + 5694, + 5706, + 5707, + 5708, + 5719, + 5721, + 5722, + 5723, + 5724, + 5730, + 5731, + 5733, + 5734, + 5735, + 5737, + 5738, + 5741, + 5742, + 5743, + 5744, + 5752, + 5753, + 5754, + 5755, + 5756, + 5757, + 5758, + 5759, + 5760, + 5761, + 5762, + 5763, + 5764, + 5765, + 5766, + 5767, + 5768, + 5769, + 5770, + 5771, + 5772, + 5773, + 5774, + 5775, + 5776, + 5777, + 5778, + 5779, + 5780, + 5781, + 5782, + 5783, + 5784, + 5785, + 5786, + 5787, + 5788, + 5789, + 5790, + 5791, + 5792, + 5793, + 5794, + 5795, + 5796, + 5797, + 5798, + 5799, + 5800, + 5801, + 5802, + 5803, + 5804, + 5805, + 5806, + 5807, + 5808, + 5809, + 5810, + 5811, + 5812, + 5813, + 5814, + 5815, + 5816, + 5817, + 5818, + 5819, + 5820, + 5821, + 5822, + 5823, + 5824, + 5825, + 5826, + 5827, + 5828, + 5829, + 5830, + 5831, + 5832, + 5833, + 5834, + 5835, + 5836, + 5837, + 5838, + 5839, + 5840, + 5841, + 5842, + 5843, + 5844, + 5845, + 5846, + 5847, + 5848, + 5849, + 5850, + 5851, + 5852, + 5853, + 5854, + 5855, + 5856, + 5857, + 5858, + 5859, + 5860, + 5861, + 5862, + 5863, + 5864, + 5865, + 5866, + 5867, + 5868, + 5869, + 5870, + 5871, + 5872, + 5873, + 5874, + 5875, + 5876, + 5877, + 5878, + 5879, + 5880, + 5881, + 5882, + 5883, + 5884, + 5885, + 5886, + 5887, + 5888, + 5889, + 5890, + 5891, + 5892, + 5893, + 5894, + 5895, + 5896, + 5897, + 5898, + 5899, + 5900, + 5901, + 5902, + 5903, + 5904, + 5905, + 5906, + 5907, + 5908, + 5909, + 5910, + 5911, + 5912, + 5913, + 5914, + 5915, + 5916, + 5917, + 5918, + 5919, + 5920, + 5921, + 5922, + 5923, + 5924, + 5925, + 5926, + 5927, + 5928, + 5929, + 5930, + 5931, + 5932, + 5933, + 5934, + 5935, + 5936, + 5937, + 5938, + 5939, + 5940, + 5941, + 5942, + 5943, + 5944, + 5945, + 5946, + 5947, + 5948, + 5949, + 5950, + 5951, + 5952, + 5953, + 5954, + 5955, + 5956, + 5957, + 5958, + 5959, + 5960, + 5961, + 5962, + 5963, + 5964, + 5965, + 5966, + 5967, + 5968, + 5969, + 5970, + 5971, + 5972, + 5973, + 5974, + 5975, + 5976, + 5977, + 5978, + 5979, + 5980, + 5981, + 5982, + 5983, + 5984, + 5985, + 5986, + 5987, + 5988, + 5989, + 5990, + 5991, + 5992, + 5993, + 5994, + 5995, + 5996, + 5997, + 5998, + 5999, + 6000, + 6001, + 6002, + 6003, + 6004, + 6005, + 6006, + 6007, + 6008, + 6009, + 6010, + 6011, + 6012, + 6013, + 6014, + 6015, + 6016, + 6017, + 6018, + 6019, + 6020, + 6021, + 6022, + 6023, + 6024, + 6025, + 6026, + 6027, + 6028, + 6029, + 6030, + 6031, + 6032, + 6033, + 6034, + 6035, + 6036, + 6037, + 6038, + 6039, + 6040, + 6041, + 6042, + 6043, + 6044, + 6045, + 6046, + 6047, + 6048, + 6049, + 6050, + 6051, + 6052, + 6053, + 6054, + 6055, + 6058, + 6059, + 6060, + 6061, + 6062, + 6063, + 6064, + 6065, + 6068, + 6069, + 6070, + 6071, + 6072, + 6073, + 6074, + 6075, + 6076, + 6077, + 6078, + 6079, + 6080, + 6081, + 6082, + 6083, + 6084, + 6085, + 6086, + 6087, + 6088, + 6089, + 6090, + 6091, + 6092, + 6093, + 6094, + 6095, + 6096, + 6097, + 6098, + 6099, + 6100, + 6101, + 6102, + 6103, + 6104, + 6105, + 6106, + 6107, + 6108, + 6109, + 6110, + 6111, + 6112, + 6113, + 6114, + 6115, + 6116, + 6117, + 6118, + 6119, + 6120, + 6121, + 6122, + 6123, + 6124, + 6125, + 6126, + 6127, + 6128, + 6129, + 6130, + 6131, + 6132, + 6133, + 6134, + 6135, + 6136, + 6137, + 6138, + 6139, + 6140, + 6141, + 6142, + 6143, + 6144, + 6145, + 6146, + 6147, + 6148, + 6149, + 6150, + 6151, + 6152, + 6153, + 6154, + 6155, + 6156, + 6157 + ], + "leftForeArm": [ + 1546, + 1547, + 1548, + 1549, + 1550, + 1551, + 1552, + 1553, + 1554, + 1555, + 1556, + 1557, + 1558, + 1559, + 1560, + 1561, + 1562, + 1563, + 1564, + 1565, + 1566, + 1567, + 1568, + 1569, + 1570, + 1571, + 1572, + 1573, + 1574, + 1575, + 1576, + 1577, + 1578, + 1579, + 1580, + 1581, + 1582, + 1583, + 1584, + 1585, + 1586, + 1587, + 1588, + 1589, + 1590, + 1591, + 1592, + 1593, + 1594, + 1595, + 1596, + 1597, + 1598, + 1599, + 1600, + 1601, + 1602, + 1603, + 1604, + 1605, + 1606, + 1607, + 1608, + 1609, + 1610, + 1611, + 1612, + 1613, + 1614, + 1615, + 1616, + 1617, + 1618, + 1620, + 1621, + 1623, + 1624, + 1625, + 1626, + 1627, + 1628, + 1629, + 1630, + 1643, + 1644, + 1646, + 1647, + 1650, + 1651, + 1654, + 1655, + 1657, + 1658, + 1659, + 1660, + 1661, + 1662, + 1663, + 1664, + 1665, + 1666, + 1685, + 1686, + 1687, + 1688, + 1689, + 1690, + 1691, + 1692, + 1693, + 1694, + 1695, + 1699, + 1700, + 1701, + 1702, + 1721, + 1722, + 1723, + 1724, + 1725, + 1726, + 1727, + 1728, + 1729, + 1730, + 1732, + 1736, + 1738, + 1741, + 1742, + 1743, + 1744, + 1750, + 1752, + 1900, + 1909, + 1910, + 1911, + 1912, + 1913, + 1914, + 1915, + 1916, + 1917, + 1918, + 1919, + 1920, + 1921, + 1922, + 1923, + 1924, + 1925, + 1926, + 1927, + 1928, + 1929, + 1930, + 1931, + 1932, + 1933, + 1934, + 1935, + 1936, + 1937, + 1938, + 1939, + 1940, + 1941, + 1942, + 1943, + 1944, + 1945, + 1946, + 1947, + 1948, + 1949, + 1950, + 1951, + 1952, + 1953, + 1954, + 1955, + 1956, + 1957, + 1958, + 1959, + 1960, + 1961, + 1962, + 1963, + 1964, + 1965, + 1966, + 1967, + 1968, + 1969, + 1970, + 1971, + 1972, + 1973, + 1974, + 1975, + 1976, + 1977, + 1978, + 1979, + 1980, + 2019, + 2059, + 2060, + 2073, + 2089, + 2098, + 2099, + 2100, + 2101, + 2102, + 2103, + 2104, + 2105, + 2106, + 2107, + 2108, + 2109, + 2110, + 2111, + 2112, + 2147, + 2148, + 2206, + 2207, + 2208, + 2209, + 2228, + 2230, + 2234, + 2235, + 2241, + 2242, + 2243, + 2244, + 2279, + 2286, + 2873, + 2874 + ], + "rightForeArm": [ + 5015, + 5016, + 5017, + 5018, + 5019, + 5020, + 5021, + 5022, + 5023, + 5024, + 5025, + 5026, + 5027, + 5028, + 5029, + 5030, + 5031, + 5032, + 5033, + 5034, + 5035, + 5036, + 5037, + 5038, + 5039, + 5040, + 5041, + 5042, + 5043, + 5044, + 5045, + 5046, + 5047, + 5048, + 5049, + 5050, + 5051, + 5052, + 5053, + 5054, + 5055, + 5056, + 5057, + 5058, + 5059, + 5060, + 5061, + 5062, + 5063, + 5064, + 5065, + 5066, + 5067, + 5068, + 5069, + 5070, + 5071, + 5072, + 5073, + 5074, + 5075, + 5076, + 5077, + 5078, + 5079, + 5080, + 5081, + 5082, + 5083, + 5084, + 5085, + 5086, + 5087, + 5090, + 5091, + 5092, + 5093, + 5094, + 5095, + 5096, + 5097, + 5098, + 5099, + 5112, + 5113, + 5116, + 5117, + 5120, + 5121, + 5124, + 5125, + 5126, + 5127, + 5128, + 5129, + 5130, + 5131, + 5132, + 5133, + 5134, + 5135, + 5154, + 5155, + 5156, + 5157, + 5158, + 5159, + 5160, + 5161, + 5162, + 5163, + 5164, + 5168, + 5169, + 5170, + 5171, + 5190, + 5191, + 5192, + 5193, + 5194, + 5195, + 5196, + 5197, + 5198, + 5199, + 5202, + 5205, + 5207, + 5210, + 5211, + 5212, + 5213, + 5219, + 5221, + 5361, + 5370, + 5371, + 5372, + 5373, + 5374, + 5375, + 5376, + 5377, + 5378, + 5379, + 5380, + 5381, + 5382, + 5383, + 5384, + 5385, + 5386, + 5387, + 5388, + 5389, + 5390, + 5391, + 5392, + 5393, + 5394, + 5395, + 5396, + 5397, + 5398, + 5399, + 5400, + 5401, + 5402, + 5403, + 5404, + 5405, + 5406, + 5407, + 5408, + 5409, + 5410, + 5411, + 5412, + 5413, + 5414, + 5415, + 5416, + 5417, + 5418, + 5419, + 5420, + 5421, + 5422, + 5423, + 5424, + 5425, + 5426, + 5427, + 5428, + 5429, + 5430, + 5431, + 5432, + 5433, + 5434, + 5435, + 5436, + 5437, + 5438, + 5439, + 5440, + 5441, + 5480, + 5520, + 5521, + 5534, + 5550, + 5559, + 5560, + 5561, + 5562, + 5563, + 5564, + 5565, + 5566, + 5567, + 5568, + 5569, + 5570, + 5571, + 5572, + 5573, + 5608, + 5609, + 5667, + 5668, + 5669, + 5670, + 5689, + 5691, + 5695, + 5696, + 5702, + 5703, + 5704, + 5705, + 5740, + 5747, + 6334, + 6335 + ], + "neck": [ + 148, + 150, + 151, + 152, + 153, + 172, + 174, + 175, + 201, + 202, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + 217, + 218, + 219, + 222, + 223, + 224, + 225, + 256, + 257, + 284, + 285, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 333, + 334, + 423, + 424, + 425, + 426, + 440, + 441, + 451, + 452, + 453, + 460, + 461, + 571, + 572, + 824, + 825, + 826, + 827, + 828, + 829, + 1279, + 1280, + 1312, + 1313, + 1319, + 1320, + 1331, + 3049, + 3050, + 3057, + 3058, + 3059, + 3068, + 3164, + 3661, + 3662, + 3663, + 3664, + 3665, + 3685, + 3686, + 3687, + 3714, + 3715, + 3716, + 3717, + 3718, + 3719, + 3720, + 3721, + 3722, + 3723, + 3724, + 3725, + 3726, + 3727, + 3728, + 3729, + 3730, + 3731, + 3734, + 3735, + 3736, + 3737, + 3768, + 3769, + 3796, + 3797, + 3807, + 3808, + 3809, + 3810, + 3811, + 3812, + 3813, + 3814, + 3815, + 3816, + 3817, + 3818, + 3819, + 3839, + 3840, + 3918, + 3919, + 3920, + 3921, + 3934, + 3935, + 3942, + 3943, + 3944, + 3950, + 4060, + 4061, + 4312, + 4313, + 4314, + 4315, + 4761, + 4762, + 4792, + 4793, + 4799, + 4800, + 4807 + ], + "rightToeBase": [ + 6611, + 6612, + 6613, + 6614, + 6615, + 6616, + 6617, + 6618, + 6619, + 6620, + 6621, + 6622, + 6623, + 6624, + 6625, + 6626, + 6627, + 6628, + 6629, + 6630, + 6631, + 6632, + 6633, + 6634, + 6635, + 6636, + 6637, + 6638, + 6639, + 6640, + 6641, + 6642, + 6643, + 6644, + 6645, + 6646, + 6647, + 6648, + 6649, + 6650, + 6651, + 6652, + 6653, + 6654, + 6655, + 6656, + 6657, + 6658, + 6659, + 6660, + 6661, + 6662, + 6663, + 6664, + 6665, + 6666, + 6667, + 6668, + 6669, + 6670, + 6671, + 6672, + 6673, + 6674, + 6675, + 6676, + 6677, + 6678, + 6679, + 6680, + 6681, + 6682, + 6683, + 6684, + 6685, + 6686, + 6687, + 6688, + 6689, + 6690, + 6691, + 6692, + 6693, + 6694, + 6695, + 6696, + 6697, + 6698, + 6699, + 6700, + 6701, + 6702, + 6703, + 6704, + 6705, + 6706, + 6707, + 6708, + 6709, + 6710, + 6711, + 6712, + 6713, + 6714, + 6715, + 6716, + 6717, + 6718, + 6736, + 6739, + 6741, + 6743, + 6745, + 6747, + 6749, + 6750, + 6752, + 6754, + 6757, + 6758, + 6760, + 6762 + ], + "spine": [ + 616, + 617, + 630, + 631, + 632, + 633, + 654, + 655, + 656, + 657, + 662, + 663, + 664, + 665, + 720, + 721, + 765, + 766, + 767, + 768, + 796, + 797, + 798, + 799, + 889, + 890, + 916, + 917, + 918, + 919, + 921, + 922, + 923, + 924, + 925, + 926, + 1188, + 1189, + 1211, + 1212, + 1248, + 1249, + 1250, + 1251, + 1264, + 1265, + 1266, + 1267, + 1323, + 1324, + 1325, + 1326, + 1327, + 1328, + 1332, + 1333, + 1334, + 1335, + 1336, + 1344, + 1345, + 1481, + 1482, + 1483, + 1484, + 1485, + 1486, + 1487, + 1488, + 1489, + 1490, + 1491, + 1492, + 1493, + 1494, + 1495, + 1496, + 1767, + 2823, + 2824, + 2825, + 2826, + 2827, + 2828, + 2829, + 2830, + 2831, + 2832, + 2833, + 2834, + 2835, + 2836, + 2837, + 2838, + 2839, + 2840, + 2841, + 2842, + 2843, + 2844, + 2845, + 2847, + 2848, + 2851, + 3016, + 3017, + 3018, + 3019, + 3020, + 3023, + 3024, + 3124, + 3173, + 3476, + 3477, + 3478, + 3480, + 3500, + 3501, + 3502, + 3504, + 3509, + 3511, + 4103, + 4104, + 4118, + 4119, + 4120, + 4121, + 4142, + 4143, + 4144, + 4145, + 4150, + 4151, + 4152, + 4153, + 4208, + 4209, + 4253, + 4254, + 4255, + 4256, + 4284, + 4285, + 4286, + 4287, + 4375, + 4376, + 4402, + 4403, + 4405, + 4406, + 4407, + 4408, + 4409, + 4410, + 4411, + 4412, + 4674, + 4675, + 4694, + 4695, + 4731, + 4732, + 4733, + 4734, + 4747, + 4748, + 4749, + 4750, + 4803, + 4804, + 4805, + 4806, + 4808, + 4809, + 4810, + 4811, + 4812, + 4820, + 4821, + 4953, + 4954, + 4955, + 4956, + 4957, + 4958, + 4959, + 4960, + 4961, + 4962, + 4963, + 4964, + 4965, + 4966, + 4967, + 4968, + 5234, + 6284, + 6285, + 6286, + 6287, + 6288, + 6289, + 6290, + 6291, + 6292, + 6293, + 6294, + 6295, + 6296, + 6297, + 6298, + 6299, + 6300, + 6301, + 6302, + 6303, + 6304, + 6305, + 6306, + 6308, + 6309, + 6312, + 6472, + 6473, + 6474, + 6545, + 6874, + 6875, + 6876, + 6878 + ], + "leftUpLeg": [ + 833, + 834, + 838, + 839, + 847, + 848, + 849, + 850, + 851, + 852, + 853, + 854, + 870, + 871, + 872, + 873, + 874, + 875, + 876, + 877, + 878, + 879, + 880, + 881, + 897, + 898, + 899, + 900, + 901, + 902, + 903, + 904, + 905, + 906, + 907, + 908, + 909, + 910, + 911, + 912, + 913, + 914, + 915, + 933, + 934, + 935, + 936, + 944, + 945, + 946, + 947, + 948, + 949, + 950, + 951, + 952, + 953, + 954, + 955, + 956, + 957, + 958, + 959, + 960, + 961, + 962, + 963, + 964, + 965, + 966, + 967, + 968, + 969, + 970, + 971, + 972, + 973, + 974, + 975, + 976, + 977, + 978, + 979, + 980, + 981, + 982, + 983, + 984, + 985, + 986, + 987, + 988, + 989, + 990, + 991, + 992, + 993, + 994, + 995, + 996, + 997, + 998, + 999, + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, + 1011, + 1012, + 1013, + 1014, + 1015, + 1016, + 1017, + 1018, + 1019, + 1020, + 1021, + 1022, + 1023, + 1024, + 1025, + 1026, + 1027, + 1028, + 1029, + 1030, + 1031, + 1032, + 1033, + 1034, + 1035, + 1036, + 1037, + 1038, + 1039, + 1040, + 1041, + 1042, + 1043, + 1044, + 1045, + 1046, + 1137, + 1138, + 1139, + 1140, + 1141, + 1142, + 1143, + 1144, + 1145, + 1146, + 1147, + 1148, + 1159, + 1160, + 1161, + 1162, + 1163, + 1164, + 1165, + 1166, + 1167, + 1168, + 1169, + 1170, + 1171, + 1172, + 1173, + 1174, + 1184, + 1185, + 1186, + 1187, + 1221, + 1222, + 1223, + 1224, + 1225, + 1226, + 1227, + 1228, + 1229, + 1230, + 1262, + 1263, + 1274, + 1275, + 1276, + 1277, + 1321, + 1322, + 1354, + 1359, + 1360, + 1361, + 1362, + 1365, + 1366, + 1367, + 1368, + 1451, + 1452, + 1453, + 1455, + 1456, + 1457, + 1458, + 1459, + 1460, + 1461, + 1462, + 1463, + 1475, + 1477, + 1478, + 1479, + 1480, + 1498, + 1499, + 1500, + 1501, + 1511, + 1512, + 1513, + 1514, + 1516, + 1517, + 1518, + 1519, + 1520, + 1521, + 1522, + 1533, + 1534, + 3125, + 3126, + 3127, + 3128, + 3131, + 3132, + 3133, + 3134, + 3135, + 3475, + 3479 + ], + "leftHand": [ + 1981, + 1982, + 1983, + 1984, + 1985, + 1986, + 1987, + 1988, + 1989, + 1990, + 1991, + 1992, + 1993, + 1994, + 1995, + 1996, + 1997, + 1998, + 1999, + 2000, + 2001, + 2002, + 2003, + 2004, + 2005, + 2006, + 2007, + 2008, + 2009, + 2010, + 2011, + 2012, + 2013, + 2014, + 2015, + 2016, + 2017, + 2018, + 2019, + 2020, + 2021, + 2022, + 2023, + 2024, + 2025, + 2026, + 2031, + 2032, + 2033, + 2034, + 2035, + 2036, + 2041, + 2042, + 2043, + 2044, + 2045, + 2046, + 2047, + 2048, + 2049, + 2050, + 2051, + 2052, + 2053, + 2054, + 2055, + 2056, + 2057, + 2058, + 2059, + 2060, + 2061, + 2062, + 2063, + 2064, + 2065, + 2066, + 2069, + 2070, + 2071, + 2072, + 2073, + 2074, + 2075, + 2076, + 2077, + 2078, + 2079, + 2080, + 2081, + 2082, + 2083, + 2084, + 2085, + 2086, + 2087, + 2088, + 2089, + 2090, + 2091, + 2092, + 2093, + 2094, + 2095, + 2096, + 2097, + 2098, + 2099, + 2100, + 2101, + 2107, + 2111, + 2113, + 2114, + 2115, + 2116, + 2117, + 2118, + 2119, + 2120, + 2121, + 2122, + 2127, + 2130, + 2131, + 2132, + 2133, + 2134, + 2135, + 2136, + 2137, + 2138, + 2139, + 2140, + 2141, + 2142, + 2143, + 2144, + 2149, + 2150, + 2151, + 2152, + 2155, + 2160, + 2163, + 2164, + 2170, + 2171, + 2172, + 2173, + 2174, + 2175, + 2176, + 2177, + 2178, + 2179, + 2180, + 2182, + 2183, + 2184, + 2185, + 2188, + 2189, + 2191, + 2192, + 2193, + 2194, + 2195, + 2196, + 2197, + 2198, + 2199, + 2200, + 2201, + 2202, + 2203, + 2207, + 2209, + 2210, + 2211, + 2212, + 2213, + 2214, + 2221, + 2222, + 2223, + 2224, + 2225, + 2226, + 2227, + 2228, + 2229, + 2231, + 2234, + 2236, + 2237, + 2238, + 2239, + 2240, + 2246, + 2247, + 2248, + 2249, + 2250, + 2251, + 2252, + 2253, + 2254, + 2255, + 2256, + 2257, + 2258, + 2259, + 2260, + 2262, + 2263, + 2264, + 2265, + 2266, + 2267, + 2268, + 2269, + 2270, + 2271, + 2274, + 2275, + 2276, + 2277, + 2278, + 2279, + 2284, + 2285, + 2287, + 2288, + 2289, + 2290, + 2293, + 2595, + 2598, + 2605, + 2608, + 2697, + 2698, + 2699, + 2700, + 2701, + 2702, + 2703, + 2704, + 2705, + 2706, + 2707, + 2708, + 2709, + 2710, + 2711, + 2712, + 2713, + 2714, + 2715, + 2716, + 2717, + 2718, + 2719, + 2720, + 2721, + 2722, + 2723, + 2724, + 2725, + 2726, + 2727, + 2728, + 2729, + 2730, + 2731, + 2732, + 2733, + 2734, + 2735, + 2736, + 2737, + 2738, + 2739, + 2740, + 2741, + 2742, + 2743, + 2744, + 2745, + 2746, + 2747, + 2748, + 2749, + 2750, + 2751, + 2752, + 2753, + 2754, + 2755, + 2756, + 2757, + 2758, + 2759, + 2760, + 2761, + 2762, + 2763, + 2764, + 2765, + 2766, + 2767, + 2768, + 2769, + 2770, + 2771, + 2772, + 2773, + 2774, + 2775, + 2776, + 2777, + 2778 + ], + "hips": [ + 631, + 632, + 654, + 657, + 662, + 665, + 676, + 677, + 678, + 679, + 705, + 720, + 796, + 799, + 800, + 801, + 802, + 807, + 808, + 809, + 810, + 815, + 816, + 822, + 823, + 830, + 831, + 832, + 833, + 834, + 835, + 836, + 837, + 838, + 839, + 840, + 841, + 842, + 843, + 844, + 845, + 846, + 855, + 856, + 857, + 858, + 859, + 860, + 861, + 862, + 863, + 864, + 865, + 866, + 867, + 868, + 869, + 871, + 878, + 881, + 882, + 883, + 884, + 885, + 886, + 887, + 888, + 889, + 890, + 912, + 915, + 916, + 917, + 918, + 919, + 920, + 932, + 937, + 938, + 939, + 1163, + 1166, + 1203, + 1204, + 1205, + 1206, + 1207, + 1208, + 1209, + 1210, + 1246, + 1247, + 1262, + 1263, + 1276, + 1277, + 1278, + 1321, + 1336, + 1337, + 1338, + 1339, + 1353, + 1354, + 1361, + 1362, + 1363, + 1364, + 1446, + 1447, + 1448, + 1449, + 1450, + 1454, + 1476, + 1497, + 1511, + 1513, + 1514, + 1515, + 1533, + 1534, + 1539, + 1540, + 1768, + 1769, + 1779, + 1780, + 1781, + 1782, + 1783, + 1784, + 1785, + 1786, + 1787, + 1788, + 1789, + 1790, + 1791, + 1792, + 1793, + 1794, + 1795, + 1796, + 1797, + 1798, + 1799, + 1800, + 1801, + 1802, + 1803, + 1804, + 1805, + 1806, + 1807, + 2909, + 2910, + 2911, + 2912, + 2913, + 2914, + 2915, + 2916, + 2917, + 2918, + 2919, + 2920, + 2921, + 2922, + 2923, + 2924, + 2925, + 2926, + 2927, + 2928, + 2929, + 2930, + 3018, + 3019, + 3021, + 3022, + 3080, + 3081, + 3082, + 3083, + 3084, + 3085, + 3086, + 3087, + 3088, + 3089, + 3090, + 3091, + 3092, + 3093, + 3094, + 3095, + 3096, + 3097, + 3098, + 3099, + 3100, + 3101, + 3102, + 3103, + 3104, + 3105, + 3106, + 3107, + 3108, + 3109, + 3110, + 3111, + 3112, + 3113, + 3114, + 3115, + 3116, + 3117, + 3118, + 3119, + 3120, + 3121, + 3122, + 3123, + 3124, + 3128, + 3129, + 3130, + 3136, + 3137, + 3138, + 3139, + 3140, + 3141, + 3142, + 3143, + 3144, + 3145, + 3146, + 3147, + 3148, + 3149, + 3150, + 3151, + 3152, + 3153, + 3154, + 3155, + 3156, + 3157, + 3158, + 3159, + 3160, + 3170, + 3172, + 3481, + 3484, + 3500, + 3502, + 3503, + 3507, + 3510, + 4120, + 4121, + 4142, + 4143, + 4150, + 4151, + 4164, + 4165, + 4166, + 4167, + 4193, + 4208, + 4284, + 4285, + 4288, + 4289, + 4290, + 4295, + 4296, + 4297, + 4298, + 4303, + 4304, + 4310, + 4311, + 4316, + 4317, + 4318, + 4319, + 4320, + 4321, + 4322, + 4323, + 4324, + 4325, + 4326, + 4327, + 4328, + 4329, + 4330, + 4331, + 4332, + 4341, + 4342, + 4343, + 4344, + 4345, + 4346, + 4347, + 4348, + 4349, + 4350, + 4351, + 4352, + 4353, + 4354, + 4355, + 4356, + 4364, + 4365, + 4368, + 4369, + 4370, + 4371, + 4372, + 4373, + 4374, + 4375, + 4376, + 4398, + 4399, + 4402, + 4403, + 4404, + 4405, + 4406, + 4418, + 4423, + 4424, + 4425, + 4649, + 4650, + 4689, + 4690, + 4691, + 4692, + 4693, + 4729, + 4730, + 4745, + 4746, + 4759, + 4760, + 4801, + 4812, + 4813, + 4814, + 4815, + 4829, + 4836, + 4837, + 4919, + 4920, + 4921, + 4922, + 4923, + 4927, + 4969, + 4983, + 4984, + 4986, + 5004, + 5005, + 5244, + 5245, + 5246, + 5247, + 5248, + 5249, + 5250, + 5251, + 5252, + 5253, + 5254, + 5255, + 5256, + 5257, + 5258, + 5259, + 5260, + 5261, + 5262, + 5263, + 5264, + 5265, + 5266, + 5267, + 5268, + 6368, + 6369, + 6370, + 6371, + 6372, + 6373, + 6374, + 6375, + 6376, + 6377, + 6378, + 6379, + 6380, + 6381, + 6382, + 6383, + 6384, + 6385, + 6386, + 6387, + 6388, + 6389, + 6473, + 6474, + 6504, + 6505, + 6506, + 6507, + 6508, + 6509, + 6510, + 6511, + 6512, + 6513, + 6514, + 6515, + 6516, + 6517, + 6518, + 6519, + 6520, + 6521, + 6522, + 6523, + 6524, + 6525, + 6526, + 6527, + 6528, + 6529, + 6530, + 6531, + 6532, + 6533, + 6534, + 6535, + 6536, + 6537, + 6538, + 6539, + 6540, + 6541, + 6542, + 6543, + 6544, + 6545, + 6549, + 6550, + 6551, + 6557, + 6558, + 6559, + 6560, + 6561, + 6562, + 6563, + 6564, + 6565, + 6566, + 6567, + 6568, + 6569, + 6570, + 6571, + 6572, + 6573 + ] +} \ No newline at end of file diff --git a/lib/common/train_util.py b/lib/common/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8b25fdb7747d32cf6d221e44bd5ed877f70179 --- /dev/null +++ b/lib/common/train_util.py @@ -0,0 +1,690 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import yaml +import os.path as osp +import torch +import numpy as np +from ..dataset.mesh_util import * +from ..net.geometry import orthogonal +from pytorch3d.renderer.mesh import rasterize_meshes +from .render_utils import Pytorch3dRasterizer +from pytorch3d.structures import Meshes +import cv2 +from PIL import Image +from tqdm import tqdm +import os +from termcolor import colored + +import pytorch_lightning as pl +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities import rank_zero_warn + + +def rename(old_dict, old_name, new_name): + new_dict = {} + for key, value in zip(old_dict.keys(), old_dict.values()): + new_key = key if key != old_name else new_name + new_dict[new_key] = old_dict[key] + return new_dict + + +class SubTrainer(pl.Trainer): + + def save_checkpoint(self, filepath, weights_only=False): + """Save model/training states as a checkpoint file through state-dump and file-write. + Args: + filepath: write-target file's path + weights_only: saving model weights only + """ + _checkpoint = self.checkpoint_connector.dump_checkpoint(weights_only) + + del_keys = [] + for key in _checkpoint["state_dict"].keys(): + for ig_key in ["normal_filter", "voxelization", "reconEngine"]: + if ig_key in key: + del_keys.append(key) + for key in del_keys: + del _checkpoint["state_dict"][key] + + if self.is_global_zero: + # write the checkpoint dictionary on the file + + if self.training_type_plugin: + checkpoint = self.training_type_plugin.on_save(_checkpoint) + try: + atomic_save(checkpoint, filepath) + except AttributeError as err: + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + rank_zero_warn( + "Warning, `hyper_parameters` dropped from checkpoint." + f" An attribute is not picklable {err}") + atomic_save(checkpoint, filepath) + + +def load_networks(cfg, model, mlp_path, normal_path): + + model_dict = model.state_dict() + main_dict = {} + normal_dict = {} + + # MLP part loading + if os.path.exists(mlp_path) and mlp_path.endswith("ckpt"): + main_dict = torch.load( + mlp_path, + map_location=torch.device(f"cuda:{cfg.gpus[0]}"))["state_dict"] + + main_dict = { + k: v + for k, v in main_dict.items() + if k in model_dict and v.shape == model_dict[k].shape and ( + "reconEngine" not in k) and ("normal_filter" not in k) and ( + "voxelization" not in k) + } + print(colored(f"Resume MLP weights from {mlp_path}", "green")) + + # normal network part loading + if os.path.exists(normal_path) and normal_path.endswith("ckpt"): + normal_dict = torch.load( + normal_path, + map_location=torch.device(f"cuda:{cfg.gpus[0]}"))["state_dict"] + + for key in normal_dict.keys(): + normal_dict = rename(normal_dict, key, + key.replace("netG", "netG.normal_filter")) + + normal_dict = { + k: v + for k, v in normal_dict.items() + if k in model_dict and v.shape == model_dict[k].shape + } + print(colored(f"Resume normal model from {normal_path}", "green")) + + model_dict.update(main_dict) + model_dict.update(normal_dict) + model.load_state_dict(model_dict) + + # clean unused GPU memory + del main_dict + del normal_dict + del model_dict + torch.cuda.empty_cache() + + +def reshape_sample_tensor(sample_tensor, num_views): + if num_views == 1: + return sample_tensor + # Need to repeat sample_tensor along the batch dim num_views times + sample_tensor = sample_tensor.unsqueeze(dim=1) + sample_tensor = sample_tensor.repeat(1, num_views, 1, 1) + sample_tensor = sample_tensor.view( + sample_tensor.shape[0] * sample_tensor.shape[1], + sample_tensor.shape[2], sample_tensor.shape[3]) + return sample_tensor + + +def gen_mesh_eval(opt, net, cuda, data, resolution=None): + resolution = opt.resolution if resolution is None else resolution + image_tensor = data['img'].to(device=cuda) + calib_tensor = data['calib'].to(device=cuda) + + net.filter(image_tensor) + + b_min = data['b_min'] + b_max = data['b_max'] + try: + verts, faces, _, _ = reconstruction_faster(net, + cuda, + calib_tensor, + resolution, + b_min, + b_max, + use_octree=False) + + except Exception as e: + print(e) + print('Can not create marching cubes at this time.') + verts, faces = None, None + return verts, faces + + +def gen_mesh(opt, net, cuda, data, save_path, resolution=None): + resolution = opt.resolution if resolution is None else resolution + image_tensor = data['img'].to(device=cuda) + calib_tensor = data['calib'].to(device=cuda) + + net.filter(image_tensor) + + b_min = data['b_min'] + b_max = data['b_max'] + try: + save_img_path = save_path[:-4] + '.png' + save_img_list = [] + for v in range(image_tensor.shape[0]): + save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), + (1, 2, 0)) * 0.5 + + 0.5)[:, :, ::-1] * 255.0 + save_img_list.append(save_img) + save_img = np.concatenate(save_img_list, axis=1) + Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path) + + verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor, + resolution, b_min, b_max) + verts_tensor = torch.from_numpy( + verts.T).unsqueeze(0).to(device=cuda).float() + xyz_tensor = net.projection(verts_tensor, calib_tensor[:1]) + uv = xyz_tensor[:, :2, :] + color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T + color = color * 0.5 + 0.5 + save_obj_mesh_with_color(save_path, verts, faces, color) + except Exception as e: + print(e) + print('Can not create marching cubes at this time.') + verts, faces, color = None, None, None + return verts, faces, color + + +def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True): + image_tensor = data['img'].to(device=cuda) + calib_tensor = data['calib'].to(device=cuda) + + netG.filter(image_tensor) + netC.filter(image_tensor) + netC.attach(netG.get_im_feat()) + + b_min = data['b_min'] + b_max = data['b_max'] + try: + save_img_path = save_path[:-4] + '.png' + save_img_list = [] + for v in range(image_tensor.shape[0]): + save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), + (1, 2, 0)) * 0.5 + + 0.5)[:, :, ::-1] * 255.0 + save_img_list.append(save_img) + save_img = np.concatenate(save_img_list, axis=1) + Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path) + + verts, faces, _, _ = reconstruction_faster(netG, + cuda, + calib_tensor, + opt.resolution, + b_min, + b_max, + use_octree=use_octree) + + # Now Getting colors + verts_tensor = torch.from_numpy( + verts.T).unsqueeze(0).to(device=cuda).float() + verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views) + color = np.zeros(verts.shape) + interval = 10000 + for i in range(len(color) // interval): + left = i * interval + right = i * interval + interval + if i == len(color) // interval - 1: + right = -1 + netC.query(verts_tensor[:, :, left:right], calib_tensor) + rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5 + color[left:right] = rgb.T + + save_obj_mesh_with_color(save_path, verts, faces, color) + except Exception as e: + print(e) + print('Can not create marching cubes at this time.') + verts, faces, color = None, None, None + return verts, faces, color + + +def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): + """Sets the learning rate to the initial LR decayed by schedule""" + if epoch in schedule: + lr *= gamma + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + + +def compute_acc(pred, gt, thresh=0.5): + ''' + return: + IOU, precision, and recall + ''' + with torch.no_grad(): + vol_pred = pred > thresh + vol_gt = gt > thresh + + union = vol_pred | vol_gt + inter = vol_pred & vol_gt + + true_pos = inter.sum().float() + + union = union.sum().float() + if union == 0: + union = 1 + vol_pred = vol_pred.sum().float() + if vol_pred == 0: + vol_pred = 1 + vol_gt = vol_gt.sum().float() + if vol_gt == 0: + vol_gt = 1 + return true_pos / union, true_pos / vol_pred, true_pos / vol_gt + + +# def calc_metrics(opt, net, cuda, dataset, num_tests, +# resolution=128, sampled_points=1000, use_kaolin=True): +# if num_tests > len(dataset): +# num_tests = len(dataset) +# with torch.no_grad(): +# chamfer_arr, p2s_arr = [], [] +# for idx in tqdm(range(num_tests)): +# data = dataset[idx * len(dataset) // num_tests] + +# verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution) +# if verts is None: +# continue + +# mesh_gt = trimesh.load(data['mesh_path']) +# mesh_gt = mesh_gt.split(only_watertight=False) +# comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt] +# mesh_gt = mesh_gt[comp_num.index(max(comp_num))] + +# mesh_pred = trimesh.Trimesh(verts, faces) + +# gt_surface_pts, _ = trimesh.sample.sample_surface_even( +# mesh_gt, sampled_points) +# pred_surface_pts, _ = trimesh.sample.sample_surface_even( +# mesh_pred, sampled_points) + +# if use_kaolin and has_kaolin: +# kal_mesh_gt = kal.rep.TriangleMesh.from_tensors( +# torch.tensor(mesh_gt.vertices).float().to(device=cuda), +# torch.tensor(mesh_gt.faces).long().to(device=cuda)) +# kal_mesh_pred = kal.rep.TriangleMesh.from_tensors( +# torch.tensor(mesh_pred.vertices).float().to(device=cuda), +# torch.tensor(mesh_pred.faces).long().to(device=cuda)) + +# kal_distance_0 = kal.metrics.mesh.point_to_surface( +# torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt) +# kal_distance_1 = kal.metrics.mesh.point_to_surface( +# torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred) + +# dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy() +# dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy() +# else: +# try: +# _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts) +# _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts) +# except Exception as e: +# print (e) +# continue + +# chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean()) +# p2s_dist = dist_pred_gt.mean() + +# chamfer_arr.append(chamfer_dist) +# p2s_arr.append(p2s_dist) + +# return np.average(chamfer_arr), np.average(p2s_arr) + + +def calc_error(opt, net, cuda, dataset, num_tests): + if num_tests > len(dataset): + num_tests = len(dataset) + with torch.no_grad(): + erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], [] + for idx in tqdm(range(num_tests)): + data = dataset[idx * len(dataset) // num_tests] + # retrieve the data + image_tensor = data['img'].to(device=cuda) + calib_tensor = data['calib'].to(device=cuda) + sample_tensor = data['samples'].to(device=cuda).unsqueeze(0) + if opt.num_views > 1: + sample_tensor = reshape_sample_tensor(sample_tensor, + opt.num_views) + label_tensor = data['labels'].to(device=cuda).unsqueeze(0) + + res, error = net.forward(image_tensor, + sample_tensor, + calib_tensor, + labels=label_tensor) + + IOU, prec, recall = compute_acc(res, label_tensor) + + # print( + # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}' + # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item())) + erorr_arr.append(error.item()) + IOU_arr.append(IOU.item()) + prec_arr.append(prec.item()) + recall_arr.append(recall.item()) + + return np.average(erorr_arr), np.average(IOU_arr), np.average( + prec_arr), np.average(recall_arr) + + +def calc_error_color(opt, netG, netC, cuda, dataset, num_tests): + if num_tests > len(dataset): + num_tests = len(dataset) + with torch.no_grad(): + error_color_arr = [] + + for idx in tqdm(range(num_tests)): + data = dataset[idx * len(dataset) // num_tests] + # retrieve the data + image_tensor = data['img'].to(device=cuda) + calib_tensor = data['calib'].to(device=cuda) + color_sample_tensor = data['color_samples'].to( + device=cuda).unsqueeze(0) + + if opt.num_views > 1: + color_sample_tensor = reshape_sample_tensor( + color_sample_tensor, opt.num_views) + + rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0) + + netG.filter(image_tensor) + _, errorC = netC.forward(image_tensor, + netG.get_im_feat(), + color_sample_tensor, + calib_tensor, + labels=rgb_tensor) + + # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}' + # .format(idx, num_tests, errorG.item(), errorC.item())) + error_color_arr.append(errorC.item()) + + return np.average(error_color_arr) + + +# pytorch lightning training related fucntions + + +def query_func(opt, netG, features, points, proj_matrix=None): + ''' + - points: size of (bz, N, 3) + - proj_matrix: size of (bz, 4, 4) + return: size of (bz, 1, N) + ''' + assert len(points) == 1 + samples = points.repeat(opt.num_views, 1, 1) + samples = samples.permute(0, 2, 1) # [bz, 3, N] + + # view specific query + if proj_matrix is not None: + samples = orthogonal(samples, proj_matrix) + + calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples) + + preds = netG.query(features=features, + points=samples, + calibs=calib_tensor) + + if type(preds) is list: + preds = preds[0] + + return preds + + +def isin(ar1, ar2): + return (ar1[..., None] == ar2).any(-1) + + +def in1d(ar1, ar2): + mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool) + mask[ar2.unique()] = True + return mask[ar1] + + +def get_visibility(xy, z, faces): + """get the visibility of vertices + + Args: + xy (torch.tensor): [N,2] + z (torch.tensor): [N,1] + faces (torch.tensor): [N,3] + size (int): resolution of rendered image + """ + + xyz = torch.cat((xy, -z), dim=1) + xyz = (xyz + 1.0) / 2.0 + faces = faces.long() + + rasterizer = Pytorch3dRasterizer(image_size=2**12) + meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...]) + raster_settings = rasterizer.raster_settings + + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, + ) + + vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :]) + vis_mask = torch.zeros(size=(z.shape[0], 1)) + vis_mask[vis_vertices_id] = 1.0 + + # print("------------------------\n") + # print(f"keep points : {vis_mask.sum()/len(vis_mask)}") + + return vis_mask + + +def batch_mean(res, key): + # recursive mean for multilevel dicts + return torch.stack([ + x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res + ]).mean() + + +def tf_log_convert(log_dict): + new_log_dict = log_dict.copy() + for k, v in log_dict.items(): + new_log_dict[k.replace("_", "/")] = v + del new_log_dict[k] + + return new_log_dict + + +def bar_log_convert(log_dict, name=None, rot=None): + from decimal import Decimal + + new_log_dict = {} + + if name is not None: + new_log_dict['name'] = name[0] + if rot is not None: + new_log_dict['rot'] = rot[0] + + for k, v in log_dict.items(): + color = "yellow" + if 'loss' in k: + color = "red" + k = k.replace("loss", "L") + elif 'acc' in k: + color = "green" + k = k.replace("acc", "A") + elif 'iou' in k: + color = "green" + k = k.replace("iou", "I") + elif 'prec' in k: + color = "green" + k = k.replace("prec", "P") + elif 'recall' in k: + color = "green" + k = k.replace("recall", "R") + + if 'lr' not in k: + new_log_dict[colored(k.split("_")[1], + color)] = colored(f"{v:.3f}", color) + else: + new_log_dict[colored(k.split("_")[1], + color)] = colored(f"{Decimal(str(v)):.1E}", + color) + + if 'loss' in new_log_dict.keys(): + del new_log_dict['loss'] + + return new_log_dict + + +def accumulate(outputs, rot_num, split): + + hparam_log_dict = {} + + metrics = outputs[0].keys() + datasets = split.keys() + + for dataset in datasets: + for metric in metrics: + keyword = f"{dataset}-{metric}" + if keyword not in hparam_log_dict.keys(): + hparam_log_dict[keyword] = 0 + for idx in range(split[dataset][0] * rot_num, + split[dataset][1] * rot_num): + hparam_log_dict[keyword] += outputs[idx][metric] + hparam_log_dict[keyword] /= (split[dataset][1] - + split[dataset][0]) * rot_num + + print(colored(hparam_log_dict, "green")) + + return hparam_log_dict + + +def calc_error_N(outputs, targets): + """calculate the error of normal (IGR) + + Args: + outputs (torch.tensor): [B, 3, N] + target (torch.tensor): [B, N, 3] + + # manifold loss and grad_loss in IGR paper + grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() + normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean() + + Returns: + torch.tensor: error of valid normals on the surface + """ + # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3)) + outputs = -outputs.permute(0, 2, 1).reshape(-1, 1) + targets = targets.reshape(-1, 3)[:, 2:3] + with_normals = targets.sum(dim=1).abs() > 0.0 + + # eikonal loss + grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean() + # normals loss + normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean() + + return grad_loss * 0.0 + normal_loss + + +def calc_knn_acc(preds, carn_verts, labels, pick_num): + """calculate knn accuracy + + Args: + preds (torch.tensor): [B, 3, N] + carn_verts (torch.tensor): [SMPLX_V_num, 3] + labels (torch.tensor): [B, N_knn, N] + """ + N_knn_full = labels.shape[1] + preds = preds.permute(0, 2, 1).reshape(-1, 3) + labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn] + labels = labels[:, :pick_num] + + dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num] + knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn] + cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0] + bool_col = torch.zeros_like(cat_mat)[:, 0] + for i in range(pick_num * 2 - 1): + bool_col += cat_mat[:, i] == cat_mat[:, i + 1] + acc = (bool_col > 0).sum() / len(bool_col) + + return acc + + +def calc_acc_seg(output, target, num_multiseg): + from pytorch_lightning.metrics import Accuracy + return Accuracy()(output.reshape(-1, num_multiseg).cpu(), + target.flatten().cpu()) + + +def add_watermark(imgs, titles): + + # Write some Text + + font = cv2.FONT_HERSHEY_SIMPLEX + bottomLeftCornerOfText = (350, 50) + bottomRightCornerOfText = (800, 50) + fontScale = 1 + fontColor = (1.0, 1.0, 1.0) + lineType = 2 + + for i in range(len(imgs)): + + title = titles[i + 1] + cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale, + fontColor, lineType) + + if i == 0: + cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText, + font, fontScale, fontColor, lineType) + + result = np.concatenate(imgs, axis=0).transpose(2, 0, 1) + + return result + + +def make_test_gif(img_dir): + + if img_dir is not None and len(os.listdir(img_dir)) > 0: + for dataset in os.listdir(img_dir): + for subject in sorted(os.listdir(osp.join(img_dir, dataset))): + img_lst = [] + im1 = None + for file in sorted( + os.listdir(osp.join(img_dir, dataset, subject))): + if file[-3:] not in ['obj', 'gif']: + img_path = os.path.join(img_dir, dataset, subject, + file) + if im1 == None: + im1 = Image.open(img_path) + else: + img_lst.append(Image.open(img_path)) + + print(os.path.join(img_dir, dataset, subject, "out.gif")) + im1.save(os.path.join(img_dir, dataset, subject, "out.gif"), + save_all=True, + append_images=img_lst, + duration=500, + loop=0) + + +def export_cfg(logger, cfg): + + cfg_export_file = osp.join(logger.save_dir, logger.name, + f"version_{logger.version}", "cfg.yaml") + + if not osp.exists(cfg_export_file): + os.makedirs(osp.dirname(cfg_export_file), exist_ok=True) + with open(cfg_export_file, "w+") as file: + _ = yaml.dump(cfg, file) diff --git a/lib/dataloader_demo.py b/lib/dataloader_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a752876d8c3e7026cfb562670cdeeb4b9bc94216 --- /dev/null +++ b/lib/dataloader_demo.py @@ -0,0 +1,80 @@ + +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) + +import argparse +from lib.common.config import get_cfg_defaults +from lib.dataset.PIFuDataset import PIFuDataset + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--show', + action='store_true', + help='vis sampler 3D') + parser.add_argument('-s', + '--speed', + action='store_true', + help='vis sampler 3D') + parser.add_argument('-l', + '--list', + action='store_true', + help='vis sampler 3D') + parser.add_argument('-c', + '--config', + default='./configs/train/icon-filter.yaml', + help='vis sampler 3D') + parser.add_argument('-d', '--dataset', default='thuman') + args_c = parser.parse_args() + + args = get_cfg_defaults() + args.merge_from_file(args_c.config) + print(args_c.dataset) + if args_c.dataset == 'cape': + + # for cape test set + cfg_test_mode = [ + "test_mode", True, "dataset.types", ["cape"], "dataset.scales", + [100.0], "dataset.rotation_num", 3,"root","./data/" + ] + args.merge_from_list(cfg_test_mode) + + # dataset sampler + dataset = PIFuDataset(args, split='test', vis=args_c.show) + print(f"Number of subjects :{len(dataset.subject_list)}") + data_dict = dataset[1] + + if args_c.list: + for k in data_dict.keys(): + if not hasattr(data_dict[k], "shape"): + print(f"{k}: {data_dict[k]}") + else: + print(f"{k}: {data_dict[k].shape}") + + if args_c.show: + # for item in dataset: + item = dataset[0] + dataset.visualize_sampling3D(item, mode='cmap') + # dataset.visualize_sampling3D(item, mode='occ') + # dataset.visualize_sampling3D(item, mode='normal') + # dataset.visualize_sampling3D(item, mode='sdf') + # dataset.visualize_sampling3D(item, mode='vis') + + if args_c.speed: + # original: 2 it/s + # smpl online compute: 2 it/s + # normal online compute: 1.5 it/s + from tqdm import tqdm + for item in tqdm(dataset): + # pass + for k in item.keys(): + if 'voxel' in k: + if not hasattr(item[k], "shape"): + print(f"{k}: {item[k]}") + else: + print(f"{k}: {item[k].shape}") + print("--------------------") diff --git a/lib/dataset/ECON_Evaluator.py b/lib/dataset/ECON_Evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..899a03ac11115a89f0f9cfa8a5460e270cf90107 --- /dev/null +++ b/lib/dataset/ECON_Evaluator.py @@ -0,0 +1,360 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.dataset.mesh_util import projection +from lib.common.render import Render +import numpy as np +import torch +from torchvision.utils import make_grid +from pytorch3d import _C +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from pytorch3d.structures import Pointclouds +from PIL import Image + +from typing import Tuple +from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals +from pytorch3d.ops.packed_to_padded import packed_to_padded + +_DEFAULT_MIN_TRIANGLE_AREA: float = 5e-3 + + +# PointFaceDistance +class _PointFaceDistance(Function): + """ + Torch autograd Function wrapper PointFaceDistance Cuda implementation + """ + @staticmethod + def forward( + ctx, + points, + points_first_idx, + tris, + tris_first_idx, + max_points, + min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA, + ): + """ + Args: + ctx: Context object used to calculate gradients. + points: FloatTensor of shape `(P, 3)` + points_first_idx: LongTensor of shape `(N,)` indicating the first point + index in each example in the batch + tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th + triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])` + tris_first_idx: LongTensor of shape `(N,)` indicating the first face + index in each example in the batch + max_points: Scalar equal to maximum number of points in the batch + min_triangle_area: (float, defaulted) Triangles of area less than this + will be treated as points/lines. + Returns: + dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared + euclidean distance of `p`-th point to the closest triangular face + in the corresponding example in the batch + idxs: LongTensor of shape `(P,)` indicating the closest triangular face + in the corresponding example in the batch. + + `dists[p]` is + `d(points[p], tris[idxs[p], 0], tris[idxs[p], 1], tris[idxs[p], 2])` + where `d(u, v0, v1, v2)` is the distance of point `u` from the triangular + face `(v0, v1, v2)` + + """ + dists, idxs = _C.point_face_dist_forward( + points, + points_first_idx, + tris, + tris_first_idx, + max_points, + min_triangle_area, + ) + ctx.save_for_backward(points, tris, idxs) + ctx.min_triangle_area = min_triangle_area + return dists, idxs + + @staticmethod + @once_differentiable + def backward(ctx, grad_dists): + grad_dists = grad_dists.contiguous() + points, tris, idxs = ctx.saved_tensors + min_triangle_area = ctx.min_triangle_area + grad_points, grad_tris = _C.point_face_dist_backward( + points, tris, idxs, grad_dists, min_triangle_area + ) + return grad_points, None, grad_tris, None, None, None + + +def _rand_barycentric_coords( + size1, size2, dtype: torch.dtype, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helper function to generate random barycentric coordinates which are uniformly + distributed over a triangle. + + Args: + size1, size2: The number of coordinates generated will be size1*size2. + Output tensors will each be of shape (size1, size2). + dtype: Datatype to generate. + device: A torch.device object on which the outputs will be allocated. + + Returns: + w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric + coordinates + """ + uv = torch.rand(2, size1, size2, dtype=dtype, device=device) + u, v = uv[0], uv[1] + u_sqrt = u.sqrt() + w0 = 1.0 - u_sqrt + w1 = u_sqrt * (1.0 - v) + w2 = u_sqrt * v + w = torch.cat([w0[..., None], w1[..., None], w2[..., None]], dim=2) + + return w + + +def sample_points_from_meshes(meshes, num_samples: int = 10000): + """ + Convert a batch of meshes to a batch of pointclouds by uniformly sampling + points on the surface of the mesh with probability proportional to the + face area. + + Args: + meshes: A Meshes object with a batch of N meshes. + num_samples: Integer giving the number of point samples per mesh. + return_normals: If True, return normals for the sampled points. + return_textures: If True, return textures for the sampled points. + + Returns: + 3-element tuple containing + + - **samples**: FloatTensor of shape (N, num_samples, 3) giving the + coordinates of sampled points for each mesh in the batch. For empty + meshes the corresponding row in the samples array will be filled with 0. + - **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector + to each sampled point. Only returned if return_normals is True. + For empty meshes the corresponding row in the normals array will + be filled with 0. + - **textures**: FloatTensor of shape (N, num_samples, C) giving a C-dimensional + texture vector to each sampled point. Only returned if return_textures is True. + For empty meshes the corresponding row in the textures array will + be filled with 0. + + Note that in a future releases, we will replace the 3-element tuple output + with a `Pointclouds` datastructure, as follows + + .. code-block:: python + + Pointclouds(samples, normals=normals, features=textures) + """ + if meshes.isempty(): + raise ValueError("Meshes are empty.") + + verts = meshes.verts_packed() + if not torch.isfinite(verts).all(): + raise ValueError("Meshes contain nan or inf.") + + faces = meshes.faces_packed() + mesh_to_face = meshes.mesh_to_faces_packed_first_idx() + num_meshes = len(meshes) + num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes. + + # Initialize samples tensor with fill value 0 for empty meshes. + samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device) + + # Only compute samples for non empty meshes + with torch.no_grad(): + areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero. + max_faces = meshes.num_faces_per_mesh().max().item() + areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F) + + # TODO (gkioxari) Confirm multinomial bug is not present with real data. + samples_face_idxs = areas_padded.multinomial( + num_samples, replacement=True + ) # (N, num_samples) + samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1) + + # Randomly generate barycentric coords. + # w (N, num_samples, 3) + # sample_face_idxs (N, num_samples) + # samples_verts (N, num_samples, 3, 3) + + samples_bw = _rand_barycentric_coords(num_valid_meshes, num_samples, verts.dtype, verts.device) + sample_verts = verts[faces][samples_face_idxs] + samples[meshes.valid] = (sample_verts * samples_bw[..., None]).sum(dim=-2) + + return samples, samples_face_idxs, samples_bw + + +def econ_point_mesh_distance(meshes, pcls, weighted=True): + + if len(meshes) != len(pcls): + raise ValueError("meshes and pointclouds must be equal sized batches") + + # packed representation for pointclouds + points = pcls.points_packed() # (P, 3) + points_first_idx = pcls.cloud_to_packed_first_idx() + max_points = pcls.num_points_per_cloud().max().item() + + # packed representation for faces + verts_packed = meshes.verts_packed() + faces_packed = meshes.faces_packed() + tris = verts_packed[faces_packed] # (T, 3, 3) + tris_first_idx = meshes.mesh_to_faces_packed_first_idx() + + # point to face distance: shape (P,) + point_to_face, idxs = _PointFaceDistance.apply( + points, points_first_idx, tris, tris_first_idx, max_points, 5e-3 + ) + + if weighted: + # weight each example by the inverse of number of points in the example + point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) + num_points_per_cloud = pcls.num_points_per_cloud() # (N,) + weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) + weights_p = 1.0 / weights_p.float() + point_to_face = torch.sqrt(point_to_face) * weights_p + + return point_to_face, idxs + + +class Evaluator: + def __init__(self, device): + + self.render = Render(size=512, device=device) + self.device = device + + def set_mesh(self, result_dict, scale=True): + + for k, v in result_dict.items(): + setattr(self, k, v) + if scale: + self.verts_pr -= self.recon_size / 2.0 + self.verts_pr /= self.recon_size / 2.0 + self.verts_gt = projection(self.verts_gt, self.calib) + self.verts_gt[:, 1] *= -1 + + self.render.load_meshes(self.verts_pr, self.faces_pr) + self.src_mesh = self.render.meshes + self.render.load_meshes(self.verts_gt, self.faces_gt) + self.tgt_mesh = self.render.meshes + + def calculate_normal_consist(self, normal_path): + + self.render.meshes = self.src_mesh + src_normal_imgs = self.render.get_image(cam_type="all", bg="black") + self.render.meshes = self.tgt_mesh + tgt_normal_imgs = self.render.get_image(cam_type="all", bg="black") + error_list = [] + if len(src_normal_imgs)>4: + # for i in range(len(src_normal_imgs)): + src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=6,padding=1) # [0,1] + tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=6,padding=1) # [0,1] + # src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + # tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) + tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) + + src_norm[src_norm == 0.0] = 1.0 + tgt_norm[tgt_norm == 0.0] = 1.0 + + src_normal_arr /= src_norm + tgt_normal_arr /= tgt_norm + + # sim_mask = self.get_laplacian_2d(tgt_normal_arr).to(self.device) + + src_normal_arr = (src_normal_arr + 1.0) * 0.5 + tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 + + error = (( + (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 + + #error_list.append(error) + + normal_img = Image.fromarray( + (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( + 1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) + normal_img.save(normal_path) + + return error + else: + src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) + tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) + + src_norm[src_norm == 0.0] = 1.0 + tgt_norm[tgt_norm == 0.0] = 1.0 + + src_normal_arr /= src_norm + tgt_normal_arr /= tgt_norm + + # sim_mask = self.get_laplacian_2d(tgt_normal_arr).to(self.device) + + src_normal_arr = (src_normal_arr + 1.0) * 0.5 + tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 + + error = (( + (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 + return error + + def calculate_chamfer_p2s(self, num_samples=1000): + + samples_tgt, _, _ = sample_points_from_meshes(self.tgt_mesh, num_samples) + samples_src, _, _ = sample_points_from_meshes(self.src_mesh, num_samples) + + tgt_points = Pointclouds(samples_tgt) + src_points = Pointclouds(samples_src) + + p2s_dist = point_mesh_distance(self.src_mesh, tgt_points)[0].sum() * 100.0 + + chamfer_dist = ( + point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist + ) * 0.5 + + return chamfer_dist, p2s_dist + + def calc_acc(self, output, target, thres=0.5, use_sdf=False): + + # # remove the surface points with thres + # non_surf_ids = (target != thres) + # output = output[non_surf_ids] + # target = target[non_surf_ids] + + with torch.no_grad(): + output = output.masked_fill(output < thres, 0.0) + output = output.masked_fill(output > thres, 1.0) + + if use_sdf: + target = target.masked_fill(target < thres, 0.0) + target = target.masked_fill(target > thres, 1.0) + + acc = output.eq(target).float().mean() + + # iou, precison, recall + output = output > thres + target = target > thres + + union = output | target + inter = output & target + + _max = torch.tensor(1.0).to(output.device) + + union = max(union.sum().float(), _max) + true_pos = max(inter.sum().float(), _max) + vol_pred = max(output.sum().float(), _max) + vol_gt = max(target.sum().float(), _max) + + return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt diff --git a/lib/dataset/Evaluator.py b/lib/dataset/Evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f259f93fcdfd8838bc4d1560ece34561eada50 --- /dev/null +++ b/lib/dataset/Evaluator.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.dataset.mesh_util import projection +from lib.common.render import Render +import numpy as np +import torch +import os.path as osp +from torchvision.utils import make_grid +from pytorch3d.io import IO +from pytorch3d.ops import sample_points_from_meshes +from pytorch3d.loss.point_mesh_distance import _PointFaceDistance +from pytorch3d.structures import Pointclouds +from PIL import Image + + +def point_mesh_distance(meshes, pcls): + + if len(meshes) != len(pcls): + raise ValueError("meshes and pointclouds must be equal sized batches") + N = len(meshes) + + # packed representation for pointclouds + points = pcls.points_packed() # (P, 3) + points_first_idx = pcls.cloud_to_packed_first_idx() + max_points = pcls.num_points_per_cloud().max().item() + + # packed representation for faces + verts_packed = meshes.verts_packed() + faces_packed = meshes.faces_packed() + tris = verts_packed[faces_packed] # (T, 3, 3) + tris_first_idx = meshes.mesh_to_faces_packed_first_idx() + + # point to face distance: shape (P,) + point_to_face = _PointFaceDistance.apply(points, points_first_idx, tris, + tris_first_idx, max_points, 5e-3) + + # weight each example by the inverse of number of points in the example + point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) + num_points_per_cloud = pcls.num_points_per_cloud() # (N,) + weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) + weights_p = 1.0 / weights_p.float() + point_to_face = torch.sqrt(point_to_face) * weights_p + point_dist = point_to_face.sum() / N + + return point_dist + + +class Evaluator: + + def __init__(self, device): + + self.render = Render(size=512, device=device) + self.device = device + + def set_mesh(self, result_dict): + + for k, v in result_dict.items(): + setattr(self, k, v) + + self.verts_pr -= self.recon_size / 2.0 + self.verts_pr /= self.recon_size / 2.0 + self.verts_gt = projection(self.verts_gt, self.calib) + self.verts_gt[:, 1] *= -1 + + self.src_mesh = self.render.VF2Mesh(self.verts_pr, self.faces_pr) + self.tgt_mesh = self.render.VF2Mesh(self.verts_gt, self.faces_gt) + + def calculate_normal_consist(self, normal_path): + + self.render.meshes = self.src_mesh + src_normal_imgs = self.render.get_rgb_image(cam_ids=[ 0,1,2, 3], + bg='black') + self.render.meshes = self.tgt_mesh + tgt_normal_imgs = self.render.get_rgb_image(cam_ids=[0,1,2, 3], + bg='black') + + src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) + tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) + + src_norm[src_norm == 0.0] = 1.0 + tgt_norm[tgt_norm == 0.0] = 1.0 + + src_normal_arr /= src_norm + tgt_normal_arr /= tgt_norm + + src_normal_arr = (src_normal_arr + 1.0) * 0.5 + tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 + error = (( + (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 + #print('normal error:', error) + + normal_img = Image.fromarray( + (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( + 1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) + normal_img.save(normal_path) + + error_list = [] + if len(src_normal_imgs) > 4: + for i in range(len(src_normal_imgs)): + src_normal_arr = src_normal_imgs[i] # Get each source normal image + tgt_normal_arr = tgt_normal_imgs[i] # Get corresponding target normal image + + src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) + tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) + + src_norm[src_norm == 0.0] = 1.0 + tgt_norm[tgt_norm == 0.0] = 1.0 + + src_normal_arr /= src_norm + tgt_normal_arr /= tgt_norm + + src_normal_arr = (src_normal_arr + 1.0) * 0.5 + tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 + + error = ((src_normal_arr - tgt_normal_arr) ** 2).sum(dim=0).mean() * 4.0 + error_list.append(error) + + + return error_list + else: + src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] + src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) + tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) + + src_norm[src_norm == 0.0] = 1.0 + tgt_norm[tgt_norm == 0.0] = 1.0 + + src_normal_arr /= src_norm + tgt_normal_arr /= tgt_norm + + # sim_mask = self.get_laplacian_2d(tgt_normal_arr).to(self.device) + + src_normal_arr = (src_normal_arr + 1.0) * 0.5 + tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 + + error = (( + (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 + #print('normal error:', error) + return error + + + def export_mesh(self, dir, name): + + IO().save_mesh(self.src_mesh, osp.join(dir, f"{name}_src.obj")) + IO().save_mesh(self.tgt_mesh, osp.join(dir, f"{name}_tgt.obj")) + + def calculate_chamfer_p2s(self, num_samples=1000): + + tgt_points = Pointclouds( + sample_points_from_meshes(self.tgt_mesh, num_samples)) + src_points = Pointclouds( + sample_points_from_meshes(self.src_mesh, num_samples)) + p2s_dist = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 + chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points) * 100.0 + + p2s_dist) * 0.5 + + return chamfer_dist, p2s_dist + + def calc_acc(self, output, target, thres=0.5, use_sdf=False): + + # # remove the surface points with thres + # non_surf_ids = (target != thres) + # output = output[non_surf_ids] + # target = target[non_surf_ids] + + with torch.no_grad(): + output = output.masked_fill(output < thres, 0.0) + output = output.masked_fill(output > thres, 1.0) + + if use_sdf: + target = target.masked_fill(target < thres, 0.0) + target = target.masked_fill(target > thres, 1.0) + + acc = output.eq(target).float().mean() + + # iou, precison, recall + output = output > thres + target = target > thres + + union = output | target + inter = output & target + + _max = torch.tensor(1.0).to(output.device) + + union = max(union.sum().float(), _max) + true_pos = max(inter.sum().float(), _max) + vol_pred = max(output.sum().float(), _max) + vol_gt = max(target.sum().float(), _max) + + return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt diff --git a/lib/dataset/NormalDataset.py b/lib/dataset/NormalDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2db7690662d2b4c60859dba66da569dcacc816 --- /dev/null +++ b/lib/dataset/NormalDataset.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import random +import os.path as osp +import numpy as np +from PIL import Image +from termcolor import colored +import torchvision.transforms as transforms + + +class NormalDataset(): + + def __init__(self, cfg, split='train'): + + self.split = split + self.root = cfg.root + self.bsize = cfg.batch_size + self.overfit = cfg.overfit + + self.opt = cfg.dataset + self.datasets = self.opt.types + self.input_size = self.opt.input_size + self.scales = self.opt.scales + + # input data types and dimensions + self.in_nml = [item[0] for item in cfg.net.in_nml] + self.in_nml_dim = [item[1] for item in cfg.net.in_nml] + self.in_total = self.in_nml + ['render_B', 'render_L'] + self.in_total_dim = self.in_nml_dim + [3, 3] + + if self.split != 'train': + self.rotations = range(0, 360, 120) + else: + self.rotations = np.arange(0, 360, 360 // + self.opt.rotation_num).astype(np.int) + + self.datasets_dict = {} + + for dataset_id, dataset in enumerate(self.datasets): + + dataset_dir = osp.join(self.root, dataset) + + self.datasets_dict[dataset] = { + "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), + dtype=str), + "scale": self.scales[dataset_id] + } + + self.subject_list = self.get_subject_list(split) + + # PIL to tensor + self.image_to_tensor = transforms.Compose([ + transforms.Resize(self.input_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + # PIL to tensor + self.mask_to_tensor = transforms.Compose([ + transforms.Resize(self.input_size), + transforms.ToTensor(), + transforms.Normalize((0.0, ), (1.0, )) + ]) + + def get_subject_list(self, split): + + subject_list = [] + + for dataset in self.datasets: + + split_txt = osp.join(self.root, dataset, f'{split}.txt') + + if osp.exists(split_txt): + print(f"load from {split_txt}") + subject_list += np.loadtxt(split_txt, dtype=str).tolist() + else: + full_txt = osp.join(self.root, dataset, 'all.txt') + print(f"split {full_txt} into train/val/test") + + full_lst = np.loadtxt(full_txt, dtype=str) + full_lst = [dataset + "/" + item for item in full_lst] + [train_lst, test_lst, + val_lst] = np.split(full_lst, [ + 500, + 500 + 5, + ]) + + np.savetxt(full_txt.replace("all", "train"), + train_lst, + fmt="%s") + np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s") + np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s") + + print(f"load from {split_txt}") + subject_list += np.loadtxt(split_txt, dtype=str).tolist() + + if self.split != 'test': + subject_list += subject_list[:self.bsize - + len(subject_list) % self.bsize] + print(colored(f"total: {len(subject_list)}", "yellow")) + random.shuffle(subject_list) + + # subject_list = ["thuman2/0008"] + return subject_list + + def __len__(self): + return len(self.subject_list) * len(self.rotations) + + def __getitem__(self, index): + + # only pick the first data if overfitting + if self.overfit: + index = 0 + + rid = index % len(self.rotations) + mid = index // len(self.rotations) + + rotation = self.rotations[rid] + subject = self.subject_list[mid].split("/")[1] + dataset = self.subject_list[mid].split("/")[0] + render_folder = "/".join( + [dataset + f"_{self.opt.rotation_num}views", subject]) + + # setup paths + data_dict = { + 'dataset': + dataset, + 'subject': + subject, + 'rotation': + rotation, + 'scale': + self.datasets_dict[dataset]["scale"], + 'image_path': + osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png') + } + + # image/normal/depth loader + for name, channel in zip(self.in_total, self.in_total_dim): + + if f'{name}_path' not in data_dict.keys(): + data_dict.update({ + f'{name}_path': + osp.join(self.root, render_folder, name, + f'{rotation:03d}.png') + }) + + # tensor update + data_dict.update({ + name: + self.imagepath2tensor(data_dict[f'{name}_path'], + channel, + inv=False) + }) + + path_keys = [ + key for key in data_dict.keys() if '_path' in key or '_dir' in key + ] + + for key in path_keys: + del data_dict[key] + + return data_dict + + def imagepath2tensor(self, path, channel=3, inv=False): + + rgba = Image.open(path).convert('RGBA') + mask = rgba.split()[-1] + image = rgba.convert('RGB') + image = self.image_to_tensor(image) + mask = self.mask_to_tensor(mask) + image = (image * mask)[:channel] + + return (image * (0.5 - inv) * 2.0).float() diff --git a/lib/dataset/NormalModule.py b/lib/dataset/NormalModule.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed1b1786624ddc1362b800d6ccf622779ea94bb --- /dev/null +++ b/lib/dataset/NormalModule.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np +from torch.utils.data import DataLoader +from .NormalDataset import NormalDataset + +# pytorch lightning related libs +import pytorch_lightning as pl + + +class NormalModule(pl.LightningDataModule): + + def __init__(self, cfg): + super(NormalModule, self).__init__() + self.cfg = cfg + self.overfit = self.cfg.overfit + + if self.overfit: + self.batch_size = 1 + else: + self.batch_size = self.cfg.batch_size + + self.data_size = {} + + def prepare_data(self): + + pass + + @staticmethod + def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + def setup(self, stage): + + if stage == 'fit' or stage is None: + self.train_dataset = NormalDataset(cfg=self.cfg, split="train") + self.val_dataset = NormalDataset(cfg=self.cfg, split="val") + self.data_size = { + 'train': len(self.train_dataset), + 'val': len(self.val_dataset) + } + + if stage == 'test' or stage is None: + self.test_dataset = NormalDataset(cfg=self.cfg, split="test") + + def train_dataloader(self): + + train_data_loader = DataLoader(self.train_dataset, + batch_size=self.batch_size, + shuffle=not self.overfit, + num_workers=self.cfg.num_threads, + pin_memory=True, + worker_init_fn=self.worker_init_fn) + + return train_data_loader + + def val_dataloader(self): + + if self.overfit: + current_dataset = self.train_dataset + else: + current_dataset = self.val_dataset + + val_data_loader = DataLoader(current_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.cfg.num_threads, + pin_memory=True) + + return val_data_loader + + def test_dataloader(self): + + test_data_loader = DataLoader(self.test_dataset, + batch_size=1, + shuffle=False, + num_workers=self.cfg.num_threads, + pin_memory=True) + + return test_data_loader diff --git a/lib/dataset/PIFuDataModule.py b/lib/dataset/PIFuDataModule.py new file mode 100644 index 0000000000000000000000000000000000000000..f9546d68d39e6118354a59d109674cced380d4f9 --- /dev/null +++ b/lib/dataset/PIFuDataModule.py @@ -0,0 +1,77 @@ +import numpy as np +from torch.utils.data import DataLoader +from .PIFuDataset import PIFuDataset +import pytorch_lightning as pl + + +class PIFuDataModule(pl.LightningDataModule): + + def __init__(self, cfg): + super(PIFuDataModule, self).__init__() + self.cfg = cfg + self.overfit = self.cfg.overfit + + if self.overfit: + self.batch_size = 1 + else: + self.batch_size = self.cfg.batch_size + + self.data_size = {} + + def prepare_data(self): + + pass + + @staticmethod + def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + def setup(self, stage): + + if stage == 'fit': + self.train_dataset = PIFuDataset(cfg=self.cfg, split="train") + self.val_dataset = PIFuDataset(cfg=self.cfg, split="val") + self.data_size = { + 'train': len(self.train_dataset), + 'val': len(self.val_dataset) + } + + if stage == 'test': + self.test_dataset = PIFuDataset(cfg=self.cfg, split="test") + + def train_dataloader(self): + + train_data_loader = DataLoader(self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.cfg.num_threads, + pin_memory=True, + worker_init_fn=self.worker_init_fn) + + return train_data_loader + + def val_dataloader(self): + + if self.overfit: + current_dataset = self.train_dataset + else: + current_dataset = self.val_dataset + + val_data_loader = DataLoader(current_dataset, + batch_size=1, + shuffle=False, + num_workers=self.cfg.num_threads, + pin_memory=True, + worker_init_fn=self.worker_init_fn) + + return val_data_loader + + def test_dataloader(self): + + test_data_loader = DataLoader(self.test_dataset, + batch_size=1, + shuffle=False, + num_workers=self.cfg.num_threads, + pin_memory=True) + + return test_data_loader diff --git a/lib/dataset/PIFuDataset.py b/lib/dataset/PIFuDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1cfc93f4b1626a6d9541930238cac9a4f798e4 --- /dev/null +++ b/lib/dataset/PIFuDataset.py @@ -0,0 +1,835 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.renderer.mesh import load_fit_body, compute_normal_batch +from lib.dataset.body_model import TetraSMPLModel +from lib.common.render import Render +from lib.dataset.mesh_util import * +from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis +from lib.net.nerf_util import sample_ray_h36m, get_wsampling_points +from termcolor import colored +import os.path as osp +import numpy as np +from PIL import Image +import random +import os, cv2 +import trimesh +import torch +import vedo +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +import trimesh +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" + +cape_gender = { + "male": [ + '00032', '00096', '00122', '00127', '00145', '00215', '02474', '03284', + '03375', '03394' + ], + "female": ['00134', '00159', '03223', '03331', '03383'] +} + + + +class PIFuDataset(): + + def __init__(self, cfg, split='train', vis=False): + + self.split = split + self.root = cfg.root + self.bsize = cfg.batch_size + self.overfit = cfg.overfit + + # for debug, only used in visualize_sampling3D + self.vis = vis + + self.opt = cfg.dataset + self.datasets = self.opt.types + self.input_size = self.opt.input_size + self.scales = self.opt.scales + self.workers = cfg.num_threads + self.prior_type = cfg.net.prior_type + + self.noise_type = self.opt.noise_type + self.noise_scale = self.opt.noise_scale + + noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21] + + self.noise_smpl_idx = [] + self.noise_smplx_idx = [] + + for idx in noise_joints: + self.noise_smpl_idx.append(idx * 3) + self.noise_smpl_idx.append(idx * 3 + 1) + self.noise_smpl_idx.append(idx * 3 + 2) + + self.noise_smplx_idx.append((idx - 1) * 3) + self.noise_smplx_idx.append((idx - 1) * 3 + 1) + self.noise_smplx_idx.append((idx - 1) * 3 + 2) + + self.use_sdf = cfg.sdf + self.sdf_clip = cfg.sdf_clip + + # [(feat_name, channel_num),...] + self.in_geo = [item[0] for item in cfg.net.in_geo] + self.in_nml = [item[0] for item in cfg.net.in_nml] + + self.in_geo_dim = [item[1] for item in cfg.net.in_geo] + self.in_nml_dim = [item[1] for item in cfg.net.in_nml] + + self.in_total = self.in_geo + self.in_nml + self.in_total_dim = self.in_geo_dim + self.in_nml_dim + + self.base_keys = ["smpl_verts", "smpl_faces"] + self.feat_names = cfg.net.smpl_feats + + self.feat_keys = self.base_keys + [f"smpl_{feat_name}" for feat_name in self.feat_names] + + if self.split == 'train': + self.rotations = np.arange(0, 360, 360 / self.opt.rotation_num).astype(np.int32) + else: + self.rotations = range(0, 360, 120) + + self.datasets_dict = {} + + for dataset_id, dataset in enumerate(self.datasets): + + mesh_dir = None + smplx_dir = None + + dataset_dir = osp.join(self.root, dataset) + + mesh_dir = osp.join(dataset_dir, "scans") + smplx_dir = osp.join(dataset_dir, "smplx") + smpl_dir = osp.join(dataset_dir, "smpl") + + self.datasets_dict[dataset] = { + "smplx_dir": smplx_dir, + "smpl_dir": smpl_dir, + "mesh_dir": mesh_dir, + "scale": self.scales[dataset_id] + } + + if split == 'train': + self.datasets_dict[dataset].update( + {"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str)}) + else: + self.datasets_dict[dataset].update( + {"subjects": np.loadtxt(osp.join(dataset_dir, "test.txt"), dtype=str)}) + + self.subject_list = self.get_subject_list(split) + self.smplx = SMPLX() + + # PIL to tensor + self.image_to_tensor = transforms.Compose([ + transforms.Resize(self.input_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + # PIL to tensor + self.mask_to_tensor = transforms.Compose([ + transforms.Resize(self.input_size), + transforms.ToTensor(), + transforms.Normalize((0.0,), (1.0,)) + ]) + + self.device = torch.device(f"cuda:{cfg.gpus[0]}") + self.render = Render(size=512, device=self.device) + + self.UV_RENDER='/sdb/zzc/zzc/paper_models/PIFu-master/PIFu-master/training_data/UV_RENDER' + self.UV_MASK='/sdb/zzc/zzc/paper_models/PIFu-master/PIFu-master/training_data/UV_MASK' + self.UV_POS='/sdb/zzc/zzc/paper_models/PIFu-master/PIFu-master/training_data/UV_POS' + self.UV_NORMAL='/sdb/zzc/zzc/paper_models/PIFu-master/PIFu-master/training_data/UV_NORMAL' + self.IMAGE_MASK='/sdb/zzc/zzc/paper_models/PIFu-master/PIFu-master/training_data/MASK' + self.PARAM='/sdb/zzc/zzc/paper_models/PIFu-master/PIFu-master/training_data/PARAM' + self.depth='./data/thuman2_36views' + + def render_normal(self, verts, faces): + + # render optimized mesh (normal, T_normal, image [-1,1]) + self.render.load_meshes(verts, faces) + return self.render.get_rgb_image() + + def get_subject_list(self, split): + + subject_list = [] + + for dataset in self.datasets: + + split_txt = osp.join(self.root, dataset, f'{split}.txt') + + if osp.exists(split_txt): + print(f"load from {split_txt}") + subject_list += np.loadtxt(split_txt, dtype=str).tolist() + else: + full_txt = osp.join(self.root, dataset, 'all.txt') + print(f"split {full_txt} into train/val/test") + + full_lst = np.loadtxt(full_txt, dtype=str) + full_lst = [dataset + "/" + item for item in full_lst] + [train_lst, test_lst, val_lst] = np.split(full_lst, [ + 500, + 500 + 5, + ]) + + np.savetxt(full_txt.replace("all", "train"), train_lst, fmt="%s") + np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s") + np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s") + + print(f"load from {split_txt}") + subject_list += np.loadtxt(split_txt, dtype=str).tolist() + + if self.split != 'test': + subject_list += subject_list[:self.bsize - len(subject_list) % self.bsize] + print(colored(f"total: {len(subject_list)}", "yellow")) + random.shuffle(subject_list) + + # subject_list = ["thuman2/0008"] + return subject_list + + def __len__(self): + return len(self.subject_list) * len(self.rotations) + + def __getitem__(self, index): + + # only pick the first data if overfitting + if self.overfit: + index = 0 + + rid = index % len(self.rotations) + mid = index // len(self.rotations) + + rotation = self.rotations[rid] + subject = self.subject_list[mid].split("/")[1] + dataset = self.subject_list[mid].split("/")[0] + render_folder = "/".join([dataset + f"_{self.opt.rotation_num}views", subject]) + if dataset=='thuman2': + old_folder="/".join([dataset + f"_{self.opt.rotation_num}views_nosideview", subject]) + else: + old_folder=render_folder + # add uv map path + # use pifu dataset + + uv_render_path=os.path.join(self.UV_RENDER,subject,'%d_%d_%02d.jpg'%(rotation,0,0)) + uv_mask_path = os.path.join(self.UV_MASK, subject, '%02d.png' % (0)) + uv_pos_path = os.path.join(self.UV_POS, subject, '%02d.exr' % (0)) + uv_normal_path = os.path.join(self.UV_NORMAL, subject, '%02d.png' % (0)) + image_mask_path=os.path.join(self.IMAGE_MASK,subject,'%d_%d_%02d.png'%(rotation,0,0)) + param_path=os.path.join(self.PARAM, subject,'%d_%d_%02d.npy'%(rotation,0,0)) + depth_path=os.path.join(self.depth,subject,"depth_F",'%03d.png'%(rotation)) + + + + # setup paths + data_dict = { + 'dataset': dataset, + 'subject': subject, + 'rotation': rotation, + 'scale': self.datasets_dict[dataset]["scale"], + 'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'), + 'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png'), + 'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.obj"), + 'vis_path': osp.join(self.root, old_folder, 'vis', f'{rotation:03d}.pt'), + 'uv_render_path': osp.join(self.root, render_folder, 'uv_color', f'{rotation:03d}.png'), + 'uv_mask_path': uv_mask_path, + 'uv_pos_path': uv_pos_path, + 'uv_normal_path': osp.join(self.root, render_folder, 'uv_normal', '%02d.png' % (0)), + 'image_mask_path':image_mask_path, + 'param_path':param_path, + 'depth_path':depth_path, + } + + if dataset == 'thuman2': + data_dict.update({ + 'mesh_path': + osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"), + 'smplx_path': + #osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}.obj"), + osp.join("./data/thuman2/smplx/", f"{subject}.obj"), + 'smpl_param': + osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"), + 'smplx_param': + osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}.pkl"), + }) + elif dataset == 'cape': + data_dict.update({ + 'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}.obj"), + 'smpl_param': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.npz"), + }) + + # load training data + data_dict.update(self.load_calib(data_dict)) + + # image/normal/depth loader + for name, channel in zip(self.in_total, self.in_total_dim): + + if f'{name}_path' not in data_dict.keys(): + data_dict.update({ + f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png') + }) + + # tensor update + if os.path.exists(data_dict[f'{name}_path']): + data_dict.update( + {name: self.imagepath2tensor(data_dict[f'{name}_path'], channel, inv=False)}) + # if name=='normal_F' and dataset == 'thuman2': + # right_angle=(rotation+270)%360 + # left_angle=(rotation+90)%360 + # normal_right_path=osp.join(self.root, render_folder, name, f'{right_angle:03d}.png') + # normal_left_path=osp.join(self.root, render_folder, name, f'{left_angle:03d}.png') + # data_dict.update( + # {'normal_R': self.imagepath2tensor(normal_right_path, channel, inv=False)}) + # data_dict.update( + # {'normal_L': self.imagepath2tensor(normal_left_path, channel, inv=False)}) + + if name=='T_normal_F' and dataset == 'thuman2': + + normal_right_path=osp.join(self.root, render_folder, "T_normal_R", f'{rotation:03d}.png') + normal_left_path=osp.join(self.root, render_folder, "T_normal_L", f'{rotation:03d}.png') + data_dict.update( + {'T_normal_R': self.imagepath2tensor(normal_right_path, channel, inv=False)}) + data_dict.update( + {'T_normal_L': self.imagepath2tensor(normal_left_path, channel, inv=False)}) + + data_dict.update(self.load_mesh(data_dict)) + + data_dict.update( + self.get_sampling_geo(data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf)) + data_dict.update(self.load_smpl(data_dict, self.vis)) + + if self.prior_type == 'pamir': + data_dict.update(self.load_smpl_voxel(data_dict)) + + if (self.split != 'test') and (not self.vis): + + del data_dict['verts'] + del data_dict['faces'] + + if not self.vis: + del data_dict['mesh'] + + path_keys = [key for key in data_dict.keys() if '_path' in key or '_dir' in key] + for key in path_keys: + del data_dict[key] + + return data_dict + + def imagepath2tensor(self, path, channel=3, inv=False): + + rgba = Image.open(path).convert('RGBA') + + # remove CAPE's noisy outliers using OpenCV's inpainting + if "cape" in path and 'T_' not in path: + mask = (cv2.imread(path.replace(path.split("/")[-2], "mask"), 0) > 1) + img = np.asarray(rgba)[:, :, :3] + fill_mask = ((mask & (img.sum(axis=2) == 0))).astype(np.uint8) + image = Image.fromarray( + cv2.inpaint(img * mask[..., None], fill_mask, 3, cv2.INPAINT_TELEA)) + mask = Image.fromarray(mask) + else: + mask = rgba.split()[-1] + image = rgba.convert('RGB') + image = self.image_to_tensor(image) + mask = self.mask_to_tensor(mask) + image = (image * mask)[:channel] + + return (image * (0.5 - inv) * 2.0).float() + + def load_calib(self, data_dict): + calib_data = np.loadtxt(data_dict['calib_path'], dtype=float) + extrinsic = calib_data[:4, :4] + intrinsic = calib_data[4:8, :4] + calib_mat = np.matmul(intrinsic, extrinsic) + calib_mat = torch.from_numpy(calib_mat).float() + return {'calib': calib_mat} + + def load_mesh(self, data_dict): + + mesh_path = data_dict['mesh_path'] + scale = data_dict['scale'] + + verts, faces = obj_loader(mesh_path) + + mesh = HoppeMesh(verts * scale, faces) + + return { + 'mesh': mesh, + 'verts': torch.as_tensor(verts * scale).float(), + 'faces': torch.as_tensor(faces).long() + } + + def add_noise(self, beta_num, smpl_pose, smpl_betas, noise_type, noise_scale, type, hashcode): + + np.random.seed(hashcode) + + if type == 'smplx': + noise_idx = self.noise_smplx_idx + else: + noise_idx = self.noise_smpl_idx + + if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0: + smpl_betas += (np.random.rand(beta_num) - + 0.5) * 2.0 * noise_scale[noise_type.index("beta")] + smpl_betas = smpl_betas.astype(np.float32) + + if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0: + smpl_pose[noise_idx] += (np.random.rand(len(noise_idx)) - + 0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")] + smpl_pose = smpl_pose.astype(np.float32) + if type == 'smplx': + return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...]) + else: + return smpl_pose, smpl_betas + + def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None): + + dataset = data_dict['dataset'] + smplx_dict = {} + + smplx_param = np.load(data_dict['smplx_param'], allow_pickle=True) + smplx_pose = smplx_param["body_pose"] # [1,63] + smplx_betas = smplx_param["betas"] # [1,10] + smplx_pose, smplx_betas = self.add_noise( + smplx_betas.shape[1], + smplx_pose[0], + smplx_betas[0], + noise_type, + noise_scale, + type='smplx', + hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8)) + + smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_param'], + scale=self.datasets_dict[dataset]['scale'], + smpl_type='smplx', + smpl_gender='male', + noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose)) + + smplx_dict.update({ + "type": "smplx", + "gender": 'male', + "body_pose": torch.as_tensor(smplx_pose), + "betas": torch.as_tensor(smplx_betas) + }) + + return smplx_out.vertices, smplx_dict + + def compute_voxel_verts(self, data_dict, noise_type=None, noise_scale=None): + + smpl_param = np.load(data_dict["smpl_param"], allow_pickle=True) + + if data_dict['dataset'] == 'cape': + pid = data_dict['subject'].split("-")[0] + gender = "male" if pid in cape_gender["male"] else "female" + smpl_pose = smpl_param['pose'].flatten() + smpl_betas = np.zeros((1, 10)) + else: + gender = 'male' + smpl_pose = rotation_matrix_to_angle_axis(torch.as_tensor( + smpl_param["full_pose"][0])).numpy() + smpl_betas = smpl_param["betas"] + + smpl_path = osp.join(self.smplx.model_dir, f"smpl/SMPL_{gender.upper()}.pkl") + tetra_path = osp.join(self.smplx.tedra_dir, f"tetra_{gender}_adult_smpl.npz") + + smpl_model = TetraSMPLModel(smpl_path, tetra_path, "adult") + + smpl_pose, smpl_betas = self.add_noise( + smpl_model.beta_shape[0], + smpl_pose.flatten(), + smpl_betas[0], + noise_type, + noise_scale, + type="smpl", + hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8), + ) + + smpl_model.set_params(pose=smpl_pose.reshape(-1, 3), + beta=smpl_betas, + trans=smpl_param["transl"]) + if data_dict['dataset'] == 'cape': + verts = np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * 100.0 + else: + verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * + smpl_param["scale"] + + smpl_param["translation"]) * self.datasets_dict[data_dict["dataset"]]["scale"] + + faces = (np.loadtxt( + osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"), + dtype=np.int32, + ) - 1) + + pad_v_num = int(8000 - verts.shape[0]) + pad_f_num = int(25100 - faces.shape[0]) + + verts = np.pad(verts, ((0, pad_v_num), (0, 0)), mode="constant", + constant_values=0.0).astype(np.float32) + faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant", + constant_values=0.0).astype(np.int32) + + return verts, faces, pad_v_num, pad_f_num + + def densely_sample(self, verts, faces): + # TODO: subdivided the triangular mesh + new_vertices,new_faces,index=trimesh.remesh.subdivide(verts,faces) + + return new_vertices, torch.LongTensor(new_faces) + ... + + def load_smpl(self, data_dict, vis=False, densely_sample=False): + + smpl_type = "smplx" if ('smplx_path' in data_dict.keys() and + os.path.exists(data_dict['smplx_path'])) else "smpl" + + return_dict = {} + + if 'smplx_param' in data_dict.keys() and \ + os.path.exists(data_dict['smplx_param']) and \ + sum(self.noise_scale) > 0.0: + smplx_verts, smplx_dict = self.compute_smpl_verts(data_dict, self.noise_type, + self.noise_scale) + smplx_faces = torch.as_tensor(self.smplx.smplx_faces).long() + smplx_cmap = torch.as_tensor(np.load(self.smplx.cmap_vert_path)).float() + + else: + smplx_vis = torch.load(data_dict['vis_path']).float() + return_dict.update({'smpl_vis': smplx_vis}) + + # noise_factor = 20 + # noise = np.random.normal(loc=0, scale=noise_factor) + smplx_verts = rescale_smpl(data_dict[f"{smpl_type}_path"], scale=100.0) + smplx_faces = torch.as_tensor(getattr(self.smplx, f"{smpl_type}_faces")).long() + smplx_cmap = self.smplx.cmap_smpl_vids(smpl_type) + + if densely_sample: + smplx_verts,smplx_faces=self.densely_sample(smplx_verts,smplx_faces) + + smplx_verts = projection(smplx_verts, data_dict['calib']).float() + + # get smpl_vis + if "smpl_vis" not in return_dict.keys() and "smpl_vis" in self.feat_keys: + (xy, z) = torch.as_tensor(smplx_verts).to(self.device).split([2, 1], dim=1) + smplx_vis = get_visibility(xy, z, torch.as_tensor(smplx_faces).to(self.device).long()) + return_dict['smpl_vis'] = smplx_vis + + if "smpl_norm" not in return_dict.keys() and "smpl_norm" in self.feat_keys: + # get smpl_norms + smplx_norms = compute_normal_batch(smplx_verts.unsqueeze(0), + smplx_faces.unsqueeze(0))[0] + return_dict["smpl_norm"] = smplx_norms + + if "smpl_cmap" not in return_dict.keys() and "smpl_cmap" in self.feat_keys: + return_dict["smpl_cmap"] = smplx_cmap + + sample_num=smplx_verts.shape[0] + verts_ids=np.arange(smplx_verts.shape[0]) + + sample_ids=torch.LongTensor(verts_ids) + return_dict.update({ + 'smpl_verts': smplx_verts, + 'smpl_faces': smplx_faces, + 'smpl_cmap': smplx_cmap, + 'smpl_sample_id':sample_ids, + }) + + if vis: + + (xy, z) = torch.as_tensor(smplx_verts).to(self.device).split([2, 1], dim=1) + smplx_vis = get_visibility(xy, z, torch.as_tensor(smplx_faces).to(self.device).long()) + + T_normal_F, T_normal_B = self.render_normal( + (smplx_verts * torch.tensor(np.array([1.0, -1.0, 1.0]))).to(self.device), + smplx_faces.to(self.device)) + + return_dict.update({ + "T_normal_F": T_normal_F.squeeze(0), + "T_normal_B": T_normal_B.squeeze(0) + }) + query_points = projection(data_dict['samples_geo'], data_dict['calib']).float() + + smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch( + smplx_verts.unsqueeze(0).to(self.device), + smplx_faces.unsqueeze(0).to(self.device), + smplx_cmap.unsqueeze(0).to(self.device), + smplx_vis.unsqueeze(0).to(self.device), + query_points.unsqueeze(0).contiguous().to(self.device)) + + return_dict.update({ + 'smpl_feat': + torch.cat((smplx_sdf[0].detach().cpu(), smplx_cmap[0].detach().cpu(), + smplx_norm[0].detach().cpu(), smplx_vis[0].detach().cpu()), + dim=1) + }) + + return return_dict + + def load_smpl_voxel(self, data_dict): + + smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts( + data_dict, self.noise_type, self.noise_scale) # compute using smpl model + smpl_verts = projection(smpl_verts, data_dict['calib']) + + smpl_verts *= 0.5 + + return { + 'voxel_verts': smpl_verts, + 'voxel_faces': smpl_faces, + 'pad_v_num': pad_v_num, + 'pad_f_num': pad_f_num + } + + def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False): + + #assert 0 + mesh = data_dict['mesh'] + calib = data_dict['calib'] + + + # Samples are around the true surface with an offset + n_samples_surface = 4*self.opt.num_sample_geo + vert_ids = np.arange(mesh.verts.shape[0]) + + samples_surface_ids = np.random.choice(vert_ids, n_samples_surface, replace=True) + + samples_surface = mesh.verts[samples_surface_ids, :] + + + # Sampling offsets are random noise with constant scale (15cm - 20cm) + offset = np.random.normal(scale=self.opt.sigma_geo, size=(n_samples_surface, 1)) + samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset + + # samples=np.concatenate([samples_surface], 0) + # np.random.shuffle(samples) + # Uniform samples in [-1, 1] + calib_inv = np.linalg.inv(calib) + n_samples_space = self.opt.num_sample_geo // 4 + samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0 + samples_space = projection(samples_space_img, calib_inv) + + samples = np.concatenate([samples_surface, samples_space], 0) + np.random.shuffle(samples) + + # labels: in->1.0; out->0.0. + inside = mesh.contains(samples) + inside_samples = samples[inside >= 0.5] + outside_samples = samples[inside < 0.5] + + nin = inside_samples.shape[0] + + if nin > self.opt.num_sample_geo // 2: + inside_samples = inside_samples[:self.opt.num_sample_geo // 2] + outside_samples = outside_samples[:self.opt.num_sample_geo // 2] + else: + outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)] + + samples = np.concatenate([inside_samples, outside_samples]) + labels = np.concatenate( + [np.ones(inside_samples.shape[0]), + np.zeros(outside_samples.shape[0])]) + + samples = torch.from_numpy(samples).float() + labels = torch.from_numpy(labels).float() + + + + # sample color + if not self.datasets[0]=='cape': + # get color + uv_render_path = data_dict['uv_render_path'] + uv_mask_path = data_dict['uv_mask_path'] + uv_pos_path = data_dict['uv_pos_path'] + uv_normal_path = data_dict['uv_normal_path'] + + # Segmentation mask for the uv render. + # [H, W] bool + uv_mask = cv2.imread(uv_mask_path) + uv_mask = uv_mask[:, :, 0] != 0 + # UV render. each pixel is the color of the point. + # [H, W, 3] 0 ~ 1 float + uv_render = cv2.imread(uv_render_path) + uv_render = cv2.cvtColor(uv_render, cv2.COLOR_BGR2RGB) / 255.0 + + # Normal render. each pixel is the surface normal of the point. + # [H, W, 3] -1 ~ 1 float + uv_normal = cv2.imread(uv_normal_path) + uv_normal = cv2.cvtColor(uv_normal, cv2.COLOR_BGR2RGB) / 255.0 + uv_normal = 2.0 * uv_normal - 1.0 + + # Position render. each pixel is the xyz coordinates of the point + uv_pos = cv2.imread(uv_pos_path, 2 | 4)[:, :, ::-1] + + ### In these few lines we flattern the masks, positions, and normals + uv_mask = uv_mask.reshape((-1)) # 512*512 + uv_pos = uv_pos.reshape((-1, 3)) + uv_render = uv_render.reshape((-1, 3)) # 512*512,3 + uv_normal = uv_normal.reshape((-1, 3)) + + surface_points = uv_pos[uv_mask] + surface_colors = uv_render[uv_mask] + surface_normal = uv_normal[uv_mask] + + # Samples are around the true surface with an offset + n_samples_surface = self.opt.num_sample_color + + if n_samples_space>surface_points.shape[0]: + print(surface_points.shape[0]) + print( uv_pos_path) + assert 0 + sample_list = random.sample(range(0, surface_points.shape[0] - 1), n_samples_surface) + surface_points=surface_points[sample_list].T + surface_colors=surface_colors[sample_list].T + surface_normal=surface_normal[sample_list].T + + # Samples are around the true surface with an offset + normal = torch.Tensor(surface_normal).float() + samples_surface = torch.Tensor(surface_points).float() \ + + torch.normal(mean=torch.zeros((1, normal.size(1))), std=self.opt.sigma_color).expand_as(normal) * normal + + sample_color=samples_surface.T + rgbs_color=(surface_colors-0.5)*2 # range -1 - 1 + rgbs_color=rgbs_color.T + colors = torch.from_numpy(rgbs_color).float() + + + + + # center.unsqueeze(0).float() + + return {'samples_geo': samples, 'labels_geo': labels,"samples_color":sample_color,"color_labels":colors} + else: + return {'samples_geo': samples, 'labels_geo': labels} + def get_param(self,data_dict): + W=512 + H=512 + + param_path = data_dict['param_path'] + # loading calibration data + param = np.load(param_path, allow_pickle=True) + # pixel unit / world unit + ortho_ratio = param.item().get('ortho_ratio') + # world unit / model unit + scale = param.item().get('scale') + # camera center world coordinate + center = param.item().get('center') + # model rotation + R = param.item().get('R') + translate = -np.matmul(R, center).reshape(3, 1) + extrinsic = np.concatenate([R, translate], axis=1) + extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0) + + + # Match camera space to image pixel space + scale_intrinsic = np.identity(4) + scale_intrinsic[0, 0] = scale / ortho_ratio + scale_intrinsic[1, 1] = -scale / ortho_ratio + scale_intrinsic[2, 2] = scale / ortho_ratio + # Match image pixel space to image uv space + uv_intrinsic = np.identity(4) + uv_intrinsic[0, 0] = 1.0 / float(W // 2) + uv_intrinsic[1, 1] = 1.0 / float(W // 2) + uv_intrinsic[2, 2] = 1.0 / float(W // 2) + + + return scale_intrinsic[:3,:3],R,translate.numpy(),center + + def get_extrinsics(self,data_dict): + calib_data = np.loadtxt(data_dict['calib_path'], dtype=float) + extrinsic = calib_data[:4, :4] + intrinsic = calib_data[4:8, :4] + return intrinsic.astype(np.float32) + + def visualize_sampling3D(self, data_dict, mode='vis'): + + # create plot + vp = vedo.Plotter(title="", size=(1500, 1500), axes=0, bg='white') + vis_list = [] + + assert mode in ['vis', 'sdf', 'normal', 'cmap', 'occ'] + + # sdf-1 cmap-3 norm-3 vis-1 + if mode == 'vis': + labels = data_dict[f'smpl_feat'][:, [-1]] # visibility + colors = np.concatenate([labels, labels, labels], axis=1) + elif mode == 'occ': + labels = data_dict[f'labels_geo'][..., None] # occupancy + colors = np.concatenate([labels, labels, labels], axis=1) + elif mode == 'sdf': + labels = data_dict[f'smpl_feat'][:, [0]] # sdf + labels -= labels.min() + labels /= labels.max() + colors = np.concatenate([labels, labels, labels], axis=1) + elif mode == 'normal': + labels = data_dict[f'smpl_feat'][:, -4:-1] # normal + colors = (labels + 1.0) * 0.5 + elif mode == 'cmap': + labels = data_dict[f'smpl_feat'][:, -7:-4] # colormap + colors = np.array(labels) + + points = projection(data_dict['samples_geo'], data_dict['calib']) + verts = projection(data_dict['verts'], data_dict['calib']) + points[:, 1] *= -1 + + verts[:, 1] *= -1 + + # create a mesh + mesh = trimesh.Trimesh(verts, data_dict['faces'], process=True) + mesh.visual.vertex_colors = [128.0, 128.0, 128.0, 255.0] + vis_list.append(mesh) + + if 'voxel_verts' in data_dict.keys(): + print(colored("voxel verts", "green")) + voxel_verts = data_dict['voxel_verts'] * 2.0 + voxel_faces = data_dict['voxel_faces'] + voxel_verts[:, 1] *= -1 + voxel = trimesh.Trimesh(voxel_verts, + voxel_faces[:, [0, 2, 1]], + process=False, + maintain_order=True) + voxel.visual.vertex_colors = [0.0, 128.0, 0.0, 255.0] + vis_list.append(voxel) + + if 'smpl_verts' in data_dict.keys(): + print(colored("smpl verts", "green")) + smplx_verts = data_dict['smpl_verts'] + smplx_faces = data_dict['smpl_faces'] + smplx_verts[:, 1] *= -1 + smplx = trimesh.Trimesh(smplx_verts, + smplx_faces[:, [0, 2, 1]], + process=False, + maintain_order=True) + smplx.visual.vertex_colors = [128.0, 128.0, 0.0, 255.0] + vis_list.append(smplx) + + # create a picure + img_pos = [1.0, 0.0, -1.0,-1.0,1.0] + for img_id, img_key in enumerate(['normal_F', 'image', 'T_normal_B','T_normal_L','T_normal_R']): + image_arr = (data_dict[img_key].detach().cpu().permute(1, 2, 0).numpy() + + 1.0) * 0.5 * 255.0 + image_dim = image_arr.shape[0] + + if img_id==3: + image=vedo.Picture(image_arr).scale(2.0 / image_dim).pos(-1.0, -1.0, -1.0).rotateY(90) + elif img_id==4: + image=vedo.Picture(image_arr).scale(2.0 / image_dim).pos(-1.0, -1.0, 1.0).rotateY(90) + else: + image = vedo.Picture(image_arr).scale(2.0 / image_dim).pos(-1.0, -1.0, img_pos[img_id]) + vis_list.append(image) + + # create a pointcloud + pc = vedo.Points(points, r=1) + vis_list.append(pc) + + vp.show(*vis_list, bg="white", axes=1.0, interactive=True) diff --git a/lib/dataset/PointFeat.py b/lib/dataset/PointFeat.py new file mode 100644 index 0000000000000000000000000000000000000000..10c129cf496987133cb0c4bd2d275d54a4e7b8e3 --- /dev/null +++ b/lib/dataset/PointFeat.py @@ -0,0 +1,251 @@ +from pytorch3d.structures import Meshes, Pointclouds +import torch.nn.functional as F +import torch +from lib.common.render_utils import face_vertices +from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection +from kaolin.ops.mesh import check_sign, face_normals +from kaolin.metrics.trianglemesh import point_to_mesh_distance +from lib.dataset.Evaluator import point_mesh_distance +from lib.dataset.ECON_Evaluator import econ_point_mesh_distance + + +def distance_matrix(x, y=None, p = 2): #pairwise distance of vectors + + y = x if type(y) == type(None) else y + + n = x.size(0) + m = y.size(0) + d = x.size(1) + + x = x.unsqueeze(1).expand(n, m, d) + y = y.unsqueeze(0).expand(n, m, d) + + dist = torch.norm(x - y, dim=-1) if torch.__version__ >= '1.7.0' else torch.pow(x - y, p).sum(2)**(1/p) + + return dist + +class NN(): + + def __init__(self, X = None, Y = None, p = 2): + self.p = p + self.train(X, Y) + + def train(self, X, Y): + self.train_pts = X + self.train_label = Y + + def __call__(self, x): + return self.predict(x) + + def predict(self, x): + if type(self.train_pts) == type(None) or type(self.train_label) == type(None): + name = self.__class__.__name__ + raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first") + + dist=[] + chunk=10000 + for i in range(0,x.shape[0],chunk): + dist.append(distance_matrix(x[i:i+chunk], self.train_pts, self.p)) + + dist = torch.cat(dist, dim=0) + labels = torch.argmin(dist, dim=1) + return self.train_label[labels],labels + +class PointFeat: + + def __init__(self, verts, faces): + + # verts [B, N_vert, 3] + # faces [B, N_face, 3] + # triangles [B, N_face, 3, 3] + + self.Bsize = verts.shape[0] + self.mesh = Meshes(verts, faces) + self.device = verts.device + self.faces = faces + + # SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth + # 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 + # 2. fill mouth holes with 30 more faces + + if verts.shape[1] == 10475: + faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] + mouth_faces = (torch.as_tensor( + SMPLX().smplx_mouth_fid).unsqueeze(0).repeat( + self.Bsize, 1, 1).to(self.device)) + self.faces = torch.cat([faces, mouth_faces], dim=1).long() + + self.verts = verts + self.triangles = face_vertices(self.verts, self.faces) + + def get_face_normals(self): + return face_normals(self.verts, self.faces) + + def get_nearest_point(self,points): + # points [1, N, 3] + # find nearest point on mesh + + #devices = points.device + points=points.squeeze(0) + nn_class=NN(X=self.verts.squeeze(0),Y=self.verts.squeeze(0),p=2) + nearest_points,nearest_points_ind=nn_class.predict(points) + + # closest_triangles = torch.gather( + # self.triangles, 1, + # pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) + # bary_weights = barycentric_coordinates_of_projection( + # points.view(-1, 3), closest_triangles) + + # bary_weights=F.normalize(bary_weights, p=2, dim=1) + + # normals = face_normals(self.triangles) + + # # make the lenght of the normal is 1 + # normals = F.normalize(normals, p=2, dim=2) + + + # # get the normal of the closest triangle + # closest_normals = torch.gather( + # normals, 1, + # pts_ind[:, :, None].expand(-1, -1, 3)).view(-1, 3) + + + return nearest_points,nearest_points_ind # on cpu + + def query_barycentirc_feats(self,points,feats): + # feats [B,N,C] + + residues, pts_ind, _ = point_to_mesh_distance(points, self.triangles) + closest_triangles = torch.gather( + self.triangles, 1, + pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) + bary_weights = barycentric_coordinates_of_projection( + points.view(-1, 3), closest_triangles) + + feat_arr=feats + feat_dim = feat_arr.shape[-1] + feat_tri = face_vertices(feat_arr, self.faces) + closest_feats = torch.gather( # query点距离最近的face的三个点的feature + feat_tri, 1, + pts_ind[:, :, None, + None].expand(-1, -1, 3, + feat_dim)).view(-1, 3, feat_dim) + pts_feats = ((closest_feats * + bary_weights[:, :, None]).sum(1).unsqueeze(0)) # 用barycentric weight加权求和 + return pts_feats.view(self.Bsize,-1,feat_dim) + + def query(self, points, feats={}): + + # points [B, N, 3] + # feats {'feat_name': [B, N, C]} + + del_keys = ["smpl_verts", "smpl_faces", "smpl_joint","smpl_sample_id"] + + residues, pts_ind, _ = point_to_mesh_distance(points, self.triangles) + closest_triangles = torch.gather( + self.triangles, 1, + pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) + bary_weights = barycentric_coordinates_of_projection( + points.view(-1, 3), closest_triangles) + + out_dict = {} + + for feat_key in feats.keys(): + + if feat_key in del_keys: + continue + + elif feats[feat_key] is not None: + feat_arr = feats[feat_key] + feat_dim = feat_arr.shape[-1] + feat_tri = face_vertices(feat_arr, self.faces) + closest_feats = torch.gather( # query点距离最近的face的三个点的feature + feat_tri, 1, + pts_ind[:, :, None, + None].expand(-1, -1, 3, + feat_dim)).view(-1, 3, feat_dim) + pts_feats = ((closest_feats * + bary_weights[:, :, None]).sum(1).unsqueeze(0)) # 用barycentric weight加权求和 + out_dict[feat_key.split("_")[1]] = pts_feats + + else: + out_dict[feat_key.split("_")[1]] = None + + if "sdf" in out_dict.keys(): + pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3)) + pts_signs = 2.0 * ( + check_sign(self.verts, self.faces[0], points).float() - 0.5) + pts_sdf = (pts_dist * pts_signs).unsqueeze(-1) + out_dict["sdf"] = pts_sdf + + if "vis" in out_dict.keys(): + out_dict["vis"] = out_dict["vis"].ge(1e-1).float() + + if "norm" in out_dict.keys(): + pts_norm = out_dict["norm"] * torch.tensor([-1.0, 1.0, -1.0]).to( + self.device) + out_dict["norm"] = F.normalize(pts_norm, dim=2) + + if "cmap" in out_dict.keys(): + out_dict["cmap"] = out_dict["cmap"].clamp_(min=0.0, max=1.0) + + for out_key in out_dict.keys(): + out_dict[out_key] = out_dict[out_key].view( + self.Bsize, -1, out_dict[out_key].shape[-1]) + + return out_dict + + + + +class ECON_PointFeat: + def __init__(self, verts, faces): + + # verts [B, N_vert, 3] + # faces [B, N_face, 3] + # triangles [B, N_face, 3, 3] + + self.Bsize = verts.shape[0] + self.device = verts.device + self.faces = faces + + # SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth + # 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 + # 2. fill mouth holes with 30 more faces + + if verts.shape[1] == 10475: + faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] + mouth_faces = ( + torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, + 1).to(self.device) + ) + self.faces = torch.cat([faces, mouth_faces], dim=1).long() + + self.verts = verts.float() + self.triangles = face_vertices(self.verts, self.faces) + self.mesh = Meshes(self.verts, self.faces).to(self.device) + + def query(self, points): + + points = points.float() + residues, pts_ind = econ_point_mesh_distance(self.mesh, Pointclouds(points), weighted=False) # 这个和ECON的不太一样 + + closest_triangles = torch.gather( + self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) + ).view(-1, 3, 3) + bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles) + + feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces) + closest_normals = torch.gather( + feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) + ).view(-1, 3, 3) + shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0)) + + pts2shoot_normals = points - shoot_verts + pts2shoot_normals = pts2shoot_normals / torch.norm(pts2shoot_normals, dim=-1, keepdim=True) + + shoot_normals = ((closest_normals * bary_weights[:, :, None]).sum(1).unsqueeze(0)) + shoot_normals = shoot_normals / torch.norm(shoot_normals, dim=-1, keepdim=True) + angles = (pts2shoot_normals * shoot_normals).sum(dim=-1).abs() + + return (torch.sqrt(residues).unsqueeze(0), angles) \ No newline at end of file diff --git a/lib/dataset/TestDataset.py b/lib/dataset/TestDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4f67b2c00ce566c775ef93e13e22c9d0d5a0351c --- /dev/null +++ b/lib/dataset/TestDataset.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam +from lib.pixielib.utils.config import cfg as pixie_cfg +from lib.pixielib.pixie import PIXIE +import lib.smplx as smplx +from lib.pare.pare.core.tester import PARETester +from lib.pymaf.utils.geometry import rot6d_to_rotmat, batch_rodrigues, rotation_matrix_to_angle_axis +from lib.pymaf.utils.imutils import process_image +from lib.common.imutils import econ_process_image +from lib.pymaf.core import path_config +from lib.pymaf.models import pymaf_net +from lib.common.config import cfg +from lib.common.render import Render +from lib.dataset.body_model import TetraSMPLModel +from lib.dataset.mesh_util import get_visibility, SMPLX +import os.path as osp +import os +import torch +import glob +import numpy as np +import random +from termcolor import colored +from PIL import ImageFile +from torchvision.models import detection + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class TestDataset(): + + def __init__(self, cfg, device): + + random.seed(1993) + + self.image_dir = cfg['image_dir'] + self.seg_dir = cfg['seg_dir'] + self.hps_type = cfg['hps_type'] + self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx' + self.smpl_gender = 'neutral' + self.colab = cfg['colab'] + + self.device = device + + keep_lst = sorted(glob.glob(f"{self.image_dir}/*")) + img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp'] + keep_lst = [item for item in keep_lst if item.split(".")[-1] in img_fmts] + + self.subject_list = sorted([item for item in keep_lst if item.split(".")[-1] in img_fmts]) + + if self.colab: + self.subject_list = [self.subject_list[0]] + + # smpl related + self.smpl_data = SMPLX() + + # smpl-smplx correspondence + self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73] + self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [68 + 61, 72 + 68] + self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(model_path=self.smpl_data. + model_dir, + gender=smpl_gender, + model_type=smpl_type, + ext='npz') + + # Load SMPL model + self.smpl_model = self.get_smpl_model(self.smpl_type, self.smpl_gender).to(self.device) + self.faces = self.smpl_model.faces + + if self.hps_type == 'pymaf': + self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device) + self.hps.load_state_dict(torch.load(path_config.CHECKPOINT_FILE)['model'], strict=True) + self.hps.eval() + + elif self.hps_type == 'pare': + self.hps = PARETester(path_config.CFG, path_config.CKPT).model + elif self.hps_type == 'pixie': + self.hps = PIXIE(config=pixie_cfg, device=self.device) + self.smpl_model = self.hps.smplx + elif self.hps_type == 'hybrik': + smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl") + self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG, + smpl_path=smpl_path, + data_path=path_config.hybrik_data_dir) + self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'), + strict=False) + self.hps.to(self.device) + elif self.hps_type == 'bev': + try: + import bev + except: + print('Could not find bev, installing via pip install --upgrade simple-romp') + os.system('pip install simple-romp==1.0.3') + import bev + settings = bev.main.default_settings + # change the argparse settings of bev here if you prefer other settings. + settings.mode = 'image' + settings.GPU = int(str(self.device).split(':')[1]) + settings.show_largest = True + # settings.show = True # uncommit this to show the original BEV predictions + self.hps = bev.BEV(settings) + + self.detector=detection.maskrcnn_resnet50_fpn(pretrained=True) + self.detector.eval() + print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green")) + + self.render = Render(size=512, device=device) + + def __len__(self): + return len(self.subject_list) + + def compute_vis_cmap(self, smpl_verts, smpl_faces): + + (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1) + smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long()) + smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type) + + return { + 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device), + 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device), + 'smpl_verts': smpl_verts.unsqueeze(0) + } + + def compute_voxel_verts(self, body_pose, global_orient, betas, trans, scale): + + smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl") + tetra_path = osp.join(self.smpl_data.tedra_dir, 'tetra_neutral_adult_smpl.npz') + smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult') + + pose = torch.cat([global_orient[0], body_pose[0]], dim=0) + smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0]) + + verts = np.concatenate([smpl_model.verts, smpl_model.verts_added], + axis=0) * scale.item() + trans.detach().cpu().numpy() + faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir, 'tetrahedrons_neutral_adult.txt'), + dtype=np.int32) - 1 + + pad_v_num = int(8000 - verts.shape[0]) + pad_f_num = int(25100 - faces.shape[0]) + + verts = np.pad(verts, + ((0, pad_v_num), + (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) * 0.5 + faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant', + constant_values=0.0).astype(np.int32) + + verts[:, 2] *= -1.0 + + voxel_dict = { + 'voxel_verts': torch.from_numpy(verts).to(self.device).unsqueeze(0).float(), + 'voxel_faces': torch.from_numpy(faces).to(self.device).unsqueeze(0).long(), + 'pad_v_num': torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(), + 'pad_f_num': torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long() + } + + return voxel_dict + + def __getitem__(self, index): + + img_path = self.subject_list[index] + img_name = img_path.split("/")[-1].rsplit(".", 1)[0] + print(img_name) + # smplx_param_path=f'./data/thuman2/smplx/{img_name[:-2]}.pkl' + # smplx_param = np.load(smplx_param_path, allow_pickle=True) + import pdb;pdb.set_trace() + if self.seg_dir is None: + img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image( + img_path, self.hps_type, 512, self.device) + + data_dict = { + 'name': img_name, + 'image': img_icon.to(self.device).unsqueeze(0), + 'ori_image': img_ori, + 'mask': img_mask, + 'uncrop_param': uncrop_param + } + + else: + img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image( + img_path, + self.hps_type, + 512, + self.device, + seg_path=os.path.join(self.seg_dir, f'{img_name}.json')) + data_dict = { + 'name': img_name, + 'image': img_icon.to(self.device).unsqueeze(0), + 'ori_image': img_ori, + 'mask': img_mask, + 'uncrop_param': uncrop_param, + 'segmentations': segmentations + } + + arr_dict=econ_process_image(img_path,self.hps_type,True,512,self.detector) + data_dict['hands_visibility']=arr_dict['hands_visibility'] + + with torch.no_grad(): + # import ipdb; ipdb.set_trace() + preds_dict = self.hps.forward(img_hps) + + data_dict['smpl_faces'] = torch.Tensor(self.faces.astype(np.int64)).long().unsqueeze(0).to( + self.device) + + if self.hps_type == 'pymaf': + output = preds_dict['smpl_out'][-1] + scale, tranX, tranY = output['theta'][0, :3] + data_dict['betas'] = output['pred_shape'] + data_dict['body_pose'] = output['rotmat'][:, 1:] + data_dict['global_orient'] = output['rotmat'][:, 0:1] + data_dict['smpl_verts'] = output['verts'] # 不确定尺度是否一样 + data_dict["type"] = "smpl" + + elif self.hps_type == 'pare': + data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:] + data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1] + data_dict['betas'] = preds_dict['pred_shape'] + data_dict['smpl_verts'] = preds_dict['smpl_vertices'] + scale, tranX, tranY = preds_dict['pred_cam'][0, :3] + data_dict["type"] = "smpl" + + elif self.hps_type == 'pixie': + data_dict.update(preds_dict) + data_dict['body_pose'] = preds_dict['body_pose'] + data_dict['global_orient'] = preds_dict['global_pose'] + data_dict['betas'] = preds_dict['shape'] + data_dict['smpl_verts'] = preds_dict['vertices'] + scale, tranX, tranY = preds_dict['cam'][0, :3] + data_dict["type"] = "smplx" + + elif self.hps_type == 'hybrik': + data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:] + data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]] + data_dict['betas'] = preds_dict['pred_shape'] + data_dict['smpl_verts'] = preds_dict['pred_vertices'] + scale, tranX, tranY = preds_dict['pred_camera'][0, :3] + scale = scale * 2 + data_dict["type"] = "smpl" + + elif self.hps_type == 'bev': + data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[[0], :10].to( + self.device).float() + pred_thetas = batch_rodrigues( + torch.from_numpy(preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float() + data_dict['body_pose'] = pred_thetas[1:][None].to(self.device) + data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device) + data_dict['smpl_verts'] = torch.from_numpy(preds_dict['verts'][[0]]).to( + self.device).float() + tranX = preds_dict['cam_trans'][0, 0] + tranY = preds_dict['cam'][0, 1] + 0.28 + scale = preds_dict['cam'][0, 0] * 1.1 + data_dict["type"] = "smpl" + + data_dict['scale'] = scale + data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).unsqueeze(0).to(self.device).float() + + # data_dict info (key-shape): + # scale, tranX, tranY - tensor.float + # betas - [1,10] / [1, 200] + # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3] + # global_orient - [1, 1, 3, 3] + # smpl_verts - [1, 6890, 3] / [1, 10475, 3] + + # from rot_mat to rot_6d for better optimization + N_body = data_dict["body_pose"].shape[1] + data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body, -1) + data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1, -1) + + return data_dict + + def render_normal(self, verts, faces): + + # render optimized mesh (normal, T_normal, image [-1,1]) + self.render.load_meshes(verts, faces) + return self.render.get_rgb_image() + + def render_depth(self, verts, faces): + + # render optimized mesh (normal, T_normal, image [-1,1]) + self.render.load_meshes(verts, faces) + return self.render.get_depth_map(cam_ids=[0, 2]) + + def visualize_alignment(self, data): + + import vedo + import trimesh + + if self.hps_type != 'pixie': + smpl_out = self.smpl_model(betas=data['betas'], + body_pose=data['body_pose'], + global_orient=data['global_orient'], + pose2rot=False) + smpl_verts = ((smpl_out.vertices + data['trans']) * + data['scale']).detach().cpu().numpy()[0] + else: + smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'], + expression_params=data['exp'], + body_pose=data['body_pose'], + global_pose=data['global_orient'], + jaw_pose=data['jaw_pose'], + left_hand_pose=data['left_hand_pose'], + right_hand_pose=data['right_hand_pose']) + + smpl_verts = ((smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0] + + smpl_verts *= np.array([1.0, -1.0, -1.0]) + faces = data['smpl_faces'][0].detach().cpu().numpy() + + image_P = data['image'] + image_F, image_B = self.render_normal(smpl_verts, faces) + + # create plot + vp = vedo.Plotter(title="", size=(1500, 1500)) + vis_list = [] + + image_F = (0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) + image_B = (0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) + image_P = (0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) + + vis_list.append( + vedo.Picture(image_P * 0.5 + image_F * 0.5).scale(2.0 / image_P.shape[0]).pos( + -1.0, -1.0, 1.0)) + vis_list.append(vedo.Picture(image_F).scale(2.0 / image_F.shape[0]).pos(-1.0, -1.0, -0.5)) + vis_list.append(vedo.Picture(image_B).scale(2.0 / image_B.shape[0]).pos(-1.0, -1.0, -1.0)) + + # create a mesh + mesh = trimesh.Trimesh(smpl_verts, faces, process=False) + mesh.visual.vertex_colors = [200, 200, 0] + vis_list.append(mesh) + + vp.show(*vis_list, bg="white", axes=1, interactive=True) + + +if __name__ == '__main__': + + cfg.merge_from_file("./configs/icon-filter.yaml") + cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml') + + cfg_show_list = ['test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False] + + cfg.merge_from_list(cfg_show_list) + cfg.freeze() + + + device = torch.device('cuda:0') + + dataset = TestDataset( + { + 'image_dir': "./examples", + 'has_det': True, # w/ or w/o detection + 'hps_type': 'bev' # pymaf/pare/pixie/hybrik/bev + }, + device) + + for i in range(len(dataset)): + dataset.visualize_alignment(dataset[i]) diff --git a/lib/dataset/__init__.py b/lib/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/dataset/body_model.py b/lib/dataset/body_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3e8549fa9bee813b9db74f352aa4b089298dc8 --- /dev/null +++ b/lib/dataset/body_model.py @@ -0,0 +1,495 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np +import pickle +import torch +import os + + +class SMPLModel(): + + def __init__(self, model_path, age): + """ + SMPL model. + + Parameter: + --------- + model_path: Path to the SMPL model parameters, pre-processed by + `preprocess.py`. + + """ + with open(model_path, 'rb') as f: + params = pickle.load(f, encoding='latin1') + + self.J_regressor = params['J_regressor'] + self.weights = np.asarray(params['weights']) + self.posedirs = np.asarray(params['posedirs']) + self.v_template = np.asarray(params['v_template']) + self.shapedirs = np.asarray(params['shapedirs']) + self.faces = np.asarray(params['f']) + self.kintree_table = np.asarray(params['kintree_table']) + + self.pose_shape = [24, 3] + self.beta_shape = [10] + self.trans_shape = [3] + + if age == 'kid': + v_template_smil = np.load( + os.path.join(os.path.dirname(model_path), + "smpl/smpl_kid_template.npy")) + v_template_smil -= np.mean(v_template_smil, axis=0) + v_template_diff = np.expand_dims(v_template_smil - self.v_template, + axis=2) + self.shapedirs = np.concatenate( + (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), + axis=2) + self.beta_shape[0] += 1 + + id_to_col = { + self.kintree_table[1, i]: i + for i in range(self.kintree_table.shape[1]) + } + self.parent = { + i: id_to_col[self.kintree_table[0, i]] + for i in range(1, self.kintree_table.shape[1]) + } + + self.pose = np.zeros(self.pose_shape) + self.beta = np.zeros(self.beta_shape) + self.trans = np.zeros(self.trans_shape) + + self.verts = None + self.J = None + self.R = None + self.G = None + + self.update() + + def set_params(self, pose=None, beta=None, trans=None): + """ + Set pose, shape, and/or translation parameters of SMPL model. Verices of the + model will be updated and returned. + + Prameters: + --------- + pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation + relative to parent joint. For root joint it's global orientation. + Represented in a axis-angle format. + + beta: Parameter for model shape. A vector of shape [10]. Coefficients for + PCA component. Only 10 components were released by MPI. + + trans: Global translation of shape [3]. + + Return: + ------ + Updated vertices. + + """ + if pose is not None: + self.pose = pose + if beta is not None: + self.beta = beta + if trans is not None: + self.trans = trans + self.update() + return self.verts + + def update(self): + """ + Called automatically when parameters are updated. + + """ + # how beta affect body shape + v_shaped = self.shapedirs.dot(self.beta) + self.v_template + # joints location + self.J = self.J_regressor.dot(v_shaped) + pose_cube = self.pose.reshape((-1, 1, 3)) + # rotation matrix for each joint + self.R = self.rodrigues(pose_cube) + I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), + (self.R.shape[0] - 1, 3, 3)) + lrotmin = (self.R[1:] - I_cube).ravel() + # how pose affect body shape in zero pose + v_posed = v_shaped + self.posedirs.dot(lrotmin) + # world transformation of each joint + G = np.empty((self.kintree_table.shape[1], 4, 4)) + G[0] = self.with_zeros( + np.hstack((self.R[0], self.J[0, :].reshape([3, 1])))) + for i in range(1, self.kintree_table.shape[1]): + G[i] = G[self.parent[i]].dot( + self.with_zeros( + np.hstack([ + self.R[i], + ((self.J[i, :] - self.J[self.parent[i], :]).reshape( + [3, 1])) + ]))) + # remove the transformation due to the rest pose + G = G - self.pack( + np.matmul( + G, + np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1]))) + # transformation of each vertex + T = np.tensordot(self.weights, G, axes=[[1], [0]]) + rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1]))) + v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, + 4])[:, :3] + self.verts = v + self.trans.reshape([1, 3]) + self.G = G + + def rodrigues(self, r): + """ + Rodrigues' rotation formula that turns axis-angle vector into rotation + matrix in a batch-ed manner. + + Parameter: + ---------- + r: Axis-angle rotation vector of shape [batch_size, 1, 3]. + + Return: + ------- + Rotation matrix of shape [batch_size, 3, 3]. + + """ + theta = np.linalg.norm(r, axis=(1, 2), keepdims=True) + # avoid zero divide + theta = np.maximum(theta, np.finfo(np.float64).tiny) + r_hat = r / theta + cos = np.cos(theta) + z_stick = np.zeros(theta.shape[0]) + m = np.dstack([ + z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick, + -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick + ]).reshape([-1, 3, 3]) + i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), + [theta.shape[0], 3, 3]) + A = np.transpose(r_hat, axes=[0, 2, 1]) + B = r_hat + dot = np.matmul(A, B) + R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m + return R + + def with_zeros(self, x): + """ + Append a [0, 0, 0, 1] vector to a [3, 4] matrix. + + Parameter: + --------- + x: Matrix to be appended. + + Return: + ------ + Matrix after appending of shape [4,4] + + """ + return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]]))) + + def pack(self, x): + """ + Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched + manner. + + Parameter: + ---------- + x: Matrices to be appended of shape [batch_size, 4, 1] + + Return: + ------ + Matrix of shape [batch_size, 4, 4] after appending. + + """ + return np.dstack((np.zeros((x.shape[0], 4, 3)), x)) + + def save_to_obj(self, path): + """ + Save the SMPL model into .obj file. + + Parameter: + --------- + path: Path to save. + + """ + with open(path, 'w') as fp: + for v in self.verts: + fp.write('v %f %f %f\n' % (v[0], v[1], v[2])) + for f in self.faces + 1: + fp.write('f %d %d %d\n' % (f[0], f[1], f[2])) + + +class TetraSMPLModel(): + + def __init__(self, + model_path, + model_addition_path, + age='adult', + v_template=None): + """ + SMPL model. + + Parameter: + --------- + model_path: Path to the SMPL model parameters, pre-processed by + `preprocess.py`. + + """ + with open(model_path, 'rb') as f: + params = pickle.load(f, encoding='latin1') + + self.J_regressor = params['J_regressor'] + self.weights = np.asarray(params['weights']) + self.posedirs = np.asarray(params['posedirs']) + + if v_template is not None: + self.v_template = v_template + else: + self.v_template = np.asarray(params['v_template']) + + self.shapedirs = np.asarray(params['shapedirs']) + self.faces = np.asarray(params['f']) + self.kintree_table = np.asarray(params['kintree_table']) + + params_added = np.load(model_addition_path) + self.v_template_added = params_added['v_template_added'] + self.weights_added = params_added['weights_added'] + self.shapedirs_added = params_added['shapedirs_added'] + self.posedirs_added = params_added['posedirs_added'] + self.tetrahedrons = params_added['tetrahedrons'] + + id_to_col = { + self.kintree_table[1, i]: i + for i in range(self.kintree_table.shape[1]) + } + self.parent = { + i: id_to_col[self.kintree_table[0, i]] + for i in range(1, self.kintree_table.shape[1]) + } + + self.pose_shape = [24, 3] + self.beta_shape = [10] + self.trans_shape = [3] + + if age == 'kid': + v_template_smil = np.load( + os.path.join(os.path.dirname(model_path), + "smpl/smpl_kid_template.npy")) + v_template_smil -= np.mean(v_template_smil, axis=0) + v_template_diff = np.expand_dims(v_template_smil - self.v_template, + axis=2) + self.shapedirs = np.concatenate( + (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), + axis=2) + self.beta_shape[0] += 1 + + self.pose = np.zeros(self.pose_shape) + self.beta = np.zeros(self.beta_shape) + self.trans = np.zeros(self.trans_shape) + + self.verts = None + self.verts_added = None + self.J = None + self.R = None + self.G = None + + self.update() + + def set_params(self, pose=None, beta=None, trans=None): + """ + Set pose, shape, and/or translation parameters of SMPL model. Verices of the + model will be updated and returned. + + Prameters: + --------- + pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation + relative to parent joint. For root joint it's global orientation. + Represented in a axis-angle format. + + beta: Parameter for model shape. A vector of shape [10]. Coefficients for + PCA component. Only 10 components were released by MPI. + + trans: Global translation of shape [3]. + + Return: + ------ + Updated vertices. + + """ + + if torch.is_tensor(pose): + pose = pose.detach().cpu().numpy() + if torch.is_tensor(beta): + beta = beta.detach().cpu().numpy() + + if pose is not None: + self.pose = pose + if beta is not None: + self.beta = beta + if trans is not None: + self.trans = trans + self.update() + return self.verts + + def update(self): + """ + Called automatically when parameters are updated. + + """ + # how beta affect body shape + v_shaped = self.shapedirs.dot(self.beta) + self.v_template + v_shaped_added = self.shapedirs_added.dot( + self.beta) + self.v_template_added + # joints location + self.J = self.J_regressor.dot(v_shaped) + pose_cube = self.pose.reshape((-1, 1, 3)) + # rotation matrix for each joint + self.R = self.rodrigues(pose_cube) + I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), + (self.R.shape[0] - 1, 3, 3)) + lrotmin = (self.R[1:] - I_cube).ravel() + # how pose affect body shape in zero pose + v_posed = v_shaped + self.posedirs.dot(lrotmin) + v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin) + # world transformation of each joint + G = np.empty((self.kintree_table.shape[1], 4, 4)) + G[0] = self.with_zeros( + np.hstack((self.R[0], self.J[0, :].reshape([3, 1])))) + for i in range(1, self.kintree_table.shape[1]): + G[i] = G[self.parent[i]].dot( + self.with_zeros( + np.hstack([ + self.R[i], + ((self.J[i, :] - self.J[self.parent[i], :]).reshape( + [3, 1])) + ]))) + # remove the transformation due to the rest pose + G = G - self.pack( + np.matmul( + G, + np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1]))) + self.G = G + # transformation of each vertex + T = np.tensordot(self.weights, G, axes=[[1], [0]]) + rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1]))) + v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, + 4])[:, :3] + self.verts = v + self.trans.reshape([1, 3]) + T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]]) + rest_shape_added_h = np.hstack( + (v_posed_added, np.ones([v_posed_added.shape[0], 1]))) + v_added = np.matmul(T_added, + rest_shape_added_h.reshape([-1, 4, + 1])).reshape([-1, 4 + ])[:, :3] + self.verts_added = v_added + self.trans.reshape([1, 3]) + + def rodrigues(self, r): + """ + Rodrigues' rotation formula that turns axis-angle vector into rotation + matrix in a batch-ed manner. + + Parameter: + ---------- + r: Axis-angle rotation vector of shape [batch_size, 1, 3]. + + Return: + ------- + Rotation matrix of shape [batch_size, 3, 3]. + + """ + theta = np.linalg.norm(r, axis=(1, 2), keepdims=True) + # avoid zero divide + theta = np.maximum(theta, np.finfo(np.float64).tiny) + r_hat = r / theta + cos = np.cos(theta) + z_stick = np.zeros(theta.shape[0]) + m = np.dstack([ + z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick, + -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick + ]).reshape([-1, 3, 3]) + i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), + [theta.shape[0], 3, 3]) + A = np.transpose(r_hat, axes=[0, 2, 1]) + B = r_hat + dot = np.matmul(A, B) + R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m + return R + + def with_zeros(self, x): + """ + Append a [0, 0, 0, 1] vector to a [3, 4] matrix. + + Parameter: + --------- + x: Matrix to be appended. + + Return: + ------ + Matrix after appending of shape [4,4] + + """ + return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]]))) + + def pack(self, x): + """ + Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched + manner. + + Parameter: + ---------- + x: Matrices to be appended of shape [batch_size, 4, 1] + + Return: + ------ + Matrix of shape [batch_size, 4, 4] after appending. + + """ + return np.dstack((np.zeros((x.shape[0], 4, 3)), x)) + + def save_mesh_to_obj(self, path): + """ + Save the SMPL model into .obj file. + + Parameter: + --------- + path: Path to save. + + """ + with open(path, 'w') as fp: + for v in self.verts: + fp.write('v %f %f %f\n' % (v[0], v[1], v[2])) + for f in self.faces + 1: + fp.write('f %d %d %d\n' % (f[0], f[1], f[2])) + + def save_tetrahedron_to_obj(self, path): + """ + Save the tetrahedron SMPL model into .obj file. + + Parameter: + --------- + path: Path to save. + + """ + + with open(path, 'w') as fp: + for v in self.verts: + fp.write('v %f %f %f 1 0 0\n' % (v[0], v[1], v[2])) + for va in self.verts_added: + fp.write('v %f %f %f 0 0 1\n' % (va[0], va[1], va[2])) + for t in self.tetrahedrons + 1: + fp.write('f %d %d %d\n' % (t[0], t[2], t[1])) + fp.write('f %d %d %d\n' % (t[0], t[3], t[2])) + fp.write('f %d %d %d\n' % (t[0], t[1], t[3])) + fp.write('f %d %d %d\n' % (t[1], t[2], t[3])) diff --git a/lib/dataset/hoppeMesh.py b/lib/dataset/hoppeMesh.py new file mode 100644 index 0000000000000000000000000000000000000000..09498319dd806fa6adcb3599b2dfcd5a34278789 --- /dev/null +++ b/lib/dataset/hoppeMesh.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np +from scipy.spatial import cKDTree +import trimesh + +import logging + +logging.getLogger("trimesh").setLevel(logging.ERROR) + + +def save_obj_mesh(mesh_path, verts, faces): + file = open(mesh_path, 'w') + for v in verts: + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def save_obj_mesh_with_color(mesh_path, verts, faces, colors): + file = open(mesh_path, 'w') + + for idx, v in enumerate(verts): + c = colors[idx] + file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % + (v[0], v[1], v[2], c[0], c[1], c[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def save_ply(mesh_path, points, rgb): + ''' + Save the visualization of sampling to a ply file. + Red points represent positive predictions. + Green points represent negative predictions. + :param mesh_path: File name to save + :param points: [N, 3] array of points + :param rgb: [N, 3] array of rgb values in the range [0~1] + :return: + ''' + to_save = np.concatenate([points, rgb * 255], axis=-1) + return np.savetxt( + mesh_path, + to_save, + fmt='%.6f %.6f %.6f %d %d %d', + comments='', + header=( + 'ply\nformat ascii 1.0\nelement vertex {:d}\n' + + 'property float x\nproperty float y\nproperty float z\n' + + 'property uchar red\nproperty uchar green\nproperty uchar blue\n' + + 'end_header').format(points.shape[0])) + + +class HoppeMesh: + + def __init__(self, verts, faces, vert_normals, face_normals): + ''' + The HoppeSDF calculates signed distance towards a predefined oriented point cloud + http://hhoppe.com/recon.pdf + For clean and high-resolution pcl data, this is the fastest and accurate approximation of sdf + :param points: pts + :param normals: normals + ''' + self.verts = verts # [n, 3] + self.faces = faces # [m, 3] + self.vert_normals = vert_normals # [n, 3] + self.face_normals = face_normals # [m, 3] + + self.kd_tree = cKDTree(self.verts) + self.len = len(self.verts) + + def query(self, points): + dists, idx = self.kd_tree.query(points, n_jobs=1) + # FIXME: because the eyebows are removed, cKDTree around eyebows + # are not accurate. Cause a few false-inside labels here. + dirs = points - self.verts[idx] + signs = (dirs * self.vert_normals[idx]).sum(axis=1) + signs = (signs > 0) * 2 - 1 + return signs * dists + + def contains(self, points): + + labels = trimesh.Trimesh(vertices=self.verts, + faces=self.faces).contains(points) + return labels + + def export(self, path): + if self.colors is not None: + save_obj_mesh_with_color(path, self.verts, self.faces, + self.colors[:, 0:3] / 255.0) + else: + save_obj_mesh(path, self.verts, self.faces) + + def export_ply(self, path): + save_ply(path, self.verts, self.colors[:, 0:3] / 255.0) + + def triangles(self): + return self.verts[self.faces] # [n, 3, 3] diff --git a/lib/dataset/mesh_util.py b/lib/dataset/mesh_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d35983eda1912cda24573a6b8e2d8bccb5abfa7a --- /dev/null +++ b/lib/dataset/mesh_util.py @@ -0,0 +1,1263 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np +import cv2 +import pymeshlab +import torch +import torchvision +import trimesh +import json +from pytorch3d.io import load_obj +import os +from termcolor import colored +import os.path as osp +from scipy.spatial import cKDTree +import _pickle as cPickle +import open3d as o3d + +from pytorch3d.structures import Meshes +import torch.nn.functional as F +from lib.common.render_utils import Pytorch3dRasterizer, face_vertices + +from pytorch3d.renderer.mesh import rasterize_meshes +from PIL import Image, ImageFont, ImageDraw +from kaolin.ops.mesh import check_sign +from kaolin.metrics.trianglemesh import point_to_mesh_distance + +from pytorch3d.loss import (mesh_laplacian_smoothing, mesh_normal_consistency) + +# import tinyobjloader + + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def obj_loader(path): + # Create reader. + reader = tinyobjloader.ObjReader() + + # Load .obj(and .mtl) using default configuration + ret = reader.ParseFromFile(path) + + if ret == False: + print("Failed to load : ", path) + return None + + # note here for wavefront obj, #v might not equal to #vt, same as #vn. + attrib = reader.GetAttrib() + verts = np.array(attrib.vertices).reshape(-1, 3) + + shapes = reader.GetShapes() + tri = shapes[0].mesh.numpy_indices().reshape(-1, 9) + faces = tri[:, [0, 3, 6]] + + return verts, faces + + +class HoppeMesh: + + def __init__(self, verts, faces): + ''' + The HoppeSDF calculates signed distance towards a predefined oriented point cloud + http://hhoppe.com/recon.pdf + For clean and high-resolution pcl data, this is the fastest and accurate approximation of sdf + :param points: pts + :param normals: normals + ''' + self.trimesh = trimesh.Trimesh(verts, faces, process=True) + self.verts = np.array(self.trimesh.vertices) + self.faces = np.array(self.trimesh.faces) + self.vert_normals, self.faces_normals = compute_normal( + self.verts, self.faces) + + def contains(self, points): + + labels = check_sign( + torch.as_tensor(self.verts).unsqueeze(0), + torch.as_tensor(self.faces), + torch.as_tensor(points).unsqueeze(0)) + return labels.squeeze(0).numpy() + + def triangles(self): + return self.verts[self.faces] # [n, 3, 3] + + +def tensor2variable(tensor, device): + # [1,23,3,3] + return torch.tensor(tensor, device=device, requires_grad=True) + + +class GMoF(torch.nn.Module): + + def __init__(self, rho=1): + super(GMoF, self).__init__() + self.rho = rho + + def extra_repr(self): + return 'rho = {}'.format(self.rho) + + def forward(self, residual): + dist = torch.div(residual, residual + self.rho**2) + return self.rho**2 * dist + + +def mesh_edge_loss(meshes, target_length: float = 0.0): + """ + Computes mesh edge length regularization loss averaged across all meshes + in a batch. Each mesh contributes equally to the final loss, regardless of + the number of edges per mesh in the batch by weighting each mesh with the + inverse number of edges. For example, if mesh 3 (out of N) has only E=4 + edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to + contribute to the final loss. + + Args: + meshes: Meshes object with a batch of meshes. + target_length: Resting value for the edge length. + + Returns: + loss: Average loss across the batch. Returns 0 if meshes contains + no meshes or all empty meshes. + """ + if meshes.isempty(): + return torch.tensor([0.0], + dtype=torch.float32, + device=meshes.device, + requires_grad=True) + + N = len(meshes) + edges_packed = meshes.edges_packed() # (sum(E_n), 3) + verts_packed = meshes.verts_packed() # (sum(V_n), 3) + edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), ) + num_edges_per_mesh = meshes.num_edges_per_mesh() # N + + # Determine the weight for each edge based on the number of edges in the + # mesh it corresponds to. + # TODO (nikhilar) Find a faster way of computing the weights for each edge + # as this is currently a bottleneck for meshes with a large number of faces. + weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx) + weights = 1.0 / weights.float() + + verts_edges = verts_packed[edges_packed] + v0, v1 = verts_edges.unbind(1) + loss = ((v0 - v1).norm(dim=1, p=2) - target_length)**2.0 + loss_vertex = loss * weights + # loss_outlier = torch.topk(loss, 100)[0].mean() + # loss_all = (loss_vertex.sum() + loss_outlier.mean()) / N + loss_all = loss_vertex.sum() / N + + return loss_all + + +def remesh(obj_path, perc, device): + + ms = pymeshlab.MeshSet() + ms.load_new_mesh(obj_path) + ms.apply_coord_laplacian_smoothing() + ms.meshing_isotropic_explicit_remeshing(targetlen=pymeshlab.PercentageValue(perc), adaptive=True) + # ms.remeshing_isotropic_explicit_remeshing( + # targetlen=pymeshlab.Percentage(perc), adaptive=True) + ms.save_current_mesh(obj_path.replace("recon", "remesh")) + polished_mesh = trimesh.load_mesh(obj_path.replace("recon", "remesh")) + verts_pr = torch.tensor( + polished_mesh.vertices).float().unsqueeze(0).to(device) + faces_pr = torch.tensor(polished_mesh.faces).long().unsqueeze(0).to(device) + + return verts_pr, faces_pr + + +def possion(mesh, obj_path): + + mesh.export(obj_path) + ms = pymeshlab.MeshSet() + ms.load_new_mesh(obj_path) + ms.surface_reconstruction_screened_poisson(depth=10) + ms.set_current_mesh(1) + ms.save_current_mesh(obj_path) + + return trimesh.load(obj_path) + + +def get_mask(tensor, dim): + + mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0 + mask = mask.type_as(tensor) + + return mask + + +def blend_rgb_norm(rgb, norm, mask): + + # [0,0,0] or [127,127,127] should be marked as mask + final = rgb * (1 - mask) + norm * (mask) + + return final.astype(np.uint8) + + +def unwrap(image, data): + + img_uncrop = uncrop( + np.array( + Image.fromarray(image).resize( + data['uncrop_param']['box_shape'][:2])), + data['uncrop_param']['center'], data['uncrop_param']['scale'], + data['uncrop_param']['crop_shape']) + + img_orig = cv2.warpAffine(img_uncrop, + np.linalg.inv(data['uncrop_param']['M'])[:2, :], + data['uncrop_param']['ori_shape'][::-1][1:], + flags=cv2.INTER_CUBIC) + + return img_orig + + +# Losses to smooth / regularize the mesh shape +def update_mesh_shape_prior_losses(mesh, losses): + + # and (b) the edge length of the predicted mesh + losses["edge"]['value'] = mesh_edge_loss(mesh) + # mesh normal consistency + losses["nc"]['value'] = mesh_normal_consistency(mesh) + # mesh laplacian smoothing + losses["laplacian"]['value'] = mesh_laplacian_smoothing(mesh, + method="uniform") + + +def rename(old_dict, old_name, new_name): + new_dict = {} + for key, value in zip(old_dict.keys(), old_dict.values()): + new_key = key if key != old_name else new_name + new_dict[new_key] = old_dict[key] + return new_dict + + +def load_checkpoint(model, cfg): + + model_dict = model.state_dict() + main_dict = {} + normal_dict = {} + + device = torch.device(f"cuda:{cfg['test_gpus'][0]}") + + if os.path.exists(cfg.resume_path) and cfg.resume_path.endswith("ckpt"): + main_dict = torch.load(cfg.resume_path, + map_location=device)['state_dict'] + + main_dict = { + k: v + for k, v in main_dict.items() + if k in model_dict and v.shape == model_dict[k].shape and ( + 'reconEngine' not in k) and ("normal_filter" not in k) and ( + 'voxelization' not in k) + } + print(colored(f"Resume MLP weights from {cfg.resume_path}", 'green')) + + if os.path.exists(cfg.normal_path) and cfg.normal_path.endswith("ckpt"): + normal_dict = torch.load(cfg.normal_path, + map_location=device)['state_dict'] + + for key in normal_dict.keys(): + normal_dict = rename(normal_dict, key, + key.replace("netG", "netG.normal_filter")) + + normal_dict = { + k: v + for k, v in normal_dict.items() + if k in model_dict and v.shape == model_dict[k].shape + } + print(colored(f"Resume normal model from {cfg.normal_path}", 'green')) + + model_dict.update(main_dict) + model_dict.update(normal_dict) + model.load_state_dict(model_dict) + + model.netG = model.netG.to(device) + model.reconEngine = model.reconEngine.to(device) + + model.netG.training = False + model.netG.eval() + + del main_dict + del normal_dict + del model_dict + + return model + + +def read_smpl_constants(folder): + """Load smpl vertex code""" + smpl_vtx_std = np.loadtxt(os.path.join(folder, 'vertices.txt')) + min_x = np.min(smpl_vtx_std[:, 0]) + max_x = np.max(smpl_vtx_std[:, 0]) + min_y = np.min(smpl_vtx_std[:, 1]) + max_y = np.max(smpl_vtx_std[:, 1]) + min_z = np.min(smpl_vtx_std[:, 2]) + max_z = np.max(smpl_vtx_std[:, 2]) + + smpl_vtx_std[:, 0] = (smpl_vtx_std[:, 0] - min_x) / (max_x - min_x) + smpl_vtx_std[:, 1] = (smpl_vtx_std[:, 1] - min_y) / (max_y - min_y) + smpl_vtx_std[:, 2] = (smpl_vtx_std[:, 2] - min_z) / (max_z - min_z) + smpl_vertex_code = np.float32(np.copy(smpl_vtx_std)) + """Load smpl faces & tetrahedrons""" + smpl_faces = np.loadtxt(os.path.join(folder, 'faces.txt'), + dtype=np.int32) - 1 + smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] + + smpl_vertex_code[smpl_faces[:, 1]] + + smpl_vertex_code[smpl_faces[:, 2]]) / 3.0 + smpl_tetras = np.loadtxt(os.path.join(folder, 'tetrahedrons.txt'), + dtype=np.int32) - 1 + + return smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras + + +def surface_field_deformation(xyz, de_nn_verts, de_nn_normals, ori_nn_verts, ori_nn_normals): + ''' + xyz: [B, N, 3] + de_nn_verts: [B, N, 3] + de_nn_normals: [B, N, 3] + ori_nn_verts: [B, N, 3] + ori_nn_normals: [B, N, 3] + ''' + vector=xyz-de_nn_verts # [B, N, 3] + delta=torch.sum(vector*de_nn_normals, dim=-1, keepdim=True)*ori_nn_normals + ori_xyz=ori_nn_verts+delta + + return ori_xyz # the deformed xyz + + +def feat_select(feat, select): + + # feat [B, featx2, N] + # select [B, 1, N] + # return [B, feat, N] + + dim = feat.shape[1] // 2 + idx = torch.tile((1-select), (1, dim, 1))*dim + \ + torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select) + feat_select = torch.gather(feat, 1, idx.long()) + + return feat_select + +def get_visibility_color(xy, z, faces): + """get the visibility of vertices + + Args: + xy (torch.tensor): [N,2] + z (torch.tensor): [N,1] + faces (torch.tensor): [N,3] + size (int): resolution of rendered image + """ + + xyz = torch.cat((xy, -z), dim=1) + xyz = (xyz + 1.0) / 2.0 + faces = faces.long() + + rasterizer = Pytorch3dRasterizer(image_size=2**12) + meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...]) + raster_settings = rasterizer.raster_settings + + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, + ) + + vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :]) + vis_mask = torch.zeros(size=(z.shape[0], 1)) + vis_mask[vis_vertices_id] = 1.0 + + # 新增的部分: 检测边缘像素 + edge_mask = torch.zeros_like(pix_to_face) + offset=1 + for i in range(-1-offset, 2+offset): + for j in range(-1-offset, 2+offset): + if i == 0 and j == 0: + continue + shifted = torch.roll(pix_to_face, shifts=(i,j), dims=(0,1)) + edge_mask = torch.logical_or(edge_mask, shifted == -1) + + # 更新可见性掩码 + edge_faces = torch.unique(pix_to_face[edge_mask]) + edge_vertices = torch.unique(faces[edge_faces]) + vis_mask[edge_vertices] = 0.0 + + return vis_mask + + +def get_visibility(xy, z, faces): + """get the visibility of vertices + + Args: + xy (torch.tensor): [N,2] + z (torch.tensor): [N,1] + faces (torch.tensor): [N,3] + size (int): resolution of rendered image + """ + + xyz = torch.cat((xy, -z), dim=1) + xyz = (xyz + 1.0) / 2.0 + faces = faces.long() + + rasterizer = Pytorch3dRasterizer(image_size=2**12) + meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...]) + raster_settings = rasterizer.raster_settings + + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, + ) + + vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :]) + vis_mask = torch.zeros(size=(z.shape[0], 1)) + vis_mask[vis_vertices_id] = 1.0 + + # print("------------------------\n") + # print(f"keep points : {vis_mask.sum()/len(vis_mask)}") + + return vis_mask + + +def barycentric_coordinates_of_projection(points, vertices): + ''' https://github.com/MPI-IS/mesh/blob/master/mesh/geometry/barycentric_coordinates_of_projection.py + ''' + """Given a point, gives projected coords of that point to a triangle + in barycentric coordinates. + See + **Heidrich**, Computing the Barycentric Coordinates of a Projected Point, JGT 05 + at http://www.cs.ubc.ca/~heidrich/Papers/JGT.05.pdf + + :param p: point to project. [B, 3] + :param v0: first vertex of triangles. [B, 3] + :returns: barycentric coordinates of ``p``'s projection in triangle defined by ``q``, ``u``, ``v`` + vectorized so ``p``, ``q``, ``u``, ``v`` can all be ``3xN`` + """ + #(p, q, u, v) + v0, v1, v2 = vertices[:, 0], vertices[:, 1], vertices[:, 2] + p = points + + q = v0 + u = v1 - v0 + v = v2 - v0 + n = torch.cross(u, v) + s = torch.sum(n * n, dim=1) + # If the triangle edges are collinear, cross-product is zero, + # which makes "s" 0, which gives us divide by zero. So we + # make the arbitrary choice to set s to epsv (=numpy.spacing(1)), + # the closest thing to zero + s[s == 0] = 1e-6 + oneOver4ASquared = 1.0 / s + w = p - q + b2 = torch.sum(torch.cross(u, w) * n, dim=1) * oneOver4ASquared + b1 = torch.sum(torch.cross(w, v) * n, dim=1) * oneOver4ASquared + weights = torch.stack((1 - b1 - b2, b1, b2), dim=-1) + # check barycenric weights + # p_n = v0*weights[:,0:1] + v1*weights[:,1:2] + v2*weights[:,2:3] + return weights + + +def cal_sdf_batch(verts, faces, cmaps, vis, points): + + # verts [B, N_vert, 3] + # faces [B, N_face, 3] + # triangles [B, N_face, 3, 3] + # points [B, N_point, 3] + # cmaps [B, N_vert, 3] + + Bsize = points.shape[0] + + normals = Meshes(verts, faces).verts_normals_padded() + + # SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth + # 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 + # 2. fill mouth holes with 30 more faces + + if verts.shape[1] == 10475: + faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] + mouth_faces = torch.as_tensor( + SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(Bsize, 1, + 1).to(faces.device) + faces = torch.cat([faces, mouth_faces], dim=1) + + triangles = face_vertices(verts, faces) + normals = face_vertices(normals, faces) + cmaps = face_vertices(cmaps, faces) + vis = face_vertices(vis, faces) + + residues, pts_ind, _ = point_to_mesh_distance(points, triangles) + closest_triangles = torch.gather( + triangles, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, + 3)).view(-1, 3, 3) + closest_normals = torch.gather( + normals, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, + 3)).view(-1, 3, 3) + closest_cmaps = torch.gather( + cmaps, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, + 3)).view(-1, 3, 3) + closest_vis = torch.gather(vis, 1, pts_ind[:, :, None, + None].expand(-1, -1, 3, + 1)).view(-1, 3, 1) + bary_weights = barycentric_coordinates_of_projection( + points.view(-1, 3), closest_triangles) + + pts_cmap = (closest_cmaps * bary_weights[:, :, None]).sum(1).unsqueeze(0) + pts_vis = (closest_vis * + bary_weights[:, :, None]).sum(1).unsqueeze(0).ge(1e-1) + pts_norm = (closest_normals * + bary_weights[:, :, None]).sum(1).unsqueeze(0) * torch.tensor( + [-1.0, 1.0, -1.0]).type_as(normals) + pts_norm = F.normalize(pts_norm, dim=2) + pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3)) + + pts_signs = 2.0 * (check_sign(verts, faces[0], points).float() - 0.5) + pts_sdf = (pts_dist * pts_signs).unsqueeze(-1) + + return pts_sdf.view(Bsize, -1, + 1), pts_norm.view(Bsize, -1, 3), pts_cmap.view( + Bsize, -1, 3), pts_vis.view(Bsize, -1, 1) + + +def orthogonal(points, calibrations, transforms=None): + ''' + Compute the orthogonal projections of 3D points into the image plane by given projection matrix + :param points: [B, 3, N] Tensor of 3D points + :param calibrations: [B, 3, 4] Tensor of projection matrix + :param transforms: [B, 2, 3] Tensor of image transform matrix + :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane + ''' + rot = calibrations[:, :3, :3] + trans = calibrations[:, :3, 3:4] + pts = torch.baddbmm(trans, rot, points) # [B, 3, N] + if transforms is not None: + scale = transforms[:2, :2] + shift = transforms[:2, 2:3] + pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) + return pts + + +def projection(points, calib): + if torch.is_tensor(points): + calib = torch.as_tensor(calib) if not torch.is_tensor(calib) else calib + return torch.mm(calib[:3, :3], points.T).T + calib[:3, 3] + else: + return np.matmul(calib[:3, :3], points.T).T + calib[:3, 3] + + +def load_calib(calib_path): + calib_data = np.loadtxt(calib_path, dtype=float) + extrinsic = calib_data[:4, :4] + intrinsic = calib_data[4:8, :4] + calib_mat = np.matmul(intrinsic, extrinsic) + calib_mat = torch.from_numpy(calib_mat).float() + return calib_mat + + +def load_obj_mesh_for_Hoppe(mesh_file): + vertex_data = [] + face_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + vertices = np.array(vertex_data) + faces = np.array(face_data) + faces[faces > 0] -= 1 + + normals, _ = compute_normal(vertices, faces) + + return vertices, normals, faces + + +def load_obj_mesh_with_color(mesh_file): + vertex_data = [] + color_data = [] + face_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + c = list(map(float, values[4:7])) + color_data.append(c) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + vertices = np.array(vertex_data) + colors = np.array(color_data) + faces = np.array(face_data) + faces[faces > 0] -= 1 + + return vertices, colors, faces + + +def load_obj_mesh(mesh_file, with_normal=False, with_texture=False): + vertex_data = [] + norm_data = [] + uv_data = [] + + face_data = [] + face_norm_data = [] + face_uv_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + elif values[0] == 'vn': + vn = list(map(float, values[1:4])) + norm_data.append(vn) + elif values[0] == 'vt': + vt = list(map(float, values[1:3])) + uv_data.append(vt) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + # deal with texture + if len(values[1].split('/')) >= 2: + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[1]), values[1:4])) + face_uv_data.append(f) + f = list( + map(lambda x: int(x.split('/')[1]), + [values[3], values[4], values[1]])) + face_uv_data.append(f) + # tri mesh + elif len(values[1].split('/')[1]) != 0: + f = list(map(lambda x: int(x.split('/')[1]), values[1:4])) + face_uv_data.append(f) + # deal with normal + if len(values[1].split('/')) == 3: + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[2]), values[1:4])) + face_norm_data.append(f) + f = list( + map(lambda x: int(x.split('/')[2]), + [values[3], values[4], values[1]])) + face_norm_data.append(f) + # tri mesh + elif len(values[1].split('/')[2]) != 0: + f = list(map(lambda x: int(x.split('/')[2]), values[1:4])) + face_norm_data.append(f) + + vertices = np.array(vertex_data) + faces = np.array(face_data) + faces[faces > 0] -= 1 + + if with_texture and with_normal: + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) + face_uvs[face_uvs > 0] -= 1 + norms = np.array(norm_data) + if norms.shape[0] == 0: + norms, _ = compute_normal(vertices, faces) + face_normals = faces + else: + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) + face_normals[face_normals > 0] -= 1 + return vertices, faces, norms, face_normals, uvs, face_uvs + + if with_texture: + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) - 1 + return vertices, faces, uvs, face_uvs + + if with_normal: + norms = np.array(norm_data) + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) - 1 + return vertices, faces, norms, face_normals + + return vertices, faces + + +def normalize_v3(arr): + ''' Normalize a numpy array of 3 component vectors shape=(n,3) ''' + lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2) + eps = 0.00000001 + lens[lens < eps] = eps + arr[:, 0] /= lens + arr[:, 1] /= lens + arr[:, 2] /= lens + return arr + + +def compute_normal(vertices, faces): + # Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal + vert_norms = np.zeros(vertices.shape, dtype=vertices.dtype) + # Create an indexed view into the vertex array using the array of three indices for triangles + tris = vertices[faces] + # Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle + face_norms = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0]) + # n is now an array of normals per triangle. The length of each normal is dependent the vertices, + # we need to normalize these, so that our next step weights each normal equally. + normalize_v3(face_norms) + # now we have a normalized array of normals, one per triangle, i.e., per triangle normals. + # But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle, + # the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards. + # The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array + vert_norms[faces[:, 0]] += face_norms + vert_norms[faces[:, 1]] += face_norms + vert_norms[faces[:, 2]] += face_norms + normalize_v3(vert_norms) + + return vert_norms, face_norms + + +def save_obj_mesh(mesh_path, verts, faces): + file = open(mesh_path, 'w') + for v in verts: + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def save_obj_mesh_with_color(mesh_path, verts, faces, colors): + file = open(mesh_path, 'w') + + for idx, v in enumerate(verts): + c = colors[idx] + file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % + (v[0], v[1], v[2], c[0], c[1], c[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def calculate_mIoU(outputs, labels): + + SMOOTH = 1e-6 + + outputs = outputs.int() + labels = labels.int() + + intersection = ( + outputs + & labels).float().sum() # Will be zero if Truth=0 or Prediction=0 + union = (outputs | labels).float().sum() # Will be zzero if both are 0 + + iou = (intersection + SMOOTH) / (union + SMOOTH + ) # We smooth our devision to avoid 0/0 + + thresholded = torch.clamp( + 20 * (iou - 0.5), 0, + 10).ceil() / 10 # This is equal to comparing with thresolds + + return thresholded.mean().detach().cpu().numpy( + ) # Or thresholded.mean() if you are interested in average across the batch + + +def mask_filter(mask, number=1000): + """only keep {number} True items within a mask + + Args: + mask (bool array): [N, ] + number (int, optional): total True item. Defaults to 1000. + """ + true_ids = np.where(mask)[0] + keep_ids = np.random.choice(true_ids, size=number) + filter_mask = np.isin(np.arange(len(mask)), keep_ids) + + return filter_mask + + +def query_mesh(path): + + verts, faces_idx, _ = load_obj(path) + + return verts, faces_idx.verts_idx + + +def add_alpha(colors, alpha=0.7): + + colors_pad = np.pad(colors, ((0, 0), (0, 1)), + mode='constant', + constant_values=alpha) + + return colors_pad + + +def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type='smpl'): + + font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf") + font = ImageFont.truetype(font_path, 30) + grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), + nrow=nrow) + grid_img = Image.fromarray( + ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * + 255.0).astype(np.uint8)) + + # add text + draw = ImageDraw.Draw(grid_img) + grid_size = 512 + if loss is not None: + draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font) + + if type == 'smpl': + for col_id, col_txt in enumerate([ + 'image', 'smpl-norm(render)', 'cloth-norm(pred)', 'diff-norm', + 'diff-mask' + ]): + draw.text((10 + (col_id * grid_size), 5), + col_txt, (255, 0, 0), + font=font) + elif type == 'cloth': + for col_id, col_txt in enumerate( + ['cloth-norm(recon)']): + draw.text((10 + (col_id * grid_size), 5), + col_txt, (255, 0, 0), + font=font) + for col_id, col_txt in enumerate(['0', '90', '180', '270']): + draw.text((10 + (col_id * grid_size), grid_size * 2 + 5), + col_txt, (255, 0, 0), + font=font) + else: + print(f"{type} should be 'smpl' or 'cloth'") + + grid_img = grid_img.resize((grid_img.size[0], grid_img.size[1]), + Image.LANCZOS) + + return grid_img + + +def clean_mesh(verts, faces): + + device = verts.device + + mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), + faces.detach().cpu().numpy()) + mesh_lst = mesh_lst.split(only_watertight=False) + comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst] + mesh_clean = mesh_lst[comp_num.index(max(comp_num))] + + final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device) + final_faces = torch.as_tensor(mesh_clean.faces).int().to(device) + + return final_verts, final_faces + + +def merge_mesh(verts_A, faces_A, verts_B, faces_B, color=False): + + sep_mesh = trimesh.Trimesh(np.concatenate([verts_A, verts_B], axis=0), + np.concatenate( + [faces_A, faces_B + faces_A.max() + 1], + axis=0), + maintain_order=True, + process=False) + if color: + colors = np.ones_like(sep_mesh.vertices) + colors[:verts_A.shape[0]] *= np.array([255.0, 0.0, 0.0]) + colors[verts_A.shape[0]:] *= np.array([0.0, 255.0, 0.0]) + sep_mesh.visual.vertex_colors = colors + + # union_mesh = trimesh.boolean.union([trimesh.Trimesh(verts_A, faces_A), + # trimesh.Trimesh(verts_B, faces_B)], engine='blender') + + return sep_mesh + + +def mesh_move(mesh_lst, step, scale=1.0): + + trans = np.array([1.0, 0.0, 0.0]) * step + + resize_matrix = trimesh.transformations.scale_and_translate( + scale=(scale), translate=trans) + + results = [] + + for mesh in mesh_lst: + mesh.apply_transform(resize_matrix) + results.append(mesh) + + return results + + +def rescale_smpl(fitted_path, scale=100, translate=(0, 0, 0)): + + fitted_body = trimesh.load(fitted_path, + process=False, + maintain_order=True, + skip_materials=True) + resize_matrix = trimesh.transformations.scale_and_translate( + scale=(scale), translate=translate) + + fitted_body.apply_transform(resize_matrix) + + return np.array(fitted_body.vertices) + + +class SMPLX(): + + def __init__(self): + + self.current_dir = "smpl_related" # new smplx file in ECON folder + + self.smpl_verts_path = osp.join(self.current_dir, + "smpl_data/smpl_verts.npy") + self.smpl_faces_path = osp.join(self.current_dir, + "smpl_data/smpl_faces.npy") + self.smplx_verts_path = osp.join(self.current_dir, + "smpl_data/smplx_verts.npy") + self.smplx_faces_path = osp.join(self.current_dir, + "smpl_data/smplx_faces.npy") + self.cmap_vert_path = osp.join(self.current_dir, + "smpl_data/smplx_cmap.npy") + + self.smplx_to_smplx_path = osp.join(self.current_dir, + "smpl_data/smplx_to_smpl.pkl") + + self.smplx_eyeball_fid = osp.join(self.current_dir, + "smpl_data/eyeball_fid.npy") + self.smplx_fill_mouth_fid = osp.join(self.current_dir, + "smpl_data/fill_mouth_fid.npy") + + self.smplx_faces = np.load(self.smplx_faces_path) + self.smplx_verts = np.load(self.smplx_verts_path) + self.smpl_verts = np.load(self.smpl_verts_path) + self.smpl_faces = np.load(self.smpl_faces_path) + + self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid) + self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid) + + self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, 'rb')) + + self.model_dir = osp.join(self.current_dir, "models") + # self.tedra_dir = osp.join(self.current_dir, "../tedra_data") + + + + # copy from econ + self.smplx_flame_vid_path = osp.join( + self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy" + ) + self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl") + self.smpl_vert_seg_path = osp.join( + self.current_dir, "smpl_vert_segmentation.json" + ) + self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy") + self.smplx_vertex_lmkid_path = osp.join( + self.current_dir, "smpl_data/smplx_vertex_lmkid.npy" + ) + + self.smplx_vertex_lmkid = np.load(self.smplx_vertex_lmkid_path) + self.smpl_vert_seg = json.load(open(self.smpl_vert_seg_path)) + self.smpl_mano_vid = np.concatenate( + [ + self.smpl_vert_seg["rightHand"], self.smpl_vert_seg["rightHandIndex1"], + self.smpl_vert_seg["leftHand"], self.smpl_vert_seg["leftHandIndex1"] + ] + ) + + self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True) + self.smplx_mano_vid = np.concatenate( + [self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]] + ) + self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True) + self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)] + + + # hands + self.smplx_mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_mano_vid), 1.0 + ) + self.smpl_mano_vertex_mask = torch.zeros(self.smpl_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smpl_mano_vid), 1.0 + ) + + # face + self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_front_flame_vid), 1.0 + ) + self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0 + ) + + + self.ghum_smpl_pairs = torch.tensor( + [ + (0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19), + (15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40), + (23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29), + (32, 32) + ] + ).long() + + # smpl-smplx correspondence + self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73] + self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68] + self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist() + + self.extra_joint_ids = np.array( + [ + 61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72, + 74, 73 + ] + ) + + self.extra_joint_ids += 68 + + self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist()) + + + def cmap_smpl_vids(self, type): + + # keys: + # closest_faces - [6890, 3] with smplx vert_idx + # bc - [6890, 3] with barycentric weights + + cmap_smplx = torch.as_tensor(np.load(self.cmap_vert_path)).float() + if type == 'smplx': + return cmap_smplx + elif type == 'smpl': + bc = torch.as_tensor(self.smplx_to_smpl['bc'].astype(np.float32)) + closest_faces = self.smplx_to_smpl['closest_faces'].astype( + np.int32) + + cmap_smpl = torch.einsum('bij, bi->bj', cmap_smplx[closest_faces], + bc) + + return cmap_smpl + + + +# copy from ECON + +def apply_face_mask(mesh, face_mask): + + mesh.update_faces(face_mask) + mesh.remove_unreferenced_vertices() + + return mesh + + +def apply_vertex_mask(mesh, vertex_mask): + + faces_mask = vertex_mask[mesh.faces].any(dim=1) + mesh = apply_face_mask(mesh, faces_mask) + + return mesh + + +def apply_vertex_face_mask(mesh, vertex_mask, face_mask): + + faces_mask = vertex_mask[mesh.faces].any(dim=1) * torch.tensor(face_mask) + mesh.update_faces(faces_mask) + mesh.remove_unreferenced_vertices() + + return mesh + + +def clean_floats(mesh): + thres = mesh.vertices.shape[0] * 1e-2 + mesh_lst = mesh.split(only_watertight=False) + clean_mesh_lst = [mesh for mesh in mesh_lst if mesh.vertices.shape[0] > thres] + return sum(clean_mesh_lst) + +def isin(input, test_elements): + # 扩展输入和测试元素的维度以进行广播 + input = input.unsqueeze(-1) + test_elements = test_elements.unsqueeze(0) + + # 比较两个张量的元素 + comparison_result = torch.eq(input, test_elements) + + # 沿着新添加的维度进行求和,以检查每个输入元素是否在测试元素中 + isin_result = comparison_result.sum(-1).bool() + + return isin_result + + + +def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=True): + + smpl_tree = cKDTree(smpl_obj.vertices) + SMPL_container = SMPLX() + + from lib.dataset.PointFeat import ECON_PointFeat + + part_extractor = ECON_PointFeat( + torch.tensor(part_mesh.vertices).unsqueeze(0).to(device), + torch.tensor(part_mesh.faces).unsqueeze(0).to(device) + ) + + (part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device)) + + remove_mask = part_dist < thres + + if region == "hand": + _, idx = smpl_tree.query(full_mesh.vertices, k=1) + full_lmkid = SMPL_container.smplx_vertex_lmkid[idx] + remove_mask = torch.logical_and( + remove_mask, + torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0) + ) + + elif region == "face": + _, idx = smpl_tree.query(full_mesh.vertices, k=5) + face_space_mask = isin( + torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid) + ) + remove_mask = torch.logical_and( + remove_mask, + face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0) + ) + + BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1) + full_mesh.update_faces(BNI_part_mask.detach().cpu()) + full_mesh.remove_unreferenced_vertices() + + if clean: + full_mesh = clean_floats(full_mesh) + + return full_mesh + +def keep_largest(mesh): + mesh_lst = mesh.split(only_watertight=False) + keep_mesh = mesh_lst[0] + for mesh in mesh_lst: + if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]: + keep_mesh = mesh + return keep_mesh + + +def poisson(mesh, obj_path, depth=10, decimation=True): + + pcd_path = obj_path[:-4] + "_soups.ply" + assert (mesh.vertex_normals.shape[1] == 3) + mesh.export(pcd_path) + pcl = o3d.io.read_point_cloud(pcd_path) + with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm: + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( + pcl, depth=depth, n_threads=6 + ) + os.remove(pcd_path) + # only keep the largest component + largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles))) + + if decimation: + # mesh decimation for faster rendering + low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000) + return low_res_mesh + else: + return largest_mesh + + + \ No newline at end of file diff --git a/lib/dataset/tbfo.ttf b/lib/dataset/tbfo.ttf new file mode 100644 index 0000000000000000000000000000000000000000..6cc76fcd568a5a42edd71272a19b15214de0b0d5 Binary files /dev/null and b/lib/dataset/tbfo.ttf differ diff --git a/lib/hybrik/models/layers/Resnet.py b/lib/hybrik/models/layers/Resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..173c15e436a6fc3226e165ed5642e1cd0353559c --- /dev/null +++ b/lib/hybrik/models/layers/Resnet.py @@ -0,0 +1,223 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + dcn=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + norm_layer=nn.BatchNorm2d, + dcn=None): + super(Bottleneck, self).__init__() + self.dcn = dcn + self.with_dcn = dcn is not None + + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes, momentum=0.1) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + self.bn2 = norm_layer(planes, momentum=0.1) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * 4, momentum=0.1) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = F.relu(self.bn1(self.conv1(x)), inplace=True) + if not self.with_dcn: + out = F.relu(self.bn2(self.conv2(out)), inplace=True) + elif self.with_modulated_dcn: + offset_mask = self.conv2_offset(out) + offset = offset_mask[:, :18 * self.deformable_groups, :, :] + mask = offset_mask[:, -9 * self.deformable_groups:, :, :] + mask = mask.sigmoid() + out = F.relu(self.bn2(self.conv2(out, offset, mask))) + else: + offset = self.conv2_offset(out) + out = F.relu(self.bn2(self.conv2(out, offset)), inplace=True) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = F.relu(out) + + return out + + +class ResNet(nn.Module): + """ ResNet """ + + def __init__(self, + architecture, + norm_layer=nn.BatchNorm2d, + dcn=None, + stage_with_dcn=(False, False, False, False)): + super(ResNet, self).__init__() + self._norm_layer = norm_layer + assert architecture in [ + "resnet18", "resnet34", "resnet50", "resnet101", 'resnet152' + ] + layers = { + 'resnet18': [2, 2, 2, 2], + 'resnet34': [3, 4, 6, 3], + 'resnet50': [3, 4, 6, 3], + 'resnet101': [3, 4, 23, 3], + 'resnet152': [3, 8, 36, 3], + } + self.inplanes = 64 + if architecture == "resnet18" or architecture == 'resnet34': + self.block = BasicBlock + else: + self.block = Bottleneck + self.layers = layers[architecture] + + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = norm_layer(64, eps=1e-5, momentum=0.1, affine=True) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + stage_dcn = [dcn if with_dcn else None for with_dcn in stage_with_dcn] + + self.layer1 = self.make_layer(self.block, + 64, + self.layers[0], + dcn=stage_dcn[0]) + self.layer2 = self.make_layer(self.block, + 128, + self.layers[1], + stride=2, + dcn=stage_dcn[1]) + self.layer3 = self.make_layer(self.block, + 256, + self.layers[2], + stride=2, + dcn=stage_dcn[2]) + + self.layer4 = self.make_layer(self.block, + 512, + self.layers[3], + stride=2, + dcn=stage_dcn[3]) + + def forward(self, x): + x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) # 64 * h/4 * w/4 + x = self.layer1(x) # 256 * h/4 * w/4 + x = self.layer2(x) # 512 * h/8 * w/8 + x = self.layer3(x) # 1024 * h/16 * w/16 + x = self.layer4(x) # 2048 * h/32 * w/32 + return x + + def stages(self): + return [self.layer1, self.layer2, self.layer3, self.layer4] + + def make_layer(self, block, planes, blocks, stride=1, dcn=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + self._norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, + planes, + stride, + downsample, + norm_layer=self._norm_layer, + dcn=dcn)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + norm_layer=self._norm_layer, + dcn=dcn)) + + return nn.Sequential(*layers) diff --git a/lib/hybrik/models/layers/smpl/SMPL.py b/lib/hybrik/models/layers/smpl/SMPL.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5d65117cb632c54860a58c6d4380b32ffcdcab --- /dev/null +++ b/lib/hybrik/models/layers/smpl/SMPL.py @@ -0,0 +1,327 @@ +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn + +from .lbs import lbs, hybrik, rotmat_to_quat, quat_to_rotmat, rotation_matrix_to_angle_axis + +try: + import cPickle as pk +except ImportError: + import pickle as pk + +ModelOutput = namedtuple( + 'ModelOutput', ['vertices', 'joints', 'joints_from_verts', 'rot_mats']) +ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) + + +def to_tensor(array, dtype=torch.float32): + if 'torch.tensor' not in str(type(array)): + return torch.tensor(array, dtype=dtype) + + +class Struct(object): + + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +class SMPL_layer(nn.Module): + NUM_JOINTS = 23 + NUM_BODY_JOINTS = 23 + NUM_BETAS = 10 + JOINT_NAMES = [ + 'pelvis', + 'left_hip', + 'right_hip', # 2 + 'spine1', + 'left_knee', + 'right_knee', # 5 + 'spine2', + 'left_ankle', + 'right_ankle', # 8 + 'spine3', + 'left_foot', + 'right_foot', # 11 + 'neck', + 'left_collar', + 'right_collar', # 14 + 'jaw', # 15 + 'left_shoulder', + 'right_shoulder', # 17 + 'left_elbow', + 'right_elbow', # 19 + 'left_wrist', + 'right_wrist', # 21 + 'left_thumb', + 'right_thumb', # 23 + 'head', + 'left_middle', + 'right_middle', # 26 + 'left_bigtoe', + 'right_bigtoe' # 28 + ] + LEAF_NAMES = [ + 'head', 'left_middle', 'right_middle', 'left_bigtoe', 'right_bigtoe' + ] + root_idx_17 = 0 + root_idx_smpl = 0 + + def __init__(self, + model_path, + h36m_jregressor, + gender='neutral', + dtype=torch.float32, + num_joints=29): + ''' SMPL model layers + + Parameters: + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + gender: str, optional + Which gender to load + ''' + super(SMPL_layer, self).__init__() + + self.ROOT_IDX = self.JOINT_NAMES.index('pelvis') + self.LEAF_IDX = [ + self.JOINT_NAMES.index(name) for name in self.LEAF_NAMES + ] + self.SPINE3_IDX = 9 + + with open(model_path, 'rb') as smpl_file: + self.smpl_data = Struct(**pk.load(smpl_file, encoding='latin1')) + + self.gender = gender + + self.dtype = dtype + + self.faces = self.smpl_data.f + ''' Register Buffer ''' + # Faces + self.register_buffer( + 'faces_tensor', + to_tensor(to_np(self.smpl_data.f, dtype=np.int64), + dtype=torch.long)) + + # The vertices of the template model, (6890, 3) + self.register_buffer( + 'v_template', + to_tensor(to_np(self.smpl_data.v_template), dtype=dtype)) + + # The shape components + # Shape blend shapes basis, (6890, 3, 10) + self.register_buffer( + 'shapedirs', to_tensor(to_np(self.smpl_data.shapedirs), + dtype=dtype)) + + # Pose blend shape basis: 6890 x 3 x 23*9, reshaped to 6890*3 x 23*9 + num_pose_basis = self.smpl_data.posedirs.shape[-1] + # 23*9 x 6890*3 + posedirs = np.reshape(self.smpl_data.posedirs, [-1, num_pose_basis]).T + self.register_buffer('posedirs', to_tensor(to_np(posedirs), + dtype=dtype)) + + # Vertices to Joints location (23 + 1, 6890) + self.register_buffer( + 'J_regressor', + to_tensor(to_np(self.smpl_data.J_regressor), dtype=dtype)) + # Vertices to Human3.6M Joints location (17, 6890) + self.register_buffer('J_regressor_h36m', + to_tensor(to_np(h36m_jregressor), dtype=dtype)) + + self.num_joints = num_joints + + # indices of parents for each joints + parents = torch.zeros(len(self.JOINT_NAMES), dtype=torch.long) + parents[:(self.NUM_JOINTS + 1)] = to_tensor( + to_np(self.smpl_data.kintree_table[0])).long() + parents[0] = -1 + # extend kinematic tree + parents[24] = 15 + parents[25] = 22 + parents[26] = 23 + parents[27] = 10 + parents[28] = 11 + if parents.shape[0] > self.num_joints: + parents = parents[:24] + + self.register_buffer('children_map', + self._parents_to_children(parents)) + # (24,) + self.register_buffer('parents', parents) + + # (6890, 23 + 1) + self.register_buffer( + 'lbs_weights', to_tensor(to_np(self.smpl_data.weights), + dtype=dtype)) + + def _parents_to_children(self, parents): + children = torch.ones_like(parents) * -1 + for i in range(self.num_joints): + if children[parents[i]] < 0: + children[parents[i]] = i + for i in self.LEAF_IDX: + if i < children.shape[0]: + children[i] = -1 + + children[self.SPINE3_IDX] = -3 + children[0] = 3 + children[self.SPINE3_IDX] = self.JOINT_NAMES.index('neck') + + return children + + def forward(self, + pose_axis_angle, + betas, + global_orient, + transl=None, + return_verts=True): + ''' Forward pass for the SMPL model + + Parameters + ---------- + pose_axis_angle: torch.tensor, optional, shape Bx(J*3) + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + betas: torch.tensor, optional, shape Bx10 + It can used if shape parameters + `betas` are predicted from some external model. + (default=None) + global_orient: torch.tensor, optional, shape Bx3 + Global Orientations. + transl: torch.tensor, optional, shape Bx3 + Global Translations. + return_verts: bool, optional + Return the vertices. (default=True) + + Returns + ------- + ''' + # batch_size = pose_axis_angle.shape[0] + + # concate root orientation with thetas + if global_orient is not None: + full_pose = torch.cat([global_orient, pose_axis_angle], dim=1) + else: + full_pose = pose_axis_angle + + # Translate thetas to rotation matrics + pose2rot = True + # vertices: (B, N, 3), joints: (B, K, 3) + vertices, joints, rot_mats, joints_from_verts_h36m = lbs( + betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.J_regressor_h36m, + self.parents, + self.lbs_weights, + pose2rot=pose2rot, + dtype=self.dtype) + + if transl is not None: + # apply translations + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + joints_from_verts_h36m += transl.unsqueeze(dim=1) + else: + vertices = vertices - \ + joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze( + 1).detach() + joints = joints - \ + joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() + joints_from_verts_h36m = joints_from_verts_h36m - \ + joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze( + 1).detach() + + output = ModelOutput(vertices=vertices, + joints=joints, + rot_mats=rot_mats, + joints_from_verts=joints_from_verts_h36m) + return output + + def hybrik(self, + pose_skeleton, + betas, + phis, + global_orient, + transl=None, + return_verts=True, + leaf_thetas=None): + ''' Inverse pass for the SMPL model + + Parameters + ---------- + pose_skeleton: torch.tensor, optional, shape Bx(J*3) + It should be a tensor that contains joint locations in + (X, Y, Z) format. (default=None) + betas: torch.tensor, optional, shape Bx10 + It can used if shape parameters + `betas` are predicted from some external model. + (default=None) + global_orient: torch.tensor, optional, shape Bx3 + Global Orientations. + transl: torch.tensor, optional, shape Bx3 + Global Translations. + return_verts: bool, optional + Return the vertices. (default=True) + + Returns + ------- + ''' + batch_size = pose_skeleton.shape[0] + + if leaf_thetas is not None: + leaf_thetas = leaf_thetas.reshape(batch_size * 5, 4) + leaf_thetas = quat_to_rotmat(leaf_thetas) + + vertices, new_joints, rot_mats, joints_from_verts = hybrik( + betas, + global_orient, + pose_skeleton, + phis, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.J_regressor_h36m, + self.parents, + self.children_map, + self.lbs_weights, + dtype=self.dtype, + train=self.training, + leaf_thetas=leaf_thetas) + + rot_mats = rot_mats.reshape(batch_size, 24, 3, 3) + # rot_aa = rotation_matrix_to_angle_axis(rot_mats) + # rot_mats = rotmat_to_quat(rot_mats).reshape(batch_size, 24 * 4) + + if transl is not None: + new_joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + # joints_from_verts += transl.unsqueeze(dim=1) + else: + vertices = vertices - \ + joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() + new_joints = new_joints - \ + new_joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() + # joints_from_verts = joints_from_verts - joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() + + output = ModelOutput(vertices=vertices, + joints=new_joints, + rot_mats=rot_mats, + joints_from_verts=joints_from_verts) + return output diff --git a/lib/hybrik/models/layers/smpl/lbs.py b/lib/hybrik/models/layers/smpl/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2414b8d9775e5c74e82864bbcddbdff2318c4c --- /dev/null +++ b/lib/hybrik/models/layers/smpl/lbs.py @@ -0,0 +1,1476 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + +import torch +import torch.nn.functional as F + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) + + +def find_dynamic_lmk_idx_and_bcoords(vertices, + pose, + dynamic_lmk_faces_idx, + dynamic_lmk_b_coords, + neck_kin_chain, + dtype=torch.float32): + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + neck_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + batch_size = vertices.shape[0] + + aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + neck_kin_chain) + rot_mats = batch_rodrigues(aa_pose.view(-1, 3), + dtype=dtype).view(batch_size, -1, 3, 3) + + rel_rot_mat = torch.eye(3, device=vertices.device, + dtype=dtype).unsqueeze_(dim=0).repeat( + batch_size, 1, 1) + for idx in range(len(neck_kin_chain)): + rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, + y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( + batch_size, -1, 3) + + lmk_faces += torch.arange(batch_size, dtype=torch.long, + device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def joints2bones(joints, parents): + ''' Decompose joints location to bone length and direction. + + Parameters + ---------- + joints: torch.tensor Bx24x3 + ''' + assert joints.shape[1] == parents.shape[0] + bone_dirs = torch.zeros_like(joints) + bone_lens = torch.zeros_like(joints[:, :, :1]) + + for c_id in range(parents.shape[0]): + p_id = parents[c_id] + if p_id == -1: + # Parent node + bone_dirs[:, c_id] = joints[:, c_id] + else: + # Child node + # (B, 3) + diff = joints[:, c_id] - joints[:, p_id] + length = torch.norm(diff, dim=1, keepdim=True) + 1e-8 + direct = diff / length + + bone_dirs[:, c_id] = direct + bone_lens[:, c_id] = length + + return bone_dirs, bone_lens + + +def bones2joints(bone_dirs, bone_lens, parents): + ''' Recover bone length and direction to joints location. + + Parameters + ---------- + bone_dirs: torch.tensor 1x24x3 + bone_lens: torch.tensor Bx24x1 + ''' + batch_size = bone_lens.shape[0] + joints = torch.zeros_like(bone_dirs).expand(batch_size, 24, 3) + + for c_id in range(parents.shape[0]): + p_id = parents[c_id] + if p_id == -1: + # Parent node + joints[:, c_id] = bone_dirs[:, c_id] + else: + # Child node + joints[:, c_id] = joints[:, p_id] + \ + bone_dirs[:, c_id] * bone_lens[:, c_id] + + return joints + + +def lbs(betas, + pose, + v_template, + shapedirs, + posedirs, + J_regressor, + J_regressor_h36m, + parents, + lbs_weights, + pose2rot=True, + dtype=torch.float32): + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + rot_mats: torch.tensor BxJx3x3 + The rotation matrics of each joints + ''' + batch_size = max(betas.shape[0], pose.shape[0]) + device = betas.device + + # Add shape contribution + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + + # 3. Add pose blend shapes + # N x J x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + if pose2rot: + if pose.numel() == batch_size * 24 * 4: + rot_mats = quat_to_rotmat(pose.reshape(batch_size * 24, + 4)).reshape( + batch_size, 24, 3, 3) + else: + rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view( + [batch_size, -1, 3, 3]) + + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + # (N x P) x (P, V * 3) -> N x V x 3 + pose_offsets = torch.matmul(pose_feature, posedirs) \ + .view(batch_size, -1, 3) + else: + pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident + rot_mats = pose.view(batch_size, -1, 3, 3) + + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, + J, + parents[:24], + dtype=dtype) + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, + device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + J_from_verts = vertices2joints(J_regressor_h36m, verts) + + return verts, J_transformed, rot_mats, J_from_verts + + +def hybrik(betas, + global_orient, + pose_skeleton, + phis, + v_template, + shapedirs, + posedirs, + J_regressor, + J_regressor_h36m, + parents, + children, + lbs_weights, + dtype=torch.float32, + train=False, + leaf_thetas=None): + ''' Performs Linear Blend Skinning with the given shape and skeleton joints + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + global_orient : torch.tensor Bx3 + The tensor of global orientation + pose_skeleton : torch.tensor BxJ*3 + The pose skeleton in (X, Y, Z) format + phis : torch.tensor BxJx2 + The rotation on bone axis parameters + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + J_regressor_h36m : torch.tensor 17xV + The regressor array that is used to calculate the 17 Human3.6M joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic parents for the model + children: dict + The dictionary that describes the kinematic chidrens for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + rot_mats: torch.tensor BxJx3x3 + The rotation matrics of each joints + ''' + batch_size = max(betas.shape[0], pose_skeleton.shape[0]) + device = betas.device + + # 1. Add shape contribution + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # 2. Get the rest joints + # NxJx3 array + if leaf_thetas is not None: + rest_J = vertices2joints(J_regressor, v_shaped) + else: + rest_J = torch.zeros((v_shaped.shape[0], 29, 3), + dtype=dtype, + device=device) + rest_J[:, :24] = vertices2joints(J_regressor, v_shaped) + + leaf_number = [411, 2445, 5905, 3216, 6617] + leaf_vertices = v_shaped[:, leaf_number].clone() + rest_J[:, 24:] = leaf_vertices + + # 3. Get the rotation matrics + if train: + rot_mats, rotate_rest_pose = batch_inverse_kinematics_transform( + pose_skeleton, + global_orient, + phis, + rest_J.clone(), + children, + parents, + dtype=dtype, + train=train, + leaf_thetas=leaf_thetas) + else: + rot_mats, rotate_rest_pose = batch_inverse_kinematics_transform_optimized( + pose_skeleton, + phis, + rest_J.clone(), + children, + parents, + dtype=dtype, + train=train, + leaf_thetas=leaf_thetas) + + test_joints = True + if test_joints: + J_transformed, A = batch_rigid_transform(rot_mats, + rest_J[:, :24].clone(), + parents[:24], + dtype=dtype) + else: + J_transformed = None + + # assert torch.mean(torch.abs(rotate_rest_pose - J_transformed)) < 1e-5 + # 4. Add pose blend shapes + # rot_mats: N x (J + 1) x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + pose_feature = (rot_mats[:, 1:] - ident).view([batch_size, -1]) + pose_offsets = torch.matmul(pose_feature, posedirs) \ + .view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, + device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + if J_regressor_h36m is not None: + J_from_verts_h36m = vertices2joints(J_regressor_h36m, verts) + else: + J_from_verts_h36m = None + + return verts, J_transformed, rot_mats, J_from_verts_h36m + + +def vertices2joints(J_regressor, vertices): + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas, shape_disps): + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def transform_mat(R, t): + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], + dim=2) + + +def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints. (Template Pose) + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + joints = torch.unsqueeze(joints, dim=-1) + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]].clone() + + # (B, K + 1, 4, 4) + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape( + -1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + # (B, 4, 4) x (B, 4, 4) + curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, + i]) + transform_chain.append(curr_res) + + # (B, K + 1, 4, 4) + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms + + +def batch_inverse_kinematics_transform(pose_skeleton, + global_orient, + phis, + rest_pose, + children, + parents, + dtype=torch.float32, + train=False, + leaf_thetas=None): + """ + Applies a batch of inverse kinematics transfoirm to the joints + + Parameters + ---------- + pose_skeleton : torch.tensor BxNx3 + Locations of estimated pose skeleton. + global_orient : torch.tensor Bx1x3x3 + Tensor of global rotation matrices + phis : torch.tensor BxNx2 + The rotation on bone axis parameters + rest_pose : torch.tensor Bx(N+1)x3 + Locations of rest_pose. (Template Pose) + children: dict + The dictionary that describes the kinematic chidrens for the model + parents : torch.tensor Bx(N+1) + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + rot_mats: torch.tensor Bx(N+1)x3x3 + The rotation matrics of each joints + rel_transforms : torch.tensor Bx(N+1)x4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + batch_size = pose_skeleton.shape[0] + device = pose_skeleton.device + + rel_rest_pose = rest_pose.clone() + rel_rest_pose[:, 1:] -= rest_pose[:, parents[1:]].clone() + rel_rest_pose = torch.unsqueeze(rel_rest_pose, dim=-1) + + # rotate the T pose + rotate_rest_pose = torch.zeros_like(rel_rest_pose) + # set up the root + rotate_rest_pose[:, 0] = rel_rest_pose[:, 0] + + rel_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1).detach() + rel_pose_skeleton[:, 1:] = rel_pose_skeleton[:, 1:] - \ + rel_pose_skeleton[:, parents[1:]].clone() + rel_pose_skeleton[:, 0] = rel_rest_pose[:, 0] + + # the predicted final pose + final_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1) + final_pose_skeleton = final_pose_skeleton - \ + final_pose_skeleton[:, 0:1] + rel_rest_pose[:, 0:1] + + rel_rest_pose = rel_rest_pose + rel_pose_skeleton = rel_pose_skeleton + final_pose_skeleton = final_pose_skeleton + rotate_rest_pose = rotate_rest_pose + + assert phis.dim() == 3 + phis = phis / (torch.norm(phis, dim=2, keepdim=True) + 1e-8) + + # TODO + if train: + global_orient_mat = batch_get_pelvis_orient(rel_pose_skeleton.clone(), + rel_rest_pose.clone(), + parents, children, dtype) + else: + global_orient_mat = batch_get_pelvis_orient_svd( + rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents, + children, dtype) + + rot_mat_chain = [global_orient_mat] + rot_mat_local = [global_orient_mat] + # leaf nodes rot_mats + if leaf_thetas is not None: + leaf_cnt = 0 + leaf_rot_mats = leaf_thetas.view([batch_size, 5, 3, 3]) + + for i in range(1, parents.shape[0]): + if children[i] == -1: + # leaf nodes + if leaf_thetas is not None: + rot_mat = leaf_rot_mats[:, leaf_cnt, :, :] + leaf_cnt += 1 + + rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[ + i]] + torch.matmul(rot_mat_chain[parents[i]], + rel_rest_pose[:, i]) + + rot_mat_chain.append( + torch.matmul(rot_mat_chain[parents[i]], rot_mat)) + rot_mat_local.append(rot_mat) + elif children[i] == -3: + # three children + rotate_rest_pose[:, + i] = rotate_rest_pose[:, + parents[i]] + torch.matmul( + rot_mat_chain[ + parents[i]], + rel_rest_pose[:, i]) + + spine_child = [] + for c in range(1, parents.shape[0]): + if parents[c] == i and c not in spine_child: + spine_child.append(c) + + # original + spine_child = [] + for c in range(1, parents.shape[0]): + if parents[c] == i and c not in spine_child: + spine_child.append(c) + + children_final_loc = [] + children_rest_loc = [] + for c in spine_child: + temp = final_pose_skeleton[:, c] - rotate_rest_pose[:, i] + children_final_loc.append(temp) + + children_rest_loc.append(rel_rest_pose[:, c].clone()) + + rot_mat = batch_get_3children_orient_svd(children_final_loc, + children_rest_loc, + rot_mat_chain[parents[i]], + spine_child, dtype) + + rot_mat_chain.append( + torch.matmul(rot_mat_chain[parents[i]], rot_mat)) + rot_mat_local.append(rot_mat) + else: + # (B, 3, 1) + rotate_rest_pose[:, + i] = rotate_rest_pose[:, + parents[i]] + torch.matmul( + rot_mat_chain[ + parents[i]], + rel_rest_pose[:, i]) + # (B, 3, 1) + child_final_loc = final_pose_skeleton[:, children[ + i]] - rotate_rest_pose[:, i] + + if not train: + orig_vec = rel_pose_skeleton[:, children[i]] + template_vec = rel_rest_pose[:, children[i]] + norm_t = torch.norm(template_vec, dim=1, keepdim=True) + orig_vec = orig_vec * norm_t / \ + torch.norm(orig_vec, dim=1, keepdim=True) + + diff = torch.norm(child_final_loc - orig_vec, + dim=1, + keepdim=True) + big_diff_idx = torch.where(diff > 15 / 1000)[0] + + child_final_loc[big_diff_idx] = orig_vec[big_diff_idx] + + child_final_loc = torch.matmul( + rot_mat_chain[parents[i]].transpose(1, 2), child_final_loc) + + child_rest_loc = rel_rest_pose[:, children[i]] + # (B, 1, 1) + child_final_norm = torch.norm(child_final_loc, dim=1, keepdim=True) + child_rest_norm = torch.norm(child_rest_loc, dim=1, keepdim=True) + + child_final_norm = torch.norm(child_final_loc, dim=1, keepdim=True) + + # (B, 3, 1) + axis = torch.cross(child_rest_loc, child_final_loc, dim=1) + axis_norm = torch.norm(axis, dim=1, keepdim=True) + + # (B, 1, 1) + cos = torch.sum( + child_rest_loc * child_final_loc, dim=1, + keepdim=True) / (child_rest_norm * child_final_norm + 1e-8) + sin = axis_norm / (child_rest_norm * child_final_norm + 1e-8) + + # (B, 3, 1) + axis = axis / (axis_norm + 1e-8) + + # Convert location revolve to rot_mat by rodrigues + # (B, 1, 1) + rx, ry, rz = torch.split(axis, 1, dim=1) + zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) + + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat_loc = ident + sin * K + (1 - cos) * torch.bmm(K, K) + + # Convert spin to rot_mat + # (B, 3, 1) + spin_axis = child_rest_loc / child_rest_norm + # (B, 1, 1) + rx, ry, rz = torch.split(spin_axis, 1, dim=1) + zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + # (B, 1, 1) + cos, sin = torch.split(phis[:, i - 1], 1, dim=1) + cos = torch.unsqueeze(cos, dim=2) + sin = torch.unsqueeze(sin, dim=2) + rot_mat_spin = ident + sin * K + (1 - cos) * torch.bmm(K, K) + rot_mat = torch.matmul(rot_mat_loc, rot_mat_spin) + + rot_mat_chain.append( + torch.matmul(rot_mat_chain[parents[i]], rot_mat)) + rot_mat_local.append(rot_mat) + + # (B, K + 1, 3, 3) + rot_mats = torch.stack(rot_mat_local, dim=1) + + return rot_mats, rotate_rest_pose.squeeze(-1) + + +def batch_inverse_kinematics_transform_optimized(pose_skeleton, + phis, + rest_pose, + children, + parents, + dtype=torch.float32, + train=False, + leaf_thetas=None): + """ + Applies a batch of inverse kinematics transfoirm to the joints + + Parameters + ---------- + pose_skeleton : torch.tensor BxNx3 + Locations of estimated pose skeleton. + global_orient : torch.tensor Bx1x3x3 + Tensor of global rotation matrices + phis : torch.tensor BxNx2 + The rotation on bone axis parameters + rest_pose : torch.tensor Bx(N+1)x3 + Locations of rest_pose. (Template Pose) + children: dict + The dictionary that describes the kinematic chidrens for the model + parents : torch.tensor Bx(N+1) + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + rot_mats: torch.tensor Bx(N+1)x3x3 + The rotation matrics of each joints + rel_transforms : torch.tensor Bx(N+1)x4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + batch_size = pose_skeleton.shape[0] + device = pose_skeleton.device + + rel_rest_pose = rest_pose.clone() + rel_rest_pose[:, 1:] -= rest_pose[:, parents[1:]].clone() + rel_rest_pose = torch.unsqueeze(rel_rest_pose, dim=-1) + + # rotate the T pose + rotate_rest_pose = torch.zeros_like(rel_rest_pose) + # set up the root + rotate_rest_pose[:, 0] = rel_rest_pose[:, 0] + + rel_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1).detach() + rel_pose_skeleton[:, 1:] = rel_pose_skeleton[:, 1:] - \ + rel_pose_skeleton[:, parents[1:]].clone() + rel_pose_skeleton[:, 0] = rel_rest_pose[:, 0] + + # the predicted final pose + final_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1) + final_pose_skeleton = final_pose_skeleton - \ + final_pose_skeleton[:, [0]] + rel_rest_pose[:, [0]] + + # assert phis.dim() == 3 + phis = phis / (torch.norm(phis, dim=2, keepdim=True) + 1e-8) + + # TODO + if train: + global_orient_mat = batch_get_pelvis_orient(rel_pose_skeleton.clone(), + rel_rest_pose.clone(), + parents, children, dtype) + else: + global_orient_mat = batch_get_pelvis_orient_svd( + rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents, + children, dtype) + + # rot_mat_chain = [global_orient_mat] + # rot_mat_local = [global_orient_mat] + + rot_mat_chain = torch.zeros((batch_size, 24, 3, 3), + dtype=torch.float32, + device=pose_skeleton.device) + rot_mat_local = torch.zeros_like(rot_mat_chain) + rot_mat_chain[:, 0] = global_orient_mat + rot_mat_local[:, 0] = global_orient_mat + + # leaf nodes rot_mats + if leaf_thetas is not None: + # leaf_cnt = 0 + leaf_rot_mats = leaf_thetas.view([batch_size, 5, 3, 3]) + + idx_levs = [ + [0], # 0 + [3], # 1 + [6], # 2 + [9], # 3 + [1, 2, 12, 13, 14], # 4 + [4, 5, 15, 16, 17], # 5 + [7, 8, 18, 19], # 6 + [10, 11, 20, 21], # 7 + [22, 23], # 8 + [24, 25, 26, 27, 28] # 9 + ] + if leaf_thetas is not None: + idx_levs = idx_levs[:-1] + + for idx_lev in range(1, len(idx_levs)): + indices = idx_levs[idx_lev] + if idx_lev == len(idx_levs) - 1: + # leaf nodes + if leaf_thetas is not None: + rot_mat = leaf_rot_mats[:, :, :, :] + parent_indices = [15, 22, 23, 10, 11] + + # rotate_rest_pose[:, indices] = rotate_rest_pose[:, parent_indices] + torch.matmul( + # rot_mat_chain[:, parent_indices], + # rel_rest_pose[:, indices] + # ) + + # rot_mat_chain[:, indices] = torch.matmul( + # rot_mat_chain[:, parent_indices], + # rot_mat + # ) + rot_mat_local[:, parent_indices] = rot_mat + + if (torch.det(rot_mat) < 0).any(): + # print( + # 0, + # torch.det(rot_mat_loc) < 0, + # torch.det(rot_mat_spin) < 0 + # ) + print('Something wrong.') + elif idx_lev == 3: + # three children + idx = indices[0] + rotate_rest_pose[:, idx] = rotate_rest_pose[:, parents[ + idx]] + torch.matmul(rot_mat_chain[:, parents[idx]], + rel_rest_pose[:, idx]) + + # original + spine_child = [12, 13, 14] + # for c in range(1, parents.shape[0]): + # if parents[c] == idx and c not in spine_child: + # spine_child.append(c) + + children_final_loc = [] + children_rest_loc = [] + for c in spine_child: + temp = final_pose_skeleton[:, c] - rotate_rest_pose[:, idx] + children_final_loc.append(temp) + + children_rest_loc.append(rel_rest_pose[:, c].clone()) + + rot_mat = batch_get_3children_orient_svd( + children_final_loc, children_rest_loc, + rot_mat_chain[:, parents[idx]], spine_child, dtype) + + rot_mat_chain[:, + idx] = torch.matmul(rot_mat_chain[:, parents[idx]], + rot_mat) + + rot_mat_local[:, idx] = rot_mat + + if (torch.det(rot_mat) < 0).any(): + print(1) + else: + len_indices = len(indices) + # (B, K, 3, 1) + rotate_rest_pose[:, indices] = rotate_rest_pose[:, parents[ + indices]] + torch.matmul(rot_mat_chain[:, parents[indices]], + rel_rest_pose[:, indices]) + # (B, 3, 1) + child_final_loc = final_pose_skeleton[:, children[ + indices]] - rotate_rest_pose[:, indices] + + if not train: + orig_vec = rel_pose_skeleton[:, children[indices]] + template_vec = rel_rest_pose[:, children[indices]] + + norm_t = torch.norm(template_vec, dim=2, + keepdim=True) # B x K x 1 + + orig_vec = orig_vec * norm_t / \ + torch.norm(orig_vec, dim=2, keepdim=True) # B x K x 3 + + diff = torch.norm(child_final_loc - orig_vec, + dim=2, + keepdim=True).reshape(-1) + big_diff_idx = torch.where(diff > 15 / 1000)[0] + + # child_final_loc[big_diff_idx] = orig_vec[big_diff_idx] + child_final_loc = child_final_loc.reshape( + batch_size * len_indices, 3, 1) + orig_vec = orig_vec.reshape(batch_size * len_indices, 3, 1) + child_final_loc[big_diff_idx] = orig_vec[big_diff_idx] + child_final_loc = child_final_loc.reshape( + batch_size, len_indices, 3, 1) + + child_final_loc = torch.matmul( + rot_mat_chain[:, parents[indices]].transpose(2, 3), + child_final_loc) + + # need rotation back ? + child_rest_loc = rel_rest_pose[:, children[indices]] + # (B, K, 1, 1) + child_final_norm = torch.norm(child_final_loc, dim=2, keepdim=True) + child_rest_norm = torch.norm(child_rest_loc, dim=2, keepdim=True) + + # (B, K, 3, 1) + axis = torch.cross(child_rest_loc, child_final_loc, dim=2) + axis_norm = torch.norm(axis, dim=2, keepdim=True) + + # (B, K, 1, 1) + cos = torch.sum( + child_rest_loc * child_final_loc, dim=2, + keepdim=True) / (child_rest_norm * child_final_norm + 1e-8) + sin = axis_norm / (child_rest_norm * child_final_norm + 1e-8) + + # (B, K, 3, 1) + axis = axis / (axis_norm + 1e-8) + + # Convert location revolve to rot_mat by rodrigues + # (B, K, 1, 1) + rx, ry, rz = torch.split(axis, 1, dim=2) + zeros = torch.zeros((batch_size, len_indices, 1, 1), + dtype=dtype, + device=device) + + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=2) \ + .view((batch_size, len_indices, 3, 3)) + ident = torch.eye(3, dtype=dtype, + device=device).reshape(1, 1, 3, 3) + rot_mat_loc = ident + sin * K + (1 - cos) * torch.matmul(K, K) + + # Convert spin to rot_mat + # (B, K, 3, 1) + spin_axis = child_rest_loc / child_rest_norm + # (B, K, 1, 1) + rx, ry, rz = torch.split(spin_axis, 1, dim=2) + zeros = torch.zeros((batch_size, len_indices, 1, 1), + dtype=dtype, + device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=2) \ + .view((batch_size, len_indices, 3, 3)) + ident = torch.eye(3, dtype=dtype, + device=device).reshape(1, 1, 3, 3) + # (B, K, 1, 1) + phi_indices = [item - 1 for item in indices] + cos, sin = torch.split(phis[:, phi_indices], 1, dim=2) + cos = torch.unsqueeze(cos, dim=3) + sin = torch.unsqueeze(sin, dim=3) + rot_mat_spin = ident + sin * K + (1 - cos) * torch.matmul(K, K) + rot_mat = torch.matmul(rot_mat_loc, rot_mat_spin) + + if (torch.det(rot_mat) < 0).any(): + print(2, + torch.det(rot_mat_loc) < 0, + torch.det(rot_mat_spin) < 0) + + rot_mat_chain[:, indices] = torch.matmul( + rot_mat_chain[:, parents[indices]], rot_mat) + rot_mat_local[:, indices] = rot_mat + + # (B, K + 1, 3, 3) + # rot_mats = torch.stack(rot_mat_local, dim=1) + rot_mats = rot_mat_local + + return rot_mats, rotate_rest_pose.squeeze(-1) + + +def batch_get_pelvis_orient_svd(rel_pose_skeleton, rel_rest_pose, parents, + children, dtype): + pelvis_child = [int(children[0])] + for i in range(1, parents.shape[0]): + if parents[i] == 0 and i not in pelvis_child: + pelvis_child.append(i) + + rest_mat = [] + target_mat = [] + for child in pelvis_child: + rest_mat.append(rel_rest_pose[:, child].clone()) + target_mat.append(rel_pose_skeleton[:, child].clone()) + + rest_mat = torch.cat(rest_mat, dim=2) + target_mat = torch.cat(target_mat, dim=2) + S = rest_mat.bmm(target_mat.transpose(1, 2)) + + mask_zero = S.sum(dim=(1, 2)) + + S_non_zero = S[mask_zero != 0].reshape(-1, 3, 3) + + U, _, V = torch.svd(S_non_zero) + + rot_mat = torch.zeros_like(S) + rot_mat[mask_zero == 0] = torch.eye(3, device=S.device) + + # rot_mat_non_zero = torch.bmm(V, U.transpose(1, 2)) + det_u_v = torch.det(torch.bmm(V, U.transpose(1, 2))) + det_modify_mat = torch.eye(3, device=U.device).unsqueeze(0).expand( + U.shape[0], -1, -1).clone() + det_modify_mat[:, 2, 2] = det_u_v + rot_mat_non_zero = torch.bmm(torch.bmm(V, det_modify_mat), + U.transpose(1, 2)) + + rot_mat[mask_zero != 0] = rot_mat_non_zero + + assert torch.sum(torch.isnan(rot_mat)) == 0, ('rot_mat', rot_mat) + + return rot_mat + + +def batch_get_pelvis_orient(rel_pose_skeleton, rel_rest_pose, parents, + children, dtype): + batch_size = rel_pose_skeleton.shape[0] + device = rel_pose_skeleton.device + + assert children[0] == 3 + pelvis_child = [int(children[0])] + for i in range(1, parents.shape[0]): + if parents[i] == 0 and i not in pelvis_child: + pelvis_child.append(i) + + spine_final_loc = rel_pose_skeleton[:, int(children[0])].clone() + spine_rest_loc = rel_rest_pose[:, int(children[0])].clone() + spine_norm = torch.norm(spine_final_loc, dim=1, keepdim=True) + spine_norm = spine_final_loc / (spine_norm + 1e-8) + + rot_mat_spine = vectors2rotmat(spine_rest_loc, spine_final_loc, dtype) + + assert torch.sum(torch.isnan(rot_mat_spine)) == 0, ('rot_mat_spine', + rot_mat_spine) + center_final_loc = 0 + center_rest_loc = 0 + for child in pelvis_child: + if child == int(children[0]): + continue + center_final_loc = center_final_loc + \ + rel_pose_skeleton[:, child].clone() + center_rest_loc = center_rest_loc + rel_rest_pose[:, child].clone() + center_final_loc = center_final_loc / (len(pelvis_child) - 1) + center_rest_loc = center_rest_loc / (len(pelvis_child) - 1) + + center_rest_loc = torch.matmul(rot_mat_spine, center_rest_loc) + + center_final_loc = center_final_loc - \ + torch.sum(center_final_loc * spine_norm, + dim=1, keepdim=True) * spine_norm + center_rest_loc = center_rest_loc - \ + torch.sum(center_rest_loc * spine_norm, + dim=1, keepdim=True) * spine_norm + + center_final_loc_norm = torch.norm(center_final_loc, dim=1, keepdim=True) + center_rest_loc_norm = torch.norm(center_rest_loc, dim=1, keepdim=True) + + # (B, 3, 1) + axis = torch.cross(center_rest_loc, center_final_loc, dim=1) + axis_norm = torch.norm(axis, dim=1, keepdim=True) + + # (B, 1, 1) + cos = torch.sum(center_rest_loc * center_final_loc, dim=1, keepdim=True) / \ + (center_rest_loc_norm * center_final_loc_norm + 1e-8) + sin = axis_norm / (center_rest_loc_norm * center_final_loc_norm + 1e-8) + + assert torch.sum(torch.isnan(cos)) == 0, ('cos', cos) + assert torch.sum(torch.isnan(sin)) == 0, ('sin', sin) + # (B, 3, 1) + axis = axis / (axis_norm + 1e-8) + + # Convert location revolve to rot_mat by rodrigues + # (B, 1, 1) + rx, ry, rz = torch.split(axis, 1, dim=1) + zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) + + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat_center = ident + sin * K + (1 - cos) * torch.bmm(K, K) + + rot_mat = torch.matmul(rot_mat_center, rot_mat_spine) + + return rot_mat + + +def batch_get_3children_orient_svd(rel_pose_skeleton, rel_rest_pose, + rot_mat_chain_parent, children_list, dtype): + rest_mat = [] + target_mat = [] + for c, child in enumerate(children_list): + if isinstance(rel_pose_skeleton, list): + target = rel_pose_skeleton[c].clone() + template = rel_rest_pose[c].clone() + else: + target = rel_pose_skeleton[:, child].clone() + template = rel_rest_pose[:, child].clone() + + target = torch.matmul(rot_mat_chain_parent.transpose(1, 2), target) + + target_mat.append(target) + rest_mat.append(template) + + rest_mat = torch.cat(rest_mat, dim=2) + target_mat = torch.cat(target_mat, dim=2) + S = rest_mat.bmm(target_mat.transpose(1, 2)) + + U, _, V = torch.svd(S) + + # rot_mat = torch.bmm(V, U.transpose(1, 2)) + det_u_v = torch.det(torch.bmm(V, U.transpose(1, 2))) + det_modify_mat = torch.eye(3, device=U.device).unsqueeze(0).expand( + U.shape[0], -1, -1).clone() + det_modify_mat[:, 2, 2] = det_u_v + rot_mat = torch.bmm(torch.bmm(V, det_modify_mat), U.transpose(1, 2)) + + assert torch.sum(torch.isnan(rot_mat)) == 0, ('3children rot_mat', rot_mat) + return rot_mat + + +def vectors2rotmat(vec_rest, vec_final, dtype): + batch_size = vec_final.shape[0] + device = vec_final.device + + # (B, 1, 1) + vec_final_norm = torch.norm(vec_final, dim=1, keepdim=True) + vec_rest_norm = torch.norm(vec_rest, dim=1, keepdim=True) + + # (B, 3, 1) + axis = torch.cross(vec_rest, vec_final, dim=1) + axis_norm = torch.norm(axis, dim=1, keepdim=True) + + # (B, 1, 1) + cos = torch.sum(vec_rest * vec_final, dim=1, keepdim=True) / \ + (vec_rest_norm * vec_final_norm + 1e-8) + sin = axis_norm / (vec_rest_norm * vec_final_norm + 1e-8) + + # (B, 3, 1) + axis = axis / (axis_norm + 1e-8) + + # Convert location revolve to rot_mat by rodrigues + # (B, 1, 1) + rx, ry, rz = torch.split(axis, 1, dim=1) + zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) + + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat_loc = ident + sin * K + (1 - cos) * torch.bmm(K, K) + + return rot_mat_loc + + +def rotmat_to_quat(rotation_matrix): + assert rotation_matrix.shape[1:] == (3, 3) + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], + dtype=torch.float32, + device=rotation_matrix.device) + hom = hom.reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + return quaternion + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / (norm_quat.norm(p=2, dim=1, keepdim=True) + 1e-8) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(B, 3, 3) + return rotMat + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3, 3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], + dtype=torch.float32, + device=rotation_matrix.device) + hom = hom.reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format( + quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis diff --git a/lib/hybrik/models/simple3dpose.py b/lib/hybrik/models/simple3dpose.py new file mode 100644 index 0000000000000000000000000000000000000000..292124c04decb222305cb69dfe57d23aae62d370 --- /dev/null +++ b/lib/hybrik/models/simple3dpose.py @@ -0,0 +1,420 @@ +from collections import namedtuple +import os + +import numpy as np +import torch +import torch.nn as nn +import yaml +from torch.nn import functional as F + +from .layers.Resnet import ResNet +from .layers.smpl.SMPL import SMPL_layer + +ModelOutput = namedtuple( + typename='ModelOutput', + field_names=[ + 'pred_shape', 'pred_theta_mats', 'pred_phi', 'pred_delta_shape', + 'pred_leaf', 'pred_uvd_jts', 'pred_xyz_jts_29', 'pred_xyz_jts_24', + 'pred_xyz_jts_24_struct', 'pred_xyz_jts_17', 'pred_vertices', + 'maxvals', 'cam_scale', 'cam_trans', 'cam_root', 'uvd_heatmap', + 'transl', 'img_feat', 'pred_camera', 'pred_aa' + ]) +ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) + + +def update_config(config_file): + with open(config_file) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + return config + + +def norm_heatmap(norm_type, heatmap): + # Input tensor shape: [N,C,...] + shape = heatmap.shape + if norm_type == 'softmax': + heatmap = heatmap.reshape(*shape[:2], -1) + # global soft max + heatmap = F.softmax(heatmap, 2) + return heatmap.reshape(*shape) + else: + raise NotImplementedError + + +class HybrIKBaseSMPLCam(nn.Module): + + def __init__(self, + cfg_file, + smpl_path, + data_path, + norm_layer=nn.BatchNorm2d): + super(HybrIKBaseSMPLCam, self).__init__() + + cfg = update_config(cfg_file)['MODEL'] + + self.deconv_dim = cfg['NUM_DECONV_FILTERS'] + self._norm_layer = norm_layer + self.num_joints = cfg['NUM_JOINTS'] + self.norm_type = cfg['POST']['NORM_TYPE'] + self.depth_dim = cfg['EXTRA']['DEPTH_DIM'] + self.height_dim = cfg['HEATMAP_SIZE'][0] + self.width_dim = cfg['HEATMAP_SIZE'][1] + self.smpl_dtype = torch.float32 + + backbone = ResNet + + self.preact = backbone(f"resnet{cfg['NUM_LAYERS']}") + + # Imagenet pretrain model + import torchvision.models as tm + if cfg['NUM_LAYERS'] == 101: + ''' Load pretrained model ''' + x = tm.resnet101(pretrained=True) + self.feature_channel = 2048 + elif cfg['NUM_LAYERS'] == 50: + x = tm.resnet50(pretrained=True) + self.feature_channel = 2048 + elif cfg['NUM_LAYERS'] == 34: + x = tm.resnet34(pretrained=True) + self.feature_channel = 512 + elif cfg['NUM_LAYERS'] == 18: + x = tm.resnet18(pretrained=True) + self.feature_channel = 512 + else: + raise NotImplementedError + model_state = self.preact.state_dict() + state = { + k: v + for k, v in x.state_dict().items() + if k in self.preact.state_dict() + and v.size() == self.preact.state_dict()[k].size() + } + model_state.update(state) + self.preact.load_state_dict(model_state) + + self.deconv_layers = self._make_deconv_layer() + self.final_layer = nn.Conv2d(self.deconv_dim[2], + self.num_joints * self.depth_dim, + kernel_size=1, + stride=1, + padding=0) + + h36m_jregressor = np.load( + os.path.join(data_path, 'J_regressor_h36m.npy')) + self.smpl = SMPL_layer(smpl_path, + h36m_jregressor=h36m_jregressor, + dtype=self.smpl_dtype) + + self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), + (16, 17), (18, 19), (20, 21), (22, 23)) + + self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), + (16, 17), (18, 19), (20, 21), (22, 23), + (25, 26), (27, 28)) + + self.leaf_pairs = ((0, 1), (3, 4)) + self.root_idx_smpl = 0 + + # mean shape + init_shape = np.load(os.path.join(data_path, 'h36m_mean_beta.npy')) + self.register_buffer('init_shape', torch.Tensor(init_shape).float()) + + init_cam = torch.tensor([0.9, 0, 0]) + self.register_buffer('init_cam', torch.Tensor(init_cam).float()) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Linear(self.feature_channel, 1024) + self.drop1 = nn.Dropout(p=0.5) + self.fc2 = nn.Linear(1024, 1024) + self.drop2 = nn.Dropout(p=0.5) + self.decshape = nn.Linear(1024, 10) + self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)] + self.deccam = nn.Linear(1024, 3) + + self.focal_length = cfg['FOCAL_LENGTH'] + self.input_size = 256.0 + + def _make_deconv_layer(self): + deconv_layers = [] + deconv1 = nn.ConvTranspose2d(self.feature_channel, + self.deconv_dim[0], + kernel_size=4, + stride=2, + padding=int(4 / 2) - 1, + bias=False) + bn1 = self._norm_layer(self.deconv_dim[0]) + deconv2 = nn.ConvTranspose2d(self.deconv_dim[0], + self.deconv_dim[1], + kernel_size=4, + stride=2, + padding=int(4 / 2) - 1, + bias=False) + bn2 = self._norm_layer(self.deconv_dim[1]) + deconv3 = nn.ConvTranspose2d(self.deconv_dim[1], + self.deconv_dim[2], + kernel_size=4, + stride=2, + padding=int(4 / 2) - 1, + bias=False) + bn3 = self._norm_layer(self.deconv_dim[2]) + + deconv_layers.append(deconv1) + deconv_layers.append(bn1) + deconv_layers.append(nn.ReLU(inplace=True)) + deconv_layers.append(deconv2) + deconv_layers.append(bn2) + deconv_layers.append(nn.ReLU(inplace=True)) + deconv_layers.append(deconv3) + deconv_layers.append(bn3) + deconv_layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*deconv_layers) + + def _initialize(self): + for name, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + nn.init.constant_(m.bias, 0) + + def flip_uvd_coord(self, pred_jts, shift=False, flatten=True): + if flatten: + assert pred_jts.dim() == 2 + num_batches = pred_jts.shape[0] + pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) + else: + assert pred_jts.dim() == 3 + num_batches = pred_jts.shape[0] + + # flip + if shift: + pred_jts[:, :, 0] = -pred_jts[:, :, 0] + else: + pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0] + + for pair in self.joint_pairs_29: + dim0, dim1 = pair + idx = torch.Tensor((dim0, dim1)).long() + inv_idx = torch.Tensor((dim1, dim0)).long() + pred_jts[:, idx] = pred_jts[:, inv_idx] + + if flatten: + pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) + + return pred_jts + + def flip_xyz_coord(self, pred_jts, flatten=True): + if flatten: + assert pred_jts.dim() == 2 + num_batches = pred_jts.shape[0] + pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) + else: + assert pred_jts.dim() == 3 + num_batches = pred_jts.shape[0] + + pred_jts[:, :, 0] = -pred_jts[:, :, 0] + + for pair in self.joint_pairs_29: + dim0, dim1 = pair + idx = torch.Tensor((dim0, dim1)).long() + inv_idx = torch.Tensor((dim1, dim0)).long() + pred_jts[:, idx] = pred_jts[:, inv_idx] + + if flatten: + pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) + + return pred_jts + + def flip_phi(self, pred_phi): + pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] + + for pair in self.joint_pairs_24: + dim0, dim1 = pair + idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() + inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() + pred_phi[:, idx] = pred_phi[:, inv_idx] + + return pred_phi + + def forward(self, + x, + flip_item=None, + flip_output=False, + gt_uvd=None, + gt_uvd_weight=None, + **kwargs): + + batch_size = x.shape[0] + + # torch.cuda.synchronize() + # model_start_t = time.time() + + x0 = self.preact(x) + out = self.deconv_layers(x0) + out = self.final_layer(out) + + # torch.cuda.synchronize() + # preat_end_t = time.time() + + out = out.reshape((out.shape[0], self.num_joints, -1)) + + maxvals, _ = torch.max(out, dim=2, keepdim=True) + + out = norm_heatmap(self.norm_type, out) + assert out.dim() == 3, out.shape + + heatmaps = out / out.sum(dim=2, keepdim=True) + + heatmaps = heatmaps.reshape( + (heatmaps.shape[0], self.num_joints, self.depth_dim, + self.height_dim, self.width_dim)) + + hm_x0 = heatmaps.sum((2, 3)) + hm_y0 = heatmaps.sum((2, 4)) + hm_z0 = heatmaps.sum((3, 4)) + + range_tensor = torch.arange(hm_x0.shape[-1], + dtype=torch.float32, + device=hm_x0.device) + hm_x = hm_x0 * range_tensor + hm_y = hm_y0 * range_tensor + hm_z = hm_z0 * range_tensor + + coord_x = hm_x.sum(dim=2, keepdim=True) + coord_y = hm_y.sum(dim=2, keepdim=True) + coord_z = hm_z.sum(dim=2, keepdim=True) + + coord_x = coord_x / float(self.width_dim) - 0.5 + coord_y = coord_y / float(self.height_dim) - 0.5 + coord_z = coord_z / float(self.depth_dim) - 0.5 + + # -0.5 ~ 0.5 + pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2) + + x0 = self.avg_pool(x0) + x0 = x0.view(x0.size(0), -1) + init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,) + init_cam = self.init_cam.expand(batch_size, -1) # (B, 3,) + + xc = x0 + + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + + delta_shape = self.decshape(xc) + pred_shape = delta_shape + init_shape + pred_phi = self.decphi(xc) + pred_camera = self.deccam(xc).reshape(batch_size, -1) + init_cam + + camScale = pred_camera[:, :1].unsqueeze(1) + camTrans = pred_camera[:, 1:].unsqueeze(1) + + camDepth = self.focal_length / (self.input_size * camScale + 1e-9) + + pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29) + pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, + 2:].clone() # unit: 2.2m + pred_xyz_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) \ + * (pred_xyz_jts_29[:, :, 2:]*2.2 + camDepth) - camTrans # unit: m + + pred_xyz_jts_29[:, :, :2] = pred_xyz_jts_29_meter / 2.2 # unit: 2.2m + + camera_root = pred_xyz_jts_29[:, [0], ] * 2.2 + camera_root[:, :, :2] += camTrans + camera_root[:, :, [2]] += camDepth + + if not self.training: + pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]] + + if flip_item is not None: + assert flip_output is not None + pred_xyz_jts_29_orig, pred_phi_orig, pred_leaf_orig, pred_shape_orig = flip_item + + if flip_output: + pred_xyz_jts_29 = self.flip_xyz_coord(pred_xyz_jts_29, + flatten=False) + if flip_output and flip_item is not None: + pred_xyz_jts_29 = (pred_xyz_jts_29 + pred_xyz_jts_29_orig.reshape( + batch_size, 29, 3)) / 2 + + pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1) + + pred_phi = pred_phi.reshape(batch_size, 23, 2) + + if flip_output: + pred_phi = self.flip_phi(pred_phi) + + if flip_output and flip_item is not None: + pred_phi = (pred_phi + pred_phi_orig) / 2 + pred_shape = (pred_shape + pred_shape_orig) / 2 + + output = self.smpl.hybrik( + pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * + 2.2, # unit: meter + betas=pred_shape.type(self.smpl_dtype), + phis=pred_phi.type(self.smpl_dtype), + global_orient=None, + return_verts=True) + pred_vertices = output.vertices.float() + # -0.5 ~ 0.5 + # pred_xyz_jts_24_struct = output.joints.float() / 2.2 + pred_xyz_jts_24_struct = output.joints.float() / 2 + # -0.5 ~ 0.5 + # pred_xyz_jts_17 = output.joints_from_verts.float() / 2.2 + pred_xyz_jts_17 = output.joints_from_verts.float() / 2 + pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24, 3, 3) + pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, + 72) / 2 + pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) + pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3) + + transl = pred_xyz_jts_29[:, 0, :] * \ + 2.2 - pred_xyz_jts_17[:, 0, :] * 2.2 + transl[:, :2] += camTrans[:, 0] + transl[:, 2] += camDepth[:, 0, 0] + + new_cam = torch.zeros_like(transl) + new_cam[:, 1:] = transl[:, :2] + new_cam[:, 0] = self.focal_length / \ + (self.input_size * transl[:, 2] + 1e-9) + + # pred_aa = output.rot_aa.reshape(batch_size, 24, 3) + + output = dict( + pred_phi=pred_phi, + pred_delta_shape=delta_shape, + pred_shape=pred_shape, + # pred_aa=pred_aa, + pred_theta_mats=pred_theta_mats, + pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1), + pred_xyz_jts_29=pred_xyz_jts_29_flat, + pred_xyz_jts_24=pred_xyz_jts_24, + pred_xyz_jts_24_struct=pred_xyz_jts_24_struct, + pred_xyz_jts_17=pred_xyz_jts_17_flat, + pred_vertices=pred_vertices, + maxvals=maxvals, + cam_scale=camScale[:, 0], + cam_trans=camTrans[:, 0], + cam_root=camera_root, + pred_camera=new_cam, + transl=transl, + # uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2), + # uvd_heatmap=heatmaps, + # img_feat=x0 + ) + return output + + def forward_gt_theta(self, gt_theta, gt_beta): + + output = self.smpl(pose_axis_angle=gt_theta, + betas=gt_beta, + global_orient=None, + return_verts=True) + + return output diff --git a/lib/net/BasePIFuNet.py b/lib/net/BasePIFuNet.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ad149e2258df3ded19b686286a609cb39c6fe5 --- /dev/null +++ b/lib/net/BasePIFuNet.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch.nn as nn +import pytorch_lightning as pl + +from .geometry import index, orthogonal, perspective + + +class BasePIFuNet(pl.LightningModule): + + def __init__( + self, + projection_mode='orthogonal', + error_term=nn.MSELoss(), + ): + """ + :param projection_mode: + Either orthogonal or perspective. + It will call the corresponding function for projection. + :param error_term: + nn Loss between the predicted [B, Res, N] and the label [B, Res, N] + """ + super(BasePIFuNet, self).__init__() + self.name = 'base' + + self.error_term = error_term + + self.index = index + self.projection = orthogonal if projection_mode == 'orthogonal' else perspective + + def forward(self, points, images, calibs, transforms=None): + ''' + :param points: [B, 3, N] world space coordinates of points + :param images: [B, C, H, W] input images + :param calibs: [B, 3, 4] calibration matrices for each image + :param transforms: Optional [B, 2, 3] image space coordinate transforms + :return: [B, Res, N] predictions for each point + ''' + features = self.filter(images) + preds = self.query(features, points, calibs, transforms) + return preds + + def filter(self, images): + ''' + Filter the input images + store all intermediate features. + :param images: [B, C, H, W] input images + ''' + return None + + def query(self, features, points, calibs, transforms=None): + ''' + Given 3D points, query the network predictions for each point. + Image features should be pre-computed before this call. + store all intermediate features. + query() function may behave differently during training/testing. + :param points: [B, 3, N] world space coordinates of points + :param calibs: [B, 3, 4] calibration matrices for each image + :param transforms: Optional [B, 2, 3] image space coordinate transforms + :param labels: Optional [B, Res, N] gt labeling + :return: [B, Res, N] predictions for each point + ''' + return None + + def get_error(self, preds, labels): + ''' + Get the network loss from the last query + :return: loss term + ''' + return self.error_term(preds, labels) diff --git a/lib/net/FBNet.py b/lib/net/FBNet.py new file mode 100644 index 0000000000000000000000000000000000000000..d184233a7906026e6d4d76ea5218c0d3c3e6bfc4 --- /dev/null +++ b/lib/net/FBNet.py @@ -0,0 +1,391 @@ +''' +Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. +BSD License. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL +DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING +OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +''' +import torch +import torch.nn as nn +import functools +import numpy as np +import pytorch_lightning as pl + + +############################################################################### +# Functions +############################################################################### +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + else: + raise NotImplementedError('normalization layer [%s] is not found' % + norm_type) + return norm_layer + + +def define_G(input_nc, + output_nc, + ngf, + netG, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm='instance', + gpu_ids=[], + last_op=nn.Tanh()): + norm_layer = get_norm_layer(norm_type=norm) + if netG == 'global': + netG = GlobalGenerator(input_nc, + output_nc, + ngf, + n_downsample_global, + n_blocks_global, + norm_layer, + last_op=last_op) + elif netG == 'local': + netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, + n_blocks_global, n_local_enhancers, + n_blocks_local, norm_layer) + elif netG == 'encoder': + netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, + norm_layer) + else: + raise ('generator not implemented!') + # print(netG) + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + netG.cuda(gpu_ids[0]) + netG.apply(weights_init) + return netG + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Generator +############################################################################## +class LocalEnhancer(pl.LightningModule): + + def __init__(self, + input_nc, + output_nc, + ngf=32, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm_layer=nn.BatchNorm2d, + padding_type='reflect'): + super(LocalEnhancer, self).__init__() + self.n_local_enhancers = n_local_enhancers + + ###### global generator model ##### + ngf_global = ngf * (2**n_local_enhancers) + model_global = GlobalGenerator(input_nc, output_nc, ngf_global, + n_downsample_global, n_blocks_global, + norm_layer).model + model_global = [model_global[i] for i in range(len(model_global) - 3) + ] # get rid of final convolution layers + self.model = nn.Sequential(*model_global) + + ###### local enhancer layers ##### + for n in range(1, n_local_enhancers + 1): + # downsample + ngf_global = ngf * (2**(n_local_enhancers - n)) + model_downsample = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), + norm_layer(ngf_global), + nn.ReLU(True), + nn.Conv2d(ngf_global, + ngf_global * 2, + kernel_size=3, + stride=2, + padding=1), + norm_layer(ngf_global * 2), + nn.ReLU(True) + ] + # residual blocks + model_upsample = [] + for i in range(n_blocks_local): + model_upsample += [ + ResnetBlock(ngf_global * 2, + padding_type=padding_type, + norm_layer=norm_layer) + ] + + # upsample + model_upsample += [ + nn.ConvTranspose2d(ngf_global * 2, + ngf_global, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + norm_layer(ngf_global), + nn.ReLU(True) + ] + + # final convolution + if n == n_local_enhancers: + model_upsample += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + nn.Tanh() + ] + + setattr(self, 'model' + str(n) + '_1', + nn.Sequential(*model_downsample)) + setattr(self, 'model' + str(n) + '_2', + nn.Sequential(*model_upsample)) + + self.downsample = nn.AvgPool2d(3, + stride=2, + padding=[1, 1], + count_include_pad=False) + + def forward(self, input): + # create input pyramid + input_downsampled = [input] + for i in range(self.n_local_enhancers): + input_downsampled.append(self.downsample(input_downsampled[-1])) + + # output at coarest level + output_prev = self.model(input_downsampled[-1]) + # build up one layer at a time + for n_local_enhancers in range(1, self.n_local_enhancers + 1): + model_downsample = getattr(self, + 'model' + str(n_local_enhancers) + '_1') + model_upsample = getattr(self, + 'model' + str(n_local_enhancers) + '_2') + input_i = input_downsampled[self.n_local_enhancers - + n_local_enhancers] + output_prev = model_upsample( + model_downsample(input_i) + output_prev) + return output_prev + + +class GlobalGenerator(pl.LightningModule): + + def __init__(self, + input_nc, + output_nc, + ngf=64, + n_downsampling=3, + n_blocks=9, + norm_layer=nn.BatchNorm2d, + padding_type='reflect', + last_op=nn.Tanh()): + assert (n_blocks >= 0) + super(GlobalGenerator, self).__init__() + activation = nn.ReLU(True) + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), activation + ] + # downsample + for i in range(n_downsampling): + mult = 2**i + model += [ + nn.Conv2d(ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=1), + norm_layer(ngf * mult * 2), activation + ] + + # resnet blocks + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ + ResnetBlock(ngf * mult, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer) + ] + + # upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d(ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + norm_layer(int(ngf * mult / 2)), activation + ] + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) + ] + if last_op is not None: + model += [last_op] + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +# Define a resnet block +class ResnetBlock(pl.LightningModule): + + def __init__(self, + dim, + padding_type, + norm_layer, + activation=nn.ReLU(True), + use_dropout=False): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, + activation, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, + use_dropout): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % + padding_type) + + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), activation + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % + padding_type) + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim) + ] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class Encoder(pl.LightningModule): + + def __init__(self, + input_nc, + output_nc, + ngf=32, + n_downsampling=4, + norm_layer=nn.BatchNorm2d): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True) + ] + # downsample + for i in range(n_downsampling): + mult = 2**i + model += [ + nn.Conv2d(ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True) + ] + + # upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d(ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True) + ] + + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + nn.Tanh() + ] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + # instance-wise average pooling + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + for b in range(input.size()[0]): + indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4 + for j in range(self.output_nc): + output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, + indices[:, 2], indices[:, 3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[indices[:, 0] + b, indices[:, 1] + j, + indices[:, 2], indices[:, 3]] = mean_feat + return outputs_mean diff --git a/lib/net/HGFilters.py b/lib/net/HGFilters.py new file mode 100644 index 0000000000000000000000000000000000000000..0511ea44a2baa1d845e8d587e70e03ea0ab9e043 --- /dev/null +++ b/lib/net/HGFilters.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.net.net_util import * +import torch.nn as nn +import torch.nn.functional as F + + +class HourGlass(nn.Module): + + def __init__(self, num_modules, depth, num_features, opt): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + self.opt = opt + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), + ConvBlock(self.features, self.features, self.opt)) + + self.add_module('b2_' + str(level), + ConvBlock(self.features, self.features, self.opt)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), + ConvBlock(self.features, self.features, self.opt)) + + self.add_module('b3_' + str(level), + ConvBlock(self.features, self.features, self.opt)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample + # if the pretrained model behaves weirdly, switch with the commented line. + # NOTE: I also found that "bicubic" works better. + up2 = F.interpolate(low3, + scale_factor=2, + mode='bicubic', + align_corners=True) + # up2 = F.interpolate(low3, scale_factor=2, mode='nearest) + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class HGFilter(nn.Module): + + def __init__(self, opt, num_modules, in_dim): + super(HGFilter, self).__init__() + self.num_modules = num_modules + + self.opt = opt + [k, s, d, p] = self.opt.conv1 + + # self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3) + self.conv1 = nn.Conv2d(in_dim, + 64, + kernel_size=k, + stride=s, + dilation=d, + padding=p) + + if self.opt.norm == 'batch': + self.bn1 = nn.BatchNorm2d(64) + elif self.opt.norm == 'group': + self.bn1 = nn.GroupNorm(32, 64) + + if self.opt.hg_down == 'conv64': + self.conv2 = ConvBlock(64, 64, self.opt) + self.down_conv2 = nn.Conv2d(64, + 128, + kernel_size=3, + stride=2, + padding=1) + elif self.opt.hg_down == 'conv128': + self.conv2 = ConvBlock(64, 128, self.opt) + self.down_conv2 = nn.Conv2d(128, + 128, + kernel_size=3, + stride=2, + padding=1) + elif self.opt.hg_down == 'ave_pool': + self.conv2 = ConvBlock(64, 128, self.opt) + else: + raise NameError('Unknown Fan Filter setting!') + + self.conv3 = ConvBlock(128, 128, self.opt) + self.conv4 = ConvBlock(128, 256, self.opt) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), + HourGlass(1, opt.num_hourglass, 256, self.opt)) + + self.add_module('top_m_' + str(hg_module), + ConvBlock(256, 256, self.opt)) + self.add_module( + 'conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + if self.opt.norm == 'batch': + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + elif self.opt.norm == 'group': + self.add_module('bn_end' + str(hg_module), + nn.GroupNorm(32, 256)) + + self.add_module( + 'l' + str(hg_module), + nn.Conv2d(256, + opt.hourglass_dim, + kernel_size=1, + stride=1, + padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module( + 'al' + str(hg_module), + nn.Conv2d(opt.hourglass_dim, + 256, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + tmpx = x + if self.opt.hg_down == 'ave_pool': + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + elif self.opt.hg_down in ['conv64', 'conv128']: + x = self.conv2(x) + x = self.down_conv2(x) + else: + raise NameError('Unknown Fan Filter setting!') + + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu( + self._modules['bn_end' + str(i)]( + self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + + + + + + +class FuseHGFilter(nn.Module): + + def __init__(self, opt, num_modules, in_dim): + super(FuseHGFilter, self).__init__() + self.num_modules = num_modules + + self.opt = opt + [k, s, d, p] = self.opt.conv1 + + # self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3) + self.conv1 = nn.Conv2d(in_dim, + 64, + kernel_size=k, + stride=s, + dilation=d, + padding=p) + + if self.opt.norm == 'batch': + self.bn1 = nn.BatchNorm2d(64) + elif self.opt.norm == 'group': + self.bn1 = nn.GroupNorm(32, 64) + + + self.conv2 = ConvBlock(64, 128, self.opt) + self.down_conv2 = nn.Conv2d(128, + 96, + kernel_size=3, + stride=2, + padding=1) + # elif self.opt.hg_down == 'conv128': + # self.conv2 = ConvBlock(64, 128, self.opt) + # self.down_conv2 = nn.Conv2d(128, + # 128, + # kernel_size=3, + # stride=2, + # padding=1) + + dim=96+32 + self.conv3 = ConvBlock(dim, dim, self.opt) + self.conv4 = ConvBlock(dim, 256, self.opt) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), + HourGlass(1, opt.num_hourglass, 256, self.opt)) + + self.add_module('top_m_' + str(hg_module), + ConvBlock(256, 256, self.opt)) + self.add_module( + 'conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + if self.opt.norm == 'batch': + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + elif self.opt.norm == 'group': + self.add_module('bn_end' + str(hg_module), + nn.GroupNorm(32, 256)) + + hourglass_dim=256 + self.add_module( + 'l' + str(hg_module), + nn.Conv2d(256, + hourglass_dim, + kernel_size=1, + stride=1, + padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module( + 'al' + str(hg_module), + nn.Conv2d(hourglass_dim, + 256, + kernel_size=1, + stride=1, + padding=0)) + + self.up_conv=nn.ConvTranspose2d(hourglass_dim,64,kernel_size=2,stride=2) + + def forward(self, x,plane): + x = F.relu(self.bn1(self.conv1(x)), True) # 64*256*256 + tmpx = x + + x = self.conv2(x) + x = self.down_conv2(x) + + x=torch.cat([x,plane],1) # 128*128*128 + + + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu( + self._modules['bn_end' + str(i)]( + self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + out=self.up_conv(outputs[-1]) + + return out \ No newline at end of file diff --git a/lib/net/HGPIFuNet.py b/lib/net/HGPIFuNet.py new file mode 100644 index 0000000000000000000000000000000000000000..a93fd7b5b679f19f5dc2b0cd8fdc78235f917614 --- /dev/null +++ b/lib/net/HGPIFuNet.py @@ -0,0 +1,500 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +#from lib.net.voxelize import Voxelization +from lib.renderer.mesh import compute_normal_batch +from lib.dataset.mesh_util import feat_select, read_smpl_constants, surface_field_deformation +from lib.net.NormalNet import NormalNet +from lib.net.MLP import MLP, DeformationMLP, TransformerEncoderLayer, SDF2Density, SDF2Occ +from lib.net.spatial import SpatialEncoder +from lib.dataset.PointFeat import PointFeat +from lib.dataset.mesh_util import SMPLX +from lib.net.VE import VolumeEncoder +from lib.net.ResBlkPIFuNet import ResnetFilter +from lib.net.UNet import UNet +from lib.net.HGFilters import * +from lib.net.Transformer import ViTVQ +from termcolor import colored +from lib.net.BasePIFuNet import BasePIFuNet +import torch.nn as nn +import torch +import numpy as np +import matplotlib.pyplot as plt +import torch.nn.functional as F +from lib.net.nerf_util import raw2outputs + + +def normalize(tensor): + min_val = tensor.min() + max_val = tensor.max() + normalized_tensor = (tensor - min_val) / (max_val - min_val) + return normalized_tensor + +def visualize_feature_map(feature_map, title, filename): + feature_map=feature_map.permute(0, 2, 3, 1) + # 选择一个样本(如果有多个) + sample_index = 0 + sample = feature_map[sample_index] + + # 选择一个通道(如果有多个) + channel_index = 0 + channel = sample[:, :, channel_index] + channel= normalize(channel) + + plt.imshow(channel.cpu().numpy(), cmap='hot') + # plt.title(title) + # plt.colorbar() + plt.axis('off') + plt.savefig(filename, dpi=300,bbox_inches='tight', pad_inches=0) # 保存图片到文件 + plt.close() # 关闭图形,释放资源 + + +class HGPIFuNet(BasePIFuNet): + """ + HG PIFu network uses Hourglass stacks as the image filter. + It does the following: + 1. Compute image feature stacks and store it in self.im_feat_list + self.im_feat_list[-1] is the last stack (output stack) + 2. Calculate calibration + 3. If training, it index on every intermediate stacks, + If testing, it index on the last stack. + 4. Classification. + 5. During training, error is calculated on all stacks. + """ + + def __init__(self, + cfg, + projection_mode="orthogonal", + error_term=nn.MSELoss()): + + super(HGPIFuNet, self).__init__(projection_mode=projection_mode, + error_term=error_term) + + self.l1_loss = nn.SmoothL1Loss() + self.opt = cfg.net + self.root = cfg.root + self.overfit = cfg.overfit + + channels_IF = self.opt.mlp_dim + + self.use_filter = self.opt.use_filter + self.prior_type = self.opt.prior_type + self.smpl_feats = self.opt.smpl_feats + + self.smpl_dim = self.opt.smpl_dim + self.voxel_dim = self.opt.voxel_dim + self.hourglass_dim = self.opt.hourglass_dim + + self.in_geo = [item[0] for item in self.opt.in_geo] + self.in_nml = [item[0] for item in self.opt.in_nml] + + self.in_geo_dim = sum([item[1] for item in self.opt.in_geo]) + self.in_nml_dim = sum([item[1] for item in self.opt.in_nml]) + + self.in_total = self.in_geo + self.in_nml + self.smpl_feat_dict = None + self.smplx_data = SMPLX() + + image_lst = [0, 1, 2] + normal_F_lst = [0, 1, 2] if "image" not in self.in_geo else [3, 4, 5] + normal_B_lst = [3, 4, 5] if "image" not in self.in_geo else [6, 7, 8] + + # only ICON or ICON-Keypoint use visibility + + if self.prior_type in ["icon", "keypoint"]: + if "image" in self.in_geo: + self.channels_filter = [ + image_lst + normal_F_lst, + image_lst + normal_B_lst, + ] + else: + self.channels_filter = [normal_F_lst, normal_B_lst] + + else: + if "image" in self.in_geo: + self.channels_filter = [ + image_lst + normal_F_lst + normal_B_lst + ] + else: + self.channels_filter = [normal_F_lst + normal_B_lst] + + use_vis = (self.prior_type in ["icon", "keypoint" + ]) and ("vis" in self.smpl_feats) + if self.prior_type in ["pamir", "pifu"]: + use_vis = 1 + + if self.use_filter: + channels_IF[0] = (self.hourglass_dim) * (2 - use_vis) + else: + channels_IF[0] = len(self.channels_filter[0]) * (2 - use_vis) + + if self.prior_type in ["icon", "keypoint"]: + channels_IF[0] += self.smpl_dim + + elif self.prior_type == "pifu": + channels_IF[0] += 1 + else: + print(f"don't support {self.prior_type}!") + + self.base_keys = ["smpl_verts", "smpl_faces"] + + self.icon_keys = self.base_keys + [ + f"smpl_{feat_name}" for feat_name in self.smpl_feats + ] + self.keypoint_keys = self.base_keys + [ + f"smpl_{feat_name}" for feat_name in self.smpl_feats + ] + + self.pamir_keys = [ + "voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num" + ] + self.pifu_keys = [] + + # channels_IF[0]+=self.hourglass_dim + # self.if_regressor = MLP( + # filter_channels=channels_IF, + # name="if", + # res_layers=self.opt.res_layers, + # norm=self.opt.norm_mlp, + # last_op=nn.Sigmoid() if not cfg.test_mode else None, + # ) + + self.deform_dim=64 + + #self.image_filter = ResnetFilter(self.opt, norm_layer=norm_type) + #self.image_filter = UNet(3,128) + # self.xy_plane_filter=ResnetFilter(self.opt, norm_layer=norm_type) + # self.yz_plane_filter=ViTVQ(image_size=512) # ResnetFilter(self.opt, norm_layer=norm_type) + # self.xz_plane_filter=ViTVQ(image_size=512) + self.image_filter=ViTVQ(image_size=512,channels=9) + # self.deformation_mlp=DeformationMLP(input_dim=self.deform_dim,opt=self.opt) + self.mlp=TransformerEncoderLayer(skips=4,multires=6,opt=self.opt) + # self.sdf2density=SDF2Density() + # self.sdf2occ=SDF2Occ() + self.color_loss=nn.L1Loss() + self.sp_encoder = SpatialEncoder() + self.step=0 + self.features_costume=None + + # network + if self.use_filter: + if self.opt.gtype == "HGPIFuNet": + self.F_filter = HGFilter(self.opt, self.opt.num_stack, + len(self.channels_filter[0])) + # self.refine_filter = FuseHGFilter(self.opt, self.opt.num_stack, + # len(self.channels_filter[0])) + + else: + print( + colored(f"Backbone {self.opt.gtype} is unimplemented", + "green")) + + summary_log = (f"{self.prior_type.upper()}:\n" + + f"w/ Global Image Encoder: {self.use_filter}\n" + + f"Image Features used by MLP: {self.in_geo}\n") + + if self.prior_type == "icon": + summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n" + summary_log += f"Dim of Image Features (local): {3 if (use_vis and not self.use_filter) else 6}\n" + summary_log += f"Dim of Geometry Features (ICON): {self.smpl_dim}\n" + elif self.prior_type == "keypoint": + summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n" + summary_log += f"Dim of Image Features (local): {3 if (use_vis and not self.use_filter) else 6}\n" + summary_log += f"Dim of Geometry Features (Keypoint): {self.smpl_dim}\n" + elif self.prior_type == "pamir": + summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" + summary_log += f"Dim of Geometry Features (PaMIR): {self.voxel_dim}\n" + else: + summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" + summary_log += f"Dim of Geometry Features (PIFu): 1 (z-value)\n" + + summary_log += f"Dim of MLP's first layer: {channels_IF[0]}\n" + + print(colored(summary_log, "yellow")) + + self.normal_filter = NormalNet(cfg) + + init_net(self, init_type="normal") + + def get_normal(self, in_tensor_dict): + + # insert normal features + if (not self.training) and (not self.overfit): + # print(colored("infer normal","blue")) + with torch.no_grad(): + feat_lst = [] + if "image" in self.in_geo: + feat_lst.append( + in_tensor_dict["image"]) # [1, 3, 512, 512] + if "normal_F" in self.in_geo and "normal_B" in self.in_geo: + if ("normal_F" not in in_tensor_dict.keys() + or "normal_B" not in in_tensor_dict.keys()): + (nmlF, nmlB) = self.normal_filter(in_tensor_dict) + else: + nmlF = in_tensor_dict["normal_F"] + nmlB = in_tensor_dict["normal_B"] + feat_lst.append(nmlF) # [1, 3, 512, 512] + feat_lst.append(nmlB) # [1, 3, 512, 512] + in_filter = torch.cat(feat_lst, dim=1) + + else: + in_filter = torch.cat([in_tensor_dict[key] for key in self.in_geo], + dim=1) + + return in_filter + + def get_mask(self, in_filter, size=128): + + mask = (F.interpolate( + in_filter[:, self.channels_filter[0]], + size=(size, size), + mode="bilinear", + align_corners=True, + ).abs().sum(dim=1, keepdim=True) != 0.0) + + return mask + + + def filter(self, in_tensor_dict, return_inter=False): + """ + Filter the input images + store all intermediate features. + :param images: [B, C, H, W] input images + """ + + in_filter = self.get_normal(in_tensor_dict) + image= in_tensor_dict["image"] + fuse_image=torch.cat([image,in_filter], dim=1) + smpl_normals={ + "T_normal_B":in_tensor_dict['normal_B'], + "T_normal_R":in_tensor_dict['T_normal_R'], + "T_normal_L":in_tensor_dict['T_normal_L'] + } + features_G = [] + + # self.smpl_normal=in_tensor_dict['T_normal_L'] + + if self.prior_type in ["icon", "keypoint"]: + if self.use_filter: + triplane_features = self.image_filter(fuse_image,smpl_normals) + + features_F = self.F_filter(in_filter[:, + self.channels_filter[0]] + ) # [(B,hg_dim,128,128) * 4] + features_B = self.F_filter(in_filter[:, + self.channels_filter[1]] + ) # [(B,hg_dim,128,128) * 4] + else: + assert 0 + + F_plane_feat,B_plane_feat,R_plane_feat,L_plane_feat=triplane_features + + refine_F_plane_feat=F_plane_feat + features_G.append(refine_F_plane_feat) + features_G.append(B_plane_feat) + features_G.append(R_plane_feat) + features_G.append(L_plane_feat) + features_G.append(torch.cat([features_F[-1],features_B[-1]], dim=1)) + + else: + assert 0 + + self.smpl_feat_dict = { + k: in_tensor_dict[k] if k in in_tensor_dict.keys() else None + for k in getattr(self, f"{self.prior_type}_keys") + } + if 'animated_smpl_verts' not in in_tensor_dict.keys(): + self.point_feat_extractor = PointFeat(self.smpl_feat_dict["smpl_verts"], + self.smpl_feat_dict["smpl_faces"]) + else: + assert 0 + + self.features_G = features_G + + # If it is not in training, only produce the last im_feat + if not self.training: + features_out = features_G + else: + features_out = features_G + + if return_inter: + return features_out, in_filter + else: + return features_out + + + + def query(self, features, points, calibs, transforms=None,type='shape'): + + xyz = self.projection(points, calibs, transforms) # project to image plane + + (xy, z) = xyz.split([2, 1], dim=1) + + + zy=torch.cat([xyz[:,2:3],xyz[:,1:2]],dim=1) + + in_cube = (xyz > -1.0) & (xyz < 1.0) + in_cube = in_cube.all(dim=1, keepdim=True).detach().float() + + preds_list = [] + + + if self.prior_type in ["icon", "keypoint"]: + + + + densely_smpl=self.smpl_feat_dict['smpl_verts'].permute(0,2,1) + #smpl_origin=self.projection(densely_smpl, torch.inverse(calibs), transforms) + smpl_vis=self.smpl_feat_dict['smpl_vis'].permute(0,2,1) + #verts_ids=self.smpl_feat_dict['smpl_sample_id'] + + + + (smpl_xy,smpl_z)=densely_smpl.split([2,1],dim=1) + smpl_zy=torch.cat([densely_smpl[:,2:3],densely_smpl[:,1:2]],dim=1) + + point_feat_out = self.point_feat_extractor.query( # this extractor changes if has animated smpl + xyz.permute(0, 2, 1).contiguous(), self.smpl_feat_dict) + vis=point_feat_out['vis'].permute(0,2,1) + #sdf_body=-point_feat_out['sdf'] # this sdf needs to be multiplied by -1 + feat_lst = [ + point_feat_out[key] for key in self.smpl_feats + if key in point_feat_out.keys() + ] + smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1) + + if len(features)==5: + + F_plane_feat1,F_plane_feat2=features[0].chunk(2,dim=1) + B_plane_feat1,B_plane_feat2=features[1].chunk(2,dim=1) + R_plane_feat1,R_plane_feat2=features[2].chunk(2,dim=1) + L_plane_feat1,L_plane_feat2=features[3].chunk(2,dim=1) + in_feat=features[4] + + + F_feat=self.index(F_plane_feat1,xy) + B_feat=self.index(B_plane_feat1,xy) + R_feat=self.index(R_plane_feat1,zy) + L_feat=self.index(L_plane_feat1,zy) + normal_feat=feat_select(self.index(in_feat, xy),vis) + three_plane_feat=(B_feat+R_feat+L_feat)/3 + triplane_feat=torch.cat([F_feat,three_plane_feat],dim=1) # 32+32=64 + + ### smpl query ### + smpl_F_feat=self.index(F_plane_feat2,smpl_xy) + smpl_B_feat=self.index(B_plane_feat2,smpl_xy) + smpl_R_feat=self.index(R_plane_feat2,smpl_zy) + smpl_L_feat=self.index(L_plane_feat2,smpl_zy) + + + + smpl_three_plane_feat=(smpl_B_feat+smpl_R_feat+smpl_L_feat)/3 + smpl_triplane_feat=torch.cat([smpl_F_feat,smpl_three_plane_feat],dim=1) # 32+32=64 + bary_centric_feat=self.point_feat_extractor.query_barycentirc_feats(xyz.permute(0,2,1).contiguous() + ,smpl_triplane_feat.permute(0,2,1)) + + + final_feat=torch.cat([triplane_feat,bary_centric_feat.permute(0,2,1),normal_feat],dim=1) # 64+64+6=134 + + if self.features_costume is not None: + assert 0 + if type=='shape': + if 'animated_smpl_verts' in self.smpl_feat_dict.keys(): + animated_smpl=self.smpl_feat_dict['animated_smpl_verts'] + + occ=self.mlp(xyz.permute(0,2,1).contiguous(),animated_smpl, + final_feat,smpl_feat,training=self.training,type=type) + else: + + occ=self.mlp(xyz.permute(0,2,1).contiguous(),densely_smpl.permute(0,2,1), + final_feat,smpl_feat,training=self.training,type=type) + occ=occ*in_cube + preds_list.append(occ) + + elif type=='color': + if 'animated_smpl_verts' in self.smpl_feat_dict.keys(): + animated_smpl=self.smpl_feat_dict['animated_smpl_verts'] + color_preds=self.mlp(xyz.permute(0,2,1).contiguous(),animated_smpl, + final_feat,smpl_feat,training=self.training,type=type) + + + else: + color_preds=self.mlp(xyz.permute(0,2,1).contiguous(),densely_smpl.permute(0,2,1), + final_feat,smpl_feat,training=self.training,type=type) + preds_list.append(color_preds) + + return preds_list + + + + + def get_error(self, preds_if_list, labels): + """calculate error + + Args: + preds_list (list): list of torch.tensor(B, 3, N) + labels (torch.tensor): (B, N_knn, N) + + Returns: + torch.tensor: error + """ + error_if = 0 + + for pred_id in range(len(preds_if_list)): + pred_if = preds_if_list[pred_id] + error_if += F.binary_cross_entropy(pred_if, labels) + + error_if /= len(preds_if_list) + + return error_if + + + def forward(self, in_tensor_dict): + + sample_tensor = in_tensor_dict["sample"] + calib_tensor = in_tensor_dict["calib"] + label_tensor = in_tensor_dict["label"] + + color_sample=in_tensor_dict["sample_color"] + color_label=in_tensor_dict["color"] + + + in_feat = self.filter(in_tensor_dict) + + + + preds_if_list = self.query(in_feat, + sample_tensor, + calib_tensor,type='shape') + + BCEloss = self.get_error(preds_if_list, label_tensor) + + color_preds=self.query(in_feat, + color_sample, + calib_tensor,type='color') + color_loss=self.color_loss(color_preds[0],color_label) + + + + if self.training: + + self.color3d_loss= color_loss + error=BCEloss+color_loss + self.grad_loss=torch.tensor(0.).float().to(BCEloss.device) + else: + error=BCEloss + + return preds_if_list[-1].detach(), error diff --git a/lib/net/HallucinatorNet.py b/lib/net/HallucinatorNet.py new file mode 100644 index 0000000000000000000000000000000000000000..e5074fab0898f6a4ad6d19c6cac09b716901d22f --- /dev/null +++ b/lib/net/HallucinatorNet.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.net.FBNet import define_G +from lib.net.net_util import init_net, VGGLoss +from lib.net.HGFilters import * +from lib.net.BasePIFuNet import BasePIFuNet +import torch +import torch.nn as nn + + +class Hallucinator(BasePIFuNet): + ''' + HG PIFu network uses Hourglass stacks as the image filter. + It does the following: + 1. Compute image feature stacks and store it in self.im_feat_list + self.im_feat_list[-1] is the last stack (output stack) + 2. Calculate calibration + 3. If training, it index on every intermediate stacks, + If testing, it index on the last stack. + 4. Classification. + 5. During training, error is calculated on all stacks. + ''' + + def __init__(self, cfg, error_term=nn.SmoothL1Loss()): + + super(Hallucinator, self).__init__(error_term=error_term) + + self.l1_loss = nn.SmoothL1Loss() + + self.opt = cfg.net + + if self.training: + self.vgg_loss = [VGGLoss()] + + self.in_nmlB = [ + item[0] for item in self.opt.in_nml + if '_B' in item[0] or item[0] == 'image' + ] + self.in_nmlL = [ + item[0] for item in self.opt.in_nml + if '_L' in item[0] or item[0] == 'image' + ] + self.in_nmlB_dim = sum([ + item[1] for item in self.opt.in_nml + if '_B' in item[0] or item[0] == 'image' + ]) + self.in_nmlL_dim = sum([ + item[1] for item in self.opt.in_nml + if '_L' in item[0] or item[0] == 'image' + ]) + + self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, + "instance") + self.netL = define_G(self.in_nmlL_dim, 3, 64, "global", 4, 9, 1, 3, + "instance") + + init_net(self) + + def forward(self, in_tensor): + + inB_list = [] + inL_list = [] + + for name in self.in_nmlB: + inB_list.append(in_tensor[name]) + for name in self.in_nmlL: + inL_list.append(in_tensor[name]) + + nmlB = self.netB(torch.cat(inB_list, dim=1)) + nmlL = self.netL(torch.cat(inL_list, dim=1)) + + # ||normal|| == 1 + nmlB = nmlB / torch.norm(nmlB, dim=1, keepdim=True) + nmlL = nmlL / torch.norm(nmlL, dim=1, keepdim=True) + + # output: float_arr [-1,1] with [B, C, H, W] + + mask = (in_tensor['image'].abs().sum(dim=1, keepdim=True) != + 0.0).detach().float() + + nmlB = nmlB * mask + #nmlL = nmlL * mask + + return nmlB, nmlL + + def get_norm_error(self, prd_B, prd_L, tgt): + """calculate normal loss + + Args: + pred (torch.tensor): [B, 6, 512, 512] + tagt (torch.tensor): [B, 6, 512, 512] + """ + + tgt_B, tgt_L = tgt['render_B'], tgt['render_L'] + + l1_B_loss = self.l1_loss(prd_B, tgt_B) + l1_L_loss = self.l1_loss(prd_L, tgt_L) + + with torch.no_grad(): + vgg_B_loss = self.vgg_loss[0](prd_B, tgt_B) + vgg_L_loss = self.vgg_loss[0](prd_L, tgt_L) + + total_loss = [ + 5.0 * l1_B_loss + vgg_B_loss, 5.0 * l1_L_loss + vgg_L_loss + ] + + return total_loss diff --git a/lib/net/MLP.py b/lib/net/MLP.py new file mode 100644 index 0000000000000000000000000000000000000000..ca0ee6d10348d14646fae7a5c00ee9d7e09acb4a --- /dev/null +++ b/lib/net/MLP.py @@ -0,0 +1,332 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch +import torch.nn as nn +import pytorch_lightning as pl +import torch.nn.functional as F +from torch.autograd import grad +# from fightingcv_attention.attention.SelfAttention import ScaledDotProductAttention +import numpy as np + +class SDF2Density(pl.LightningModule): + def __init__(self): + super(SDF2Density, self).__init__() + + # learnable parameters beta, with initial value 0.1 + self.beta = nn.Parameter(torch.tensor(0.1)) + + def forward(self, sdf): + # use Laplace CDF to compute the probability + # temporally use sigmoid to represent laplace CDF + return 1.0/(self.beta+1e-6)*F.sigmoid(-sdf/(self.beta+1e-6)) + +class SDF2Occ(pl.LightningModule): + def __init__(self): + super(SDF2Occ, self).__init__() + + # learnable parameters beta, with initial value 0.1 + self.beta = nn.Parameter(torch.tensor(0.1)) + + def forward(self, sdf): + # use Laplace CDF to compute the probability + # temporally use sigmoid to represent laplace CDF + return F.sigmoid(-sdf/(self.beta+1e-6)) + + +class DeformationMLP(pl.LightningModule): + def __init__(self,input_dim=64,output_dim=3,activation='LeakyReLU',name=None,opt=None): + super(DeformationMLP, self).__init__() + self.name = name + self.activation = activation + self.activate = nn.LeakyReLU(inplace=True) + # self.mlp = nn.Sequential( + # nn.Conv1d(input_dim+8+1+3, 64, 1), + # nn.LeakyReLU(inplace=True), + # nn.Conv1d(64, output_dim, 1), + # ) + channels=[input_dim+8+1+3,128, 64, output_dim] + self.deform_mlp=MLP(filter_channels=channels, + name="if", + res_layers=opt.res_layers, + norm=opt.norm_mlp, + last_op=None) # occupancy + smplx_dim = 10475 + k=8 + self.per_pt_code = nn.Embedding(smplx_dim,k) + + def forward(self, feature,smpl_vis,pts_id, xyz): + ''' + feature may include multiple view inputs + args: + feature: [B, C_in, N] + return: + [B, C_out, N] prediction + ''' + y = feature + e_code=self.per_pt_code(pts_id).permute(0,2,1) # a code that distinguishes each point on different parts of the body + y=torch.cat([y,xyz,smpl_vis,e_code],1) + y = self.deform_mlp(y) + return y + +class MLP(pl.LightningModule): + + def __init__(self, + filter_channels, + name=None, + res_layers=[], + norm='group', + last_op=None): + + super(MLP, self).__init__() + + self.filters = nn.ModuleList() + self.norms = nn.ModuleList() + self.res_layers = res_layers + self.norm = norm + self.last_op = last_op + self.name = name + self.activate = nn.LeakyReLU(inplace=True) + + for l in range(0, len(filter_channels) - 1): + if l in self.res_layers: + self.filters.append( + nn.Conv1d(filter_channels[l] + filter_channels[0], + filter_channels[l + 1], 1)) + else: + self.filters.append( + nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) + + if l != len(filter_channels) - 2: + if norm == 'group': + self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) + elif norm == 'batch': + self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) + elif norm == 'instance': + self.norms.append(nn.InstanceNorm1d(filter_channels[l + + 1])) + elif norm == 'weight': + self.filters[l] = nn.utils.weight_norm(self.filters[l], + name='weight') + # print(self.filters[l].weight_g.size(), + # self.filters[l].weight_v.size()) + + def forward(self, feature): + ''' + feature may include multiple view inputs + args: + feature: [B, C_in, N] + return: + [B, C_out, N] prediction + ''' + y = feature + tmpy = feature + + for i, f in enumerate(self.filters): + + y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) + if i != len(self.filters) - 1: + if self.norm not in ['batch', 'group', 'instance']: + y = self.activate(y) + else: + y = self.activate(self.norms[i](y)) + + if self.last_op is not None: + y = self.last_op(y) + + return y + + +# Positional encoding (section 5.1) +class Embedder(pl.LightningModule): + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x : x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires=6, i=0): + if i == -1: + return nn.Identity(), 3 + + embed_kwargs = { + 'include_input' : True, + 'input_dims' : 3, + 'max_freq_log2' : multires-1, + 'num_freqs' : multires, + 'log_sampling' : True, + 'periodic_fns' : [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj : eo.embed(x) + return embed, embedder_obj.out_dim + + +# Transformer encoder layer +# uses Embedder to add positional encoding to input points +# uses query points as query, deformed points as key, point features as value for attention +class TransformerEncoderLayer(pl.LightningModule): + def __init__(self, d_model=256, skips=4, multires=6, num_mlp_layers=8, dropout=0.1, opt=None): + super(TransformerEncoderLayer, self).__init__() + + embed_fn, input_ch = get_embedder(multires=multires) + self.skips=skips + self.dropout = dropout + D=num_mlp_layers + self.positional_encoding = embed_fn + self.d_model = d_model + triplane_dim=64 + opt.mlp_dim[0]=triplane_dim+6+8 + opt.mlp_dim_color[0]=triplane_dim+6+8 + + self.geo_mlp=MLP(filter_channels=opt.mlp_dim, + name="if", + res_layers=opt.res_layers, + norm=opt.norm_mlp, + last_op=nn.Sigmoid()) # occupancy + + self.color_mlp=MLP(filter_channels=opt.mlp_dim_color, + name="color_if", + res_layers=opt.res_layers, + norm=opt.norm_mlp, + last_op=nn.Tanh()) # color + + self.softmax = nn.Softmax(dim=-1) + + + + def forward(self,query_points,key_points,point_features,smpl_feat,training=True,type='shape'): + # Q=self.positional_encoding(query_points) #[B,N,39] + # K=self.positional_encoding(key_points) #[B,N',39] + # V=point_features.permute(0,2,1) #[B,N',192] + # t=0.1 + # #attn_output, attn_output_weights = self.attention(Q.permute(1,0,2), K.permute(1,0,2), V.permute(1,0,2)) #[B,N,192] + # attn_output_weights = torch.bmm(Q, K.transpose(1, 2)) #[B,N,N'] + # attn_output_weights = self.softmax(attn_output_weights/t) #[B,N,N'] + # # drop out + # attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True) + # # master feature + # attn_output = torch.bmm(attn_output_weights, V) #[B,N,192] + + attn_output=point_features # [B,N,192] bary centric interpolation + + feature=torch.cat([attn_output,smpl_feat],dim=1) + + if type=='shape': + h=feature + + h=self.geo_mlp(h) # [B,1,N] + return h + + + elif type=='color': + #f=self.head(feature) #[B,N,512] + + h=feature + + h=self.color_mlp(h) # [B,3,N] + return h + elif type=='shape_color': + h_s=feature + h_c=feature + + h_s=self.geo_mlp(h_s) # [B,1,N] + + h_c=self.color_mlp(h_c) # [B,3,N] + + return h_s,h_c + + + + +class Swish(pl.LightningModule): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + x = x * F.sigmoid(x) + return x + + + + + + + + + +# # Import pytorch modules +# import torch +# import torch.nn as nn +# import torch.nn.functional as F + +# Define positional encoding class +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=1000): + super(PositionalEncoding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + +# # Define model parameters +# d_model = 256 # output size of MLP +# nhead = 8 # number of attention heads +# dim_feedforward = 512 # hidden size of MLP +# num_layers = 2 # number of MLP layers +# num_frequencies = 6 # number of frequencies for positional encoding +# dropout = 0.1 # dropout rate + +# # Define model components +# pos_encoder = PositionalEncoding(d_model, num_frequencies) # positional encoding layer +# encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) # transformer encoder layer +# encoder = nn.TransformerEncoder(encoder_layer, num_layers) # transformer encoder +# mlp_geo = nn.Sequential(nn.Linear(3, d_model), nn.ReLU(), nn.Linear(d_model, d_model)) # MLP for geometry +# mlp_alb = nn.Sequential(nn.Linear(3, d_model), nn.ReLU(), nn.Linear(d_model, d_model)) # MLP for albedo +# head_geo = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 3)) # geometry head +# head_alb = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 3), nn.Sigmoid()) # albedo head + +# # Define input tensors +# # deformed body points: (batch_size, num_points, 3) +# x = torch.randn(batch_size, num_points, 3) +# # query point positions: (batch_size, num_queries, 3) +# y = torch.randn(batch_size, num_queries, 3) + +# # Map both d + + diff --git a/lib/net/NormalNet.py b/lib/net/NormalNet.py new file mode 100644 index 0000000000000000000000000000000000000000..cf180d81a18a992d02b118b45e163ff0d5cde69d --- /dev/null +++ b/lib/net/NormalNet.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.net.FBNet import define_G +from lib.net.net_util import init_net, VGGLoss +from lib.net.HGFilters import * +from lib.net.BasePIFuNet import BasePIFuNet +import torch +import torch.nn as nn + + +class NormalNet(BasePIFuNet): + ''' + HG PIFu network uses Hourglass stacks as the image filter. + It does the following: + 1. Compute image feature stacks and store it in self.im_feat_list + self.im_feat_list[-1] is the last stack (output stack) + 2. Calculate calibration + 3. If training, it index on every intermediate stacks, + If testing, it index on the last stack. + 4. Classification. + 5. During training, error is calculated on all stacks. + ''' + + def __init__(self, cfg, error_term=nn.SmoothL1Loss()): + + super(NormalNet, self).__init__(error_term=error_term) + + self.l1_loss = nn.SmoothL1Loss() + + self.opt = cfg.net + self.training=False + if self.training: + self.vgg_loss = [VGGLoss()] + + self.in_nmlF = [ + item[0] for item in self.opt.in_nml + if '_F' in item[0] or item[0] == 'image' + ] + self.in_nmlB = [ + item[0] for item in self.opt.in_nml + if '_B' in item[0] or item[0] == 'image' + ] + self.in_nmlF_dim = sum([ + item[1] for item in self.opt.in_nml + if '_F' in item[0] or item[0] == 'image' + ]) + self.in_nmlB_dim = sum([ + item[1] for item in self.opt.in_nml + if '_B' in item[0] or item[0] == 'image' + ]) + + self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, + "instance") + self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, + "instance") + + init_net(self) + + def forward(self, in_tensor): + + inF_list = [] + inB_list = [] + + for name in self.in_nmlF: + inF_list.append(in_tensor[name]) + for name in self.in_nmlB: + inB_list.append(in_tensor[name]) + + nmlF = self.netF(torch.cat(inF_list, dim=1)) + nmlB = self.netB(torch.cat(inB_list, dim=1)) + + # ||normal|| == 1 + nmlF = nmlF / torch.norm(nmlF, dim=1, keepdim=True) + nmlB = nmlB / torch.norm(nmlB, dim=1, keepdim=True) + + # output: float_arr [-1,1] with [B, C, H, W] + + mask = (in_tensor['image'].abs().sum(dim=1, keepdim=True) != + 0.0).detach().float() + + nmlF = nmlF * mask + nmlB = nmlB * mask + + return nmlF, nmlB + + def get_norm_error(self, prd_F, prd_B, tgt): + """calculate normal loss + + Args: + pred (torch.tensor): [B, 6, 512, 512] + tagt (torch.tensor): [B, 6, 512, 512] + """ + + tgt_F, tgt_B = tgt['normal_F'], tgt['normal_B'] + + l1_F_loss = self.l1_loss(prd_F, tgt_F) + l1_B_loss = self.l1_loss(prd_B, tgt_B) + + with torch.no_grad(): + vgg_F_loss = self.vgg_loss[0](prd_F, tgt_F) + vgg_B_loss = self.vgg_loss[0](prd_B, tgt_B) + + total_loss = [ + 5.0 * l1_F_loss + vgg_F_loss, 5.0 * l1_B_loss + vgg_B_loss + ] + + return total_loss diff --git a/lib/net/PymaridPoolingTransformer.py b/lib/net/PymaridPoolingTransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d02265700e9b58646cbd51e1e3ecaff6aa1489b4 --- /dev/null +++ b/lib/net/PymaridPoolingTransformer.py @@ -0,0 +1,397 @@ +from os import sep +from pickle import TRUE +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg + +import numpy as np + +__all__ = [ + 'p2t_tiny', 'p2t_small', 'p2t_base', 'p2t_large' +] + + +class IRB(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, ksize=3, act_layer=nn.Hardswish, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0) + self.act = act_layer() + self.conv = nn.Conv2d(hidden_features, hidden_features, kernel_size=ksize, padding=ksize//2, stride=1, groups=hidden_features) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.permute(0,2,1).reshape(B, C, H, W) + x = self.fc1(x) + x = self.act(x) + x = self.conv(x) + x = self.act(x) + x = self.fc2(x) + return x.reshape(B, C, -1).permute(0,2,1) + + +class PoolingAttention(nn.Module): + def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., + pool_ratios=[1,2,3,6]): + + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + self.num_elements = np.array([t*t for t in pool_ratios]).sum() + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Sequential(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = nn.Sequential(nn.Linear(dim, dim * 2, bias=qkv_bias)) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.pool_ratios = pool_ratios + self.pools = nn.ModuleList() + + self.norm = nn.LayerNorm(dim) + + def forward(self, x, H, W, d_convs=None): + B, N, C = x.shape + + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + pools = [] + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + for (pool_ratio, l) in zip(self.pool_ratios, d_convs): + pool = F.adaptive_avg_pool2d(x_, (round(H/pool_ratio), round(W/pool_ratio))) + pool = pool + l(pool) # fix backward bug in higher torch versions when training + pools.append(pool.view(B, C, -1)) + + pools = torch.cat(pools, dim=2) + pools = self.norm(pools.permute(0,2,1)) + + kv = self.kv(pools).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + x = (attn @ v) + x = x.transpose(1,2).contiguous().reshape(B, N, C) + + x = self.proj(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_ratios=[12,16,20,24]): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = PoolingAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, pool_ratios=pool_ratios) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = IRB(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=nn.Hardswish, drop=drop, ksize=3) + + def forward(self, x, H, W, d_convs=None): + x = x + self.drop_path(self.attn(self.norm1(x), H, W, d_convs=d_convs)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + +class PatchEmbed(nn.Module): + """ (Overlapped) Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, kernel_size=3, in_chans=3, embed_dim=768, overlap=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + if not overlap: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size, padding=kernel_size//2) + + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, (H, W) + + + +class PyramidPoolingTransformer(nn.Module): + def __init__(self, img_size=512, patch_size=2, in_chans=3, num_classes=1000, embed_dims=[64, 256, 320, 512], + num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0.1, norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 9, 3]): # + super().__init__() + self.num_classes = num_classes + self.depths = depths + + self.embed_dims = embed_dims + + # pyramid pooling ratios for each stage + pool_ratios = [[12,16,20,24], [6,8,10,12], [3,4,5,6], [1,2,3,4]] + + self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=4, kernel_size=7, in_chans=in_chans, + embed_dim=embed_dims[0], overlap=True) + + self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1], overlap=True) + self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2], overlap=True) + self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3], overlap=True) + + self.d_convs1 = nn.ModuleList([nn.Conv2d(embed_dims[0], embed_dims[0], kernel_size=3, stride=1, padding=1, groups=embed_dims[0]) for temp in pool_ratios[0]]) + self.d_convs2 = nn.ModuleList([nn.Conv2d(embed_dims[1], embed_dims[1], kernel_size=3, stride=1, padding=1, groups=embed_dims[1]) for temp in pool_ratios[1]]) + self.d_convs3 = nn.ModuleList([nn.Conv2d(embed_dims[2], embed_dims[2], kernel_size=3, stride=1, padding=1, groups=embed_dims[2]) for temp in pool_ratios[2]]) + self.d_convs4 = nn.ModuleList([nn.Conv2d(embed_dims[3], embed_dims[3], kernel_size=3, stride=1, padding=1, groups=embed_dims[3]) for temp in pool_ratios[3]]) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + + + ksize = 3 + + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, pool_ratios=pool_ratios[0]) + for i in range(depths[0])]) + + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, pool_ratios=pool_ratios[1]) + for i in range(depths[1])]) + + cur += depths[1] + + + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, pool_ratios=pool_ratios[2]) + for i in range(depths[2])]) + + cur += depths[2] + + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, pool_ratios=pool_ratios[3]) + for i in range(depths[3])]) + + # classification head + self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + self.gap = nn.AdaptiveAvgPool1d(1) + + self.apply(self._init_weights) + + #print(self) + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + @torch.jit.ignore + def no_weight_decay(self): + # return {'pos_embed', 'cls_token'} # has pos_embed may be better + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + + # stage 1 + x, (H, W) = self.patch_embed1(x) + + for idx, blk in enumerate(self.block1): + x = blk(x, H, W, self.d_convs1) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + + # stage 2 + x, (H, W) = self.patch_embed2(x) + + for idx, blk in enumerate(self.block2): + x = blk(x, H, W, self.d_convs2) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + + # # stage 3 + # x, (H, W) = self.patch_embed3(x) + + # for idx, blk in enumerate(self.block3): + # x = blk(x, H, W, self.d_convs3) + # x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + + # # stage 4 + # x, (H, W) = self.patch_embed4(x) + + # for idx, blk in enumerate(self.block4): + # x = blk(x, H, W, self.d_convs4) + + return x + + def forward_features_for_fpn(self, x): + outs = [] + + B = x.shape[0] + + # stage 1 + x, (H, W) = self.patch_embed1(x) + + for idx, blk in enumerate(self.block1): + x = blk(x, H, W, self.d_convs1) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + outs.append(x) + + # stage 2 + x, (H, W) = self.patch_embed2(x) + + for idx, blk in enumerate(self.block2): + x = blk(x, H, W, self.d_convs2) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + outs.append(x) + + x, (H, W) = self.patch_embed3(x) + + for idx, blk in enumerate(self.block3): + x = blk(x, H, W, self.d_convs3) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + outs.append(x) + + # stage 4 + x, (H, W) = self.patch_embed4(x) + + for idx, blk in enumerate(self.block4): + x = blk(x, H, W, self.d_convs4) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2) + outs.append(x) + + return outs + + def forward(self, x): + x = self.forward_features(x) + # x = torch.mean(x, dim=1) + # x = self.head(x) + + return x + + def forward_for_fpn(self, x): + return self.forward_features_for_fpn(x) + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + + return out_dict + + +@register_model +def p2t_tiny(pretrained=False, **kwargs): + model = PyramidPoolingTransformer( + patch_size=4, embed_dims=[48, 96, 240, 384], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 6, 3], + **kwargs) + model.default_cfg = _cfg() + + return model + +@register_model +def p2t_small(pretrained=True, **kwargs): + model = PyramidPoolingTransformer( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 9, 3], **kwargs) + model.default_cfg = _cfg() + + return model + +@register_model +def p2t_base(pretrained=False, **kwargs): + model = PyramidPoolingTransformer( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], + **kwargs) + model.default_cfg = _cfg() + + return model + +@register_model +def p2t_medium(pretrained=False, **kwargs): + model = PyramidPoolingTransformer( + patch_size=4, embed_dims=[64, 128, 384, 512], num_heads=[1, 2, 6, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 15, 3], + **kwargs) + model.default_cfg = _cfg() + + return model + +@register_model +def p2t_large(pretrained=False, **kwargs): + model = PyramidPoolingTransformer( + patch_size=4, embed_dims=[64, 128, 320, 640], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], + **kwargs) + model.default_cfg = _cfg() + + return model diff --git a/lib/net/ResBlkPIFuNet.py b/lib/net/ResBlkPIFuNet.py new file mode 100644 index 0000000000000000000000000000000000000000..3033e2c4d61a3fbaef77d57cf43b07822c18940c --- /dev/null +++ b/lib/net/ResBlkPIFuNet.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .BasePIFuNet import BasePIFuNet +import functools + +from .net_util import * +from lib.dataset.PointFeat import PointFeat +from lib.dataset.mesh_util import feat_select + + +# class ResBlkPIFuNet(BasePIFuNet): +# def __init__(self, opt, +# projection_mode='orthogonal'): +# if opt.color_loss_type == 'l1': +# error_term = nn.L1Loss() +# elif opt.color_loss_type == 'mse': +# error_term = nn.MSELoss() + +# super(ResBlkPIFuNet, self).__init__( +# projection_mode=projection_mode, +# error_term=error_term) + +# self.name = 'respifu' +# self.opt = opt +# self.smpl_feats = self.opt.smpl_feats +# norm_type = get_norm_layer(norm_type=opt.norm_color) +# self.image_filter = ResnetFilter(opt, norm_layer=norm_type) +# self.smpl_feat_dict=None + +# self.surface_classifier = SurfaceClassifier( +# filter_channels=self.opt.mlp_dim_color, +# num_views=self.opt.num_views, +# no_residual=self.opt.no_residual, +# last_op=nn.Tanh()) + +# self.normalizer = DepthNormalizer(opt) + +# init_net(self) + +# def filter(self, images): +# ''' +# Filter the input images +# store all intermediate features. +# :param images: [B, C, H, W] input images +# ''' +# self.im_feat = self.image_filter(images) + +# def attach(self, im_feat): +# #self.im_feat = torch.cat([im_feat, self.im_feat], 1) +# self.geo_feat=im_feat + +# def query(self, points, calibs, transforms=None, labels=None): +# ''' +# Given 3D points, query the network predictions for each point. +# Image features should be pre-computed before this call. +# store all intermediate features. +# query() function may behave differently during training/testing. +# :param points: [B, 3, N] world space coordinates of points +# :param calibs: [B, 3, 4] calibration matrices for each image +# :param transforms: Optional [B, 2, 3] image space coordinate transforms +# :param labels: Optional [B, Res, N] gt labeling +# :return: [B, Res, N] predictions for each point +# ''' +# if labels is not None: +# self.labels = labels + + +# xyz = self.projection(points, calibs, transforms) +# xy = xyz[:, :2, :] +# z = xyz[:, 2:3, :] + +# z_feat = self.normalizer(z) + + +# if self.smpl_feat_dict==None: +# # This is a list of [B, Feat_i, N] features +# point_local_feat_list = [self.index(self.im_feat, xy), z_feat] +# # [B, Feat_all, N] +# point_local_feat = torch.cat(point_local_feat_list, 1) + +# self.preds = self.surface_classifier(point_local_feat) +# else: +# point_feat_extractor = PointFeat(self.smpl_feat_dict["smpl_verts"], +# self.smpl_feat_dict["smpl_faces"]) +# point_feat_out = point_feat_extractor.query( +# xyz.permute(0, 2, 1).contiguous(), self.smpl_feat_dict) + +# feat_lst = [ +# point_feat_out[key] for key in self.smpl_feats +# if key in point_feat_out.keys() +# ] +# smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1) +# point_normal_feat = feat_select(self.index(self.geo_feat, xy), # select front or back normal feature +# smpl_feat[:, [-1], :]) +# point_color_feat = torch.cat([self.index(self.im_feat, xy), z_feat],1) +# point_feat_list = [point_normal_feat, point_color_feat, smpl_feat[:, :-1, :]] +# point_feat = torch.cat(point_feat_list, 1) +# self.preds = self.surface_classifier(point_feat) + +# def forward(self, images, im_feat, points, calibs, transforms=None, labels=None): + +# self.filter(images) + +# self.attach(im_feat) + + +# self.query(points, calibs, transforms, labels) + + +# error = self.get_error(self.preds,self.labels) + +# return self.preds, error + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False): + """Initialize the Resnet block + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False): + """Construct a convolutional block. + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + if last: + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + else: + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class ResnetFilter(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, + n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert (n_blocks >= 0) + super(ResnetFilter, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + if i == n_blocks - 1: + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=use_bias, last=True)] + else: + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=use_bias)] + + if opt.use_tanh: + model += [nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) diff --git a/lib/net/Transformer.py b/lib/net/Transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a4c3dc088d244f2527ecdf2b27252687a53921 --- /dev/null +++ b/lib/net/Transformer.py @@ -0,0 +1,452 @@ +# ------------------------------------------------------------------------------------ +# Enhancing Transformers +# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. +# Licensed under the MIT License [see LICENSE for details] +# ------------------------------------------------------------------------------------ +# Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch) +# Copyright (c) 2020 Phil Wang. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import math +import numpy as np +from typing import Union, Tuple, List, Optional +from functools import partial +import pytorch_lightning as pl + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +def get_2d_sincos_pos_embed(embed_dim, grid_size): + """ + grid_size: int or (int, int) of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def init_weights(m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + w = m.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: nn.Module) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, dim) + ) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None: + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out) + +class CrossAttention(nn.Module): + def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None: + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.norm = nn.LayerNorm(dim) + + self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() + self.multi_head_attention=PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)) + + + def forward(self, x: torch.FloatTensor, q_x:torch.FloatTensor) -> torch.FloatTensor: + + q_in = self.multi_head_attention(q_x)+q_x + q_in = self.norm(q_in) + + q = rearrange(self.to_q(q_in),'b n (h d) -> b h n d', h = self.heads) + kv = self.to_kv(x).chunk(2, dim = -1) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv) + + attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out),q_in + + +class Transformer(nn.Module): + def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None: + super().__init__() + self.layers = nn.ModuleList([]) + for idx in range(depth): + layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)), + PreNorm(dim, FeedForward(dim, mlp_dim))]) + self.layers.append(layer) + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + +class CrossTransformer(nn.Module): + def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None: + super().__init__() + self.layers = nn.ModuleList([]) + for idx in range(depth): + layer = nn.ModuleList([CrossAttention(dim, heads=heads, dim_head=dim_head), + PreNorm(dim, FeedForward(dim, mlp_dim))]) + self.layers.append(layer) + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.FloatTensor, q_x:torch.FloatTensor) -> torch.FloatTensor: + encoder_output=x + for attn, ff in self.layers: + x,q_in = attn(encoder_output, q_x) + x = x + q_in + x = ff(x) + x + q_x=x + + return self.norm(q_x) + +class ViTEncoder(nn.Module): + def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], + dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None: + super().__init__() + image_height, image_width = image_size if isinstance(image_size, tuple) \ + else (image_size, image_size) + patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ + else (patch_size, patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) + + self.num_patches = (image_height // patch_height) * (image_width // patch_width) + self.patch_dim = channels * patch_height * patch_width + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b c h w -> b (h w) c'), + ) + self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.apply(init_weights) + + def forward(self, img: torch.FloatTensor) -> torch.FloatTensor: + x = self.to_patch_embedding(img) + x = x + self.en_pos_embedding + x = self.transformer(x) + + return x + + +class ViTDecoder(nn.Module): + def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], + dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 32, dim_head: int = 64) -> None: + super().__init__() + image_height, image_width = image_size if isinstance(image_size, tuple) \ + else (image_size, image_size) + patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ + else (patch_size, patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) + + self.num_patches = (image_height // patch_height) * (image_width // patch_width) + self.patch_dim = channels * patch_height * patch_width + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False) + self.to_pixel = nn.Sequential( + Rearrange('b (h w) c -> b c h w', h=image_height // patch_height), + nn.ConvTranspose2d(dim, channels, kernel_size=4, stride=4) + ) + + self.apply(init_weights) + + def forward(self, token: torch.FloatTensor) -> torch.FloatTensor: + x = token + self.de_pos_embedding + x = self.transformer(x) + x = self.to_pixel(x) + + return x + + def get_last_layer(self) -> nn.Parameter: + return self.to_pixel[-1].weight + + +class CrossAttDecoder(nn.Module): + def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], + dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 32, dim_head: int = 64) -> None: + super().__init__() + image_height, image_width = image_size if isinstance(image_size, tuple) \ + else (image_size, image_size) + patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ + else (patch_size, patch_size) + + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b c h w -> b (h w) c'), + ) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) + + self.num_patches = (image_height // patch_height) * (image_width // patch_width) + self.patch_dim = channels * patch_height * patch_width + + self.transformer = CrossTransformer(dim, depth, heads, dim_head, mlp_dim) + + self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False) + self.to_pixel = nn.Sequential( + Rearrange('b (h w) c -> b c h w', h=image_height // patch_height), + nn.ConvTranspose2d(dim, channels, kernel_size=4, stride=4) + ) + + self.apply(init_weights) + + def forward(self, token: torch.FloatTensor, query_img:torch.FloatTensor) -> torch.FloatTensor: + # batch_size=token.shape[0] + # query=self.query.repeat(batch_size,1,1)+self.de_pos_embedding + query=self.to_patch_embedding(query_img)+self.de_pos_embedding + x = token + self.de_pos_embedding + x = self.transformer(x,query) + x = self.to_pixel(x) + + return x + + def get_last_layer(self) -> nn.Parameter: + return self.to_pixel[-1].weight + + +class BaseQuantizer(nn.Module): + def __init__(self, embed_dim: int, n_embed: int, straight_through: bool = True, use_norm: bool = True, + use_residual: bool = False, num_quantizers: Optional[int] = None) -> None: + super().__init__() + self.straight_through = straight_through + self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x + + self.use_residual = use_residual + self.num_quantizers = num_quantizers + + self.embed_dim = embed_dim + self.n_embed = n_embed + + self.embedding = nn.Embedding(self.n_embed, self.embed_dim) + self.embedding.weight.data.normal_() + + def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: + pass + + def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: + if not self.use_residual: + z_q, loss, encoding_indices = self.quantize(z) + else: + z_q = torch.zeros_like(z) + residual = z.detach().clone() + + losses = [] + encoding_indices = [] + + for _ in range(self.num_quantizers): + z_qi, loss, indices = self.quantize(residual.clone()) + residual.sub_(z_qi) + z_q.add_(z_qi) + + encoding_indices.append(indices) + losses.append(loss) + + losses, encoding_indices = map(partial(torch.stack, dim = -1), (losses, encoding_indices)) + loss = losses.mean() + + # preserve gradients with straight-through estimator + if self.straight_through: + z_q = z + (z_q - z).detach() + + return z_q, loss, encoding_indices + + +class VectorQuantizer(BaseQuantizer): + def __init__(self, embed_dim: int, n_embed: int, beta: float = 0.25, use_norm: bool = True, + use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None: + super().__init__(embed_dim, n_embed, True, + use_norm, use_residual, num_quantizers) + + self.beta = beta + + def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: + z_reshaped_norm = self.norm(z.view(-1, self.embed_dim)) + embedding_norm = self.norm(self.embedding.weight) + + d = torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) + \ + torch.sum(embedding_norm ** 2, dim=1) - 2 * \ + torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm) + + encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + encoding_indices = encoding_indices.view(*z.shape[:-1]) + + z_q = self.embedding(encoding_indices).view(z.shape) + z_qnorm, z_norm = self.norm(z_q), self.norm(z) + + # compute loss for embedding + loss = self.beta * torch.mean((z_qnorm.detach() - z_norm)**2) + \ + torch.mean((z_qnorm - z_norm.detach())**2) + + return z_qnorm, loss, encoding_indices + + +class ViTVQ(pl.LightningModule): + def __init__(self,image_size=512, patch_size=16,channels=3) -> None: + super().__init__() + + self.encoder = ViTEncoder(image_size=image_size, patch_size=patch_size, dim=256,depth=8,heads=8,mlp_dim=2048,channels=channels) + self.F_decoder = ViTDecoder(image_size=image_size, patch_size=patch_size, dim=256,depth=3,heads=8,mlp_dim=2048) + self.B_decoder= CrossAttDecoder(image_size=image_size, patch_size=patch_size, dim=256,depth=3,heads=8,mlp_dim=2048) + self.R_decoder= CrossAttDecoder(image_size=image_size, patch_size=patch_size, dim=256,depth=3,heads=8,mlp_dim=2048) + self.L_decoder= CrossAttDecoder(image_size=image_size, patch_size=patch_size, dim=256,depth=3,heads=8,mlp_dim=2048) + # self.quantizer = VectorQuantizer(embed_dim=32,n_embed=8192) + # self.pre_quant = nn.Linear(512, 32) + # self.post_quant = nn.Linear(32, 512) + + + def forward(self, x: torch.FloatTensor,smpl_normal) -> torch.FloatTensor: + enc_out = self.encode(x) + dec = self.decode(enc_out,smpl_normal) + + return dec + + + def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + h = self.encoder(x) + # h = self.pre_quant(h) + # quant, emb_loss, _ = self.quantizer(h) + + return h #, emb_loss + + def decode(self, enc_out: torch.FloatTensor,smpl_normal) -> torch.FloatTensor: + back_query=smpl_normal['T_normal_B'] + right_query=smpl_normal['T_normal_R'] + left_query=smpl_normal['T_normal_L'] + # quant = self.post_quant(quant) + dec_F = self.F_decoder(enc_out) + dec_B = self.B_decoder(enc_out,back_query) + dec_R = self.R_decoder(enc_out,right_query) + dec_L = self.L_decoder(enc_out,left_query) + + return (dec_F,dec_B,dec_R,dec_L) + + # def encode_codes(self, x: torch.FloatTensor) -> torch.LongTensor: + # h = self.encoder(x) + # h = self.pre_quant(h) + # _, _, codes = self.quantizer(h) + + # return codes + + # def decode_codes(self, code: torch.LongTensor) -> torch.FloatTensor: + # quant = self.quantizer.embedding(code) + # quant = self.quantizer.norm(quant) + + # if self.quantizer.use_residual: + # quant = quant.sum(-2) + + # dec = self.decode(quant) + + # return dec \ No newline at end of file diff --git a/lib/net/UNet.py b/lib/net/UNet.py new file mode 100644 index 0000000000000000000000000000000000000000..11953793d625584651f22134bde33d79f5aa25c0 --- /dev/null +++ b/lib/net/UNet.py @@ -0,0 +1,127 @@ +""" Parts of the U-Net model """ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + + + +""" Full assembly of the parts to form the complete network """ + + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=False): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = (DoubleConv(n_channels, 64)) + self.down1 = (Down(64, 128)) + self.down2 = (Down(128, 256)) + self.down3 = (Down(256, 512)) + factor = 2 if bilinear else 1 + self.down4 = (Down(512, 1024 // factor)) + self.up1 = (Up(1024, 512 // factor, bilinear)) + self.up2 = (Up(512, 256 // factor, bilinear)) + self.up3 = (Up(256, 128 // factor, bilinear)) + self.up4 = (Up(128, 64, bilinear)) + self.outc = (OutConv(64, n_classes)) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits + + def use_checkpointing(self): + self.inc = torch.utils.checkpoint(self.inc) + self.down1 = torch.utils.checkpoint(self.down1) + self.down2 = torch.utils.checkpoint(self.down2) + self.down3 = torch.utils.checkpoint(self.down3) + self.down4 = torch.utils.checkpoint(self.down4) + self.up1 = torch.utils.checkpoint(self.up1) + self.up2 = torch.utils.checkpoint(self.up2) + self.up3 = torch.utils.checkpoint(self.up3) + self.up4 = torch.utils.checkpoint(self.up4) + self.outc = torch.utils.checkpoint(self.outc) \ No newline at end of file diff --git a/lib/net/VE.py b/lib/net/VE.py new file mode 100644 index 0000000000000000000000000000000000000000..2158a6aaa46090bbf5505fb9c9e7d1aed909735d --- /dev/null +++ b/lib/net/VE.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch.nn as nn +import pytorch_lightning as pl + + +class BaseNetwork(pl.LightningModule): + + def __init__(self): + super(BaseNetwork, self).__init__() + + def init_weights(self, init_type='xavier', gain=0.02): + ''' + initializes network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + +class Residual3D(BaseNetwork): + + def __init__(self, numIn, numOut): + super(Residual3D, self).__init__() + self.numIn = numIn + self.numOut = numOut + self.with_bias = True + # self.bn = nn.GroupNorm(4, self.numIn) + self.bn = nn.BatchNorm3d(self.numIn) + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv3d(self.numIn, + self.numOut, + bias=self.with_bias, + kernel_size=3, + stride=1, + padding=2, + dilation=2) + # self.bn1 = nn.GroupNorm(4, self.numOut) + self.bn1 = nn.BatchNorm3d(self.numOut) + self.conv2 = nn.Conv3d(self.numOut, + self.numOut, + bias=self.with_bias, + kernel_size=3, + stride=1, + padding=1) + # self.bn2 = nn.GroupNorm(4, self.numOut) + self.bn2 = nn.BatchNorm3d(self.numOut) + self.conv3 = nn.Conv3d(self.numOut, + self.numOut, + bias=self.with_bias, + kernel_size=3, + stride=1, + padding=1) + + if self.numIn != self.numOut: + self.conv4 = nn.Conv3d(self.numIn, + self.numOut, + bias=self.with_bias, + kernel_size=1) + self.init_weights() + + def forward(self, x): + residual = x + # out = self.bn(x) + # out = self.relu(out) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + # out = self.conv3(out) + # out = self.relu(out) + + if self.numIn != self.numOut: + residual = self.conv4(x) + + return out + residual + + +class VolumeEncoder(BaseNetwork): + """CycleGan Encoder""" + + def __init__(self, num_in=3, num_out=32, num_stacks=2): + super(VolumeEncoder, self).__init__() + self.num_in = num_in + self.num_out = num_out + self.num_inter = 8 + self.num_stacks = num_stacks + self.with_bias = True + + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv3d(self.num_in, + self.num_inter, + bias=self.with_bias, + kernel_size=5, + stride=2, + padding=4, + dilation=2) + # self.bn1 = nn.GroupNorm(4, self.num_inter) + self.bn1 = nn.BatchNorm3d(self.num_inter) + self.conv2 = nn.Conv3d(self.num_inter, + self.num_out, + bias=self.with_bias, + kernel_size=5, + stride=2, + padding=4, + dilation=2) + # self.bn2 = nn.GroupNorm(4, self.num_out) + self.bn2 = nn.BatchNorm3d(self.num_out) + + self.conv_out1 = nn.Conv3d(self.num_out, + self.num_out, + bias=self.with_bias, + kernel_size=3, + stride=1, + padding=1, + dilation=1) + self.conv_out2 = nn.Conv3d(self.num_out, + self.num_out, + bias=self.with_bias, + kernel_size=3, + stride=1, + padding=1, + dilation=1) + + for idx in range(self.num_stacks): + self.add_module("res" + str(idx), + Residual3D(self.num_out, self.num_out)) + + self.init_weights() + + def forward(self, x, intermediate_output=True): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out_lst = [] + for idx in range(self.num_stacks): + out = self._modules["res" + str(idx)](out) + out_lst.append(out) + + if intermediate_output: + return out_lst + else: + return [out_lst[-1]] diff --git a/lib/net/__init__.py b/lib/net/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c0733316a97ee387fd81472e69482552b376f9 --- /dev/null +++ b/lib/net/__init__.py @@ -0,0 +1,6 @@ +from .BasePIFuNet import BasePIFuNet +from .HGPIFuNet import HGPIFuNet +from .NormalNet import NormalNet +from .VE import VolumeEncoder +from .UNet import UNet +from .HallucinatorNet import Hallucinator diff --git a/lib/net/geometry.py b/lib/net/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..270c3aeefafc1674c2c114432ae6abf9a6797002 --- /dev/null +++ b/lib/net/geometry.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch + + +def index(feat, uv): + ''' + :param feat: [B, C, H, W] image features + :param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1] + :return: [B, C, N] image features at the uv coordinates + ''' + uv = uv.transpose(1, 2) # [B, N, 2] + + (B, N, _) = uv.shape + C = feat.shape[1] + + if uv.shape[-1] == 3: + # uv = uv[:,:,[2,1,0]] + # uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...] + uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3] + else: + uv = uv.unsqueeze(2) # [B, N, 1, 2] + + # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample + # for old versions, simply remove the aligned_corners argument. + samples = torch.nn.functional.grid_sample( + feat, uv, align_corners=True) # [B, C, N, 1] + #samples = grid_sample(feat, uv) # [B, C, N, 1] + return samples.view(B, C, N) # [B, C, N] + + +def grid_sample(image, optical): + N, C, IH, IW = image.shape + _, H, W, _ = optical.shape + + ix = optical[..., 0] + iy = optical[..., 1] + + ix = ((ix + 1) / 2) * (IW-1); + iy = ((iy + 1) / 2) * (IH-1); + with torch.no_grad(): + ix_nw = torch.floor(ix); + iy_nw = torch.floor(iy); + ix_ne = ix_nw + 1; + iy_ne = iy_nw; + ix_sw = ix_nw; + iy_sw = iy_nw + 1; + ix_se = ix_nw + 1; + iy_se = iy_nw + 1; + + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) + + with torch.no_grad(): + torch.clamp(ix_nw, 0, IW-1, out=ix_nw) + torch.clamp(iy_nw, 0, IH-1, out=iy_nw) + + torch.clamp(ix_ne, 0, IW-1, out=ix_ne) + torch.clamp(iy_ne, 0, IH-1, out=iy_ne) + + torch.clamp(ix_sw, 0, IW-1, out=ix_sw) + torch.clamp(iy_sw, 0, IH-1, out=iy_sw) + + torch.clamp(ix_se, 0, IW-1, out=ix_se) + torch.clamp(iy_se, 0, IH-1, out=iy_se) + + image = image.view(N, C, IH * IW) + + + nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) + ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) + sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) + se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) + + out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) + + return out_val + +def orthogonal(points, calibrations, transforms=None): + ''' + Compute the orthogonal projections of 3D points into the image plane by given projection matrix + :param points: [B, 3, N] Tensor of 3D points + :param calibrations: [B, 3, 4] Tensor of projection matrix + :param transforms: [B, 2, 3] Tensor of image transform matrix + :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane + ''' + rot = calibrations[:, :3, :3] + trans = calibrations[:, :3, 3:4] + pts = torch.baddbmm(trans, rot, points) # [B, 3, N] + if transforms is not None: + scale = transforms[:2, :2] + shift = transforms[:2, 2:3] + pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) + return pts + + +def perspective(points, calibrations, transforms=None): + ''' + Compute the perspective projections of 3D points into the image plane by given projection matrix + :param points: [Bx3xN] Tensor of 3D points + :param calibrations: [Bx3x4] Tensor of projection matrix + :param transforms: [Bx2x3] Tensor of image transform matrix + :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane + ''' + rot = calibrations[:, :3, :3] + trans = calibrations[:, :3, 3:4] + homo = torch.baddbmm(trans, rot, points) # [B, 3, N] + xy = homo[:, :2, :] / homo[:, 2:3, :] + if transforms is not None: + scale = transforms[:2, :2] + shift = transforms[:2, 2:3] + xy = torch.baddbmm(shift, scale, xy) + + xyz = torch.cat([xy, homo[:, 2:3, :]], 1) + return xyz diff --git a/lib/net/local_affine.py b/lib/net/local_affine.py new file mode 100644 index 0000000000000000000000000000000000000000..0292f13cd108d5373235642c23534e9d1ab9945e --- /dev/null +++ b/lib/net/local_affine.py @@ -0,0 +1,61 @@ +# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology. +# All rights reserved. +# This file is part of the pytorch-nicp, +# and is released under the "MIT License Agreement". Please see the LICENSE +# file that should have been included as part of this package. + +import torch +import torch.nn as nn +import torch.sparse as sp + + +# reference: https://github.com/wuhaozhe/pytorch-nicp +class LocalAffine(nn.Module): + + def __init__(self, num_points, batch_size=1, edges=None): + ''' + specify the number of points, the number of points should be constant across the batch + and the edges torch.Longtensor() with shape N * 2 + the local affine operator supports batch operation + batch size must be constant + add additional pooling on top of w matrix + ''' + super(LocalAffine, self).__init__() + self.A = nn.Parameter( + torch.eye(3).unsqueeze(0).unsqueeze(0).repeat( + batch_size, num_points, 1, 1)) + self.b = nn.Parameter( + torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat( + batch_size, num_points, 1, 1)) + self.edges = edges + self.num_points = num_points + + def stiffness(self): + ''' + calculate the stiffness of local affine transformation + f norm get infinity gradient when w is zero matrix, + ''' + if self.edges is None: + raise Exception("edges cannot be none when calculate stiff") + idx1 = self.edges[:, 0] + idx2 = self.edges[:, 1] + affine_weight = torch.cat((self.A, self.b), dim=3) + w1 = torch.index_select(affine_weight, dim=1, index=idx1) + w2 = torch.index_select(affine_weight, dim=1, index=idx2) + w_diff = (w1 - w2)**2 + w_rigid = (torch.linalg.det(self.A) - 1.0)**2 + return w_diff, w_rigid + + def forward(self, x, return_stiff=False): + ''' + x should have shape of B * N * 3 + ''' + x = x.unsqueeze(3) + out_x = torch.matmul(self.A, x) + out_x = out_x + self.b + out_x.squeeze_(3) + if return_stiff: + stiffness, rigid = self.stiffness() + return out_x, stiffness, rigid + else: + return out_x diff --git a/lib/net/nerf_util.py b/lib/net/nerf_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e25832723089a3cbe60b85f7d7d5f84cd508261c --- /dev/null +++ b/lib/net/nerf_util.py @@ -0,0 +1,262 @@ +import torch +import numpy as np +import cv2 +import torch.nn.functional as F +from PIL import Image + + + +def project(xyz, K, RT): + """ + xyz: [N, 3] + K: [3, 3] + RT: [3, 4] + """ + xyz = np.dot(RT[:, :3],xyz.T).T + RT[:, 3:].T + xyz = np.dot(K,xyz.T).T + xy = xyz[:, :2] + 256 + return xy + + +def get_rays(H, W, K, R, T): + # w2c=np.concatenate([R,T],axis=1) + # w2c=np.concatenate([w2c,[[0,0,0,1]]],axis=0) + # c2w=np.linalg.inv(w2c) + # i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') + # dirs = np.stack([(i-256)/K[0][0], -(j-256)/K[1][1], -np.ones_like(i)], -1) + # # Rotate ray directions from camera frame to the world frame + # rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] + # # Translate camera frame's origin to the world frame. It is the origin of all rays. + # rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) + # calculate the camera origin + rays_o = -np.dot(np.linalg.inv(R), T).ravel()+np.array([0,0,500]) + # calculate the world coordinates of pixels + i, j = np.meshgrid(np.arange(W, dtype=np.float32), + np.arange(H, dtype=np.float32), + indexing='xy') + #xy1 = np.stack([i, j, np.ones_like(i)], axis=2) + pixel_camera = np.stack([(i-256)/K[0][0], -(j-256)/K[1][1], -np.ones_like(i)], -1) + pixel_world = np.dot(R.T, (pixel_camera - T.ravel()).reshape(-1,3).T).T.reshape(H,W,3) + # calculate the ray direction + rays_d = pixel_world - rays_o[None, None] + rays_d = rays_d / np.linalg.norm(rays_d, axis=2, keepdims=True) + rays_o = np.broadcast_to(rays_o, rays_d.shape) + return rays_o, rays_d + + +def get_bound_corners(bounds): + min_x, min_y, min_z = bounds[0] + max_x, max_y, max_z = bounds[1] + corners_3d = np.array([ + [min_x, min_y, min_z], + [min_x, min_y, max_z], + [min_x, max_y, min_z], + [min_x, max_y, max_z], + [max_x, min_y, min_z], + [max_x, min_y, max_z], + [max_x, max_y, min_z], + [max_x, max_y, max_z], + ]) + return corners_3d + + +def get_bound_2d_mask(bounds, K, pose, H, W): + corners_3d = get_bound_corners(bounds) + corners_2d = project(corners_3d, K, pose) + corners_2d = np.round(corners_2d).astype(int) + mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(mask, [corners_2d[[0, 1, 3, 2, 0]]], 1) + cv2.fillPoly(mask, [corners_2d[[4, 5, 7, 6, 4]]], 1) + cv2.fillPoly(mask, [corners_2d[[0, 1, 5, 4, 0]]], 1) + cv2.fillPoly(mask, [corners_2d[[2, 3, 7, 6, 2]]], 1) + cv2.fillPoly(mask, [corners_2d[[0, 2, 6, 4, 0]]], 1) + cv2.fillPoly(mask, [corners_2d[[1, 3, 7, 5, 1]]], 1) + return mask + + +def get_near_far(bounds, ray_o, ray_d): + """calculate intersections with 3d bounding box""" + bounds = bounds + np.array([-0.01, 0.01])[:, None] + nominator = bounds[None] - ray_o[:, None] + # calculate the step of intersections at six planes of the 3d bounding box + d_intersect = (nominator / (ray_d[:, None] + 1e-9)).reshape(-1, 6) + # calculate the six interections + p_intersect = d_intersect[..., None] * ray_d[:, None] + ray_o[:, None] + # calculate the intersections located at the 3d bounding box + min_x, min_y, min_z, max_x, max_y, max_z = bounds.ravel() + eps = 1e-6 + p_mask_at_box = (p_intersect[..., 0] >= (min_x - eps)) * \ + (p_intersect[..., 0] <= (max_x + eps)) * \ + (p_intersect[..., 1] >= (min_y - eps)) * \ + (p_intersect[..., 1] <= (max_y + eps)) * \ + (p_intersect[..., 2] >= (min_z - eps)) * \ + (p_intersect[..., 2] <= (max_z + eps)) + # obtain the intersections of rays which intersect exactly twice + mask_at_box = p_mask_at_box.sum(-1) == 2 + p_intervals = p_intersect[mask_at_box][p_mask_at_box[mask_at_box]].reshape( + -1, 2, 3) + + # calculate the step of intersections + ray_o = ray_o[mask_at_box] + ray_d = ray_d[mask_at_box] + norm_ray = np.linalg.norm(ray_d, axis=1) + d0 = np.linalg.norm(p_intervals[:, 0] - ray_o, axis=1) / norm_ray + d1 = np.linalg.norm(p_intervals[:, 1] - ray_o, axis=1) / norm_ray + near = np.minimum(d0, d1) + far = np.maximum(d0, d1) + + return near, far, mask_at_box + + +def sample_ray_h36m(img, msk, K, R, T, bounds, nrays, training = True): + H, W = img.shape[:2] + K[2,2]=1 + ray_o, ray_d = get_rays(H, W, K, R, T) # world coordinate + + pose = np.concatenate([R, T], axis=1) + bound_mask = get_bound_2d_mask(bounds, K, pose, H, W) # 可视化bound mask + # # bound_mask [512,512] + # # save bound mask as image + # bound_mask = bound_mask.astype(np.uint8) + # bound_mask = bound_mask * 255 + # bound_mask = Image.fromarray(bound_mask) + # msk_image=Image.fromarray(msk) + # bound_mask.save('bound_mask.png') + # msk_image.save('msk.png') + + + img[bound_mask != 1] = 0 + + #msk = msk * bound_mask + + + if training: + nsampled_rays = 0 + # face_sample_ratio = cfg.face_sample_ratio + # body_sample_ratio = cfg.body_sample_ratio + body_sample_ratio = 0.8 + ray_o_list = [] + ray_d_list = [] + rgb_list = [] + body_mask_list = [] + near_list = [] + far_list = [] + coord_list = [] + mask_at_box_list = [] + + while nsampled_rays < nrays: + n_body = int((nrays - nsampled_rays) * body_sample_ratio) + n_rand = (nrays - nsampled_rays) - n_body + + # sample rays on body + coord_body = np.argwhere(msk > 0) + + coord_body = coord_body[np.random.randint(0, len(coord_body)-1, n_body)] + + # sample rays in the bound mask + coord = np.argwhere(bound_mask > 0) + coord = coord[np.random.randint(0, len(coord), n_rand)] + + coord = np.concatenate([coord_body, coord], axis=0) + + ray_o_ = ray_o[coord[:, 0], coord[:, 1]] + ray_d_ = ray_d[coord[:, 0], coord[:, 1]] + rgb_ = img[coord[:, 0], coord[:, 1]] + body_mask_ = msk[coord[:, 0], coord[:, 1]] + + near_, far_, mask_at_box = get_near_far(bounds, ray_o_, ray_d_) + + ray_o_list.append(ray_o_[mask_at_box]) + ray_d_list.append(ray_d_[mask_at_box]) + rgb_list.append(rgb_[mask_at_box]) + body_mask_list.append(body_mask_[mask_at_box]) + near_list.append(near_) + far_list.append(far_) + coord_list.append(coord[mask_at_box]) + mask_at_box_list.append(mask_at_box[mask_at_box]) + nsampled_rays += len(near_) + + ray_o = np.concatenate(ray_o_list).astype(np.float32) + ray_d = np.concatenate(ray_d_list).astype(np.float32) + rgb = np.concatenate(rgb_list).astype(np.float32) + body_mask = (np.concatenate(body_mask_list) > 0).astype(np.float32) + near = np.concatenate(near_list).astype(np.float32) + far = np.concatenate(far_list).astype(np.float32) + coord = np.concatenate(coord_list) + mask_at_box = np.concatenate(mask_at_box_list) + else: + rgb = img.reshape(-1, 3).astype(np.float32) + body_mask = msk.reshape(-1).astype(np.float32) + ray_o = ray_o.reshape(-1, 3).astype(np.float32) + ray_d = ray_d.reshape(-1, 3).astype(np.float32) + near, far, mask_at_box = get_near_far(bounds, ray_o, ray_d) + mask_at_box = np.logical_and(mask_at_box > 0, body_mask > 0) + near = near.astype(np.float32) + far = far.astype(np.float32) + rgb = rgb[mask_at_box] + body_mask = body_mask[mask_at_box] + ray_o = ray_o[mask_at_box] + ray_d = ray_d[mask_at_box] + coord = np.argwhere(mask_at_box.reshape(H, W) == 1) + + return rgb, body_mask, ray_o, ray_d, near, far, coord, mask_at_box + + +def raw2outputs(raw, z_vals, rays_d, white_bkgd=False): + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) + + dists = z_vals[...,1:] - z_vals[...,:-1] + dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape).to(z_vals.device)], -1) # [N_rays, N_samples] + + dists = dists * torch.norm(rays_d[...,None,:], dim=-1) + + rgb = raw[...,:3] # [N_rays, N_samples, 3]A + noise = 0. + + alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] + # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) + weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)).to(z_vals.device), 1.-alpha + 1e-10], -1), -1)[:, :-1] #后面的cumprod是累乘函数,是求Ti这个积分项 + rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] C and c + + depth_map = torch.sum(weights * z_vals, -1) + disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map).to(z_vals.device), depth_map / torch.sum(weights, -1)) + acc_map = torch.sum(weights, -1) + + if white_bkgd: + rgb_map = rgb_map + (1.-acc_map[...,None]) + return rgb_map, disp_map, acc_map, weights, depth_map + + +def get_wsampling_points(ray_o, ray_d, near, far): + """ + sample pts on rays + """ + N_samples=64 + # calculate the steps for each ray + t_vals = torch.linspace(0., 1., steps=N_samples) + z_vals = near[..., None] * (1. - t_vals) + far[..., None] * t_vals + + + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat([mids, z_vals[..., -1:]], -1) + lower = torch.cat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape) + z_vals = lower + (upper - lower) * t_rand + + pts = ray_o[ :, None] + ray_d[ :, None] * z_vals[..., None] + + return pts, z_vals diff --git a/lib/net/net_util.py b/lib/net/net_util.py new file mode 100644 index 0000000000000000000000000000000000000000..14dc56c472f95bb470372c7cf59fb7c8b7f6d0df --- /dev/null +++ b/lib/net/net_util.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from torchvision import models +import torch +from torch.nn import init +import torch.nn as nn +import torch.nn.functional as F +import functools +from torch.autograd import grad + + +def gradient(inputs, outputs): + d_points = torch.ones_like(outputs, + requires_grad=False, + device=outputs.device) + points_grad = grad(outputs=outputs, + inputs=inputs, + grad_outputs=d_points, + create_graph=True, + retain_graph=True, + only_inputs=True, + allow_unused=True)[0] + return points_grad + + +# def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): +# "3x3 convolution with padding" +# return nn.Conv2d(in_planes, out_planes, kernel_size=3, +# stride=strd, padding=padding, bias=bias) + + +def conv3x3(in_planes, + out_planes, + kernel=3, + strd=1, + dilation=1, + padding=1, + bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=kernel, + dilation=dilation, + stride=strd, + padding=padding, + bias=bias) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % + init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find( + 'BatchNorm2d' + ) != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + # print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='xavier', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + net = torch.nn.DataParallel(net) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def imageSpaceRotation(xy, rot): + ''' + args: + xy: (B, 2, N) input + rot: (B, 2) x,y axis rotation angles + + rotation center will be always image center (other rotation center can be represented by additional z translation) + ''' + disp = rot.unsqueeze(2).sin().expand_as(xy) + return (disp * xy).sum(dim=1) + + +def cal_gradient_penalty(netD, + real_data, + fake_data, + device, + type='mixed', + constant=1.0, + lambda_gp=10.0): + """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 + + Arguments: + netD (network) -- discriminator network + real_data (tensor array) -- real images + fake_data (tensor array) -- generated images from the generator + device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + type (str) -- if we mix real and fake data or not [real | fake | mixed]. + constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 + lambda_gp (float) -- weight for this loss + + Returns the gradient penalty loss + """ + if lambda_gp > 0.0: + # either use real images, fake images, or a linear interpolation of two. + if type == 'real': + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1) + alpha = alpha.expand( + real_data.shape[0], + real_data.nelement() // + real_data.shape[0]).contiguous().view(*real_data.shape) + alpha = alpha.to(device) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad( + outputs=disc_interpolates, + inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, + retain_graph=True, + only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)** + 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, + affine=True, + track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, + affine=False, + track_running_stats=False) + elif norm_type == 'group': + norm_layer = functools.partial(nn.GroupNorm, 32) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % + norm_type) + return norm_layer + + +class Flatten(nn.Module): + + def forward(self, input): + return input.view(input.size(0), -1) + + +class ConvBlock(nn.Module): + + def __init__(self, in_planes, out_planes, opt): + super(ConvBlock, self).__init__() + [k, s, d, p] = opt.conv3x3 + self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d, + p) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d, + p) + + if opt.norm == 'batch': + self.bn1 = nn.BatchNorm2d(in_planes) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.bn4 = nn.BatchNorm2d(in_planes) + elif opt.norm == 'group': + self.bn1 = nn.GroupNorm(32, in_planes) + self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) + self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) + self.bn4 = nn.GroupNorm(32, in_planes) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + self.bn4, + nn.ReLU(True), + nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=1, + bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Vgg19(torch.nn.Module): + + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class VGGLoss(nn.Module): + + def __init__(self): + super(VGGLoss, self).__init__() + self.vgg = Vgg19().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], + y_vgg[i].detach()) + return loss diff --git a/lib/net/spatial.py b/lib/net/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..c85f22652a6960a941f15ca1cfff09859e05a00b --- /dev/null +++ b/lib/net/spatial.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import torch +import pytorch_lightning as pl +import numpy as np + + +class SpatialEncoder(pl.LightningModule): + + def __init__(self, + sp_level=1, + sp_type="rel_z_decay", + scale=1.0, + n_kpt=24, + sigma=0.2): + + super().__init__() + + self.sp_type = sp_type + self.sp_level = sp_level + self.n_kpt = n_kpt + self.scale = scale + self.sigma = sigma + + @staticmethod + def position_embedding(x, nlevels, scale=1.0): + """ + args: + x: (B, N, C) + return: + (B, N, C * n_levels * 2) + """ + if nlevels <= 0: + return x + vec = SpatialEncoder.pe_vector(nlevels, x.device, scale) + + B, N, _ = x.shape + y = x[:, :, None, :] * vec[None, None, :, None] + z = torch.cat((torch.sin(y), torch.cos(y)), axis=-1).view(B, N, -1) + + return torch.cat([x, z], -1) + + @staticmethod + def pe_vector(nlevels, device, scale=1.0): + v, val = [], 1 + for _ in range(nlevels): + v.append(scale * np.pi * val) + val *= 2 + return torch.from_numpy(np.asarray(v, dtype=np.float32)).to(device) + + def get_dim(self): + if self.sp_type in ["z", "rel_z", "rel_z_decay"]: + if "rel" in self.sp_type: + return (1 + 2 * self.sp_level) * self.n_kpt + else: + return 1 + 2 * self.sp_level + elif "xyz" in self.sp_type: + if "rel" in self.sp_type: + return (1 + 2 * self.sp_level) * 3 * self.n_kpt + else: + return (1 + 2 * self.sp_level) * 3 + + return 0 + + def forward(self, cxyz, kptxyz): + + B, N = cxyz.shape[:2] + K = kptxyz.shape[1] + + dz = cxyz[:, :, None, 2:3] - kptxyz[:, None, :, 2:3] + dxyz = cxyz[:, :, None] - kptxyz[:, None, :] + + # (B, N, K) + weight = torch.exp(-(dxyz**2).sum(-1) / (2.0 * (self.sigma**2))) + + # position embedding ( B, N, K * (2*n_levels+1) ) + out = self.position_embedding(dz.view(B, N, K), self.sp_level) + + # BV,N,K,(2*n_levels+1) * B,N,K,1 = B,N,K*(2*n_levels+1) -> BV,K*(2*n_levels+1),N + out = (out.view(B, N, -1, K) * weight[:, :, None]).view(B, N, -1).permute(0,2,1) + + return out + + +if __name__ == "__main__": + pts = torch.randn(2, 10000, 3).to("cuda") + kpts = torch.randn(2, 24, 3).to("cuda") + + sp_encoder = SpatialEncoder(sp_level=3, + sp_type="rel_z_decay", + scale=1.0, + n_kpt=24, + sigma=0.1).to("cuda") + out = sp_encoder(pts, kpts) + print(out.shape) diff --git a/lib/net/voxelize.py b/lib/net/voxelize.py new file mode 100644 index 0000000000000000000000000000000000000000..572e80e36371f0a5b1f7fb3a96a3ae7decc5e621 --- /dev/null +++ b/lib/net/voxelize.py @@ -0,0 +1,185 @@ +from __future__ import division, print_function +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.autograd import Function + +import voxelize_cuda + + +class VoxelizationFunction(Function): + """ + Definition of differentiable voxelization function + Currently implemented only for cuda Tensors + """ + + @staticmethod + def forward(ctx, smpl_vertices, smpl_face_center, smpl_face_normal, + smpl_vertex_code, smpl_face_code, smpl_tetrahedrons, + volume_res, sigma, smooth_kernel_size): + """ + forward pass + Output format: (batch_size, z_dims, y_dims, x_dims, channel_num) + """ + assert (smpl_vertices.size()[1] == smpl_vertex_code.size()[1]) + assert (smpl_face_center.size()[1] == smpl_face_normal.size()[1]) + assert (smpl_face_center.size()[1] == smpl_face_code.size()[1]) + ctx.batch_size = smpl_vertices.size()[0] + ctx.volume_res = volume_res + ctx.sigma = sigma + ctx.smooth_kernel_size = smooth_kernel_size + ctx.smpl_vertex_num = smpl_vertices.size()[1] + ctx.device = smpl_vertices.device + + smpl_vertices = smpl_vertices.contiguous() + smpl_face_center = smpl_face_center.contiguous() + smpl_face_normal = smpl_face_normal.contiguous() + smpl_vertex_code = smpl_vertex_code.contiguous() + smpl_face_code = smpl_face_code.contiguous() + smpl_tetrahedrons = smpl_tetrahedrons.contiguous() + + occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, + ctx.volume_res, + ctx.volume_res).fill_(0.0) + semantic_volume = torch.cuda.FloatTensor(ctx.batch_size, + ctx.volume_res, + ctx.volume_res, + ctx.volume_res, 3).fill_(0.0) + weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size, + ctx.volume_res, + ctx.volume_res, + ctx.volume_res).fill_(1e-3) + + # occ_volume [B, volume_res, volume_res, volume_res] + # semantic_volume [B, volume_res, volume_res, volume_res, 3] + # weight_sum_volume [B, volume_res, volume_res, volume_res] + + occ_volume, semantic_volume, weight_sum_volume = voxelize_cuda.forward_semantic_voxelization( + smpl_vertices, smpl_vertex_code, smpl_tetrahedrons, occ_volume, + semantic_volume, weight_sum_volume, sigma) + + return semantic_volume + + +class Voxelization(nn.Module): + """ + Wrapper around the autograd function VoxelizationFunction + """ + + def __init__(self, smpl_vertex_code, smpl_face_code, smpl_face_indices, + smpl_tetraderon_indices, volume_res, sigma, + smooth_kernel_size, batch_size, device): + super(Voxelization, self).__init__() + assert (len(smpl_face_indices.shape) == 2) + assert (len(smpl_tetraderon_indices.shape) == 2) + assert (smpl_face_indices.shape[1] == 3) + assert (smpl_tetraderon_indices.shape[1] == 4) + + self.volume_res = volume_res + self.sigma = sigma + self.smooth_kernel_size = smooth_kernel_size + self.batch_size = batch_size + self.device = device + + self.smpl_vertex_code = smpl_vertex_code + self.smpl_face_code = smpl_face_code + self.smpl_face_indices = smpl_face_indices + self.smpl_tetraderon_indices = smpl_tetraderon_indices + + def update_param(self, batch_size, smpl_tetra): + + self.batch_size = batch_size + self.smpl_tetraderon_indices = smpl_tetra + + smpl_vertex_code_batch = np.tile(self.smpl_vertex_code, + (self.batch_size, 1, 1)) + smpl_face_code_batch = np.tile(self.smpl_face_code, + (self.batch_size, 1, 1)) + smpl_face_indices_batch = np.tile(self.smpl_face_indices, + (self.batch_size, 1, 1)) + smpl_tetraderon_indices_batch = np.tile(self.smpl_tetraderon_indices, + (self.batch_size, 1, 1)) + + smpl_vertex_code_batch = torch.from_numpy( + smpl_vertex_code_batch).contiguous().to(self.device) + smpl_face_code_batch = torch.from_numpy( + smpl_face_code_batch).contiguous().to(self.device) + smpl_face_indices_batch = torch.from_numpy( + smpl_face_indices_batch).contiguous().to(self.device) + smpl_tetraderon_indices_batch = torch.from_numpy( + smpl_tetraderon_indices_batch).contiguous().to(self.device) + + self.register_buffer('smpl_vertex_code_batch', smpl_vertex_code_batch) + self.register_buffer('smpl_face_code_batch', smpl_face_code_batch) + self.register_buffer('smpl_face_indices_batch', + smpl_face_indices_batch) + self.register_buffer('smpl_tetraderon_indices_batch', + smpl_tetraderon_indices_batch) + + def forward(self, smpl_vertices): + """ + Generate semantic volumes from SMPL vertices + """ + assert (smpl_vertices.size()[0] == self.batch_size) + self.check_input(smpl_vertices) + smpl_faces = self.vertices_to_faces(smpl_vertices) + smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices) + smpl_face_center = self.calc_face_centers(smpl_faces) + smpl_face_normal = self.calc_face_normals(smpl_faces) + smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1] + smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :] + vol = VoxelizationFunction.apply(smpl_vertices_surface, + smpl_face_center, smpl_face_normal, + self.smpl_vertex_code_batch, + self.smpl_face_code_batch, + smpl_tetrahedrons, self.volume_res, + self.sigma, self.smooth_kernel_size) + return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw) + + def vertices_to_faces(self, vertices): + assert (vertices.ndimension() == 3) + bs, nv = vertices.shape[:2] + device = vertices.device + face = self.smpl_face_indices_batch + ( + torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] + vertices_ = vertices.reshape((bs * nv, 3)) + return vertices_[face.long()] + + def vertices_to_tetrahedrons(self, vertices): + assert (vertices.ndimension() == 3) + bs, nv = vertices.shape[:2] + device = vertices.device + tets = self.smpl_tetraderon_indices_batch + ( + torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] + vertices_ = vertices.reshape((bs * nv, 3)) + return vertices_[tets.long()] + + def calc_face_centers(self, face_verts): + assert len(face_verts.shape) == 4 + assert face_verts.shape[2] == 3 + assert face_verts.shape[3] == 3 + bs, nf = face_verts.shape[:2] + face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + + face_verts[:, :, 2, :]) / 3.0 + face_centers = face_centers.reshape((bs, nf, 3)) + return face_centers + + def calc_face_normals(self, face_verts): + assert len(face_verts.shape) == 4 + assert face_verts.shape[2] == 3 + assert face_verts.shape[3] == 3 + bs, nf = face_verts.shape[:2] + face_verts = face_verts.reshape((bs * nf, 3, 3)) + v10 = face_verts[:, 0] - face_verts[:, 1] + v12 = face_verts[:, 2] - face_verts[:, 1] + normals = F.normalize(torch.cross(v10, v12), eps=1e-5) + normals = normals.reshape((bs, nf, 3)) + return normals + + def check_input(self, x): + if x.device == 'cpu': + raise TypeError('Voxelization module supports only cuda tensors') + if x.type() != 'torch.cuda.FloatTensor': + raise TypeError( + 'Voxelization module supports only float32 tensors') diff --git a/lib/pixielib/__init__.py b/lib/pixielib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pixielib/models/FLAME.py b/lib/pixielib/models/FLAME.py new file mode 100644 index 0000000000000000000000000000000000000000..cad57e28eec25d514399e03486839147f56d8d00 --- /dev/null +++ b/lib/pixielib/models/FLAME.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# Using this computer program means that you agree to the terms +# in the LICENSE file included with this software distribution. +# Any use not explicitly granted by the LICENSE is prohibited. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# For comments or questions, please email us at pixie@tue.mpg.de +# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de + +import torch +import torch.nn as nn +import numpy as np +import pickle +import torch.nn.functional as F + + +class FLAMETex(nn.Module): + """ + FLAME texture: + https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 + FLAME texture converted from BFM: + https://github.com/TimoBolkart/BFM_to_FLAME + """ + + def __init__(self, config): + super(FLAMETex, self).__init__() + if config.tex_type == 'BFM': + mu_key = 'MU' + pc_key = 'PC' + n_pc = 199 + tex_path = config.tex_path + tex_space = np.load(tex_path) + texture_mean = tex_space[mu_key].reshape(1, -1) + texture_basis = tex_space[pc_key].reshape(-1, n_pc) + + elif config.tex_type == 'FLAME': + mu_key = 'mean' + pc_key = 'tex_dir' + n_pc = 200 + tex_path = config.flame_tex_path + tex_space = np.load(tex_path) + texture_mean = tex_space[mu_key].reshape(1, -1) / 255. + texture_basis = tex_space[pc_key].reshape(-1, n_pc) / 255. + else: + print('texture type ', config.tex_type, 'not exist!') + raise NotImplementedError + + n_tex = config.n_tex + num_components = texture_basis.shape[1] + texture_mean = torch.from_numpy(texture_mean).float()[None, ...] + texture_basis = torch.from_numpy( + texture_basis[:, :n_tex]).float()[None, ...] + self.register_buffer('texture_mean', texture_mean) + self.register_buffer('texture_basis', texture_basis) + + def forward(self, texcode=None): + ''' + texcode: [batchsize, n_tex] + texture: [bz, 3, 256, 256], range: 0-1 + ''' + texture = self.texture_mean + \ + (self.texture_basis*texcode[:, None, :]).sum(-1) + texture = texture.reshape(texcode.shape[0], 512, 512, + 3).permute(0, 3, 1, 2) + texture = F.interpolate(texture, [256, 256]) + texture = texture[:, [2, 1, 0], :, :] + return texture + + +def texture_flame2smplx(cached_data, flame_texture, smplx_texture): + ''' Convert flame texture map (face-only) into smplx texture map (includes body texture) + TODO: pytorch version ==> grid sample + ''' + if smplx_texture.shape[0] != smplx_texture.shape[1]: + print('SMPL-X texture not squared (%d != %d)' % + (smplx_texture[0], smplx_texture[1])) + return + if smplx_texture.shape[0] != cached_data['target_resolution']: + print( + 'SMPL-X texture size does not match cached image resolution (%d != %d)' + % (smplx_texture.shape[0], cached_data['target_resolution'])) + return + x_coords = cached_data['x_coords'] + y_coords = cached_data['y_coords'] + target_pixel_ids = cached_data['target_pixel_ids'] + source_uv_points = cached_data['source_uv_points'] + + source_tex_coords = np.zeros_like((source_uv_points)).astype(int) + source_tex_coords[:, 0] = np.clip( + flame_texture.shape[0] * (1.0 - source_uv_points[:, 1]), 0.0, + flame_texture.shape[0]).astype(int) + source_tex_coords[:, 1] = np.clip( + flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, + flame_texture.shape[1]).astype(int) + + smplx_texture[y_coords[target_pixel_ids].astype(int), + x_coords[target_pixel_ids].astype(int), :] = flame_texture[ + source_tex_coords[:, 0], source_tex_coords[:, 1]] + + return smplx_texture diff --git a/lib/pixielib/models/SMPLX.py b/lib/pixielib/models/SMPLX.py new file mode 100644 index 0000000000000000000000000000000000000000..da49d6925696782e2c255cfd75a02f31ca554c9c --- /dev/null +++ b/lib/pixielib/models/SMPLX.py @@ -0,0 +1,1039 @@ +""" +original from https://github.com/vchoutas/smplx +modified by Vassilis and Yao +""" + +import torch +import torch.nn as nn +import numpy as np +import pickle + +from .lbs import ( + Struct, + to_tensor, + to_np, + lbs, + vertices2landmarks, + JointsFromVerticesSelector, + find_dynamic_lmk_idx_and_bcoords, +) + +# SMPLX +J14_NAMES = [ + "right_ankle", + "right_knee", + "right_hip", + "left_hip", + "left_knee", + "left_ankle", + "right_wrist", + "right_elbow", + "right_shoulder", + "left_shoulder", + "left_elbow", + "left_wrist", + "neck", + "head", +] +SMPLX_names = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "jaw", + "left_eye_smplx", + "right_eye_smplx", + "left_index1", + "left_index2", + "left_index3", + "left_middle1", + "left_middle2", + "left_middle3", + "left_pinky1", + "left_pinky2", + "left_pinky3", + "left_ring1", + "left_ring2", + "left_ring3", + "left_thumb1", + "left_thumb2", + "left_thumb3", + "right_index1", + "right_index2", + "right_index3", + "right_middle1", + "right_middle2", + "right_middle3", + "right_pinky1", + "right_pinky2", + "right_pinky3", + "right_ring1", + "right_ring2", + "right_ring3", + "right_thumb1", + "right_thumb2", + "right_thumb3", + "right_eye_brow1", + "right_eye_brow2", + "right_eye_brow3", + "right_eye_brow4", + "right_eye_brow5", + "left_eye_brow5", + "left_eye_brow4", + "left_eye_brow3", + "left_eye_brow2", + "left_eye_brow1", + "nose1", + "nose2", + "nose3", + "nose4", + "right_nose_2", + "right_nose_1", + "nose_middle", + "left_nose_1", + "left_nose_2", + "right_eye1", + "right_eye2", + "right_eye3", + "right_eye4", + "right_eye5", + "right_eye6", + "left_eye4", + "left_eye3", + "left_eye2", + "left_eye1", + "left_eye6", + "left_eye5", + "right_mouth_1", + "right_mouth_2", + "right_mouth_3", + "mouth_top", + "left_mouth_3", + "left_mouth_2", + "left_mouth_1", + "left_mouth_5", + "left_mouth_4", + "mouth_bottom", + "right_mouth_4", + "right_mouth_5", + "right_lip_1", + "right_lip_2", + "lip_top", + "left_lip_2", + "left_lip_1", + "left_lip_3", + "lip_bottom", + "right_lip_3", + "right_contour_1", + "right_contour_2", + "right_contour_3", + "right_contour_4", + "right_contour_5", + "right_contour_6", + "right_contour_7", + "right_contour_8", + "contour_middle", + "left_contour_8", + "left_contour_7", + "left_contour_6", + "left_contour_5", + "left_contour_4", + "left_contour_3", + "left_contour_2", + "left_contour_1", + "head_top", + "left_big_toe", + "left_ear", + "left_eye", + "left_heel", + "left_index", + "left_middle", + "left_pinky", + "left_ring", + "left_small_toe", + "left_thumb", + "nose", + "right_big_toe", + "right_ear", + "right_eye", + "right_heel", + "right_index", + "right_middle", + "right_pinky", + "right_ring", + "right_small_toe", + "right_thumb", +] +extra_names = [ + "head_top", + "left_big_toe", + "left_ear", + "left_eye", + "left_heel", + "left_index", + "left_middle", + "left_pinky", + "left_ring", + "left_small_toe", + "left_thumb", + "nose", + "right_big_toe", + "right_ear", + "right_eye", + "right_heel", + "right_index", + "right_middle", + "right_pinky", + "right_ring", + "right_small_toe", + "right_thumb", +] +SMPLX_names += extra_names + +part_indices = {} +part_indices["body"] = np.array([ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 123, + 124, + 125, + 126, + 127, + 132, + 134, + 135, + 136, + 137, + 138, + 143, +]) +part_indices["torso"] = np.array([ + 0, + 1, + 2, + 3, + 6, + 9, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 22, + 23, + 24, + 55, + 56, + 57, + 58, + 59, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, +]) +part_indices["head"] = np.array([ + 12, + 15, + 22, + 23, + 24, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 125, + 126, + 134, + 136, + 137, +]) +part_indices["face"] = np.array([ + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, +]) +part_indices["upper"] = np.array([ + 12, + 13, + 14, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, +]) +part_indices["hand"] = np.array([ + 20, + 21, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 128, + 129, + 130, + 131, + 133, + 139, + 140, + 141, + 142, + 144, +]) +part_indices["left_hand"] = np.array([ + 20, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 128, + 129, + 130, + 131, + 133, +]) +part_indices["right_hand"] = np.array([ + 21, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 139, + 140, + 141, + 142, + 144, +]) +# kinematic tree +head_kin_chain = [15, 12, 9, 6, 3, 0] + +# --smplx joints +# 00 - Global +# 01 - L_Thigh +# 02 - R_Thigh +# 03 - Spine +# 04 - L_Calf +# 05 - R_Calf +# 06 - Spine1 +# 07 - L_Foot +# 08 - R_Foot +# 09 - Spine2 +# 10 - L_Toes +# 11 - R_Toes +# 12 - Neck +# 13 - L_Shoulder +# 14 - R_Shoulder +# 15 - Head +# 16 - L_UpperArm +# 17 - R_UpperArm +# 18 - L_ForeArm +# 19 - R_ForeArm +# 20 - L_Hand +# 21 - R_Hand +# 22 - Jaw +# 23 - L_Eye +# 24 - R_Eye + + +class SMPLX(nn.Module): + """ + Given smplx parameters, this class generates a differentiable SMPLX function + which outputs a mesh and 3D joints + """ + + def __init__(self, config): + super(SMPLX, self).__init__() + # print("creating the SMPLX Decoder") + ss = np.load(config.smplx_model_path, allow_pickle=True) + smplx_model = Struct(**ss) + + self.dtype = torch.float32 + self.register_buffer( + "faces_tensor", + to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long), + ) + # The vertices of the template model + self.register_buffer( + "v_template", + to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)) + # The shape components and expression + # expression space is the same as FLAME + shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) + shapedirs = torch.cat( + [ + shapedirs[:, :, :config.n_shape], + shapedirs[:, :, 300:300 + config.n_exp], + ], + 2, + ) + self.register_buffer("shapedirs", shapedirs) + # The pose components + num_pose_basis = smplx_model.posedirs.shape[-1] + posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T + self.register_buffer("posedirs", + to_tensor(to_np(posedirs), dtype=self.dtype)) + self.register_buffer( + "J_regressor", + to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)) + parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() + parents[0] = -1 + self.register_buffer("parents", parents) + self.register_buffer( + "lbs_weights", + to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) + # for face keypoints + self.register_buffer( + "lmk_faces_idx", + torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)) + self.register_buffer( + "lmk_bary_coords", + torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), + ) + self.register_buffer( + "dynamic_lmk_faces_idx", + torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long), + ) + self.register_buffer( + "dynamic_lmk_bary_coords", + torch.tensor(smplx_model.dynamic_lmk_bary_coords, + dtype=self.dtype), + ) + # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks + self.register_buffer("head_kin_chain", + torch.tensor(head_kin_chain, dtype=torch.long)) + + # -- initialize parameters + # shape and expression + self.register_buffer( + "shape_params", + nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), + requires_grad=False), + ) + self.register_buffer( + "expression_params", + nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), + requires_grad=False), + ) + # pose: represented as rotation matrx [number of joints, 3, 3] + self.register_buffer( + "global_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "head_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "neck_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "jaw_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "eye_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "body_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "left_hand_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), + requires_grad=False, + ), + ) + self.register_buffer( + "right_hand_pose", + nn.Parameter( + torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), + requires_grad=False, + ), + ) + + if config.extra_joint_path: + self.extra_joint_selector = JointsFromVerticesSelector( + fname=config.extra_joint_path) + self.use_joint_regressor = True + self.keypoint_names = SMPLX_names + if self.use_joint_regressor: + with open(config.j14_regressor_path, "rb") as f: + j14_regressor = pickle.load(f, encoding="latin1") + source = [] + target = [] + for idx, name in enumerate(self.keypoint_names): + if name in J14_NAMES: + source.append(idx) + target.append(J14_NAMES.index(name)) + source = np.asarray(source) + target = np.asarray(target) + self.register_buffer("source_idxs", torch.from_numpy(source)) + self.register_buffer("target_idxs", torch.from_numpy(target)) + joint_regressor = torch.from_numpy(j14_regressor).to( + dtype=torch.float32) + self.register_buffer("extra_joint_regressor", joint_regressor) + self.part_indices = part_indices + + def forward( + self, + shape_params=None, + expression_params=None, + global_pose=None, + body_pose=None, + jaw_pose=None, + eye_pose=None, + left_hand_pose=None, + right_hand_pose=None, + ): + """ + Args: + shape_params: [N, number of shape parameters] + expression_params: [N, number of expression parameters] + global_pose: pelvis pose, [N, 1, 3, 3] + body_pose: [N, 21, 3, 3] + jaw_pose: [N, 1, 3, 3] + eye_pose: [N, 2, 3, 3] + left_hand_pose: [N, 15, 3, 3] + right_hand_pose: [N, 15, 3, 3] + Returns: + vertices: [N, number of vertices, 3] + landmarks: [N, number of landmarks (68 face keypoints), 3] + joints: [N, number of smplx joints (145), 3] + """ + if shape_params is None: + batch_size = global_pose.shape[0] + shape_params = self.shape_params.expand(batch_size, -1) + else: + batch_size = shape_params.shape[0] + if expression_params is None: + expression_params = self.expression_params.expand(batch_size, -1) + if global_pose is None: + global_pose = self.global_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + if body_pose is None: + body_pose = self.body_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + if jaw_pose is None: + jaw_pose = self.jaw_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + if eye_pose is None: + eye_pose = self.eye_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + if left_hand_pose is None: + left_hand_pose = self.left_hand_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + if right_hand_pose is None: + right_hand_pose = self.right_hand_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + + shape_components = torch.cat([shape_params, expression_params], dim=1) + full_pose = torch.cat( + [ + global_pose, + body_pose, + jaw_pose, + eye_pose, + left_hand_pose, + right_hand_pose, + ], + dim=1, + ) + template_vertices = self.v_template.unsqueeze(0).expand( + batch_size, -1, -1) + # smplx + vertices, joints = lbs( + shape_components, + full_pose, + template_vertices, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + dtype=self.dtype, + pose2rot=False, + ) + # face dynamic landmarks + lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( + batch_size, -1) + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand( + batch_size, -1, -1) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( + vertices, + full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.head_kin_chain, + ) + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, lmk_bary_coords) + + final_joint_set = [joints, landmarks] + if hasattr(self, "extra_joint_selector"): + # Add any extra joints that might be needed + extra_joints = self.extra_joint_selector(vertices, + self.faces_tensor) + final_joint_set.append(extra_joints) + # Create the final joint set + joints = torch.cat(final_joint_set, dim=1) + # if self.use_joint_regressor: + # reg_joints = torch.einsum("ji,bik->bjk", + # self.extra_joint_regressor, vertices) + # joints[:, self.source_idxs] = ( + # joints[:, self.source_idxs].detach() * 0.0 + + # reg_joints[:, self.target_idxs] * 1.0) + return vertices, landmarks, joints + + def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"): + """change absolute pose to relative pose + Basic knowledge for SMPLX kinematic tree: + absolute pose = parent pose * relative pose + Here, pose must be represented as rotation matrix (batch_sizexnx3x3) + """ + if abs_joint == "head": + # Pelvis -> Spine 1, 2, 3 -> Neck -> Head + kin_chain = [15, 12, 9, 6, 3, 0] + elif abs_joint == "neck": + # Pelvis -> Spine 1, 2, 3 -> Neck -> Head + kin_chain = [12, 9, 6, 3, 0] + elif abs_joint == "right_wrist": + # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder + # -> right elbow -> right wrist + kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] + elif abs_joint == "left_wrist": + # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder + # -> Left elbow -> Left wrist + kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] + else: + raise NotImplementedError( + f"pose_abs2rel does not support: {abs_joint}") + + batch_size = global_pose.shape[0] + dtype = global_pose.dtype + device = global_pose.device + full_pose = torch.cat([global_pose, body_pose], dim=1) + rel_rot_mat = (torch.eye(3, device=device, + dtype=dtype).unsqueeze_(dim=0).repeat( + batch_size, 1, 1)) + for idx in kin_chain[1:]: + rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) + + # This contains the absolute pose of the parent + abs_parent_pose = rel_rot_mat.detach() + # Let's assume that in the input this specific joint is predicted as an absolute value + abs_joint_pose = body_pose[:, kin_chain[0] - 1] + # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head + rel_joint_pose = torch.matmul( + abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), + abs_joint_pose.reshape(-1, 3, 3), + ) + # Replace the new relative pose + body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose + return body_pose + + def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"): + """change relative pose to absolute pose + Basic knowledge for SMPLX kinematic tree: + absolute pose = parent pose * relative pose + Here, pose must be represented as rotation matrix (batch_sizexnx3x3) + """ + full_pose = torch.cat([global_pose, body_pose], dim=1) + + if abs_joint == "head": + # Pelvis -> Spine 1, 2, 3 -> Neck -> Head + kin_chain = [15, 12, 9, 6, 3, 0] + elif abs_joint == "neck": + # Pelvis -> Spine 1, 2, 3 -> Neck -> Head + kin_chain = [12, 9, 6, 3, 0] + elif abs_joint == "right_wrist": + # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder + # -> right elbow -> right wrist + kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] + elif abs_joint == "left_wrist": + # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder + # -> Left elbow -> Left wrist + kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] + else: + raise NotImplementedError( + f"pose_rel2abs does not support: {abs_joint}") + rel_rot_mat = torch.eye(3, + device=full_pose.device, + dtype=full_pose.dtype).unsqueeze_(dim=0) + for idx in kin_chain: + rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) + abs_pose = rel_rot_mat[:, None, :, :] + return abs_pose diff --git a/lib/pixielib/models/__init__.py b/lib/pixielib/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pixielib/models/encoders.py b/lib/pixielib/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..dfed3dd69392b382151e12320a0761dcabc213bb --- /dev/null +++ b/lib/pixielib/models/encoders.py @@ -0,0 +1,75 @@ +import numpy as np +import torch.nn as nn +import torch +import torch.nn.functional as F + + +class ResnetEncoder(nn.Module): + + def __init__(self, append_layers=None): + super(ResnetEncoder, self).__init__() + from . import resnet + # feature_size = 2048 + self.feature_dim = 2048 + self.encoder = resnet.load_ResNet50Model() # out: 2048 + # regressor + self.append_layers = append_layers + # for normalize input images + MEAN = [0.485, 0.456, 0.406] + STD = [0.229, 0.224, 0.225] + self.register_buffer('MEAN', torch.tensor(MEAN)[None, :, None, None]) + self.register_buffer('STD', torch.tensor(STD)[None, :, None, None]) + + def forward(self, inputs): + ''' inputs: [bz, 3, h, w], range: [0,1] + ''' + inputs = (inputs - self.MEAN) / self.STD + features = self.encoder(inputs) + if self.append_layers: + features = self.last_op(features) + return features + + +class MLP(nn.Module): + + def __init__(self, channels=[2048, 1024, 1], last_op=None): + super(MLP, self).__init__() + layers = [] + + for l in range(0, len(channels) - 1): + layers.append(nn.Linear(channels[l], channels[l + 1])) + if l < len(channels) - 2: + layers.append(nn.ReLU()) + if last_op: + layers.append(last_op) + + self.layers = nn.Sequential(*layers) + + def forward(self, inputs): + outs = self.layers(inputs) + return outs + + +class HRNEncoder(nn.Module): + + def __init__(self, append_layers=None): + super(HRNEncoder, self).__init__() + from . import hrnet + self.feature_dim = 2048 + self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048 + # regressor + self.append_layers = append_layers + # for normalize input images + MEAN = [0.485, 0.456, 0.406] + STD = [0.229, 0.224, 0.225] + self.register_buffer('MEAN', torch.tensor(MEAN)[None, :, None, None]) + self.register_buffer('STD', torch.tensor(STD)[None, :, None, None]) + + def forward(self, inputs): + ''' inputs: [bz, 3, h, w], range: [0,1] + ''' + inputs = (inputs - self.MEAN) / self.STD + features = self.encoder(inputs)['concat'] + if self.append_layers: + features = self.last_op(features) + return features diff --git a/lib/pixielib/models/hrnet.py b/lib/pixielib/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..76ee54e88ff4bae0197392ff39db1b542d1d249e --- /dev/null +++ b/lib/pixielib/models/hrnet.py @@ -0,0 +1,562 @@ +''' +borrowed from https://github.com/vchoutas/expose/blob/master/expose/models/backbone/hrnet.py +''' + +import os.path as osp +import torch +import torch.nn as nn + +from torchvision.models.resnet import Bottleneck, BasicBlock + +BN_MOMENTUM = 0.1 + + +def load_HRNet(pretrained=False): + hr_net_cfg_dict = { + 'use_old_impl': False, + 'pretrained_layers': ['*'], + 'stage1': { + 'num_modules': 1, + 'num_branches': 1, + 'num_blocks': [4], + 'num_channels': [64], + 'block': 'BOTTLENECK', + 'fuse_method': 'SUM' + }, + 'stage2': { + 'num_modules': 1, + 'num_branches': 2, + 'num_blocks': [4, 4], + 'num_channels': [48, 96], + 'block': 'BASIC', + 'fuse_method': 'SUM' + }, + 'stage3': { + 'num_modules': 4, + 'num_branches': 3, + 'num_blocks': [4, 4, 4], + 'num_channels': [48, 96, 192], + 'block': 'BASIC', + 'fuse_method': 'SUM' + }, + 'stage4': { + 'num_modules': 3, + 'num_branches': 4, + 'num_blocks': [4, 4, 4, 4], + 'num_channels': [48, 96, 192, 384], + 'block': 'BASIC', + 'fuse_method': 'SUM' + } + } + hr_net_cfg = hr_net_cfg_dict + model = HighResolutionNet(hr_net_cfg) + + return model + + +class HighResolutionModule(nn.Module): + + def __init__(self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, + num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample(scale_factor=2**(j - i), + mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + self.inplanes = 64 + super(HighResolutionNet, self).__init__() + use_old_impl = cfg.get('use_old_impl') + self.use_old_impl = use_old_impl + + # stem net + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg.get('stage1', {}) + num_channels = self.stage1_cfg['num_channels'][0] + block = blocks_dict[self.stage1_cfg['block']] + num_blocks = self.stage1_cfg['num_blocks'][0] + self.layer1 = self._make_layer(block, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg.get('stage2', {}) + num_channels = self.stage2_cfg.get('num_channels', (32, 64)) + block = blocks_dict[self.stage2_cfg.get('block')] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + stage2_num_channels = num_channels + self.transition1 = self._make_transition_layer([stage1_out_channel], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg.get('stage3') + num_channels = self.stage3_cfg['num_channels'] + block = blocks_dict[self.stage3_cfg['block']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + stage3_num_channels = num_channels + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg.get('stage4') + num_channels = self.stage4_cfg['num_channels'] + block = blocks_dict[self.stage4_cfg['block']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + stage_4_out_channels = num_channels + + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, + num_channels, + multi_scale_output=not self.use_old_impl) + stage4_num_channels = num_channels + + self.output_channels_dim = pre_stage_channels + + self.pretrained_layers = cfg['pretrained_layers'] + self.init_weights() + + self.avg_pooling = nn.AdaptiveAvgPool2d(1) + + if use_old_impl: + in_dims = (2**2 * stage2_num_channels[-1] + + 2**1 * stage3_num_channels[-1] + + stage_4_out_channels[-1]) + else: + # TODO: Replace with parameters + in_dims = 4 * 384 + self.subsample_4 = self._make_subsample_layer( + in_channels=stage4_num_channels[0], num_layers=3) + + self.subsample_3 = self._make_subsample_layer( + in_channels=stage2_num_channels[-1], num_layers=2) + self.subsample_2 = self._make_subsample_layer( + in_channels=stage3_num_channels[-1], num_layers=1) + self.conv_layers = self._make_conv_layer(in_channels=in_dims, + num_layers=5) + + def get_output_dim(self): + base_output = { + f'layer{idx + 1}': val + for idx, val in enumerate(self.output_channels_dim) + } + output = base_output.copy() + for key in base_output: + output[f'{key}_avg_pooling'] = output[key] + output['concat'] = 2048 + return output + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, + outchannels, + 3, + 2, + 1, + bias=False), nn.BatchNorm2d(outchannels), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_conv_layer(self, + in_channels=2048, + num_layers=3, + num_filters=2048, + stride=1): + + layers = [] + for i in range(num_layers): + + downsample = nn.Conv2d(in_channels, + num_filters, + stride=1, + kernel_size=1, + bias=False) + layers.append( + Bottleneck(in_channels, + num_filters // 4, + downsample=downsample)) + in_channels = num_filters + + return nn.Sequential(*layers) + + def _make_subsample_layer(self, in_channels=96, num_layers=3, stride=2): + + layers = [] + for i in range(num_layers): + + layers.append( + nn.Conv2d(in_channels=in_channels, + out_channels=2 * in_channels, + kernel_size=3, + stride=stride, + padding=1)) + in_channels = 2 * in_channels + layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + + def _make_stage(self, + layer_config, + num_inchannels, + multi_scale_output=True, + log=False): + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = blocks_dict[layer_config['block']] + fuse_method = layer_config['fuse_method'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output)) + modules[-1].log = log + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + if i < self.stage2_cfg['num_branches']: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + if i < self.stage3_cfg['num_branches']: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + if not self.use_old_impl: + y_list = self.stage4(x_list) + + output = {} + for idx, x in enumerate(y_list): + output[f'layer{idx + 1}'] = x + + feat_list = [] + if self.use_old_impl: + x3 = self.subsample_3(x_list[1]) + x2 = self.subsample_2(x_list[2]) + x1 = x_list[3] + feat_list = [x3, x2, x1] + else: + x4 = self.subsample_4(y_list[0]) + x3 = self.subsample_3(y_list[1]) + x2 = self.subsample_2(y_list[2]) + x1 = y_list[3] + feat_list = [x4, x3, x2, x1] + + xf = self.conv_layers(torch.cat(feat_list, dim=1)) + xf = xf.mean(dim=(2, 3)) + xf = xf.view(xf.size(0), -1) + output['concat'] = xf + # y_list = self.stage4(x_list) + # output['stage4'] = y_list[0] + # output['stage4_avg_pooling'] = self.avg_pooling(y_list[0]).view( + # *y_list[0].shape[:2]) + + # concat_outputs = y_list + x_list + # output['concat'] = torch.cat([ + # self.avg_pooling(tensor).view(*tensor.shape[:2]) + # for tensor in concat_outputs], + # dim=1) + + return output + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + + def load_weights(self, pretrained=''): + pretrained = osp.expandvars(pretrained) + if osp.isfile(pretrained): + pretrained_state_dict = torch.load( + pretrained, map_location=torch.device("cpu")) + + need_init_state_dict = {} + for name, m in pretrained_state_dict.items(): + if (name.split('.')[0] in self.pretrained_layers + or self.pretrained_layers[0] == '*'): + need_init_state_dict[name] = m + missing, unexpected = self.load_state_dict(need_init_state_dict, + strict=False) + elif pretrained: + raise ValueError('{} is not exist!'.format(pretrained)) diff --git a/lib/pixielib/models/lbs.py b/lib/pixielib/models/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..df496d880602d046a98bf3cf651144614f4f9cc7 --- /dev/null +++ b/lib/pixielib/models/lbs.py @@ -0,0 +1,466 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np +import os +import yaml +import torch +import torch.nn.functional as F +from torch import nn + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) + + +def find_dynamic_lmk_idx_and_bcoords(vertices, + pose, + dynamic_lmk_faces_idx, + dynamic_lmk_b_coords, + head_kin_chain, + dtype=torch.float32): + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + head_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + batch_size = vertices.shape[0] + pose = pose.detach() + # aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + # head_kin_chain) + # rot_mats = batch_rodrigues( + # aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) + rot_mats = torch.index_select(pose, 1, head_kin_chain) + + rel_rot_mat = torch.eye(3, device=vertices.device, + dtype=dtype).unsqueeze_(dim=0) + for idx in range(len(head_kin_chain)): + # rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + # print(y_rot_angle[0]) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) + # print(y_rot_angle[0]) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, + y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( + batch_size, -1, 3) + + lmk_faces += torch.arange(batch_size, dtype=torch.long, + device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def lbs(betas, + pose, + v_template, + shapedirs, + posedirs, + J_regressor, + parents, + lbs_weights, + pose2rot=True, + dtype=torch.float32): + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device = betas.device + + # Add shape contribution + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + + # 3. Add pose blend shapes + # N x J x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + if pose2rot: + rot_mats = batch_rodrigues(pose.view(-1, 3), + dtype=dtype).view([batch_size, -1, 3, 3]) + + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + # (N x P) x (P, V * 3) -> N x V x 3 + pose_offsets = torch.matmul(pose_feature, posedirs) \ + .view(batch_size, -1, 3) + else: + pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident + rot_mats = pose.view(batch_size, -1, 3, 3) + + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, + device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + return verts, J_transformed + + +def vertices2joints(J_regressor, vertices): + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas, shape_disps): + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def transform_mat(R, t): + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], + dim=2) + + +def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape( + -1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, + i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + # # The last column of the transformations contains the posed joints + # posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms + + +class JointsFromVerticesSelector(nn.Module): + + def __init__(self, fname): + ''' Selects extra joints from vertices + ''' + super(JointsFromVerticesSelector, self).__init__() + + err_msg = ('Either pass a filename or triangle face ids, names and' + ' barycentrics') + assert fname is not None or (face_ids is not None and bcs is not None + and names is not None), err_msg + if fname is not None: + fname = os.path.expanduser(os.path.expandvars(fname)) + with open(fname, 'r') as f: + data = yaml.safe_load(f) + names = list(data.keys()) + bcs = [] + face_ids = [] + for name, d in data.items(): + face_ids.append(d['face']) + bcs.append(d['bc']) + bcs = np.array(bcs, dtype=np.float32) + face_ids = np.array(face_ids, dtype=np.int32) + assert len(bcs) == len(face_ids), ( + 'The number of barycentric coordinates must be equal to the faces') + assert len(names) == len(face_ids), ( + 'The number of names must be equal to the number of ') + + self.names = names + self.register_buffer('bcs', torch.tensor(bcs, dtype=torch.float32)) + self.register_buffer('face_ids', + torch.tensor(face_ids, dtype=torch.long)) + + def extra_joint_names(self): + ''' Returns the names of the extra joints + ''' + return self.names + + def forward(self, vertices, faces): + if len(self.face_ids) < 1: + return [] + vertex_ids = faces[self.face_ids].reshape(-1) + # Should be BxNx3x3 + triangles = torch.index_select(vertices, 1, vertex_ids).reshape( + -1, len(self.bcs), 3, 3) + return (triangles * self.bcs[None, :, :, None]).sum(dim=2) + + +# def to_tensor(array, dtype=torch.float32): +# if torch.is_tensor(array): +# return array +# else: +# return torch.tensor(array, dtype=dtype) + + +def to_tensor(array, dtype=torch.float32): + if 'torch.tensor' not in str(type(array)): + return torch.tensor(array, dtype=dtype) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +class Struct(object): + + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) diff --git a/lib/pixielib/models/moderators.py b/lib/pixielib/models/moderators.py new file mode 100644 index 0000000000000000000000000000000000000000..00bed7fac1cc1ad5153e08f65aea37a20d37c7fc --- /dev/null +++ b/lib/pixielib/models/moderators.py @@ -0,0 +1,110 @@ +''' Moderator +# Input feature: body, part(head, hand) +# output: fused feature, weight +''' +import numpy as np +import torch.nn as nn +import torch +import torch.nn.functional as F + +# MLP + temperature softmax +# w = SoftMax(w^\prime * temperature) + + +class TempSoftmaxFusion(nn.Module): + + def __init__(self, + channels=[2048 * 2, 1024, 1], + detach_inputs=False, + detach_feature=False): + super(TempSoftmaxFusion, self).__init__() + self.detach_inputs = detach_inputs + self.detach_feature = detach_feature + # weight + layers = [] + for l in range(0, len(channels) - 1): + layers.append(nn.Linear(channels[l], channels[l + 1])) + if l < len(channels) - 2: + layers.append(nn.ReLU()) + self.layers = nn.Sequential(*layers) + # temperature + self.register_parameter('temperature', nn.Parameter(torch.ones(1))) + + def forward(self, x, y, work=True): + ''' + x: feature from body + y: feature from part(head/hand) + work: whether to fuse features + ''' + if work: + # 1. cat input feature, predict the weights + f_in = torch.cat([x, y], dim=1) + if self.detach_inputs: + f_in = f_in.detach() + f_temp = self.layers(f_in) + f_weight = F.softmax(f_temp * self.temperature, dim=1) + + # 2. feature fusion + if self.detach_feature: + x = x.detach() + y = y.detach() + f_out = f_weight[:, [0]] * x + f_weight[:, [1]] * y + x_out = f_out + y_out = f_out + else: + x_out = x + y_out = y + f_weight = None + return x_out, y_out, f_weight + + +# MLP + Gumbel-Softmax trick +# w = w^{\prime} - w^{\prime}\text{.detach()} + w^{\prime}\text{.gt(0.5)} + + +class GumbelSoftmaxFusion(nn.Module): + + def __init__(self, + channels=[2048 * 2, 1024, 1], + detach_inputs=False, + detach_feature=False): + super(GumbelSoftmaxFusion, self).__init__() + self.detach_inputs = detach_inputs + self.detach_feature = detach_feature + + # weight + layers = [] + for l in range(0, len(channels) - 1): + layers.append(nn.Linear(channels[l], channels[l + 1])) + if l < len(channels) - 2: + layers.append(nn.ReLU()) + layers.append(nn.Softmax()) + self.layers = nn.Sequential(*layers) + + def forward(self, x, y, work=True): + ''' + x: feature from body + y: feature from part(head/hand) + work: whether to fuse features + ''' + if work: + # 1. cat input feature, predict the weights + f_in = torch.cat([x, y], dim=-1) + if self.detach_inputs: + f_in = f_in.detach() + f_weight = self.layers(f_in) + # weight to be hard + f_weight = f_weight - f_weight.detach() + f_weight.gt(0.5) + + # 2. feature fusion + if self.detach_feature: + x = x.detach() + y = y.detach() + f_out = f_weight[:, [0]] * x + f_weight[:, [1]] * y + x_out = f_out + y_out = f_out + else: + x_out = x + y_out = y + f_weight = None + return x_out, y_out, f_weight diff --git a/lib/pixielib/models/resnet.py b/lib/pixielib/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ef00ceb8ba9bc04ed8690203f6d267109d8bca3a --- /dev/null +++ b/lib/pixielib/models/resnet.py @@ -0,0 +1,325 @@ +""" +Author: Soubhik Sanyal +Copyright (c) 2019, Soubhik Sanyal +All rights reserved. +Loads different resnet models +""" +''' + file: Resnet.py + date: 2018_05_02 + author: zhangxiong(1025679612@qq.com) + mark: copied from pytorch source code +''' + +import torch.nn as nn +import torch.nn.functional as F +import torch +from torch.nn.parameter import Parameter +import torch.optim as optim +import numpy as np +import math +import torchvision + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x1 = self.layer4(x) + + x2 = self.avgpool(x1) + x2 = x2.view(x2.size(0), -1) + # x = self.fc(x) + # x2: [bz, 2048] for shape + # x1: [bz, 2048, 7, 7] for texture + return x2 + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +def copy_parameter_from_resnet(model, resnet_dict): + cur_state_dict = model.state_dict() + # import ipdb; ipdb.set_trace() + for name, param in list(resnet_dict.items())[0:None]: + if name not in cur_state_dict: + # print(name, ' not available in reconstructed resnet') + continue + if isinstance(param, Parameter): + param = param.data + try: + cur_state_dict[name].copy_(param) + except: + # print(name, ' is inconsistent!') + continue + # print('copy resnet state dict finished!') + + +def load_ResNet50Model(): + model = ResNet(Bottleneck, [3, 4, 6, 3]) + copy_parameter_from_resnet( + model, + torchvision.models.resnet50(pretrained=True).state_dict()) + return model + + +def load_ResNet101Model(): + model = ResNet(Bottleneck, [3, 4, 23, 3]) + copy_parameter_from_resnet( + model, + torchvision.models.resnet101(pretrained=True).state_dict()) + return model + + +def load_ResNet152Model(): + model = ResNet(Bottleneck, [3, 8, 36, 3]) + copy_parameter_from_resnet( + model, + torchvision.models.resnet152(pretrained=True).state_dict()) + return model + + +# model.load_state_dict(checkpoint['model_state_dict']) + +# Unet + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, + mode='bilinear', + align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels // 2, + in_channels // 2, + kernel_size=2, + stride=2) + + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad( + x1, + [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + + +class UNet(nn.Module): + + def __init__(self, n_channels, n_classes, bilinear=True): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + self.down4 = Down(512, 512) + self.up1 = Up(1024, 256, bilinear) + self.up2 = Up(512, 128, bilinear) + self.up3 = Up(256, 64, bilinear) + self.up4 = Up(128, 64, bilinear) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = F.normalize(x) + return x diff --git a/lib/pixielib/pixie.py b/lib/pixielib/pixie.py new file mode 100644 index 0000000000000000000000000000000000000000..556c2a4c3fd115fe501c8dfb1b41fc48a2907a85 --- /dev/null +++ b/lib/pixielib/pixie.py @@ -0,0 +1,611 @@ +# -*- coding: utf-8 -*- +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# Using this computer program means that you agree to the terms +# in the LICENSE file included with this software distribution. +# Any use not explicitly granted by the LICENSE is prohibited. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# For comments or questions, please email us at pixie@tue.mpg.de +# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de + +import os +import torch +import torchvision +import torch.nn.functional as F +import torch.nn as nn + +import numpy as np +from skimage.io import imread +import cv2 + +from .models.encoders import ResnetEncoder, MLP, HRNEncoder +from .models.moderators import TempSoftmaxFusion +from .models.SMPLX import SMPLX +from .utils import util +from .utils import rotation_converter as converter +from .utils import tensor_cropper +from .utils.config import cfg + + +class PIXIE(object): + + def __init__(self, config=None, device="cuda:0"): + if config is None: + self.cfg = cfg + else: + self.cfg = config + + self.device = device + # parameters setting + self.param_list_dict = {} + for lst in self.cfg.params.keys(): + param_list = cfg.params.get(lst) + self.param_list_dict[lst] = { + i: cfg.model.get("n_" + i) + for i in param_list + } + + # Build the models + self._create_model() + # Set up the cropping modules used to generate face/hand crops from the body predictions + self._setup_cropper() + + def forward(self, data): + + # encode + decode + param_dict = self.encode( + {"body": { + "image": data + }}, + threthold=True, + keep_local=True, + copy_and_paste=False, + ) + opdict = self.decode(param_dict["body"], param_type="body") + + return opdict + + def _setup_cropper(self): + self.Cropper = {} + for crop_part in ["head", "hand"]: + data_cfg = self.cfg.dataset[crop_part] + scale_size = (data_cfg.scale_min + data_cfg.scale_max) * 0.5 + self.Cropper[crop_part] = tensor_cropper.Cropper( + crop_size=data_cfg.image_size, + scale=[scale_size, scale_size], + trans_scale=0, + ) + + def _create_model(self): + self.model_dict = {} + # Build all image encoders + # Hand encoder only works for right hand, for left hand, flip inputs and flip the results back + self.Encoder = {} + for key in self.cfg.network.encoder.keys(): + if self.cfg.network.encoder.get(key).type == "resnet50": + self.Encoder[key] = ResnetEncoder().to(self.device) + elif self.cfg.network.encoder.get(key).type == "hrnet": + self.Encoder[key] = HRNEncoder().to(self.device) + self.model_dict[f"Encoder_{key}"] = self.Encoder[key].state_dict() + + # Build the parameter regressors + self.Regressor = {} + for key in self.cfg.network.regressor.keys(): + n_output = sum(self.param_list_dict[f"{key}_list"].values()) + channels = ([2048] + self.cfg.network.regressor.get(key).channels + + [n_output]) + if self.cfg.network.regressor.get(key).type == "mlp": + self.Regressor[key] = MLP(channels=channels).to(self.device) + self.model_dict[f"Regressor_{key}"] = self.Regressor[ + key].state_dict() + + # Build the extractors + # to extract separate head/left hand/right hand feature from body feature + self.Extractor = {} + for key in self.cfg.network.extractor.keys(): + channels = [ + 2048 + ] + self.cfg.network.extractor.get(key).channels + [2048] + if self.cfg.network.extractor.get(key).type == "mlp": + self.Extractor[key] = MLP(channels=channels).to(self.device) + self.model_dict[f"Extractor_{key}"] = self.Extractor[ + key].state_dict() + + # Build the moderators + self.Moderator = {} + for key in self.cfg.network.moderator.keys(): + share_part = key.split("_")[0] + detach_inputs = self.cfg.network.moderator.get(key).detach_inputs + detach_feature = self.cfg.network.moderator.get(key).detach_feature + channels = [2048 * 2 + ] + self.cfg.network.moderator.get(key).channels + [2] + self.Moderator[key] = TempSoftmaxFusion( + detach_inputs=detach_inputs, + detach_feature=detach_feature, + channels=channels, + ).to(self.device) + self.model_dict[f"Moderator_{key}"] = self.Moderator[ + key].state_dict() + + # Build the SMPL-X body model, which we also use to represent faces and + # hands, using the relevant parts only + self.smplx = SMPLX(self.cfg.model).to(self.device) + self.part_indices = self.smplx.part_indices + + # -- resume model + model_path = self.cfg.pretrained_modelpath + if os.path.exists(model_path): + checkpoint = torch.load(model_path, map_location=self.device) + for key in self.model_dict.keys(): + util.copy_state_dict(self.model_dict[key], checkpoint[key]) + else: + print(f"pixie trained model path: {model_path} does not exist!") + exit() + # eval mode + for module in [ + self.Encoder, self.Regressor, self.Moderator, self.Extractor + ]: + for net in module.values(): + net.eval() + + def decompose_code(self, code, num_dict): + """Convert a flattened parameter vector to a dictionary of parameters""" + code_dict = {} + start = 0 + for key in num_dict: + end = start + int(num_dict[key]) + code_dict[key] = code[:, start:end] + start = end + return code_dict + + def part_from_body(self, image, part_key, points_dict, crop_joints=None): + """crop part(head/left_hand/right_hand) out from body data, joints also change accordingly""" + assert part_key in ["head", "left_hand", "right_hand"] + assert "smplx_kpt" in points_dict.keys() + if part_key == "head": + # use face 68 kpts for cropping head image + indices_key = "face" + elif part_key == "left_hand": + indices_key = "left_hand" + elif part_key == "right_hand": + indices_key = "right_hand" + + # get points for cropping + part_indices = self.part_indices[indices_key] + if crop_joints is not None: + points_for_crop = crop_joints[:, part_indices] + else: + points_for_crop = points_dict["smplx_kpt"][:, part_indices] + + # crop + cropper_key = "hand" if "hand" in part_key else part_key + points_scale = image.shape[-2:] + cropped_image, tform = self.Cropper[cropper_key].crop( + image, points_for_crop, points_scale) + # transform points(must be normalized to [-1.1]) accordingly + cropped_points_dict = {} + for points_key in points_dict.keys(): + points = points_dict[points_key] + cropped_points = self.Cropper[cropper_key].transform_points( + points, tform, points_scale, normalize=True) + cropped_points_dict[points_key] = cropped_points + return cropped_image, cropped_points_dict + + @torch.no_grad() + def encode( + self, + data, + threthold=True, + keep_local=True, + copy_and_paste=False, + body_only=False, + ): + """Encode images to smplx parameters + Args: + data: dict + key: image_type (body/head/hand) + value: + image: [bz, 3, 224, 224], range [0,1] + image_hd(needed if key==body): a high res version of image, only for cropping parts from body image + head_image: optinal, well-cropped head from body image + left_hand_image: optinal, well-cropped left hand from body image + right_hand_image: optinal, well-cropped right hand from body image + Returns: + param_dict: dict + key: image_type (body/head/hand) + value: param_dict + """ + for key in data.keys(): + assert key in ["body", "head", "hand"] + + feature = {} + param_dict = {} + + # Encode features + for key in data.keys(): + part = key + # encode feature + feature[key] = {} + feature[key][part] = self.Encoder[part](data[key]["image"]) + + # for head/hand image + if key == "head" or key == "hand": + # predict head/hand-only parameters from part feature + part_dict = self.decompose_code( + self.Regressor[part](feature[key][part]), + self.param_list_dict[f"{part}_list"], + ) + # if input is part data, skip feature fusion: share feature is the same as part feature + # then predict share parameters + feature[key][f"{key}_share"] = feature[key][key] + share_dict = self.decompose_code( + self.Regressor[f"{part}_share"]( + feature[key][f"{part}_share"]), + self.param_list_dict[f"{part}_share_list"], + ) + # compose parameters + param_dict[key] = {**share_dict, **part_dict} + + # for body image + if key == "body": + fusion_weight = {} + f_body = feature["body"]["body"] + # extract part feature + for part_name in ["head", "left_hand", "right_hand"]: + feature["body"][f"{part_name}_share"] = self.Extractor[ + f"{part_name}_share"](f_body) + + # -- check if part crops are given, if not, crop parts by coarse body estimation + if ("head_image" not in data[key].keys() + or "left_hand_image" not in data[key].keys() + or "right_hand_image" not in data[key].keys()): + # - run without fusion to get coarse estimation, for cropping parts + # body only + body_dict = self.decompose_code( + self.Regressor[part](feature[key][part]), + self.param_list_dict[part + "_list"], + ) + # head share + head_share_dict = self.decompose_code( + self.Regressor["head" + "_share"]( + feature[key]["head" + "_share"]), + self.param_list_dict["head" + "_share_list"], + ) + # right hand share + right_hand_share_dict = self.decompose_code( + self.Regressor["hand" + "_share"]( + feature[key]["right_hand" + "_share"]), + self.param_list_dict["hand" + "_share_list"], + ) + # left hand share + left_hand_share_dict = self.decompose_code( + self.Regressor["hand" + "_share"]( + feature[key]["left_hand" + "_share"]), + self.param_list_dict["hand" + "_share_list"], + ) + # change the dict name from right to left + left_hand_share_dict[ + "left_hand_pose"] = left_hand_share_dict.pop( + "right_hand_pose") + left_hand_share_dict[ + "left_wrist_pose"] = left_hand_share_dict.pop( + "right_wrist_pose") + param_dict[key] = { + **body_dict, + **head_share_dict, + **left_hand_share_dict, + **right_hand_share_dict, + } + if body_only: + param_dict["moderator_weight"] = None + return param_dict + prediction_body_only = self.decode(param_dict[key], + param_type="body") + # crop + for part_name in ["head", "left_hand", "right_hand"]: + part = part_name.split("_")[-1] + points_dict = { + "smplx_kpt": + prediction_body_only["smplx_kpt"], + "trans_verts": + prediction_body_only["transformed_vertices"], + } + image_hd = torchvision.transforms.Resize(1024)( + data["body"]["image"]) + cropped_image, cropped_joints_dict = self.part_from_body( + image_hd, part_name, points_dict) + data[key][part_name + "_image"] = cropped_image + + # -- encode features from part crops, then fuse feature using the weight from moderator + for part_name in ["head", "left_hand", "right_hand"]: + part = part_name.split("_")[-1] + cropped_image = data[key][part_name + "_image"] + # if left hand, flip it as if it is right hand + if part_name == "left_hand": + cropped_image = torch.flip(cropped_image, dims=(-1, )) + # run part regressor + f_part = self.Encoder[part](cropped_image) + part_dict = self.decompose_code( + self.Regressor[part](f_part), + self.param_list_dict[f"{part}_list"], + ) + part_share_dict = self.decompose_code( + self.Regressor[f"{part}_share"](f_part), + self.param_list_dict[f"{part}_share_list"], + ) + param_dict["body_" + part_name] = { + **part_dict, + **part_share_dict + } + + # moderator to assign weight, then integrate features + f_body_out, f_part_out, f_weight = self.Moderator[ + f"{part}_share"](feature["body"][f"{part_name}_share"], + f_part, + work=True) + if copy_and_paste: + # copy and paste strategy always trusts the results from part + feature["body"][f"{part_name}_share"] = f_part + elif threthold and part == "hand": + # for hand, if part weight > 0.7 (very confident, then fully trust part) + part_w = f_weight[:, [1]] + part_w[part_w > 0.7] = 1.0 + f_body_out = (feature["body"][f"{part_name}_share"] * + (1.0 - part_w) + f_part * part_w) + feature["body"][f"{part_name}_share"] = f_body_out + else: + feature["body"][f"{part_name}_share"] = f_body_out + fusion_weight[part_name] = f_weight + # save weights from moderator, that can be further used for optimization/running specific tasks on parts + param_dict["moderator_weight"] = fusion_weight + + # -- predict parameters from fused body feature + # head share + head_share_dict = self.decompose_code( + self.Regressor["head" + "_share"](feature[key]["head" + + "_share"]), + self.param_list_dict["head" + "_share_list"], + ) + # right hand share + right_hand_share_dict = self.decompose_code( + self.Regressor["hand" + "_share"]( + feature[key]["right_hand" + "_share"]), + self.param_list_dict["hand" + "_share_list"], + ) + # left hand share + left_hand_share_dict = self.decompose_code( + self.Regressor["hand" + "_share"]( + feature[key]["left_hand" + "_share"]), + self.param_list_dict["hand" + "_share_list"], + ) + # change the dict name from right to left + left_hand_share_dict[ + "left_hand_pose"] = left_hand_share_dict.pop( + "right_hand_pose") + left_hand_share_dict[ + "left_wrist_pose"] = left_hand_share_dict.pop( + "right_wrist_pose") + param_dict["body"] = { + **body_dict, + **head_share_dict, + **left_hand_share_dict, + **right_hand_share_dict, + } + # copy tex param from head param dict to body param dict + param_dict["body"]["tex"] = param_dict["body_head"]["tex"] + param_dict["body"]["light"] = param_dict["body_head"]["light"] + + if keep_local: + # for local change that will not affect whole body and produce unnatral pose, trust part + param_dict[key]["exp"] = param_dict["body_head"]["exp"] + param_dict[key]["right_hand_pose"] = param_dict[ + "body_right_hand"]["right_hand_pose"] + param_dict[key]["left_hand_pose"] = param_dict[ + "body_left_hand"]["right_hand_pose"] + + return param_dict + + def convert_pose(self, param_dict, param_type): + """Convert pose parameters to rotation matrix + Args: + param_dict: smplx parameters + param_type: should be one of body/head/hand + Returns: + param_dict: smplx parameters + """ + assert param_type in ["body", "head", "hand"] + + # convert pose representations: the output from network are continous repre or axis angle, + # while the input pose for smplx need to be rotation matrix + for key in param_dict: + if "pose" in key and "jaw" not in key: + param_dict[key] = converter.batch_cont2matrix(param_dict[key]) + if param_type == "body" or param_type == "head": + param_dict["jaw_pose"] = converter.batch_euler2matrix( + param_dict["jaw_pose"])[:, None, :, :] + + # complement params if it's not in given param dict + if param_type == "head": + batch_size = param_dict["shape"].shape[0] + param_dict["abs_head_pose"] = param_dict["head_pose"].clone() + param_dict["global_pose"] = param_dict["head_pose"] + param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze( + 0).expand( + batch_size, -1, -1, + -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]] + param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze( + 0).expand(batch_size, -1, -1, -1) + param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze( + 0).expand(batch_size, -1, -1, -1) + param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze( + 0).expand(batch_size, -1, -1, -1) + param_dict[ + "right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze( + 0).expand(batch_size, -1, -1, -1) + elif param_type == "hand": + batch_size = param_dict["right_hand_pose"].shape[0] + param_dict["abs_right_wrist_pose"] = param_dict[ + "right_wrist_pose"].clone() + dtype = param_dict["right_hand_pose"].dtype + device = param_dict["right_hand_pose"].device + x_180_pose = (torch.eye(3, dtype=dtype, + device=device).unsqueeze(0).repeat( + 1, 1, 1)) + x_180_pose[0, 2, 2] = -1.0 + x_180_pose[0, 1, 1] = -1.0 + param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + param_dict["shape"] = self.smplx.shape_params.expand( + batch_size, -1) + param_dict["exp"] = self.smplx.expression_params.expand( + batch_size, -1) + param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1) + param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze( + 0).expand( + batch_size, -1, -1, + -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]] + param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze( + 0).expand(batch_size, -1, -1, -1) + param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze( + 0).expand(batch_size, -1, -1, -1) + elif param_type == "body": + # the predcition from the head and hand share regressor is always absolute pose + batch_size = param_dict["shape"].shape[0] + param_dict["abs_head_pose"] = param_dict["head_pose"].clone() + param_dict["abs_right_wrist_pose"] = param_dict[ + "right_wrist_pose"].clone() + param_dict["abs_left_wrist_pose"] = param_dict[ + "left_wrist_pose"].clone() + # the body-hand share regressor is working for right hand + # so we assume body network get the flipped feature for the left hand. then get the parameters + # then we need to flip it back to left, which matches the input left hand + param_dict["left_wrist_pose"] = util.flip_pose( + param_dict["left_wrist_pose"]) + param_dict["left_hand_pose"] = util.flip_pose( + param_dict["left_hand_pose"]) + else: + exit() + + return param_dict + + def decode(self, param_dict, param_type): + """Decode model parameters to smplx vertices & joints & texture + Args: + param_dict: smplx parameters + param_type: should be one of body/head/hand + Returns: + predictions: smplx predictions + """ + if "jaw_pose" in param_dict.keys() and len( + param_dict["jaw_pose"].shape) == 2: + self.convert_pose(param_dict, param_type) + elif param_dict["right_wrist_pose"].shape[-1] == 6: + self.convert_pose(param_dict, param_type) + + # concatenate body pose + partbody_pose = param_dict["partbody_pose"] + param_dict["body_pose"] = torch.cat( + [ + partbody_pose[:, :11], + param_dict["neck_pose"], + partbody_pose[:, 11:11 + 2], + param_dict["head_pose"], + partbody_pose[:, 13:13 + 4], + param_dict["left_wrist_pose"], + param_dict["right_wrist_pose"], + ], + dim=1, + ) + + # change absolute head&hand pose to relative pose according to rest body pose + if param_type == "head" or param_type == "body": + param_dict["body_pose"] = self.smplx.pose_abs2rel( + param_dict["global_pose"], + param_dict["body_pose"], + abs_joint="head") + if param_type == "hand" or param_type == "body": + param_dict["body_pose"] = self.smplx.pose_abs2rel( + param_dict["global_pose"], + param_dict["body_pose"], + abs_joint="left_wrist", + ) + param_dict["body_pose"] = self.smplx.pose_abs2rel( + param_dict["global_pose"], + param_dict["body_pose"], + abs_joint="right_wrist", + ) + + if self.cfg.model.check_pose: + # check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose) + # xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left) + for pose_ind in [14]: # head [15-1, 20-1, 21-1]: + curr_pose = param_dict["body_pose"][:, pose_ind] + euler_pose = converter._compute_euler_from_matrix(curr_pose) + for i, max_angle in enumerate([20, 70, 10]): + euler_pose_curr = euler_pose[:, i] + euler_pose_curr[euler_pose_curr != torch.clamp( + euler_pose_curr, + min=-max_angle * np.pi / 180, + max=max_angle * np.pi / 180, + )] = 0.0 + param_dict[ + "body_pose"][:, pose_ind] = converter.batch_euler2matrix( + euler_pose) + + # SMPLX + verts, landmarks, joints = self.smplx( + shape_params=param_dict["shape"], + expression_params=param_dict["exp"], + global_pose=param_dict["global_pose"], + body_pose=param_dict["body_pose"], + jaw_pose=param_dict["jaw_pose"], + left_hand_pose=param_dict["left_hand_pose"], + right_hand_pose=param_dict["right_hand_pose"], + ) + smplx_kpt3d = joints.clone() + + # projection + cam = param_dict[param_type + "_cam"] + trans_verts = util.batch_orth_proj(verts, cam) + predicted_landmarks = util.batch_orth_proj(landmarks, cam)[:, :, :2] + predicted_joints = util.batch_orth_proj(joints, cam)[:, :, :2] + + prediction = { + "vertices": verts, + "transformed_vertices": trans_verts, + "face_kpt": predicted_landmarks, + "smplx_kpt": predicted_joints, + "smplx_kpt3d": smplx_kpt3d, + "joints": joints, + "cam": param_dict[param_type + "_cam"], + } + + # change the order of face keypoints, to be the same as "standard" 68 keypoints + prediction["face_kpt"] = torch.cat( + [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], + dim=1) + + prediction.update(param_dict) + + return prediction + + def decode_Tpose(self, param_dict): + """return body mesh in T pose, support body and head param dict only""" + verts, _, _ = self.smplx( + shape_params=param_dict["shape"], + expression_params=param_dict["exp"], + jaw_pose=param_dict["jaw_pose"], + ) + return verts diff --git a/lib/pixielib/utils/array_cropper.py b/lib/pixielib/utils/array_cropper.py new file mode 100644 index 0000000000000000000000000000000000000000..308d7d619bf194e9d372f6fcdb48b5b287bf6659 --- /dev/null +++ b/lib/pixielib/utils/array_cropper.py @@ -0,0 +1,86 @@ +''' +crop +for numpy array +Given image, bbox(center, bboxsize) +return: cropped image, tform(used for transform the keypoint accordingly) + +only support crop to squared images +''' + +import numpy as np +from skimage.transform import estimate_transform, warp, resize, rescale + + +def points2bbox(points, points_scale=None): + # recover range + if points_scale: + points[:, 0] = points[:, 0] * points_scale[1] / 2 + points_scale[1] / 2 + points[:, 1] = points[:, 1] * points_scale[0] / 2 + points_scale[0] / 2 + + left = np.min(points[:, 0]) + right = np.max(points[:, 0]) + top = np.min(points[:, 1]) + bottom = np.max(points[:, 1]) + size = max(right - left, bottom - top) + # + old_size*0.1]) + center = np.array( + [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) + return center, size + # translate center + + +def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.): + trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale + center = center + trans_scale * bbox_size # 0.5 + scale = np.random.rand() * (scale[1] - scale[0]) + scale[0] + size = int(bbox_size * scale) + return center, size + + +def crop_array(image, center, bboxsize, crop_size): + ''' for single image only + Args: + image (numpy.Array): the reference array of shape HxWXC. + size (Tuple[int, int]): a tuple with the height and width that will be + used to resize the extracted patches. + Returns: + cropped_image + tform: 3x3 affine matrix + ''' + # points: top-left, top-right, bottom-right + src_pts = np.array([[center[0] - bboxsize / 2, center[1] - bboxsize / 2], + [center[0] + bboxsize / 2, center[1] - bboxsize / 2], + [center[0] + bboxsize / 2, center[1] + bboxsize / 2]]) + DST_PTS = np.array([[0, 0], [crop_size - 1, 0], + [crop_size - 1, crop_size - 1]]) + + # estimate transformation between points + tform = estimate_transform('similarity', src_pts, DST_PTS) + + # warp images + cropped_image = warp(image, + tform.inverse, + output_shape=(crop_size, crop_size)) + + return cropped_image, tform.params.T + + +class Cropper(object): + + def __init__(self, crop_size, scale=[1, 1], trans_scale=0.): + self.crop_size = crop_size + self.scale = scale + self.trans_scale = trans_scale + + def crop(self, image, points, points_scale=None): + # points to bbox + center, bbox_size = points2bbox(points, points_scale) + # argument bbox. + center, bbox_size = augment_bbox(center, + bbox_size, + scale=self.scale, + trans_scale=self.trans_scale) + # crop + cropped_image, tform = crop_array(image, center, bbox_size, + self.crop_size) + return cropped_image, tform diff --git a/lib/pixielib/utils/config.py b/lib/pixielib/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..127c43445a8a640185d92a8c3ded6ed8eb9ff519 --- /dev/null +++ b/lib/pixielib/utils/config.py @@ -0,0 +1,205 @@ +""" +Default config for PIXIE +""" +from yacs.config import CfgNode as CN +import argparse +import yaml +import os + +cfg = CN() + +abs_pixie_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..")) +cfg.pixie_dir = abs_pixie_dir +cfg.device = "cuda" +cfg.device_id = "0" +cfg.pretrained_modelpath = os.path.join("smpl_related/HPS/pixie_data", + "pixie_model.tar") +# smplx parameter settings +cfg.params = CN() +cfg.params.body_list = [ + "body_cam", "global_pose", "partbody_pose", "neck_pose" +] +cfg.params.head_list = ["head_cam", "tex", "light"] +cfg.params.head_share_list = ["shape", "exp", "head_pose", "jaw_pose"] +cfg.params.hand_list = ["hand_cam"] +cfg.params.hand_share_list = [ + "right_wrist_pose", + "right_hand_pose", +] # only for right hand + +# ---------------------------------------------------------------------------- # +# Options for Body model +# ---------------------------------------------------------------------------- # +cfg.model = CN() +cfg.model.topology_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "SMPL_X_template_FLAME_uv.obj") +cfg.model.topology_smplxtex_path = os.path.join(cfg.pixie_dir, + "smpl_related/HPS/pixie_data", + "smplx_tex.obj") +cfg.model.topology_smplx_hand_path = os.path.join(cfg.pixie_dir, + "smpl_related/HPS/pixie_data", + "smplx_hand.obj") +cfg.model.smplx_model_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "SMPLX_NEUTRAL_2020.npz") +cfg.model.face_mask_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "uv_face_mask.png") +cfg.model.face_eye_mask_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "uv_face_eye_mask.png") +cfg.model.tex_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "FLAME_albedo_from_BFM.npz") +cfg.model.extra_joint_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "smplx_extra_joints.yaml") +cfg.model.j14_regressor_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "SMPLX_to_J14.pkl") +cfg.model.flame2smplx_cached_path = os.path.join(cfg.pixie_dir, + "smpl_related/HPS/pixie_data", + "flame2smplx_tex_1024.npy") +cfg.model.smplx_tex_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "smplx_tex.png") +cfg.model.mano_ids_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "MANO_SMPLX_vertex_ids.pkl") +cfg.model.flame_ids_path = os.path.join(cfg.pixie_dir, "smpl_related/HPS/pixie_data", + "SMPL-X__FLAME_vertex_ids.npy") +cfg.model.uv_size = 256 +cfg.model.n_shape = 200 +cfg.model.n_tex = 50 +cfg.model.n_exp = 50 +cfg.model.n_body_cam = 3 +cfg.model.n_head_cam = 3 +cfg.model.n_hand_cam = 3 +cfg.model.tex_type = "BFM" # BFM, FLAME, albedoMM +cfg.model.uvtex_type = "SMPLX" # FLAME or SMPLX +cfg.model.use_tex = False # whether to use flame texture model +cfg.model.flame_tex_path = "" + +# pose +cfg.model.n_global_pose = 3 * 2 +cfg.model.n_head_pose = 3 * 2 +cfg.model.n_neck_pose = 3 * 2 +cfg.model.n_jaw_pose = 3 # euler angle +cfg.model.n_body_pose = 21 * 3 * 2 +cfg.model.n_partbody_pose = (21 - 4) * 3 * 2 +cfg.model.n_left_hand_pose = 15 * 3 * 2 +cfg.model.n_right_hand_pose = 15 * 3 * 2 +cfg.model.n_left_wrist_pose = 1 * 3 * 2 +cfg.model.n_right_wrist_pose = 1 * 3 * 2 +cfg.model.n_light = 27 +cfg.model.check_pose = True + +# ---------------------------------------------------------------------------- # +# Options for Dataset +# ---------------------------------------------------------------------------- # +cfg.dataset = CN() +cfg.dataset.source = ["body", "head", "hand"] + +# head/face dataset +cfg.dataset.head = CN() +cfg.dataset.head.batch_size = 24 +cfg.dataset.head.num_workers = 2 +cfg.dataset.head.from_body = True +cfg.dataset.head.image_size = 224 +cfg.dataset.head.image_hd_size = 224 +cfg.dataset.head.scale_min = 1.8 +cfg.dataset.head.scale_max = 2.2 +cfg.dataset.head.trans_scale = 0.3 +# body datset +cfg.dataset.body = CN() +cfg.dataset.body.batch_size = 24 +cfg.dataset.body.num_workers = 2 +cfg.dataset.body.image_size = 224 +cfg.dataset.body.image_hd_size = 1024 +cfg.dataset.body.use_hd = True +# hand datset +cfg.dataset.hand = CN() +cfg.dataset.hand.batch_size = 24 +cfg.dataset.hand.num_workers = 2 +cfg.dataset.hand.image_size = 224 +cfg.dataset.hand.image_hd_size = 512 +cfg.dataset.hand.scale_min = 2.2 +cfg.dataset.hand.scale_max = 2.6 +cfg.dataset.hand.trans_scale = 0.4 + +# ---------------------------------------------------------------------------- # +# Options for Network +# ---------------------------------------------------------------------------- # +cfg.network = CN() +cfg.network.encoder = CN() +cfg.network.encoder.body = CN() +cfg.network.encoder.body.type = "hrnet" +cfg.network.encoder.head = CN() +cfg.network.encoder.head.type = "resnet50" +cfg.network.encoder.hand = CN() +cfg.network.encoder.hand.type = "resnet50" + +cfg.network.regressor = CN() +cfg.network.regressor.head_share = CN() +cfg.network.regressor.head_share.type = "mlp" +cfg.network.regressor.head_share.channels = [1024, 1024] +cfg.network.regressor.hand_share = CN() +cfg.network.regressor.hand_share.type = "mlp" +cfg.network.regressor.hand_share.channels = [1024, 1024] +cfg.network.regressor.body = CN() +cfg.network.regressor.body.type = "mlp" +cfg.network.regressor.body.channels = [1024] +cfg.network.regressor.head = CN() +cfg.network.regressor.head.type = "mlp" +cfg.network.regressor.head.channels = [1024] +cfg.network.regressor.hand = CN() +cfg.network.regressor.hand.type = "mlp" +cfg.network.regressor.hand.channels = [1024] + +cfg.network.extractor = CN() +cfg.network.extractor.head_share = CN() +cfg.network.extractor.head_share.type = "mlp" +cfg.network.extractor.head_share.channels = [] +cfg.network.extractor.left_hand_share = CN() +cfg.network.extractor.left_hand_share.type = "mlp" +cfg.network.extractor.left_hand_share.channels = [] +cfg.network.extractor.right_hand_share = CN() +cfg.network.extractor.right_hand_share.type = "mlp" +cfg.network.extractor.right_hand_share.channels = [] + +cfg.network.moderator = CN() +cfg.network.moderator.head_share = CN() +cfg.network.moderator.head_share.detach_inputs = False +cfg.network.moderator.head_share.detach_feature = False +cfg.network.moderator.head_share.type = "temp-softmax" +cfg.network.moderator.head_share.channels = [1024, 1024] +cfg.network.moderator.head_share.reduction = 4 +cfg.network.moderator.head_share.scale_type = "scalars" +cfg.network.moderator.head_share.scale_init = 1.0 +cfg.network.moderator.hand_share = CN() +cfg.network.moderator.hand_share.detach_inputs = False +cfg.network.moderator.hand_share.detach_feature = False +cfg.network.moderator.hand_share.type = "temp-softmax" +cfg.network.moderator.hand_share.channels = [1024, 1024] +cfg.network.moderator.hand_share.reduction = 4 +cfg.network.moderator.hand_share.scale_type = "scalars" +cfg.network.moderator.hand_share.scale_init = 0.0 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return cfg.clone() + + +def update_cfg(cfg, cfg_file): + # cfg.merge_from_file(cfg_file, allow_unsafe=True) + cfg.merge_from_file(cfg_file) + return cfg.clone() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--cfg", type=str, help="cfg file path") + + args = parser.parse_args() + cfg = get_cfg_defaults() + if args.cfg is not None: + cfg_file = args.cfg + cfg = update_cfg(cfg, args.cfg) + cfg.cfg_file = cfg_file + return cfg diff --git a/lib/pixielib/utils/renderer.py b/lib/pixielib/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..03e235d52684de9d8456e1609a0ca9edc8d0c2d0 --- /dev/null +++ b/lib/pixielib/utils/renderer.py @@ -0,0 +1,590 @@ +""" +Author: Yao Feng +Copyright (c) 2020, Yao Feng +All rights reserved. +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from skimage.io import imread +import imageio + +from . import util + + +def set_rasterizer(type='pytorch3d'): + if type == 'pytorch3d': + global Meshes, load_obj, rasterize_meshes + from pytorch3d.structures import Meshes + from pytorch3d.io import load_obj + from pytorch3d.renderer.mesh import rasterize_meshes + elif type == 'standard': + global standard_rasterize, load_obj + import os + from .util import load_obj + # Use JIT Compiling Extensions + # ref: https://pytorch.org/tutorials/advanced/cpp_extension.html + from torch.utils.cpp_extension import load, CUDA_HOME + curr_dir = os.path.dirname(__file__) + standard_rasterize_cuda = \ + load(name='standard_rasterize_cuda', + sources=[f'{curr_dir}/rasterizer/standard_rasterize_cuda.cpp', + f'{curr_dir}/rasterizer/standard_rasterize_cuda_kernel.cu'], + extra_cuda_cflags=['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7 + from standard_rasterize_cuda import standard_rasterize + # If JIT does not work, try manually installation first + # 1. see instruction here: pixielib/utils/rasterizer/INSTALL.md + # 2. add this: "from .rasterizer.standard_rasterize_cuda import standard_rasterize" here + + +class StandardRasterizer(nn.Module): + """ Alg: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation + Notice: + x,y,z are in image space, normalized to [-1, 1] + can render non-squared image + not differentiable + """ + + def __init__(self, height, width=None): + """ + use fixed raster_settings for rendering faces + """ + super().__init__() + if width is None: + width = height + self.h = h = height + self.w = w = width + + def forward(self, vertices, faces, attributes=None, h=None, w=None): + device = vertices.device + if h is None: + h = self.h + if w is None: + w = self.h + bz = vertices.shape[0] + depth_buffer = torch.zeros([bz, h, w]).float().to(device) + 1e6 + triangle_buffer = torch.zeros([bz, h, w]).int().to(device) - 1 + baryw_buffer = torch.zeros([bz, h, w, 3]).float().to(device) + vert_vis = torch.zeros([bz, vertices.shape[1]]).float().to(device) + + vertices = vertices.clone().float() + vertices[..., 0] = vertices[..., 0] * w / 2 + w / 2 + vertices[..., 1] = vertices[..., 1] * h / 2 + h / 2 + vertices[..., 2] = vertices[..., 2] * w / 2 + f_vs = util.face_vertices(vertices, faces) + + standard_rasterize(f_vs, depth_buffer, triangle_buffer, baryw_buffer, + h, w) + pix_to_face = triangle_buffer[:, :, :, None].long() + bary_coords = baryw_buffer[:, :, :, None, :] + vismask = (pix_to_face > -1).float() + D = attributes.shape[-1] + attributes = attributes.clone() + attributes = attributes.view(attributes.shape[0] * attributes.shape[1], + 3, attributes.shape[-1]) + N, H, W, K, _ = bary_coords.shape + mask = pix_to_face == -1 + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) + pixel_vals = torch.cat( + [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + return pixel_vals + + +class Pytorch3dRasterizer(nn.Module): + """ Borrowed from https://github.com/facebookresearch/pytorch3d + This class implements methods for rasterizing a batch of heterogenous Meshes. + Notice: + x,y,z are in image space, normalized + can only render squared image now + """ + + def __init__(self, image_size=224): + """ + use fixed raster_settings for rendering faces + """ + super().__init__() + raster_settings = { + 'image_size': image_size, + 'blur_radius': 0.0, + 'faces_per_pixel': 1, + 'bin_size': None, + 'max_faces_per_bin': None, + 'perspective_correct': False, + } + raster_settings = util.dict2obj(raster_settings) + self.raster_settings = raster_settings + + def forward(self, vertices, faces, attributes=None, h=None, w=None): + fixed_vertices = vertices.clone() + fixed_vertices[..., :2] = -fixed_vertices[..., :2] + meshes_screen = Meshes(verts=fixed_vertices.float(), + faces=faces.long()) + raster_settings = self.raster_settings + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + ) + vismask = (pix_to_face > -1).float() + D = attributes.shape[-1] + attributes = attributes.clone() + attributes = attributes.view(attributes.shape[0] * attributes.shape[1], + 3, attributes.shape[-1]) + N, H, W, K, _ = bary_coords.shape + mask = pix_to_face == -1 + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) + pixel_vals = torch.cat( + [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + return pixel_vals + + +class SRenderY(nn.Module): + + def __init__(self, + image_size, + obj_filename, + uv_size=256, + rasterizer_type='standard'): + super(SRenderY, self).__init__() + self.image_size = image_size + self.uv_size = uv_size + + if rasterizer_type == 'pytorch3d': + self.rasterizer = Pytorch3dRasterizer(image_size) + self.uv_rasterizer = Pytorch3dRasterizer(uv_size) + verts, faces, aux = load_obj(obj_filename) + uvcoords = aux.verts_uvs[None, ...] # (N, V, 2) + uvfaces = faces.textures_idx[None, ...] # (N, F, 3) + faces = faces.verts_idx[None, ...] + elif rasterizer_type == 'standard': + self.rasterizer = StandardRasterizer(image_size) + self.uv_rasterizer = StandardRasterizer(uv_size) + verts, uvcoords, faces, uvfaces = load_obj(obj_filename) + verts = verts[None, ...] + uvcoords = uvcoords[None, ...] + faces = faces[None, ...] + uvfaces = uvfaces[None, ...] + else: + NotImplementedError + + # faces + dense_triangles = util.generate_triangles(uv_size, uv_size) + self.register_buffer( + 'dense_faces', + torch.from_numpy(dense_triangles).long()[None, :, :]) + self.register_buffer('faces', faces) + self.register_buffer('raw_uvcoords', uvcoords) + + # uv coords + uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0. + 1.], + -1) # [bz, ntv, 3] + uvcoords = uvcoords * 2 - 1 + uvcoords[..., 1] = -uvcoords[..., 1] + face_uvcoords = util.face_vertices(uvcoords, uvfaces) + self.register_buffer('uvcoords', uvcoords) + self.register_buffer('uvfaces', uvfaces) + self.register_buffer('face_uvcoords', face_uvcoords) + + # shape colors, for rendering shape overlay + colors = torch.tensor([180, 180, 180])[None, None, :].repeat( + 1, + faces.max() + 1, 1).float() / 255. + face_colors = util.face_vertices(colors, faces) + self.register_buffer('vertex_colors', colors) + self.register_buffer('face_colors', face_colors) + + # SH factors for lighting + pi = np.pi + constant_factor = torch.tensor([ + 1 / np.sqrt(4 * pi), ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), + ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), ((2 * pi) / 3) * + (np.sqrt(3 / (4 * pi))), (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))) + ]).float() + self.register_buffer('constant_factor', constant_factor) + + def forward(self, + vertices, + transformed_vertices, + albedos, + lights=None, + light_type='point', + background=None, + h=None, + w=None): + ''' + -- Texture Rendering + vertices: [batch_size, V, 3], vertices in world space, for calculating normals, then shading + transformed_vertices: [batch_size, V, 3], rnage:[-1,1], projected vertices, in image space, for rasterization + albedos: [batch_size, 3, h, w], uv map + lights: + spherical homarnic: [N, 9(shcoeff), 3(rgb)] + points/directional lighting: [N, n_lights, 6(xyzrgb)] + light_type: + point or directional + ''' + batch_size = vertices.shape[0] + # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100) + transformed_vertices = transformed_vertices.clone() + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] - transformed_vertices[:, :, + 2].min( + ) + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] / transformed_vertices[:, :, + 2].max( + ) + transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10 + + # attributes + face_vertices = util.face_vertices( + vertices, self.faces.expand(batch_size, -1, -1)) + normals = util.vertex_normals(vertices, + self.faces.expand(batch_size, -1, -1)) + face_normals = util.face_vertices( + normals, self.faces.expand(batch_size, -1, -1)) + transformed_normals = util.vertex_normals( + transformed_vertices, self.faces.expand(batch_size, -1, -1)) + transformed_face_normals = util.face_vertices( + transformed_normals, self.faces.expand(batch_size, -1, -1)) + attributes = torch.cat([ + self.face_uvcoords.expand(batch_size, -1, -1, -1), + transformed_face_normals.detach(), + face_vertices.detach(), face_normals + ], -1) + + # rasterize + rendering = self.rasterizer(transformed_vertices, + self.faces.expand(batch_size, -1, -1), + attributes, h, w) + + #### + # vis mask + alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() + + # albedo + uvcoords_images = rendering[:, :3, :, :] + grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] + albedo_images = F.grid_sample(albedos, grid, align_corners=False) + + # visible mask for pixels with positive normal direction + transformed_normal_map = rendering[:, 3:6, :, :].detach() + pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float() + + # shading + normal_images = rendering[:, 9:12, :, :] + if lights is not None: + if lights.shape[1] == 9: + shading_images = self.add_SHlight(normal_images, lights) + else: + if light_type == 'point': + vertice_images = rendering[:, 6:9, :, :].detach() + shading = self.add_pointlight( + vertice_images.permute(0, 2, 3, + 1).reshape([batch_size, -1, 3]), + normal_images.permute(0, 2, 3, + 1).reshape([batch_size, -1, 3]), + lights) + shading_images = shading.reshape([ + batch_size, albedo_images.shape[2], + albedo_images.shape[3], 3 + ]).permute(0, 3, 1, 2) + else: + shading = self.add_directionlight( + normal_images.permute(0, 2, 3, + 1).reshape([batch_size, -1, 3]), + lights) + shading_images = shading.reshape([ + batch_size, albedo_images.shape[2], + albedo_images.shape[3], 3 + ]).permute(0, 3, 1, 2) + images = albedo_images * shading_images + else: + images = albedo_images + shading_images = images.detach() * 0. + + if background is None: + images = images*alpha_images + \ + torch.ones_like(images).to(vertices.device)*(1-alpha_images) + else: + # background = F.interpolate(background, [self.image_size, self.image_size]) + images = images * alpha_images + background.contiguous() * ( + 1 - alpha_images) + + outputs = { + 'images': images, + 'albedo_images': albedo_images, + 'alpha_images': alpha_images, + 'pos_mask': pos_mask, + 'shading_images': shading_images, + 'grid': grid, + 'normals': normals, + 'normal_images': normal_images, + 'transformed_normals': transformed_normals, + } + + return outputs + + def add_SHlight(self, normal_images, sh_coeff): + ''' + sh_coeff: [bz, 9, 3] + ''' + N = normal_images + sh = torch.stack([ + N[:, 0] * 0. + 1., N[:, 0], N[:, 1], N[:, 2], N[:, 0] * N[:, 1], + N[:, 0] * N[:, 2], N[:, 1] * N[:, 2], N[:, 0]**2 - N[:, 1]**2, 3 * + (N[:, 2]**2) - 1 + ], 1) # [bz, 9, h, w] + sh = sh * self.constant_factor[None, :, None, None] + # [bz, 9, 3, h, w] + shading = torch.sum( + sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1) + return shading + + def add_pointlight(self, vertices, normals, lights): + ''' + vertices: [bz, nv, 3] + lights: [bz, nlight, 6] + returns: + shading: [bz, nv, 3] + ''' + light_positions = lights[:, :, :3] + light_intensities = lights[:, :, 3:] + directions_to_lights = F.normalize(light_positions[:, :, None, :] - + vertices[:, None, :, :], + dim=3) + # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) + normals_dot_lights = (normals[:, None, :, :] * + directions_to_lights).sum(dim=3) + shading = normals_dot_lights[:, :, :, + None] * light_intensities[:, :, None, :] + return shading.mean(1) + + def add_directionlight(self, normals, lights): + ''' + normals: [bz, nv, 3] + lights: [bz, nlight, 6] + returns: + shading: [bz, nv, 3] + ''' + light_direction = lights[:, :, :3] + light_intensities = lights[:, :, 3:] + directions_to_lights = F.normalize( + light_direction[:, :, None, :].expand(-1, -1, normals.shape[1], + -1), + dim=3) + # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) + # normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) + normals_dot_lights = torch.clamp( + (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0., 1.) + shading = normals_dot_lights[:, :, :, + None] * light_intensities[:, :, None, :] + return shading.mean(1) + + def render_shape(self, + vertices, + transformed_vertices, + colors=None, + background=None, + detail_normal_images=None, + lights=None, + return_grid=False, + uv_detail_normals=None, + h=None, + w=None): + ''' + -- rendering shape with detail normal map + ''' + batch_size = vertices.shape[0] + if lights is None: + light_positions = torch.tensor([ + [-5, 5, -5], + [5, 5, -5], + [-5, -5, -5], + [5, -5, -5], + [0, 0, -5], + ])[None, :, :].expand(batch_size, -1, -1).float() + + light_intensities = torch.ones_like(light_positions).float() * 1.7 + lights = torch.cat((light_positions, light_intensities), + 2).to(vertices.device) + # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100) + transformed_vertices = transformed_vertices.clone() + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] - transformed_vertices[:, :, + 2].min( + ) + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] / transformed_vertices[:, :, + 2].max( + ) + transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10 + + # Attributes + face_vertices = util.face_vertices( + vertices, self.faces.expand(batch_size, -1, -1)) + normals = util.vertex_normals(vertices, + self.faces.expand(batch_size, -1, -1)) + face_normals = util.face_vertices( + normals, self.faces.expand(batch_size, -1, -1)) + transformed_normals = util.vertex_normals( + transformed_vertices, self.faces.expand(batch_size, -1, -1)) + transformed_face_normals = util.face_vertices( + transformed_normals, self.faces.expand(batch_size, -1, -1)) + if colors is None: + colors = self.face_colors.expand(batch_size, -1, -1, -1) + attributes = torch.cat([ + colors, + transformed_face_normals.detach(), + face_vertices.detach(), face_normals, + self.face_uvcoords.expand(batch_size, -1, -1, -1) + ], -1) + # rasterize + rendering = self.rasterizer(transformed_vertices, + self.faces.expand(batch_size, -1, -1), + attributes, h, w) + + #### + alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() + + # albedo + albedo_images = rendering[:, :3, :, :] + # mask + transformed_normal_map = rendering[:, 3:6, :, :].detach() + pos_mask = (transformed_normal_map[:, 2:, :, :] < 0).float() + + # shading + normal_images = rendering[:, 9:12, :, :].detach() + vertice_images = rendering[:, 6:9, :, :].detach() + if detail_normal_images is not None: + normal_images = detail_normal_images + if uv_detail_normals is not None: + uvcoords_images = rendering[:, 12:15, :, :] + grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] + detail_normal_images = F.grid_sample(uv_detail_normals, + grid, + align_corners=False) + normal_images = detail_normal_images + + shading = self.add_directionlight( + normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), + lights) + shading_images = shading.reshape( + [batch_size, albedo_images.shape[2], albedo_images.shape[3], + 3]).permute(0, 3, 1, 2).contiguous() + shaded_images = albedo_images * shading_images + + if background is None: + shape_images = shaded_images*alpha_images + \ + torch.ones_like(shaded_images).to( + vertices.device)*(1-alpha_images) + else: + # background = F.interpolate(background, [self.image_size, self.image_size]) + shape_images = shaded_images*alpha_images + \ + background.contiguous()*(1-alpha_images) + + if return_grid: + uvcoords_images = rendering[:, 12:15, :, :] + grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] + return shape_images, normal_images, grid + else: + return shape_images + + def render_depth(self, transformed_vertices): + ''' + -- rendering depth + ''' + transformed_vertices = transformed_vertices.clone() + batch_size = transformed_vertices.shape[0] + + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] - transformed_vertices[:, :, + 2].min( + ) + z = -transformed_vertices[:, :, 2:].repeat(1, 1, 3) + z = z - z.min() + z = z / z.max() + # Attributes + attributes = util.face_vertices(z, + self.faces.expand(batch_size, -1, -1)) + # rasterize + rendering = self.rasterizer(transformed_vertices, + self.faces.expand(batch_size, -1, -1), + attributes) + + #### + alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() + depth_images = rendering[:, :1, :, :] + return depth_images + + def render_colors(self, transformed_vertices, colors, h=None, w=None): + ''' + -- rendering colors: could be rgb color/ normals, etc + colors: [bz, num of vertices, 3] + ''' + transformed_vertices = transformed_vertices.clone() + batch_size = colors.shape[0] + # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100) + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] - transformed_vertices[:, :, + 2].min( + ) + transformed_vertices[:, :, + 2] = transformed_vertices[:, :, + 2] / transformed_vertices[:, :, + 2].max( + ) + transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10 + # Attributes + attributes = util.face_vertices(colors, + self.faces.expand(batch_size, -1, -1)) + # rasterize + rendering = self.rasterizer(transformed_vertices, + self.faces.expand(batch_size, -1, -1), + attributes, + h=h, + w=w) + #### + alpha_images = rendering[:, [-1], :, :].detach() + images = rendering[:, :3, :, :] * alpha_images + return images + + def world2uv(self, vertices): + ''' + project vertices from world space to uv space + vertices: [bz, V, 3] + uv_vertices: [bz, 3, h, w] + ''' + batch_size = vertices.shape[0] + face_vertices = util.face_vertices( + vertices, self.faces.expand(batch_size, -1, -1)) + uv_vertices = self.uv_rasterizer( + self.uvcoords.expand(batch_size, -1, -1), + self.uvfaces.expand(batch_size, -1, -1), face_vertices)[:, :3] + return uv_vertices diff --git a/lib/pixielib/utils/rotation_converter.py b/lib/pixielib/utils/rotation_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d628b4427ca2c8e8582f34142607400b9f070bc --- /dev/null +++ b/lib/pixielib/utils/rotation_converter.py @@ -0,0 +1,545 @@ +import torch +import torch.nn.functional as F +import numpy as np +''' Rotation Converter + This function is borrowed from https://github.com/kornia/kornia + +ref: https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html# +Repre: euler angle(3), axis angle(3), rotation matrix(3x3), quaternion(4), continuous rotation representation (6) +batch_rodrigues: axis angle -> matrix +''' +pi = torch.Tensor([3.14159265358979323846]) + + +def rad2deg(tensor): + """Function that converts angles from radians to degrees. + + See :class:`~torchgeometry.RadToDeg` for details. + + Args: + tensor (Tensor): Tensor of arbitrary shape. + + Returns: + Tensor: Tensor with same shape as input. + + Example: + >>> input = tgm.pi * torch.rand(1, 3, 3) + >>> output = tgm.rad2deg(input) + """ + if not torch.is_tensor(tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(tensor))) + + return 180. * tensor / pi.to(tensor.device).type(tensor.dtype) + + +def deg2rad(tensor): + """Function that converts angles from degrees to radians. + + See :class:`~torchgeometry.DegToRad` for details. + + Args: + tensor (Tensor): Tensor of arbitrary shape. + + Returns: + Tensor: Tensor with same shape as input. + + Examples:: + + >>> input = 360. * torch.rand(1, 3, 3) + >>> output = tgm.deg2rad(input) + """ + if not torch.is_tensor(tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(tensor))) + + return tensor * pi.to(tensor.device).type(tensor.dtype) / 180. + + +# to quaternion + + +def euler_to_quaternion(r): + x = r[..., 0] + y = r[..., 1] + z = r[..., 2] + + z = z / 2.0 + y = y / 2.0 + x = x / 2.0 + cz = torch.cos(z) + sz = torch.sin(z) + cy = torch.cos(y) + sy = torch.sin(y) + cx = torch.cos(x) + sx = torch.sin(x) + quaternion = torch.zeros_like(r.repeat(1, 2))[..., :4].to(r.device) + quaternion[..., 0] += cx * cy * cz - sx * sy * sz + quaternion[..., 1] += cx * sy * sz + cy * cz * sx + quaternion[..., 2] += cx * cz * sy - sx * cy * sz + quaternion[..., 3] += cx * cy * sz + sx * cz * sy + return quaternion + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + # if not rotation_matrix.shape[-2:] == (3, 4): + # raise ValueError( + # "Input size must be a N x 3 x 4 tensor. Got {}".format( + # rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1.float() + mask_c1 = mask_d2 * (1 - mask_d0_d1.float()) + mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1 + mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float()) + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor: + """Convert an angle axis to a quaternion. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + angle_axis (torch.Tensor): tensor with angle axis. + + Return: + torch.Tensor: tensor with quaternion. + + Shape: + - Input: :math:`(*, 3)` where `*` means, any number of dimensions + - Output: :math:`(*, 4)` + + Example: + >>> angle_axis = torch.rand(2, 4) # Nx4 + >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3 + """ + if not torch.is_tensor(angle_axis): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(angle_axis))) + + if not angle_axis.shape[-1] == 3: + raise ValueError( + "Input must be a tensor of shape Nx3 or 3. Got {}".format( + angle_axis.shape)) + # unpack input and compute conversion + a0: torch.Tensor = angle_axis[..., 0:1] + a1: torch.Tensor = angle_axis[..., 1:2] + a2: torch.Tensor = angle_axis[..., 2:3] + theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2 + + theta: torch.Tensor = torch.sqrt(theta_squared) + half_theta: torch.Tensor = theta * 0.5 + + mask: torch.Tensor = theta_squared > 0.0 + ones: torch.Tensor = torch.ones_like(half_theta) + + k_neg: torch.Tensor = 0.5 * ones + k_pos: torch.Tensor = torch.sin(half_theta) / theta + k: torch.Tensor = torch.where(mask, k_pos, k_neg) + w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones) + + quaternion: torch.Tensor = torch.zeros_like(angle_axis) + quaternion[..., 0:1] += a0 * k + quaternion[..., 1:2] += a1 * k + quaternion[..., 2:3] += a2 * k + return torch.cat([w, quaternion], dim=-1) + + +# quaternion to + + +def quaternion_to_rotation_matrix(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(B, 3, 3) + return rotMat + + +def quaternion_to_angle_axis(quaternion: torch.Tensor): + """Convert quaternion vector to angle axis of rotation. TODO: CORRECT + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format( + quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * \ + torch.ones_like(sin_theta).to(quaternion.device) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion).to( + quaternion.device)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +# credit to Muhammed Kocabas +# matrix to euler angle +# Device = Union[str, torch.device] +_AXIS_TO_IND = {'x': 0, 'y': 1, 'z': 2} + + +def _elementary_basis_vector(axis): + b = torch.zeros(3) + b[_AXIS_TO_IND[axis]] = 1 + return b + + +def _compute_euler_from_matrix(dcm, seq='xyz', extrinsic=False): + # The algorithm assumes intrinsic frame transformations. For representation + # the paper uses transformation matrices, which are transpose of the + # direction cosine matrices used by our Rotation class. + # Adapt the algorithm for our case by + # 1. Instead of transposing our representation, use the transpose of the + # O matrix as defined in the paper, and be careful to swap indices + # 2. Reversing both axis sequence and angles for extrinsic rotations + orig_device = dcm.device + dcm = dcm.to('cpu') + seq = seq.lower() + + if extrinsic: + seq = seq[::-1] + + if dcm.ndim == 2: + dcm = dcm[None, :, :] + num_rotations = dcm.shape[0] + + device = dcm.device + + # Step 0 + # Algorithm assumes axes as column vectors, here we use 1D vectors + n1 = _elementary_basis_vector(seq[0]) + n2 = _elementary_basis_vector(seq[1]) + n3 = _elementary_basis_vector(seq[2]) + + # Step 2 + sl = torch.dot(torch.cross(n1, n2), n3) + cl = torch.dot(n1, n3) + + # angle offset is lambda from the paper referenced in [2] from docstring of + # `as_euler` function + offset = torch.atan2(sl, cl) + c = torch.stack((n2, torch.cross(n1, n2), n1)).type(dcm.dtype).to(device) + + # Step 3 + rot = torch.tensor([ + [1, 0, 0], + [0, cl, sl], + [0, -sl, cl], + ]).type(dcm.dtype) + # import IPython; IPython.embed(); exit + res = torch.einsum('ij,...jk->...ik', c, dcm) + dcm_transformed = torch.einsum('...ij,jk->...ik', res, c.T @ rot) + + # Step 4 + angles = torch.zeros((num_rotations, 3), dtype=dcm.dtype, device=device) + + # Ensure less than unit norm + positive_unity = dcm_transformed[:, 2, 2] > 1 + negative_unity = dcm_transformed[:, 2, 2] < -1 + dcm_transformed[positive_unity, 2, 2] = 1 + dcm_transformed[negative_unity, 2, 2] = -1 + angles[:, 1] = torch.acos(dcm_transformed[:, 2, 2]) + + # Steps 5, 6 + eps = 1e-7 + safe1 = (torch.abs(angles[:, 1]) >= eps) + safe2 = (torch.abs(angles[:, 1] - np.pi) >= eps) + + # Step 4 (Completion) + angles[:, 1] += offset + + # 5b + safe_mask = torch.logical_and(safe1, safe2) + angles[safe_mask, 0] = torch.atan2(dcm_transformed[safe_mask, 0, 2], + -dcm_transformed[safe_mask, 1, 2]) + angles[safe_mask, 2] = torch.atan2(dcm_transformed[safe_mask, 2, 0], + dcm_transformed[safe_mask, 2, 1]) + if extrinsic: + # For extrinsic, set first angle to zero so that after reversal we + # ensure that third angle is zero + # 6a + angles[~safe_mask, 0] = 0 + # 6b + angles[~safe1, 2] = torch.atan2( + dcm_transformed[~safe1, 1, 0] - dcm_transformed[~safe1, 0, 1], + dcm_transformed[~safe1, 0, 0] + dcm_transformed[~safe1, 1, 1]) + # 6c + angles[~safe2, 2] = -torch.atan2( + dcm_transformed[~safe2, 1, 0] + dcm_transformed[~safe2, 0, 1], + dcm_transformed[~safe2, 0, 0] - dcm_transformed[~safe2, 1, 1]) + else: + # For instrinsic, set third angle to zero + # 6a + angles[~safe_mask, 2] = 0 + # 6b + angles[~safe1, 0] = torch.atan2( + dcm_transformed[~safe1, 1, 0] - dcm_transformed[~safe1, 0, 1], + dcm_transformed[~safe1, 0, 0] + dcm_transformed[~safe1, 1, 1]) + # 6c + angles[~safe2, 0] = torch.atan2( + dcm_transformed[~safe2, 1, 0] + dcm_transformed[~safe2, 0, 1], + dcm_transformed[~safe2, 0, 0] - dcm_transformed[~safe2, 1, 1]) + + # Step 7 + if seq[0] == seq[2]: + # lambda = 0, so we can only ensure angle2 -> [0, pi] + adjust_mask = torch.logical_or(angles[:, 1] < 0, angles[:, 1] > np.pi) + else: + # lambda = + or - pi/2, so we can ensure angle2 -> [-pi/2, pi/2] + adjust_mask = torch.logical_or(angles[:, 1] < -np.pi / 2, + angles[:, 1] > np.pi / 2) + + # Dont adjust gimbal locked angle sequences + adjust_mask = torch.logical_and(adjust_mask, safe_mask) + + angles[adjust_mask, 0] += np.pi + angles[adjust_mask, 1] = 2 * offset - angles[adjust_mask, 1] + angles[adjust_mask, 2] -= np.pi + + angles[angles < -np.pi] += 2 * np.pi + angles[angles > np.pi] -= 2 * np.pi + + # Step 8 + if not torch.all(safe_mask): + print("Gimbal lock detected. Setting third angle to zero since" + "it is not possible to uniquely determine all angles.") + + # Reverse role of extrinsic and intrinsic rotations, but let third angle be + # zero for gimbal locked cases + if extrinsic: + # angles = angles[:, ::-1] + angles = torch.flip(angles, dims=[ + -1, + ]) + + angles = angles.to(orig_device) + return angles + + +# batch converter + + +def batch_euler2axis(r): + return quaternion_to_angle_axis(euler_to_quaternion(r)) + + +def batch_euler2matrix(r): + return quaternion_to_rotation_matrix(euler_to_quaternion(r)) + + +def batch_matrix2euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + # only y biw + # TODO: add x, z + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) + + +def batch_matrix2axis(rot_mats): + return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats)) + + +def batch_axis2matrix(theta): + # angle axis to rotation matrix + # theta N x 3 + # return quat2mat(quat) + # batch_rodrigues + return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta)) + + +def batch_axis2euler(theta): + return batch_matrix2euler(batch_axis2matrix(theta)) + + +def batch_axis2euler(r): + return rot_mat_to_euler(batch_rodrigues(r)) + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + ''' same as batch_matrix2axis + Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + Code from smplx/flame, what PS people often use + ''' + + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def batch_cont2matrix(module_input): + ''' Decoder for transforming a latent representation to rotation matrices + + Implements the decoding method described in: + "On the Continuity of Rotation Representations in Neural Networks" + Code from https://github.com/vchoutas/expose + ''' + batch_size = module_input.shape[0] + reshaped_input = module_input.reshape(-1, 3, 2) + + # Normalize the first vector + b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1) + + dot_prod = torch.sum(b1 * reshaped_input[:, :, 1].clone(), + dim=1, + keepdim=True) + # Compute the second vector by finding the orthogonal complement to it + b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1) + # Finish building the basis by taking the cross product + b3 = torch.cross(b1, b2, dim=1) + rot_mats = torch.stack([b1, b2, b3], dim=-1) + + return rot_mats.view(batch_size, -1, 3, 3) diff --git a/lib/pixielib/utils/tensor_cropper.py b/lib/pixielib/utils/tensor_cropper.py new file mode 100644 index 0000000000000000000000000000000000000000..bef5881dfa49426ccacaa63d7eb2581bfc1571f7 --- /dev/null +++ b/lib/pixielib/utils/tensor_cropper.py @@ -0,0 +1,172 @@ +''' +crop +for torch tensor +Given image, bbox(center, bboxsize) +return: cropped image, tform(used for transform the keypoint accordingly) + +only support crop to squared images +''' +import torch +from kornia.geometry.transform.imgwarp import (warp_perspective, + get_perspective_transform, + warp_affine) + + +def points2bbox(points, points_scale=None): + if points_scale: + assert points_scale[0] == points_scale[1] + points = points.clone() + points[:, :, :2] = (points[:, :, :2] * 0.5 + 0.5) * points_scale[0] + min_coords, _ = torch.min(points, dim=1) + xmin, ymin = min_coords[:, 0], min_coords[:, 1] + max_coords, _ = torch.max(points, dim=1) + xmax, ymax = max_coords[:, 0], max_coords[:, 1] + center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 + + width = (xmax - xmin) + height = (ymax - ymin) + # Convert the bounding box to a square box + size = torch.max(width, height).unsqueeze(-1) + return center, size + + +def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.): + batch_size = center.shape[0] + trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2. - + 1.) * trans_scale + center = center + trans_scale * bbox_size # 0.5 + scale = torch.rand([batch_size, 1], device=center.device) * \ + (scale[1] - scale[0]) + scale[0] + size = bbox_size * scale + return center, size + + +def crop_tensor(image, + center, + bbox_size, + crop_size, + interpolation='bilinear', + align_corners=False): + ''' for batch image + Args: + image (torch.Tensor): the reference tensor of shape BXHxWXC. + center: [bz, 2] + bboxsize: [bz, 1] + crop_size; + interpolation (str): Interpolation flag. Default: 'bilinear'. + align_corners (bool): mode for grid_generation. Default: False. See + https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details + Returns: + cropped_image + tform + ''' + dtype = image.dtype + device = image.device + batch_size = image.shape[0] + # points: top-left, top-right, bottom-right, bottom-left + src_pts = torch.zeros([4, 2], dtype=dtype, + device=device).unsqueeze(0).expand( + batch_size, -1, -1).contiguous() + + src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1) + src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 + src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 + src_pts[:, 2, :] = center + bbox_size * 0.5 + src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5 + src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5 + + DST_PTS = torch.tensor([[ + [0, 0], + [crop_size - 1, 0], + [crop_size - 1, crop_size - 1], + [0, crop_size - 1], + ]], + dtype=dtype, + device=device).expand(batch_size, -1, -1) + # estimate transformation between points + dst_trans_src = get_perspective_transform(src_pts, DST_PTS) + # simulate broadcasting + # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1) + + # warp images + cropped_image = warp_affine(image, + dst_trans_src[:, :2, :], + (crop_size, crop_size), + mode=interpolation, + align_corners=align_corners) + + tform = torch.transpose(dst_trans_src, 2, 1) + # tform = torch.inverse(dst_trans_src) + return cropped_image, tform + + +class Cropper(object): + + def __init__(self, crop_size, scale=[1, 1], trans_scale=0.): + self.crop_size = crop_size + self.scale = scale + self.trans_scale = trans_scale + + def crop(self, image, points, points_scale=None): + # points to bbox + center, bbox_size = points2bbox(points.clone(), points_scale) + # argument bbox. TODO: add rotation? + center, bbox_size = augment_bbox(center, + bbox_size, + scale=self.scale, + trans_scale=self.trans_scale) + # crop + cropped_image, tform = crop_tensor(image, center, bbox_size, + self.crop_size) + return cropped_image, tform + + def transform_points(self, + points, + tform, + points_scale=None, + normalize=True): + points_2d = points[:, :, :2] + + #'input points must use original range' + if points_scale: + assert points_scale[0] == points_scale[1] + points_2d = (points_2d * 0.5 + 0.5) * points_scale[0] + + batch_size, n_points, _ = points.shape + trans_points_2d = torch.bmm( + torch.cat([ + points_2d, + torch.ones([batch_size, n_points, 1], + device=points.device, + dtype=points.dtype) + ], + dim=-1), tform) + trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], + dim=-1) + if normalize: + trans_points[:, :, :2] = trans_points[:, :, :2] / \ + self.crop_size*2 - 1 + return trans_points + + +def transform_points(points, tform, points_scale=None): + points_2d = points[:, :, :2] + + #'input points must use original range' + if points_scale: + assert points_scale[0] == points_scale[1] + points_2d = (points_2d * 0.5 + 0.5) * points_scale[0] + # import ipdb; ipdb.set_trace() + + batch_size, n_points, _ = points.shape + trans_points_2d = torch.bmm( + torch.cat([ + points_2d, + torch.ones([batch_size, n_points, 1], + device=points.device, + dtype=points.dtype) + ], + dim=-1), tform) + trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], + dim=-1) + return trans_points diff --git a/lib/pixielib/utils/util.py b/lib/pixielib/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..392551c138767bb2156c7f11d81d46035768e79a --- /dev/null +++ b/lib/pixielib/utils/util.py @@ -0,0 +1,704 @@ +import numpy as np +import torch +import torch.nn.functional as F +import math +from collections import OrderedDict +import os +from scipy.ndimage import morphology +import PIL.Image as pil_img +from skimage.io import imsave +import cv2 +import pickle + +# ---------------------------- process/generate vertices, normals, faces + + +def generate_triangles(h, w, mask=None): + ''' + quad layout: + 0 1 ... w-1 + w w+1 + . + w*h + ''' + triangles = [] + margin = 0 + for x in range(margin, w - 1 - margin): + for y in range(margin, h - 1 - margin): + triangle0 = [y * w + x, y * w + x + 1, (y + 1) * w + x] + triangle1 = [y * w + x + 1, (y + 1) * w + x + 1, (y + 1) * w + x] + triangles.append(triangle0) + triangles.append(triangle1) + triangles = np.array(triangles) + triangles = triangles[:, [0, 2, 1]] + return triangles + + +def face_vertices(vertices, faces): + """ + borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py + :param vertices: [batch size, number of vertices, 3] + :param faces: [batch size, number of faces, 3] + :return: [batch size, number of faces, 3, 3] + """ + assert (vertices.ndimension() == 3) + assert (faces.ndimension() == 3) + assert (vertices.shape[0] == faces.shape[0]) + assert (vertices.shape[2] == 3) + assert (faces.shape[2] == 3) + + bs, nv = vertices.shape[:2] + bs, nf = faces.shape[:2] + device = vertices.device + faces = faces + \ + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] + vertices = vertices.reshape((bs * nv, 3)) + # pytorch only supports long and byte tensors for indexing + return vertices[faces.long()] + + +def vertex_normals(vertices, faces): + """ + borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py + :param vertices: [batch size, number of vertices, 3] + :param faces: [batch size, number of faces, 3] + :return: [batch size, number of vertices, 3] + """ + assert (vertices.ndimension() == 3) + assert (faces.ndimension() == 3) + assert (vertices.shape[0] == faces.shape[0]) + assert (vertices.shape[2] == 3) + assert (faces.shape[2] == 3) + bs, nv = vertices.shape[:2] + bs, nf = faces.shape[:2] + device = vertices.device + normals = torch.zeros(bs * nv, 3).to(device) + + faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * + nv)[:, None, None] # expanded faces + vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()] + + faces = faces.reshape(-1, 3) + vertices_faces = vertices_faces.reshape(-1, 3, 3) + + normals.index_add_( + 0, faces[:, 1].long(), + torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], + vertices_faces[:, 0] - vertices_faces[:, 1])) + normals.index_add_( + 0, faces[:, 2].long(), + torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], + vertices_faces[:, 1] - vertices_faces[:, 2])) + normals.index_add_( + 0, faces[:, 0].long(), + torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], + vertices_faces[:, 2] - vertices_faces[:, 0])) + + normals = F.normalize(normals, eps=1e-6, dim=1) + normals = normals.reshape((bs, nv, 3)) + # pytorch only supports long and byte tensors for indexing + return normals + + +def batch_orth_proj(X, camera): + ''' + X is N x num_verts x 3 + ''' + camera = camera.clone().view(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + X_trans = torch.cat([X_trans, X[:, :, 2:]], 2) + Xn = (camera[:, :, 0:1] * X_trans) + return Xn + + +# borrowed from https://github.com/vchoutas/expose +DIM_FLIP = np.array([1, -1, -1], dtype=np.float32) +DIM_FLIP_TENSOR = torch.tensor([1, -1, -1], dtype=torch.float32) + + +def flip_pose(pose_vector, pose_format='rot-mat'): + if pose_format == 'aa': + if torch.is_tensor(pose_vector): + dim_flip = DIM_FLIP_TENSOR + else: + dim_flip = DIM_FLIP + return (pose_vector.reshape(-1, 3) * dim_flip).reshape(-1) + elif pose_format == 'rot-mat': + rot_mats = pose_vector.reshape(-1, 9).clone() + + rot_mats[:, [1, 2, 3, 6]] *= -1 + return rot_mats.view_as(pose_vector) + else: + raise ValueError(f'Unknown rotation format: {pose_format}') + + +# -------------------------------------- image processing +# ref: https://torchgeometry.readthedocs.io/en/latest/_modules/kornia/filters +def gaussian(window_size, sigma): + + def gauss_fcn(x): + return -(x - window_size // 2)**2 / float(2 * sigma**2) + + gauss = torch.stack( + [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)]) + return gauss / gauss.sum() + + +def get_gaussian_kernel(kernel_size: int, sigma: float): + r"""Function that returns Gaussian filter coefficients. + + Args: + kernel_size (int): filter size. It should be odd and positive. + sigma (float): gaussian standard deviation. + + Returns: + Tensor: 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples:: + + >>> kornia.image.get_gaussian_kernel(3, 2.5) + tensor([0.3243, 0.3513, 0.3243]) + + >>> kornia.image.get_gaussian_kernel(5, 1.5) + tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) + """ + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or \ + kernel_size <= 0: + raise TypeError("kernel_size must be an odd positive integer. " + "Got {}".format(kernel_size)) + window_1d = gaussian(kernel_size, sigma) + return window_1d + + +def get_gaussian_kernel2d(kernel_size, sigma): + r"""Function that returns Gaussian filter matrix coefficients. + + Args: + kernel_size (Tuple[int, int]): filter sizes in the x and y direction. + Sizes should be odd and positive. + sigma (Tuple[int, int]): gaussian standard deviation in the x and y + direction. + + Returns: + Tensor: 2D tensor with gaussian filter matrix coefficients. + + Shape: + - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` + + Examples:: + + >>> kornia.image.get_gaussian_kernel2d((3, 3), (1.5, 1.5)) + tensor([[0.0947, 0.1183, 0.0947], + [0.1183, 0.1478, 0.1183], + [0.0947, 0.1183, 0.0947]]) + + >>> kornia.image.get_gaussian_kernel2d((3, 5), (1.5, 1.5)) + tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], + [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], + [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) + """ + if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: + raise TypeError( + "kernel_size must be a tuple of length two. Got {}".format( + kernel_size)) + if not isinstance(sigma, tuple) or len(sigma) != 2: + raise TypeError( + "sigma must be a tuple of length two. Got {}".format(sigma)) + ksize_x, ksize_y = kernel_size + sigma_x, sigma_y = sigma + kernel_x = get_gaussian_kernel(ksize_x, sigma_x) + kernel_y = get_gaussian_kernel(ksize_y, sigma_y) + kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), + kernel_y.unsqueeze(-1).t()) + return kernel_2d + + +def gaussian_blur(x, kernel_size=(5, 5), sigma=(1.3, 1.3)): + b, c, h, w = x.shape + kernel = get_gaussian_kernel2d(kernel_size, sigma).to(x.device).to(x.dtype) + kernel = kernel.repeat(c, 1, 1, 1) + padding = [(k - 1) // 2 for k in kernel_size] + return F.conv2d(x, kernel, padding=padding, stride=1, groups=c) + + +def _compute_binary_kernel(window_size): + r"""Creates a binary kernel to extract the patches. If the window size + is HxW will create a (H*W)xHxW kernel. + """ + window_range = window_size[0] * window_size[1] + kernel: torch.Tensor = torch.zeros(window_range, window_range) + for i in range(window_range): + kernel[i, i] += 1.0 + return kernel.view(window_range, 1, window_size[0], window_size[1]) + + +def median_blur(x, kernel_size=(3, 3)): + b, c, h, w = x.shape + kernel = _compute_binary_kernel(kernel_size).to(x.device).to(x.dtype) + kernel = kernel.repeat(c, 1, 1, 1) + padding = [(k - 1) // 2 for k in kernel_size] + features = F.conv2d(x, kernel, padding=padding, stride=1, groups=c) + features = features.view(b, c, -1, h, w) + median = torch.median(features, dim=2)[0] + return median + + +def get_laplacian_kernel2d(kernel_size: int): + r"""Function that returns Gaussian filter matrix coefficients. + + Args: + kernel_size (int): filter size should be odd. + + Returns: + Tensor: 2D tensor with laplacian filter matrix coefficients. + + Shape: + - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` + + Examples:: + + >>> kornia.image.get_laplacian_kernel2d(3) + tensor([[ 1., 1., 1.], + [ 1., -8., 1.], + [ 1., 1., 1.]]) + + >>> kornia.image.get_laplacian_kernel2d(5) + tensor([[ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., -24., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.]]) + + """ + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or \ + kernel_size <= 0: + raise TypeError("ksize must be an odd positive integer. Got {}".format( + kernel_size)) + + kernel = torch.ones((kernel_size, kernel_size)) + mid = kernel_size // 2 + kernel[mid, mid] = 1 - kernel_size**2 + kernel_2d: torch.Tensor = kernel + return kernel_2d + + +def laplacian(x): + # https://torchgeometry.readthedocs.io/en/latest/_modules/kornia/filters/laplacian.html + b, c, h, w = x.shape + kernel_size = 3 + kernel = get_laplacian_kernel2d(kernel_size).to(x.device).to(x.dtype) + kernel = kernel.repeat(c, 1, 1, 1) + padding = (kernel_size - 1) // 2 + return F.conv2d(x, kernel, padding=padding, stride=1, groups=c) + + +# -------------------------------------- io + + +def copy_state_dict(cur_state_dict, pre_state_dict, prefix='', load_name=None): + + def _get_params(key): + key = prefix + key + if key in pre_state_dict: + return pre_state_dict[key] + return None + + for k in cur_state_dict.keys(): + if load_name is not None: + if load_name not in k: + continue + v = _get_params(k) + try: + if v is None: + # print('parameter {} not found'.format(k)) + continue + cur_state_dict[k].copy_(v) + except: + # print('copy param {} failed'.format(k)) + continue + + +def dict2obj(d): + # if isinstance(d, list): + # d = [dict2obj(x) for x in d] + if not isinstance(d, dict): + return d + + class C(object): + pass + + o = C() + for k in d: + o.__dict__[k] = dict2obj(d[k]) + return o + + +# original saved file with DataParallel + + +def remove_module(state_dict): + # create new OrderedDict that does not contain `module.` + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + return new_state_dict + + +def tensor2image(tensor): + image = tensor.detach().cpu().numpy() + image = image * 255. + image = np.maximum(np.minimum(image, 255), 0) + image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]] + return image.astype(np.uint8).copy() + + +def dict_tensor2npy(tensor_dict): + npy_dict = {} + for key in tensor_dict: + npy_dict[key] = tensor_dict[key][0].cpu().numpy() + return npy_dict + + +def load_config(cfg_file): + import yaml + with open(cfg_file, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + return cfg + + +def move_dict_to_device(dict, device, tensor2float=False): + for k, v in dict.items(): + if isinstance(v, torch.Tensor): + if tensor2float: + dict[k] = v.float().to(device) + else: + dict[k] = v.to(device) + + +def write_obj( + obj_name, + vertices, + faces, + colors=None, + texture=None, + uvcoords=None, + uvfaces=None, + inverse_face_order=False, + normal_map=None, +): + ''' Save 3D face model with texture. + borrowed from https://github.com/YadiraF/PRNet/blob/master/utils/write.py + Args: + obj_name: str + vertices: shape = (nver, 3) + colors: shape = (nver, 3) + faces: shape = (ntri, 3) + texture: shape = (uv_size, uv_size, 3) + uvcoords: shape = (nver, 2) max value<=1 + ''' + if obj_name.split('.')[-1] != 'obj': + obj_name = obj_name + '.obj' + mtl_name = obj_name.replace('.obj', '.mtl') + texture_name = obj_name.replace('.obj', '.png') + material_name = 'FaceTexture' + + faces = faces.copy() + # mesh lab start with 1, python/c++ start from 0 + faces += 1 + if inverse_face_order: + faces = faces[:, [2, 1, 0]] + if uvfaces is not None: + uvfaces = uvfaces[:, [2, 1, 0]] + + # write obj + with open(obj_name, 'w') as f: + if texture is not None: + f.write('mtllib %s\n\n' % os.path.basename(mtl_name)) + + # write vertices + if colors is None: + for i in range(vertices.shape[0]): + f.write('v {} {} {}\n'.format(vertices[i, 0], vertices[i, 1], + vertices[i, 2])) + else: + for i in range(vertices.shape[0]): + f.write('v {} {} {} {} {} {}\n'.format(vertices[i, 0], + vertices[i, 1], + vertices[i, + 2], colors[i, + 0], + colors[i, + 1], colors[i, + 2])) + + # write uv coords + if texture is None: + for i in range(faces.shape[0]): + f.write('f {} {} {}\n'.format(faces[i, 0], faces[i, 1], + faces[i, 2])) + else: + for i in range(uvcoords.shape[0]): + f.write('vt {} {}\n'.format(uvcoords[i, 0], uvcoords[i, 1])) + f.write('usemtl %s\n' % material_name) + # write f: ver ind/ uv ind + uvfaces = uvfaces + 1 + for i in range(faces.shape[0]): + f.write('f {}/{} {}/{} {}/{}\n'.format(faces[i, 0], uvfaces[i, + 0], + faces[i, 1], uvfaces[i, + 1], + faces[i, + 2], uvfaces[i, + 2])) + # write mtl + with open(mtl_name, 'w') as f: + f.write('newmtl %s\n' % material_name) + s = 'map_Kd {}\n'.format( + os.path.basename(texture_name)) # map to image + f.write(s) + + if normal_map is not None: + if torch.is_tensor(normal_map): + normal_map = normal_map.detach().cpu().numpy().squeeze( + ) + + normal_map = np.transpose(normal_map, (1, 2, 0)) + name, _ = os.path.splitext(obj_name) + normal_name = f'{name}_normals.png' + f.write(f'disp {normal_name}') + + out_normal_map = normal_map / (np.linalg.norm( + normal_map, axis=-1, keepdims=True) + 1e-9) + out_normal_map = (out_normal_map + 1) * 0.5 + + cv2.imwrite(normal_name, (out_normal_map * 255).astype( + np.uint8)[:, :, ::-1]) + + cv2.imwrite(texture_name, texture) + + +def save_pkl(savepath, params, ind=0): + out_data = {} + for k, v in params.items(): + if torch.is_tensor(v): + out_data[k] = v[ind].detach().cpu().numpy() + else: + out_data[k] = v + # import ipdb; ipdb.set_trace() + with open(savepath, 'wb') as f: + pickle.dump(out_data, f, protocol=2) + + +# load obj, similar to load_obj from pytorch3d + + +def load_obj(obj_filename): + """ Ref: https://github.com/facebookresearch/pytorch3d/blob/25c065e9dafa90163e7cec873dbb324a637c68b7/pytorch3d/io/obj_io.py + Load a mesh from a file-like object. + """ + with open(obj_filename, 'r') as f: + lines = [line.strip() for line in f] + + verts, uvcoords = [], [] + faces, uv_faces = [], [] + # startswith expects each line to be a string. If the file is read in as + # bytes then first decode to strings. + if lines and isinstance(lines[0], bytes): + lines = [el.decode("utf-8") for el in lines] + + for line in lines: + tokens = line.strip().split() + if line.startswith("v "): # Line is a vertex. + vert = [float(x) for x in tokens[1:4]] + if len(vert) != 3: + msg = "Vertex %s does not have 3 values. Line: %s" + raise ValueError(msg % (str(vert), str(line))) + verts.append(vert) + elif line.startswith("vt "): # Line is a texture. + tx = [float(x) for x in tokens[1:3]] + if len(tx) != 2: + raise ValueError( + "Texture %s does not have 2 values. Line: %s" % + (str(tx), str(line))) + uvcoords.append(tx) + elif line.startswith("f "): # Line is a face. + # Update face properties info. + face = tokens[1:] + face_list = [f.split("/") for f in face] + for vert_props in face_list: + # Vertex index. + faces.append(int(vert_props[0])) + if len(vert_props) > 1: + if vert_props[1] != "": + # Texture index is present e.g. f 4/1/1. + uv_faces.append(int(vert_props[1])) + + verts = torch.tensor(verts, dtype=torch.float32) + uvcoords = torch.tensor(uvcoords, dtype=torch.float32) + faces = torch.tensor(faces, dtype=torch.long) + faces = faces.reshape(-1, 3) - 1 + uv_faces = torch.tensor(uv_faces, dtype=torch.long) + uv_faces = uv_faces.reshape(-1, 3) - 1 + return (verts, uvcoords, faces, uv_faces) + + +# ---------------------------------- visualization +def draw_rectangle(img, + bbox, + bbox_color=(255, 255, 255), + thickness=3, + is_opaque=False, + alpha=0.5): + """Draws the rectangle around the object + borrowed from: https://bbox-visualizer.readthedocs.io/en/latest/_modules/bbox_visualizer/bbox_visualizer.html + Parameters + ---------- + img : ndarray + the actual image + bbox : list + a list containing x_min, y_min, x_max and y_max of the rectangle positions + bbox_color : tuple, optional + the color of the box, by default (255,255,255) + thickness : int, optional + thickness of the outline of the box, by default 3 + is_opaque : bool, optional + if False, draws a solid rectangular outline. Else, a filled rectangle which is semi transparent, by default False + alpha : float, optional + strength of the opacity, by default 0.5 + + Returns + ------- + ndarray + the image with the bounding box drawn + """ + + output = img.copy() + if not is_opaque: + cv2.rectangle(output, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + bbox_color, thickness) + else: + overlay = img.copy() + + cv2.rectangle(overlay, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + bbox_color, -1) + # cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output) + + return output + + +def plot_bbox(image, bbox): + ''' Draw bbox + Args: + image: the input image + bbox: [left, top, right, bottom] + ''' + image = cv2.rectangle(image.copy(), (bbox[1], bbox[0]), (bbox[3], bbox[2]), + [0, 255, 0], + thickness=3) + # image = draw_rectangle(image, bbox, bbox_color=[0,255,0]) + return image + + +end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype=np.int32) - 1 + + +def plot_kpts(image, kpts, color='r'): + ''' Draw 68 key points + Args: + image: the input image + kpt: (68, 3). + ''' + kpts = kpts.copy().astype(np.int32) + if color == 'r': + c = (255, 0, 0) + elif color == 'g': + c = (0, 255, 0) + elif color == 'b': + c = (255, 0, 0) + image = image.copy() + kpts = kpts.copy() + + for i in range(kpts.shape[0]): + st = kpts[i, :2] + if kpts.shape[1] == 4: + if kpts[i, 3] > 0.5: + c = (0, 255, 0) + else: + c = (0, 0, 255) + image = cv2.circle(image, (st[0], st[1]), 1, c, 2) + if i in end_list: + continue + ed = kpts[i + 1, :2] + image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]), + (255, 255, 255), 1) + + return image + + +def plot_verts(image, kpts, color='r'): + ''' Draw 68 key points + Args: + image: the input image + kpt: (68, 3). + ''' + kpts = kpts.copy().astype(np.int32) + if color == 'r': + c = (255, 0, 0) + elif color == 'g': + c = (0, 255, 0) + elif color == 'b': + c = (0, 0, 255) + elif color == 'y': + c = (0, 255, 255) + image = image.copy() + + for i in range(kpts.shape[0]): + st = kpts[i, :2] + image = cv2.circle(image, (st[0], st[1]), 1, c, 5) + + return image + + +def tensor_vis_landmarks(images, + landmarks, + gt_landmarks=None, + color='g', + isScale=True): + # visualize landmarks + vis_landmarks = [] + images = images.cpu().numpy() + predicted_landmarks = landmarks.detach().cpu().numpy() + if gt_landmarks is not None: + gt_landmarks_np = gt_landmarks.detach().cpu().numpy() + for i in range(images.shape[0]): + image = images[i] + image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]].copy() + image = (image * 255) + if isScale: + predicted_landmark = predicted_landmarks[i] * \ + image.shape[0]/2 + image.shape[0]/2 + else: + predicted_landmark = predicted_landmarks[i] + if predicted_landmark.shape[0] == 68: + image_landmarks = plot_kpts(image, predicted_landmark, color) + if gt_landmarks is not None: + image_landmarks = plot_verts( + image_landmarks, gt_landmarks_np[i] * image.shape[0] / 2 + + image.shape[0] / 2, 'r') + else: + image_landmarks = plot_verts(image, predicted_landmark, color) + if gt_landmarks is not None: + image_landmarks = plot_verts( + image_landmarks, gt_landmarks_np[i] * image.shape[0] / 2 + + image.shape[0] / 2, 'r') + vis_landmarks.append(image_landmarks) + + vis_landmarks = np.stack(vis_landmarks) + vis_landmarks = torch.from_numpy( + vis_landmarks[:, :, :, [2, 1, 0]].transpose( + 0, 3, 1, 2)) / 255. # , dtype=torch.float32) + return vis_landmarks diff --git a/lib/pymaf/configs/pymaf_config.yaml b/lib/pymaf/configs/pymaf_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f9b57256adbc29f5611b88db31472c2857e3955 --- /dev/null +++ b/lib/pymaf/configs/pymaf_config.yaml @@ -0,0 +1,47 @@ +SOLVER: + MAX_ITER: 500000 + TYPE: Adam + BASE_LR: 0.00005 + GAMMA: 0.1 + STEPS: [0] + EPOCHS: [0] +DEBUG: False +LOGDIR: '' +DEVICE: cuda +NUM_WORKERS: 8 +SEED_VALUE: -1 +LOSS: + KP_2D_W: 300.0 + KP_3D_W: 300.0 + SHAPE_W: 0.06 + POSE_W: 60.0 + VERT_W: 0.0 + INDEX_WEIGHTS: 2.0 + # Loss weights for surface parts. (24 Parts) + PART_WEIGHTS: 0.3 + # Loss weights for UV regression. + POINT_REGRESSION_WEIGHTS: 0.5 +TRAIN: + NUM_WORKERS: 8 + BATCH_SIZE: 64 + PIN_MEMORY: True +TEST: + BATCH_SIZE: 32 +MODEL: + PyMAF: + BACKBONE: 'res50' + MLP_DIM: [256, 128, 64, 5] + N_ITER: 3 + AUX_SUPV_ON: True + DP_HEATMAP_SIZE: 56 +RES_MODEL: + DECONV_WITH_BIAS: False + NUM_DECONV_LAYERS: 3 + NUM_DECONV_FILTERS: + - 256 + - 256 + - 256 + NUM_DECONV_KERNELS: + - 4 + - 4 + - 4 diff --git a/lib/pymaf/core/__init__.py b/lib/pymaf/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pymaf/core/base_trainer.py b/lib/pymaf/core/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..757420fd8c22876a477163039d0706fb77233cb5 --- /dev/null +++ b/lib/pymaf/core/base_trainer.py @@ -0,0 +1,106 @@ +# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py +from __future__ import division +import logging +from utils import CheckpointSaver +from tensorboardX import SummaryWriter + +import torch +from tqdm import tqdm + +tqdm.monitor_interval = 0 + +logger = logging.getLogger(__name__) + + +class BaseTrainer(object): + """Base class for Trainer objects. + Takes care of checkpointing/logging/resuming training. + """ + + def __init__(self, options): + self.options = options + if options.multiprocessing_distributed: + self.device = torch.device('cuda', options.gpu) + else: + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + # override this function to define your model, optimizers etc. + self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, + overwrite=options.overwrite) + if options.rank == 0: + self.summary_writer = SummaryWriter(self.options.summary_dir) + self.init_fn() + + self.checkpoint = None + if options.resume and self.saver.exists_checkpoint(): + self.checkpoint = self.saver.load_checkpoint( + self.models_dict, self.optimizers_dict) + + if self.checkpoint is None: + self.epoch_count = 0 + self.step_count = 0 + else: + self.epoch_count = self.checkpoint['epoch'] + self.step_count = self.checkpoint['total_step_count'] + + if self.checkpoint is not None: + self.checkpoint_batch_idx = self.checkpoint['batch_idx'] + else: + self.checkpoint_batch_idx = 0 + + self.best_performance = float('inf') + + def load_pretrained(self, checkpoint_file=None): + """Load a pretrained checkpoint. + This is different from resuming training using --resume. + """ + if checkpoint_file is not None: + checkpoint = torch.load(checkpoint_file) + for model in self.models_dict: + if model in checkpoint: + self.models_dict[model].load_state_dict(checkpoint[model], + strict=True) + print(f'Checkpoint {model} loaded') + + def move_dict_to_device(self, dict, device, tensor2float=False): + for k, v in dict.items(): + if isinstance(v, torch.Tensor): + if tensor2float: + dict[k] = v.float().to(device) + else: + dict[k] = v.to(device) + + # The following methods (with the possible exception of test) have to be implemented in the derived classes + def train(self, epoch): + raise NotImplementedError('You need to provide an train method') + + def init_fn(self): + raise NotImplementedError('You need to provide an _init_fn method') + + def train_step(self, input_batch): + raise NotImplementedError('You need to provide a _train_step method') + + def train_summaries(self, input_batch): + raise NotImplementedError( + 'You need to provide a _train_summaries method') + + def visualize(self, input_batch): + raise NotImplementedError('You need to provide a visualize method') + + def validate(self): + pass + + def test(self): + pass + + def evaluate(self): + pass + + def fit(self): + # Run training for num_epochs epochs + for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), + total=self.options.num_epochs, + initial=self.epoch_count): + self.epoch_count = epoch + self.train(epoch) + return diff --git a/lib/pymaf/core/cfgs.py b/lib/pymaf/core/cfgs.py new file mode 100644 index 0000000000000000000000000000000000000000..09ac4fa48483aa9e595b7e4b27dfa7426cb11d33 --- /dev/null +++ b/lib/pymaf/core/cfgs.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import os +import json +from yacs.config import CfgNode as CN + +# Configuration variables +cfg = CN(new_allowed=True) + +cfg.OUTPUT_DIR = 'results' +cfg.DEVICE = 'cuda' +cfg.DEBUG = False +cfg.LOGDIR = '' +cfg.VAL_VIS_BATCH_FREQ = 200 +cfg.TRAIN_VIS_ITER_FERQ = 1000 +cfg.SEED_VALUE = -1 + +cfg.TRAIN = CN(new_allowed=True) + +cfg.LOSS = CN(new_allowed=True) +cfg.LOSS.KP_2D_W = 300.0 +cfg.LOSS.KP_3D_W = 300.0 +cfg.LOSS.SHAPE_W = 0.06 +cfg.LOSS.POSE_W = 60.0 +cfg.LOSS.VERT_W = 0.0 + +# Loss weights for dense correspondences +cfg.LOSS.INDEX_WEIGHTS = 2.0 +# Loss weights for surface parts. (24 Parts) +cfg.LOSS.PART_WEIGHTS = 0.3 +# Loss weights for UV regression. +cfg.LOSS.POINT_REGRESSION_WEIGHTS = 0.5 + +cfg.MODEL = CN(new_allowed=True) + +cfg.MODEL.PyMAF = CN(new_allowed=True) + +# switch +cfg.TRAIN.VAL_LOOP = True + +cfg.TEST = CN(new_allowed=True) + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + # return cfg.clone() + return cfg + + +def update_cfg(cfg_file): + # cfg = get_cfg_defaults() + cfg.merge_from_file(cfg_file) + # return cfg.clone() + return cfg + + +def parse_args(args): + cfg_file = args.cfg_file + if args.cfg_file is not None: + cfg = update_cfg(args.cfg_file) + else: + cfg = get_cfg_defaults() + + # if args.misc is not None: + # cfg.merge_from_list(args.misc) + + return cfg + + +def parse_args_extend(args): + if args.resume: + if not os.path.exists(args.log_dir): + raise ValueError( + 'Experiment are set to resume mode, but log directory does not exist.' + ) + + # load log's cfg + cfg_file = os.path.join(args.log_dir, 'cfg.yaml') + cfg = update_cfg(cfg_file) + + if args.misc is not None: + cfg.merge_from_list(args.misc) + else: + parse_args(args) diff --git a/lib/pymaf/core/constants.py b/lib/pymaf/core/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e832437d732fb5c32fc50f975bd9f1c7a750ec --- /dev/null +++ b/lib/pymaf/core/constants.py @@ -0,0 +1,153 @@ +# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/constants.py +FOCAL_LENGTH = 5000. +IMG_RES = 224 + +# Mean and standard deviation for normalizing input image +IMG_NORM_MEAN = [0.485, 0.456, 0.406] +IMG_NORM_STD = [0.229, 0.224, 0.225] +""" +We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides. +We keep a superset of 24 joints such that we include all joints from every dataset. +If a dataset doesn't provide annotations for a specific joint, we simply ignore it. +The joints used here are the following: +""" +JOINT_NAMES = [ + # 25 OpenPose joints (in the order provided by OpenPose) + 'OP Nose', + 'OP Neck', + 'OP RShoulder', + 'OP RElbow', + 'OP RWrist', + 'OP LShoulder', + 'OP LElbow', + 'OP LWrist', + 'OP MidHip', + 'OP RHip', + 'OP RKnee', + 'OP RAnkle', + 'OP LHip', + 'OP LKnee', + 'OP LAnkle', + 'OP REye', + 'OP LEye', + 'OP REar', + 'OP LEar', + 'OP LBigToe', + 'OP LSmallToe', + 'OP LHeel', + 'OP RBigToe', + 'OP RSmallToe', + 'OP RHeel', + # 24 Ground Truth joints (superset of joints from different datasets) + 'Right Ankle', + 'Right Knee', + 'Right Hip', # 2 + 'Left Hip', + 'Left Knee', # 4 + 'Left Ankle', + 'Right Wrist', # 6 + 'Right Elbow', + 'Right Shoulder', # 8 + 'Left Shoulder', + 'Left Elbow', # 10 + 'Left Wrist', + 'Neck (LSP)', # 12 + 'Top of Head (LSP)', + 'Pelvis (MPII)', # 14 + 'Thorax (MPII)', + 'Spine (H36M)', # 16 + 'Jaw (H36M)', + 'Head (H36M)', # 18 + 'Nose', + 'Left Eye', + 'Right Eye', + 'Left Ear', + 'Right Ear' +] + +# Dict containing the joints in numerical order +JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))} + +# Map joints to SMPL joints +JOINT_MAP = { + 'OP Nose': 24, + 'OP Neck': 12, + 'OP RShoulder': 17, + 'OP RElbow': 19, + 'OP RWrist': 21, + 'OP LShoulder': 16, + 'OP LElbow': 18, + 'OP LWrist': 20, + 'OP MidHip': 0, + 'OP RHip': 2, + 'OP RKnee': 5, + 'OP RAnkle': 8, + 'OP LHip': 1, + 'OP LKnee': 4, + 'OP LAnkle': 7, + 'OP REye': 25, + 'OP LEye': 26, + 'OP REar': 27, + 'OP LEar': 28, + 'OP LBigToe': 29, + 'OP LSmallToe': 30, + 'OP LHeel': 31, + 'OP RBigToe': 32, + 'OP RSmallToe': 33, + 'OP RHeel': 34, + 'Right Ankle': 8, + 'Right Knee': 5, + 'Right Hip': 45, + 'Left Hip': 46, + 'Left Knee': 4, + 'Left Ankle': 7, + 'Right Wrist': 21, + 'Right Elbow': 19, + 'Right Shoulder': 17, + 'Left Shoulder': 16, + 'Left Elbow': 18, + 'Left Wrist': 20, + 'Neck (LSP)': 47, + 'Top of Head (LSP)': 48, + 'Pelvis (MPII)': 49, + 'Thorax (MPII)': 50, + 'Spine (H36M)': 51, + 'Jaw (H36M)': 52, + 'Head (H36M)': 53, + 'Nose': 24, + 'Left Eye': 26, + 'Right Eye': 25, + 'Left Ear': 28, + 'Right Ear': 27 +} + +# Joint selectors +# Indices to get the 14 LSP joints from the 17 H36M joints +H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] +H36M_TO_J14 = H36M_TO_J17[:14] +# Indices to get the 14 LSP joints from the ground truth joints +J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17] +J24_TO_J14 = J24_TO_J17[:14] +J24_TO_J19 = J24_TO_J17[:14] + [19, 20, 21, 22, 23] +J24_TO_JCOCO = [19, 20, 21, 22, 23, 9, 8, 10, 7, 11, 6, 3, 2, 4, 1, 5, 0] + +# Permutation of SMPL pose parameters when flipping the shape +SMPL_JOINTS_FLIP_PERM = [ + 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, + 20, 23, 22 +] +SMPL_POSE_FLIP_PERM = [] +for i in SMPL_JOINTS_FLIP_PERM: + SMPL_POSE_FLIP_PERM.append(3 * i) + SMPL_POSE_FLIP_PERM.append(3 * i + 1) + SMPL_POSE_FLIP_PERM.append(3 * i + 2) +# Permutation indices for the 24 ground truth joints +J24_FLIP_PERM = [ + 5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, + 20, 23, 22 +] +# Permutation indices for the full set of 49 joints +J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\ + + [25+i for i in J24_FLIP_PERM] +SMPL_J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\ + + [25+i for i in SMPL_JOINTS_FLIP_PERM] diff --git a/lib/pymaf/core/fits_dict.py b/lib/pymaf/core/fits_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..efd13e0a90ae0a9281c9ef1548ceaa0ea4a0d2ba --- /dev/null +++ b/lib/pymaf/core/fits_dict.py @@ -0,0 +1,133 @@ +''' +This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py +''' +import os +import cv2 +import torch +import numpy as np +from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis + +from core import path_config, constants + +import logging + +logger = logging.getLogger(__name__) + + +class FitsDict(): + """ Dictionary keeping track of the best fit per image in the training set """ + + def __init__(self, options, train_dataset): + self.options = options + self.train_dataset = train_dataset + self.fits_dict = {} + self.valid_fit_state = {} + # array used to flip SMPL pose parameters + self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM, + dtype=torch.int64) + # Load dictionary state + for ds_name, ds in train_dataset.dataset_dict.items(): + if ds_name in ['h36m']: + dict_file = os.path.join(path_config.FINAL_FITS_DIR, + ds_name + '.npy') + self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file)) + self.valid_fit_state[ds_name] = torch.ones(len( + self.fits_dict[ds_name]), + dtype=torch.uint8) + else: + dict_file = os.path.join(path_config.FINAL_FITS_DIR, + ds_name + '.npz') + fits_dict = np.load(dict_file) + opt_pose = torch.from_numpy(fits_dict['pose']) + opt_betas = torch.from_numpy(fits_dict['betas']) + opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to( + torch.uint8) + self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas], + dim=1) + self.valid_fit_state[ds_name] = opt_valid_fit + + if not options.single_dataset: + for ds in train_dataset.datasets: + if ds.dataset not in ['h36m']: + ds.pose = self.fits_dict[ds.dataset][:, :72].numpy() + ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy() + ds.has_smpl = self.valid_fit_state[ds.dataset].numpy() + + def save(self): + """ Save dictionary state to disk """ + for ds_name in self.train_dataset.dataset_dict.keys(): + dict_file = os.path.join(self.options.checkpoint_dir, + ds_name + '_fits.npy') + np.save(dict_file, self.fits_dict[ds_name].cpu().numpy()) + + def __getitem__(self, x): + """ Retrieve dictionary entries """ + dataset_name, ind, rot, is_flipped = x + batch_size = len(dataset_name) + pose = torch.zeros((batch_size, 72)) + betas = torch.zeros((batch_size, 10)) + for ds, i, n in zip(dataset_name, ind, range(batch_size)): + params = self.fits_dict[ds][i] + pose[n, :] = params[:72] + betas[n, :] = params[72:] + pose = pose.clone() + # Apply flipping and rotation + pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped) + betas = betas.clone() + return pose, betas + + def get_vaild_state(self, dataset_name, ind): + batch_size = len(dataset_name) + valid_fit = torch.zeros(batch_size, dtype=torch.uint8) + for ds, i, n in zip(dataset_name, ind, range(batch_size)): + valid_fit[n] = self.valid_fit_state[ds][i] + valid_fit = valid_fit.clone() + return valid_fit + + def __setitem__(self, x, val): + """ Update dictionary entries """ + dataset_name, ind, rot, is_flipped, update = x + pose, betas = val + batch_size = len(dataset_name) + # Undo flipping and rotation + pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot) + params = torch.cat((pose, betas), dim=-1).cpu() + for ds, i, n in zip(dataset_name, ind, range(batch_size)): + if update[n]: + self.fits_dict[ds][i] = params[n] + + def flip_pose(self, pose, is_flipped): + """flip SMPL pose parameters""" + is_flipped = is_flipped.byte() + pose_f = pose.clone() + pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts] + # we also negate the second and the third dimension of the axis-angle representation + pose_f[is_flipped, 1::3] *= -1 + pose_f[is_flipped, 2::3] *= -1 + return pose_f + + def rotate_pose(self, pose, rot): + """Rotate SMPL pose parameters by rot degrees""" + pose = pose.clone() + cos = torch.cos(-np.pi * rot / 180.) + sin = torch.sin(-np.pi * rot / 180.) + zeros = torch.zeros_like(cos) + r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device) + r3[:, 0, -1] = 1 + R = torch.cat([ + torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1), + torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3 + ], + dim=1) + global_pose = pose[:, :3] + global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose) + global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3] + global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3) + global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3 + global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy() + global_pose_np = np.zeros((global_pose.shape[0], 3)) + for i in range(global_pose.shape[0]): + aa, _ = cv2.Rodrigues(global_pose_rotmat[i]) + global_pose_np[i, :] = aa.squeeze() + pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device) + return pose diff --git a/lib/pymaf/core/path_config.py b/lib/pymaf/core/path_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c98c5167d268aaeb81d82506ac1fea91a33dc68e --- /dev/null +++ b/lib/pymaf/core/path_config.py @@ -0,0 +1,38 @@ +""" +This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/path_config.py +path configuration +This file contains definitions of useful data stuctures and the paths +for the datasets and data files necessary to run the code. +Things you need to change: *_ROOT that indicate the path to each dataset +""" +import os + +# pymaf +pymaf_data_dir = os.path.join(os.path.dirname(__file__), + "../../../data/HPS/pymaf_data") + +SMPL_MEAN_PARAMS = os.path.join(pymaf_data_dir, "smpl_mean_params.npz") +SMPL_MODEL_DIR = os.path.join(pymaf_data_dir, "../../smpl_related/models/smpl") +MESH_DOWNSAMPLEING = os.path.join(pymaf_data_dir, "mesh_downsampling.npz") + +CUBE_PARTS_FILE = os.path.join(pymaf_data_dir, "cube_parts.npy") +JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(pymaf_data_dir, + "J_regressor_extra.npy") +JOINT_REGRESSOR_H36M = os.path.join(pymaf_data_dir, "J_regressor_h36m.npy") +VERTEX_TEXTURE_FILE = os.path.join(pymaf_data_dir, "vertex_texture.npy") +SMPL_MEAN_PARAMS = os.path.join(pymaf_data_dir, "smpl_mean_params.npz") +CHECKPOINT_FILE = os.path.join(pymaf_data_dir, + "pretrained_model/PyMAF_model_checkpoint.pt") + +# pare +pare_data_dir = os.path.join(os.path.dirname(__file__), + "../../../data/HPS/pare_data") +CFG = os.path.join(pare_data_dir, "pare/checkpoints/pare_w_3dpw_config.yaml") +CKPT = os.path.join(pare_data_dir, + "pare/checkpoints/pare_w_3dpw_checkpoint.ckpt") + +# hybrik +hybrik_data_dir = os.path.join(os.path.dirname(__file__), + "../../../data/HPS/hybrik_data") +HYBRIK_CFG = os.path.join(hybrik_data_dir, "hybrik_config.yaml") +HYBRIK_CKPT = os.path.join(hybrik_data_dir, "pretrained_w_cam.pth") diff --git a/lib/pymaf/core/train_options.py b/lib/pymaf/core/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3691fe8ab83c05b9f0e6bef5d372581b8beb0a --- /dev/null +++ b/lib/pymaf/core/train_options.py @@ -0,0 +1,139 @@ +import argparse + + +class TrainOptions(): + + def __init__(self): + self.parser = argparse.ArgumentParser() + + gen = self.parser.add_argument_group('General') + gen.add_argument( + '--resume', + dest='resume', + default=False, + action='store_true', + help='Resume from checkpoint (Use latest checkpoint by default') + + io = self.parser.add_argument_group('io') + io.add_argument('--log_dir', + default='logs', + help='Directory to store logs') + io.add_argument( + '--pretrained_checkpoint', + default=None, + help='Load a pretrained checkpoint at the beginning training') + + train = self.parser.add_argument_group('Training Options') + train.add_argument('--num_epochs', + type=int, + default=200, + help='Total number of training epochs') + train.add_argument('--regressor', + type=str, + choices=['hmr', 'pymaf_net'], + default='pymaf_net', + help='Name of the SMPL regressor.') + train.add_argument('--cfg_file', + type=str, + default='./configs/pymaf_config.yaml', + help='config file path for PyMAF.') + train.add_argument( + '--img_res', + type=int, + default=224, + help= + 'Rescale bounding boxes to size [img_res, img_res] before feeding them in the network' + ) + train.add_argument( + '--rot_factor', + type=float, + default=30, + help='Random rotation in the range [-rot_factor, rot_factor]') + train.add_argument( + '--noise_factor', + type=float, + default=0.4, + help= + 'Randomly multiply pixel values with factor in the range [1-noise_factor, 1+noise_factor]' + ) + train.add_argument( + '--scale_factor', + type=float, + default=0.25, + help= + 'Rescale bounding boxes by a factor of [1-scale_factor,1+scale_factor]' + ) + train.add_argument( + '--openpose_train_weight', + default=0., + help='Weight for OpenPose keypoints during training') + train.add_argument('--gt_train_weight', + default=1., + help='Weight for GT keypoints during training') + train.add_argument('--eval_dataset', + type=str, + default='h36m-p2-mosh', + help='Name of the evaluation dataset.') + train.add_argument('--single_dataset', + default=False, + action='store_true', + help='Use a single dataset') + train.add_argument('--single_dataname', + type=str, + default='h36m', + help='Name of the single dataset.') + train.add_argument('--eval_pve', + default=False, + action='store_true', + help='evaluate PVE') + train.add_argument('--overwrite', + default=False, + action='store_true', + help='overwrite the latest checkpoint') + + train.add_argument('--distributed', + action='store_true', + help='Use distributed training') + train.add_argument('--dist_backend', + default='nccl', + type=str, + help='distributed backend') + train.add_argument('--dist_url', + default='tcp://127.0.0.1:10356', + type=str, + help='url used to set up distributed training') + train.add_argument('--world_size', + default=1, + type=int, + help='number of nodes for distributed training') + train.add_argument("--local_rank", default=0, type=int) + train.add_argument('--rank', + default=0, + type=int, + help='node rank for distributed training') + train.add_argument( + '--multiprocessing_distributed', + action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + + misc = self.parser.add_argument_group('Misc Options') + misc.add_argument('--misc', + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + return + + def parse_args(self): + """Parse input arguments.""" + self.args = self.parser.parse_args() + self.save_dump() + return self.args + + def save_dump(self): + """Store all argument values to a json file. + The default location is logs/expname/args.json. + """ + pass diff --git a/lib/pymaf/models/__init__.py b/lib/pymaf/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85ca9042485c0af8b1f29a4d4eaa1547935a40f --- /dev/null +++ b/lib/pymaf/models/__init__.py @@ -0,0 +1,3 @@ +from .hmr import hmr +from .pymaf_net import pymaf_net +from .smpl import SMPL diff --git a/lib/pymaf/models/hmr.py b/lib/pymaf/models/hmr.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb1cd6b7e2e4581f2c5d9cb5e952049b5a075e1 --- /dev/null +++ b/lib/pymaf/models/hmr.py @@ -0,0 +1,303 @@ +# This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/hmr.py + +import torch +import torch.nn as nn +import torchvision.models.resnet as resnet +import numpy as np +import math +from lib.pymaf.utils.geometry import rot6d_to_rotmat + +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +class Bottleneck(nn.Module): + """ Redefinition of Bottleneck residual block + Adapted from the official PyTorch implementation + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet_Backbone(nn.Module): + """ Feature Extrator with ResNet backbone + """ + + def __init__(self, model='res50', pretrained=True): + if model == 'res50': + block, layers = Bottleneck, [3, 4, 6, 3] + else: + pass # TODO + + self.inplanes = 64 + super().__init__() + npose = 24 * 6 + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + + if pretrained: + resnet_imagenet = resnet.resnet50(pretrained=True) + self.load_state_dict(resnet_imagenet.state_dict(), strict=False) + logger.info('loaded resnet50 imagenet pretrained model') + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + def _get_deconv_cfg(deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = _get_deconv_cfg( + num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d(in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias)) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + + batch_size = x.shape[0] + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + xf = self.avgpool(x4) + xf = xf.view(xf.size(0), -1) + + x_featmap = x4 + + return x_featmap, xf + + +class HMR(nn.Module): + """ SMPL Iterative Regressor with ResNet50 backbone + """ + + def __init__(self, block, layers, smpl_mean_params): + self.inplanes = 64 + super().__init__() + npose = 24 * 6 + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc1 = nn.Linear(512 * block.expansion + npose + 13, 1024) + self.drop1 = nn.Dropout() + self.fc2 = nn.Linear(1024, 1024) + self.drop2 = nn.Dropout() + self.decpose = nn.Linear(1024, npose) + self.decshape = nn.Linear(1024, 10) + self.deccam = nn.Linear(1024, 3) + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + mean_params = np.load(smpl_mean_params) + init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) + init_shape = torch.from_numpy( + mean_params['shape'][:].astype('float32')).unsqueeze(0) + init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.register_buffer('init_cam', init_cam) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, + x, + init_pose=None, + init_shape=None, + init_cam=None, + n_iter=3): + + batch_size = x.shape[0] + + if init_pose is None: + init_pose = self.init_pose.expand(batch_size, -1) + if init_shape is None: + init_shape = self.init_shape.expand(batch_size, -1) + if init_cam is None: + init_cam = self.init_cam.expand(batch_size, -1) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + xf = self.avgpool(x4) + xf = xf.view(xf.size(0), -1) + + pred_pose = init_pose + pred_shape = init_shape + pred_cam = init_cam + for i in range(n_iter): + xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1) + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + pred_pose = self.decpose(xc) + pred_pose + pred_shape = self.decshape(xc) + pred_shape + pred_cam = self.deccam(xc) + pred_cam + + pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) + + return pred_rotmat, pred_shape, pred_cam + + +def hmr(smpl_mean_params, pretrained=True, **kwargs): + """ Constructs an HMR model with ResNet50 backbone. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs) + if pretrained: + resnet_imagenet = resnet.resnet50(pretrained=True) + model.load_state_dict(resnet_imagenet.state_dict(), strict=False) + return model diff --git a/lib/pymaf/models/maf_extractor.py b/lib/pymaf/models/maf_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ce46004fea04b9caa1a71c877e6f9537de1389 --- /dev/null +++ b/lib/pymaf/models/maf_extractor.py @@ -0,0 +1,138 @@ +# This script is borrowed and extended from https://github.com/shunsukesaito/PIFu/blob/master/lib/model/SurfaceClassifier.py + +from packaging import version +import torch +import scipy +import os +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from lib.common.config import cfg +from lib.pymaf.utils.geometry import projection +from lib.pymaf.core.path_config import MESH_DOWNSAMPLEING + +import logging + +logger = logging.getLogger(__name__) + + +class MAF_Extractor(nn.Module): + ''' Mesh-aligned Feature Extrator + + As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices. + The features extrated from spatial feature maps will go through a MLP for dimension reduction. + ''' + + def __init__(self, device=torch.device('cuda')): + super().__init__() + + self.device = device + self.filters = [] + self.num_views = 1 + filter_channels = cfg.MODEL.PyMAF.MLP_DIM + self.last_op = nn.ReLU(True) + + for l in range(0, len(filter_channels) - 1): + if 0 != l: + self.filters.append( + nn.Conv1d(filter_channels[l] + filter_channels[0], + filter_channels[l + 1], 1)) + else: + self.filters.append( + nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) + + self.add_module("conv%d" % l, self.filters[l]) + + self.im_feat = None + self.cam = None + + # downsample SMPL mesh and assign part labels + # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz + smpl_mesh_graph = np.load(MESH_DOWNSAMPLEING, + allow_pickle=True, + encoding='latin1') + + A = smpl_mesh_graph['A'] + U = smpl_mesh_graph['U'] + D = smpl_mesh_graph['D'] # shape: (2,) + + # downsampling + ptD = [] + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) + + # downsampling mapping from 6890 points to 431 points + # ptD[0].to_dense() - Size: [1723, 6890] + # ptD[1].to_dense() - Size: [431. 1723] + Dmap = torch.matmul(ptD[1].to_dense(), + ptD[0].to_dense()) # 6890 -> 431 + self.register_buffer('Dmap', Dmap) + + def reduce_dim(self, feature): + ''' + Dimension reduction by multi-layer perceptrons + :param feature: list of [B, C_s, N] point-wise features before dimension reduction + :return: [B, C_p x N] concatantion of point-wise features after dimension reduction + ''' + y = feature + tmpy = feature + for i, f in enumerate(self.filters): + y = self._modules['conv' + + str(i)](y if i == 0 else torch.cat([y, tmpy], 1)) + if i != len(self.filters) - 1: + y = F.leaky_relu(y) + if self.num_views > 1 and i == len(self.filters) // 2: + y = y.view(-1, self.num_views, y.shape[1], + y.shape[2]).mean(dim=1) + tmpy = feature.view(-1, self.num_views, feature.shape[1], + feature.shape[2]).mean(dim=1) + + y = self.last_op(y) + + y = y.view(y.shape[0], -1) + return y + + def sampling(self, points, im_feat=None, z_feat=None): + ''' + Given 2D points, sample the point-wise features for each point, + the dimension of point-wise features will be reduced from C_s to C_p by MLP. + Image features should be pre-computed before this call. + :param points: [B, N, 2] image coordinates of points + :im_feat: [B, C_s, H_s, W_s] spatial feature maps + :return: [B, C_p x N] concatantion of point-wise features after dimension reduction + ''' + if im_feat is None: + im_feat = self.im_feat + + batch_size = im_feat.shape[0] + + if version.parse(torch.__version__) >= version.parse('1.3.0'): + # Default grid_sample behavior has changed to align_corners=False since 1.3.0. + point_feat = torch.nn.functional.grid_sample( + im_feat, points.unsqueeze(2), align_corners=True)[..., 0] + else: + point_feat = torch.nn.functional.grid_sample( + im_feat, points.unsqueeze(2))[..., 0] + + mesh_align_feat = self.reduce_dim(point_feat) + return mesh_align_feat + + def forward(self, p, s_feat=None, cam=None, **kwargs): + ''' Returns mesh-aligned features for the 3D mesh points. + + Args: + p (tensor): [B, N_m, 3] mesh vertices + s_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps + cam (tensor): [B, 3] camera + Return: + mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features + ''' + if cam is None: + cam = self.cam + p_proj_2d = projection(p, cam, retain_z=False) + mesh_align_feat = self.sampling(p_proj_2d, s_feat) + return mesh_align_feat diff --git a/lib/pymaf/models/pymaf_net.py b/lib/pymaf/models/pymaf_net.py new file mode 100644 index 0000000000000000000000000000000000000000..340b37dcee9101aa4a4f52f4ea73dafce6628d0a --- /dev/null +++ b/lib/pymaf/models/pymaf_net.py @@ -0,0 +1,363 @@ +import torch +import torch.nn as nn +import numpy as np + +from lib.pymaf.utils.geometry import rot6d_to_rotmat, projection, rotation_matrix_to_angle_axis +from .maf_extractor import MAF_Extractor +from .smpl import SMPL, SMPL_MODEL_DIR, SMPL_MEAN_PARAMS, H36M_TO_J14 +from .hmr import ResNet_Backbone +from .res_module import IUV_predict_layer +from lib.common.config import cfg +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +class Regressor(nn.Module): + + def __init__(self, feat_dim, smpl_mean_params): + super().__init__() + + npose = 24 * 6 + + self.fc1 = nn.Linear(feat_dim + npose + 13, 1024) + self.drop1 = nn.Dropout() + self.fc2 = nn.Linear(1024, 1024) + self.drop2 = nn.Dropout() + self.decpose = nn.Linear(1024, npose) + self.decshape = nn.Linear(1024, 10) + self.deccam = nn.Linear(1024, 3) + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + self.smpl = SMPL(SMPL_MODEL_DIR, batch_size=64, create_transl=False) + + mean_params = np.load(smpl_mean_params) + init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) + init_shape = torch.from_numpy( + mean_params['shape'][:].astype('float32')).unsqueeze(0) + init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.register_buffer('init_cam', init_cam) + + def forward(self, + x, + init_pose=None, + init_shape=None, + init_cam=None, + n_iter=1, + J_regressor=None): + batch_size = x.shape[0] + + if init_pose is None: + init_pose = self.init_pose.expand(batch_size, -1) + if init_shape is None: + init_shape = self.init_shape.expand(batch_size, -1) + if init_cam is None: + init_cam = self.init_cam.expand(batch_size, -1) + + pred_pose = init_pose + pred_shape = init_shape + pred_cam = init_cam + for i in range(n_iter): + xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1) + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + pred_pose = self.decpose(xc) + pred_pose + pred_shape = self.decshape(xc) + pred_shape + pred_cam = self.deccam(xc) + pred_cam + + pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) + + pred_output = self.smpl(betas=pred_shape, + body_pose=pred_rotmat[:, 1:], + global_orient=pred_rotmat[:, 0].unsqueeze(1), + pose2rot=False) + + pred_vertices = pred_output.vertices + pred_joints = pred_output.joints + pred_smpl_joints = pred_output.smpl_joints + pred_keypoints_2d = projection(pred_joints, pred_cam) + pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, + 3)).reshape( + -1, 72) + + if J_regressor is not None: + pred_joints = torch.matmul(J_regressor, pred_vertices) + pred_pelvis = pred_joints[:, [0], :].clone() + pred_joints = pred_joints[:, H36M_TO_J14, :] + pred_joints = pred_joints - pred_pelvis + + output = { + 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1), + 'verts': pred_vertices, + 'kp_2d': pred_keypoints_2d, + 'kp_3d': pred_joints, + 'smpl_kp_3d': pred_smpl_joints, + 'rotmat': pred_rotmat, + 'pred_cam': pred_cam, + 'pred_shape': pred_shape, + 'pred_pose': pred_pose, + } + return output + + def forward_init(self, + x, + init_pose=None, + init_shape=None, + init_cam=None, + n_iter=1, + J_regressor=None): + batch_size = x.shape[0] + + if init_pose is None: + init_pose = self.init_pose.expand(batch_size, -1) + if init_shape is None: + init_shape = self.init_shape.expand(batch_size, -1) + if init_cam is None: + init_cam = self.init_cam.expand(batch_size, -1) + + pred_pose = init_pose + pred_shape = init_shape + pred_cam = init_cam + + pred_rotmat = rot6d_to_rotmat(pred_pose.contiguous()).view( + batch_size, 24, 3, 3) + + pred_output = self.smpl(betas=pred_shape, + body_pose=pred_rotmat[:, 1:], + global_orient=pred_rotmat[:, 0].unsqueeze(1), + pose2rot=False) + + pred_vertices = pred_output.vertices + pred_joints = pred_output.joints + pred_smpl_joints = pred_output.smpl_joints + pred_keypoints_2d = projection(pred_joints, pred_cam) + pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, + 3)).reshape( + -1, 72) + + if J_regressor is not None: + pred_joints = torch.matmul(J_regressor, pred_vertices) + pred_pelvis = pred_joints[:, [0], :].clone() + pred_joints = pred_joints[:, H36M_TO_J14, :] + pred_joints = pred_joints - pred_pelvis + + output = { + 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1), + 'verts': pred_vertices, + 'kp_2d': pred_keypoints_2d, + 'kp_3d': pred_joints, + 'smpl_kp_3d': pred_smpl_joints, + 'rotmat': pred_rotmat, + 'pred_cam': pred_cam, + 'pred_shape': pred_shape, + 'pred_pose': pred_pose, + } + return output + + +class PyMAF(nn.Module): + """ PyMAF based Deep Regressor for Human Mesh Recovery + PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021 + """ + + def __init__(self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True): + super().__init__() + self.feature_extractor = ResNet_Backbone( + model=cfg.MODEL.PyMAF.BACKBONE, pretrained=pretrained) + + # deconv layers + self.inplanes = self.feature_extractor.inplanes + self.deconv_with_bias = cfg.RES_MODEL.DECONV_WITH_BIAS + self.deconv_layers = self._make_deconv_layer( + cfg.RES_MODEL.NUM_DECONV_LAYERS, + cfg.RES_MODEL.NUM_DECONV_FILTERS, + cfg.RES_MODEL.NUM_DECONV_KERNELS, + ) + + self.maf_extractor = nn.ModuleList() + for _ in range(cfg.MODEL.PyMAF.N_ITER): + self.maf_extractor.append(MAF_Extractor()) + ma_feat_len = self.maf_extractor[-1].Dmap.shape[ + 0] * cfg.MODEL.PyMAF.MLP_DIM[-1] + + grid_size = 21 + xv, yv = torch.meshgrid([ + torch.linspace(-1, 1, grid_size), + torch.linspace(-1, 1, grid_size) + ]) + points_grid = torch.stack([xv.reshape(-1), + yv.reshape(-1)]).unsqueeze(0) + self.register_buffer('points_grid', points_grid) + grid_feat_len = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1] + + self.regressor = nn.ModuleList() + for i in range(cfg.MODEL.PyMAF.N_ITER): + if i == 0: + ref_infeat_dim = grid_feat_len + else: + ref_infeat_dim = ma_feat_len + self.regressor.append( + Regressor(feat_dim=ref_infeat_dim, + smpl_mean_params=smpl_mean_params)) + + dp_feat_dim = 256 + self.with_uv = cfg.LOSS.POINT_REGRESSION_WEIGHTS > 0 + if cfg.MODEL.PyMAF.AUX_SUPV_ON: + self.dp_head = IUV_predict_layer(feat_dim=dp_feat_dim) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """ + Deconv_layer used in Simple Baselines: + Xiao et al. Simple Baselines for Human Pose Estimation and Tracking + https://github.com/microsoft/human-pose-estimation.pytorch + """ + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + def _get_deconv_cfg(deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = _get_deconv_cfg( + num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d(in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias)) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x, J_regressor=None): + + batch_size = x.shape[0] + + # spatial features and global features + s_feat, g_feat = self.feature_extractor(x) + + assert cfg.MODEL.PyMAF.N_ITER >= 0 and cfg.MODEL.PyMAF.N_ITER <= 3 + if cfg.MODEL.PyMAF.N_ITER == 1: + deconv_blocks = [self.deconv_layers] + elif cfg.MODEL.PyMAF.N_ITER == 2: + deconv_blocks = [self.deconv_layers[0:6], self.deconv_layers[6:9]] + elif cfg.MODEL.PyMAF.N_ITER == 3: + deconv_blocks = [ + self.deconv_layers[0:3], self.deconv_layers[3:6], + self.deconv_layers[6:9] + ] + + out_list = {} + + # initial parameters + # TODO: remove the initial mesh generation during forward to reduce runtime + # by generating initial mesh the beforehand: smpl_output = self.init_smpl + smpl_output = self.regressor[0].forward_init(g_feat, + J_regressor=J_regressor) + + out_list['smpl_out'] = [smpl_output] + out_list['dp_out'] = [] + + # for visulization + vis_feat_list = [s_feat.detach()] + + # parameter predictions + for rf_i in range(cfg.MODEL.PyMAF.N_ITER): + pred_cam = smpl_output['pred_cam'] + pred_shape = smpl_output['pred_shape'] + pred_pose = smpl_output['pred_pose'] + + pred_cam = pred_cam.detach() + pred_shape = pred_shape.detach() + pred_pose = pred_pose.detach() + + s_feat_i = deconv_blocks[rf_i](s_feat) + s_feat = s_feat_i + vis_feat_list.append(s_feat_i.detach()) + + self.maf_extractor[rf_i].im_feat = s_feat_i + self.maf_extractor[rf_i].cam = pred_cam + + if rf_i == 0: + sample_points = torch.transpose( + self.points_grid.expand(batch_size, -1, -1), 1, 2) + ref_feature = self.maf_extractor[rf_i].sampling(sample_points) + else: + pred_smpl_verts = smpl_output['verts'].detach() + # TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration + pred_smpl_verts_ds = torch.matmul( + self.maf_extractor[rf_i].Dmap.unsqueeze(0), + pred_smpl_verts) # [B, 431, 3] + ref_feature = self.maf_extractor[rf_i]( + pred_smpl_verts_ds) # [B, 431 * n_feat] + + smpl_output = self.regressor[rf_i](ref_feature, + pred_pose, + pred_shape, + pred_cam, + n_iter=1, + J_regressor=J_regressor) + out_list['smpl_out'].append(smpl_output) + + if self.training and cfg.MODEL.PyMAF.AUX_SUPV_ON: + iuv_out_dict = self.dp_head(s_feat) + out_list['dp_out'].append(iuv_out_dict) + + return out_list + + +def pymaf_net(smpl_mean_params, pretrained=True): + """ Constructs an PyMAF model with ResNet50 backbone. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = PyMAF(smpl_mean_params, pretrained) + return model diff --git a/lib/pymaf/models/res_module.py b/lib/pymaf/models/res_module.py new file mode 100644 index 0000000000000000000000000000000000000000..28c70aa9122c409550ae44bb3ccf78492eea1c47 --- /dev/null +++ b/lib/pymaf/models/res_module.py @@ -0,0 +1,388 @@ +# code brought in part from https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from lib.pymaf.core.cfgs import cfg + +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes * groups, + out_planes * groups, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + groups=groups) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1): + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride, groups=groups) + self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, groups=groups) + self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1): + super().__init__() + self.conv1 = nn.Conv2d(inplanes * groups, + planes * groups, + kernel_size=1, + bias=False, + groups=groups) + self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes * groups, + planes * groups, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=groups) + self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes * groups, + planes * self.expansion * groups, + kernel_size=1, + bias=False, + groups=groups) + self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +resnet_spec = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) +} + + +class IUV_predict_layer(nn.Module): + + def __init__(self, + feat_dim=256, + final_cov_k=3, + part_out_dim=25, + with_uv=True): + super().__init__() + + self.with_uv = with_uv + if self.with_uv: + self.predict_u = nn.Conv2d(in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0) + + self.predict_v = nn.Conv2d(in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0) + + self.predict_ann_index = nn.Conv2d( + in_channels=feat_dim, + out_channels=15, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0) + + self.predict_uv_index = nn.Conv2d(in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0) + + self.inplanes = feat_dim + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + return_dict = {} + + predict_uv_index = self.predict_uv_index(x) + predict_ann_index = self.predict_ann_index(x) + + return_dict['predict_uv_index'] = predict_uv_index + return_dict['predict_ann_index'] = predict_ann_index + + if self.with_uv: + predict_u = self.predict_u(x) + predict_v = self.predict_v(x) + return_dict['predict_u'] = predict_u + return_dict['predict_v'] = predict_v + else: + return_dict['predict_u'] = None + return_dict['predict_v'] = None + # return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device) + # return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device) + + return return_dict + + +class SmplResNet(nn.Module): + + def __init__(self, + resnet_nums, + in_channels=3, + num_classes=229, + last_stride=2, + n_extra_feat=0, + truncate=0, + **kwargs): + super().__init__() + + self.inplanes = 64 + self.truncate = truncate + # extra = cfg.MODEL.EXTRA + # self.deconv_with_bias = extra.DECONV_WITH_BIAS + block, layers = resnet_spec[resnet_nums] + + self.conv1 = nn.Conv2d(in_channels, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], + stride=2) if truncate < 2 else None + self.layer4 = self._make_layer( + block, 512, layers[3], + stride=last_stride) if truncate < 1 else None + + self.avg_pooling = nn.AdaptiveAvgPool2d(1) + + self.num_classes = num_classes + if num_classes > 0: + self.final_layer = nn.Linear(512 * block.expansion, num_classes) + nn.init.xavier_uniform_(self.final_layer.weight, gain=0.01) + + self.n_extra_feat = n_extra_feat + if n_extra_feat > 0: + self.trans_conv = nn.Sequential( + nn.Conv2d(n_extra_feat + 512 * block.expansion, + 512 * block.expansion, + kernel_size=1, + bias=False), + nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM), + nn.ReLU(True)) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, infeat=None): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) if self.truncate < 2 else x2 + x4 = self.layer4(x3) if self.truncate < 1 else x3 + + if infeat is not None: + x4 = self.trans_conv(torch.cat([infeat, x4], 1)) + + if self.num_classes > 0: + xp = self.avg_pooling(x4) + cls = self.final_layer(xp.view(xp.size(0), -1)) + if not cfg.DANET.USE_MEAN_PARA: + # for non-negative scale + scale = F.relu(cls[:, 0]).unsqueeze(1) + cls = torch.cat((scale, cls[:, 1:]), dim=1) + else: + cls = None + + return cls, {'x4': x4} + + def init_weights(self, pretrained=''): + if os.path.isfile(pretrained): + logger.info('=> loading pretrained model {}'.format(pretrained)) + # self.load_state_dict(pretrained_state_dict, strict=False) + checkpoint = torch.load(pretrained) + if isinstance(checkpoint, OrderedDict): + # state_dict = checkpoint + state_dict_old = self.state_dict() + for key in state_dict_old.keys(): + if key in checkpoint.keys(): + if state_dict_old[key].shape != checkpoint[key].shape: + del checkpoint[key] + state_dict = checkpoint + elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict_old = checkpoint['state_dict'] + state_dict = OrderedDict() + # delete 'module.' because it is saved from DataParallel module + for key in state_dict_old.keys(): + if key.startswith('module.'): + # state_dict[key[7:]] = state_dict[key] + # state_dict.pop(key) + state_dict[key[7:]] = state_dict_old[key] + else: + state_dict[key] = state_dict_old[key] + else: + raise RuntimeError( + 'No state_dict found in checkpoint file {}'.format( + pretrained)) + self.load_state_dict(state_dict, strict=False) + else: + logger.error('=> imagenet pretrained model dose not exist') + logger.error('=> please download it first') + raise ValueError('imagenet pretrained model does not exist') + + +class LimbResLayers(nn.Module): + + def __init__(self, + resnet_nums, + inplanes, + outplanes=None, + groups=1, + **kwargs): + super().__init__() + + self.inplanes = inplanes + block, layers = resnet_spec[resnet_nums] + self.outplanes = 512 if outplanes == None else outplanes + self.layer4 = self._make_layer(block, + self.outplanes, + layers[3], + stride=2, + groups=groups) + + self.avg_pooling = nn.AdaptiveAvgPool2d(1) + + def _make_layer(self, block, planes, blocks, stride=1, groups=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes * groups, + planes * block.expansion * groups, + kernel_size=1, + stride=stride, + bias=False, + groups=groups), + nn.BatchNorm2d(planes * block.expansion * groups, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, groups=groups)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer4(x) + x = self.avg_pooling(x) + + return x diff --git a/lib/pymaf/models/smpl.py b/lib/pymaf/models/smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..ad0059acc3d88d7bf13d6bca25ed9da1b82bb5fe --- /dev/null +++ b/lib/pymaf/models/smpl.py @@ -0,0 +1,92 @@ +# This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/smpl.py + +import torch +import numpy as np +from lib.smplx import SMPL as _SMPL +from lib.smplx.body_models import ModelOutput +from lib.smplx.lbs import vertices2joints +from collections import namedtuple + +from lib.pymaf.core import path_config, constants + +SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS +SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR + +# Indices to get the 14 LSP joints from the 17 H36M joints +H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] +H36M_TO_J14 = H36M_TO_J17[:14] + + +class SMPL(_SMPL): + """ Extension of the official SMPL implementation to support more joints """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] + J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) + self.register_buffer( + 'J_regressor_extra', + torch.tensor(J_regressor_extra, dtype=torch.float32)) + self.joint_map = torch.tensor(joints, dtype=torch.long) + self.ModelOutput = namedtuple( + 'ModelOutput_', ModelOutput._fields + ( + 'smpl_joints', + 'joints_J19', + )) + self.ModelOutput.__new__.__defaults__ = (None, ) * len( + self.ModelOutput._fields) + + def forward(self, *args, **kwargs): + kwargs['get_skin'] = True + smpl_output = super().forward(*args, **kwargs) + extra_joints = vertices2joints(self.J_regressor_extra, + smpl_output.vertices) + # smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3] + vertices = smpl_output.vertices + joints = torch.cat([smpl_output.joints, extra_joints], dim=1) + smpl_joints = smpl_output.joints[:, :24] + joints = joints[:, self.joint_map, :] # [B, 49, 3] + joints_J24 = joints[:, -24:, :] + joints_J19 = joints_J24[:, constants.J24_TO_J19, :] + output = self.ModelOutput(vertices=vertices, + global_orient=smpl_output.global_orient, + body_pose=smpl_output.body_pose, + joints=joints, + joints_J19=joints_J19, + smpl_joints=smpl_joints, + betas=smpl_output.betas, + full_pose=smpl_output.full_pose) + return output + + +def get_smpl_faces(): + smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False) + return smpl.faces + + +def get_part_joints(smpl_joints): + batch_size = smpl_joints.shape[0] + + # part_joints = torch.zeros().to(smpl_joints.device) + + one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), + (12, 15), (13, 16), (14, 17)] + two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), + (18, 20), (19, 21)] + + one_seg_pairs.extend(two_seg_pairs) + + single_joints = [(10), (11), (15), (22), (23)] + + part_joints = [] + + for j_p in one_seg_pairs: + new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True) + part_joints.append(new_joint) + + for j_p in single_joints: + part_joints.append(smpl_joints[:, j_p:j_p + 1]) + + part_joints = torch.cat(part_joints, dim=1) + + return part_joints diff --git a/lib/pymaf/utils/__init__.py b/lib/pymaf/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pymaf/utils/geometry.py b/lib/pymaf/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e0f6ee352be8351e7acf838ae8a44fff32222b --- /dev/null +++ b/lib/pymaf/utils/geometry.py @@ -0,0 +1,452 @@ +import torch +import numpy as np +from torch.nn import functional as F +""" +Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" + + +def batch_rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(B, 3, 3) + return rotMat + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3, 3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], + dtype=torch.float32, + device=rotation_matrix.device).reshape( + 1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format( + quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def projection(pred_joints, pred_camera, retain_z=False): + pred_cam_t = torch.stack([ + pred_camera[:, 1], pred_camera[:, 2], 2 * 5000. / + (224. * pred_camera[:, 0] + 1e-9) + ], + dim=-1) + batch_size = pred_joints.shape[0] + camera_center = torch.zeros(batch_size, 2) + pred_keypoints_2d = perspective_projection( + pred_joints, + rotation=torch.eye(3).unsqueeze(0).expand(batch_size, -1, + -1).to(pred_joints.device), + translation=pred_cam_t, + focal_length=5000., + camera_center=camera_center, + retain_z=retain_z) + # Normalize keypoints to [-1,1] + pred_keypoints_2d = pred_keypoints_2d / (224. / 2.) + return pred_keypoints_2d + + +def perspective_projection(points, + rotation, + translation, + focal_length, + camera_center, + retain_z=False): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:, 0, 0] = focal_length + K[:, 1, 1] = focal_length + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + if retain_z: + return projected_points + else: + return projected_points[:, :, :-1] + + +def estimate_translation_np(S, + joints_2d, + joints_conf, + focal_length=5000, + img_size=224): + """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length, focal_length]) + # optical center + center = np.array([img_size / 2., img_size / 2.]) + + # transformations + Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1) + XY = np.reshape(S[:, 0:2], -1) + O = np.tile(center, num_joints) + F = np.tile(f, num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1) + + # least squares + Q = np.array([ + F * np.tile(np.array([1, 0]), num_joints), + F * np.tile(np.array([0, 1]), num_joints), + O - np.reshape(joints_2d, -1) + ]).T + c = (np.reshape(joints_2d, -1) - O) * Z - F * XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W, Q) + c = np.dot(W, c) + + # square matrix + A = np.dot(Q.T, Q) + b = np.dot(Q.T, c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.): + """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + + device = S.device + # Use only joints 25:49 (GT joints) + S = S[:, 25:, :].cpu().numpy() + joints_2d = joints_2d[:, 25:, :].cpu().numpy() + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float32) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np(S_i, + joints_i, + conf_i, + focal_length=focal_length, + img_size=img_size) + return torch.from_numpy(trans).to(device) + + +def Rot_y(angle, category='torch', prepend_dim=True, device=None): + '''Rotate around y-axis by angle + Args: + category: 'torch' or 'numpy' + prepend_dim: prepend an extra dimension + Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) + ''' + m = np.array([[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], + [-np.sin(angle), 0., np.cos(angle)]]) + if category == 'torch': + if prepend_dim: + return torch.tensor(m, dtype=torch.float, + device=device).unsqueeze(0) + else: + return torch.tensor(m, dtype=torch.float, device=device) + elif category == 'numpy': + if prepend_dim: + return np.expand_dims(m, 0) + else: + return m + else: + raise ValueError("category must be 'torch' or 'numpy'") + + +def Rot_x(angle, category='torch', prepend_dim=True, device=None): + '''Rotate around x-axis by angle + Args: + category: 'torch' or 'numpy' + prepend_dim: prepend an extra dimension + Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) + ''' + m = np.array([[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], + [0., np.sin(angle), np.cos(angle)]]) + if category == 'torch': + if prepend_dim: + return torch.tensor(m, dtype=torch.float, + device=device).unsqueeze(0) + else: + return torch.tensor(m, dtype=torch.float, device=device) + elif category == 'numpy': + if prepend_dim: + return np.expand_dims(m, 0) + else: + return m + else: + raise ValueError("category must be 'torch' or 'numpy'") + + +def Rot_z(angle, category='torch', prepend_dim=True, device=None): + '''Rotate around z-axis by angle + Args: + category: 'torch' or 'numpy' + prepend_dim: prepend an extra dimension + Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) + ''' + m = np.array([[np.cos(angle), -np.sin(angle), 0.], + [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]) + if category == 'torch': + if prepend_dim: + return torch.tensor(m, dtype=torch.float, + device=device).unsqueeze(0) + else: + return torch.tensor(m, dtype=torch.float, device=device) + elif category == 'numpy': + if prepend_dim: + return np.expand_dims(m, 0) + else: + return m + else: + raise ValueError("category must be 'torch' or 'numpy'") diff --git a/lib/pymaf/utils/imutils.py b/lib/pymaf/utils/imutils.py new file mode 100644 index 0000000000000000000000000000000000000000..165b6923e9070b87917644d09d04a95f37b8590b --- /dev/null +++ b/lib/pymaf/utils/imutils.py @@ -0,0 +1,510 @@ +""" +This file contains functions that are used to perform data augmentation. +""" +from turtle import reset +import cv2 +import io +import torch +import numpy as np +import scipy.misc +from PIL import Image +from rembg.bg import remove +from torchvision.models import detection + +from lib.pymaf.core import constants +from lib.pymaf.utils.streamer import aug_matrix +from lib.common.cloth_extraction import load_segmentation +from torchvision import transforms + + +def load_img(img_file): + + img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED) + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if not img_file.endswith("png"): + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + + return img + + +def get_bbox(img, det): + + input = np.float32(img) + input = (input / 255.0 - + (0.5, 0.5, 0.5)) / (0.5, 0.5, 0.5) # TO [-1.0, 1.0] + input = input.transpose(2, 0, 1) # TO [3 x H x W] + bboxes, probs = det(torch.from_numpy(input).float().unsqueeze(0)) + + probs = probs.unsqueeze(3) + bboxes = (bboxes * probs).sum(dim=1, keepdim=True) / probs.sum( + dim=1, keepdim=True) + bbox = bboxes[0, 0, 0].cpu().numpy() + + return bbox + + +def get_transformer(input_res): + + image_to_tensor = transforms.Compose([ + transforms.Resize(input_res), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + mask_to_tensor = transforms.Compose([ + transforms.Resize(input_res), + transforms.ToTensor(), + transforms.Normalize((0.0, ), (1.0, )) + ]) + + image_to_pymaf_tensor = transforms.Compose([ + transforms.Resize(size=224), + transforms.Normalize(mean=constants.IMG_NORM_MEAN, + std=constants.IMG_NORM_STD) + ]) + + image_to_pixie_tensor = transforms.Compose([transforms.Resize(224)]) + + def image_to_hybrik_tensor(img): + # mean + img[0].add_(-0.406) + img[1].add_(-0.457) + img[2].add_(-0.480) + + # std + img[0].div_(0.225) + img[1].div_(0.224) + img[2].div_(0.229) + return img + + return [ + image_to_tensor, mask_to_tensor, image_to_pymaf_tensor, + image_to_pixie_tensor, image_to_hybrik_tensor + ] + + +def process_image(img_file, + hps_type, + input_res=512, + device=None, + seg_path=None): + """Read image, do preprocessing and possibly crop it according to the bounding box. + If there are bounding box annotations, use them to crop the image. + If no bounding box is specified but openpose detections are available, use them to get the bounding box. + """ + + [ + image_to_tensor, mask_to_tensor, image_to_pymaf_tensor, + image_to_pixie_tensor, image_to_hybrik_tensor + ] = get_transformer(input_res) + + img_ori = load_img(img_file) + + in_height, in_width, _ = img_ori.shape + M = aug_matrix(in_width, in_height, input_res * 2, input_res * 2) + + # from rectangle to square + img_for_crop = cv2.warpAffine(img_ori, + M[0:2, :], (input_res * 2, input_res * 2), + flags=cv2.INTER_CUBIC) + + # detection for bbox + detector = detection.maskrcnn_resnet50_fpn(pretrained=True) + detector.eval() + predictions = detector( + [torch.from_numpy(img_for_crop).permute(2, 0, 1) / 255.])[0] + human_ids = torch.logical_and( + predictions["labels"] == 1, + predictions["scores"] == predictions["scores"].max()).nonzero().squeeze(1) + bbox = predictions["boxes"][human_ids, :].flatten().detach().cpu().numpy() + + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + center = np.array([(bbox[0] + bbox[2]) / 2.0, + (bbox[1] + bbox[3]) / 2.0]) + + scale = max(height, width) / 180 + + if hps_type == 'hybrik': + img_np = crop_for_hybrik(img_for_crop, center, + np.array([scale * 180, scale * 180])) + else: + img_np, cropping_parameters = crop(img_for_crop, center, scale, + (input_res, input_res)) + + with torch.no_grad(): + buf = io.BytesIO() + Image.fromarray(img_np).save(buf, format='png') + img_pil = Image.open(io.BytesIO(remove( + buf.getvalue()))).convert("RGBA") + + # for icon + img_rgb = image_to_tensor(img_pil.convert("RGB")) + img_mask = torch.tensor(1.0) - (mask_to_tensor(img_pil.split()[-1]) < + torch.tensor(0.5)).float() + img_tensor = img_rgb * img_mask + + # for hps + img_hps = img_np.astype(np.float32) / 255. + img_hps = torch.from_numpy(img_hps).permute(2, 0, 1) + + if hps_type == 'bev': + img_hps = img_np[:, :, [2, 1, 0]] + elif hps_type == 'hybrik': + img_hps = image_to_hybrik_tensor(img_hps).unsqueeze(0).to(device) + elif hps_type != 'pixie': + img_hps = image_to_pymaf_tensor(img_hps).unsqueeze(0).to(device) + else: + img_hps = image_to_pixie_tensor(img_hps).unsqueeze(0).to(device) + + # uncrop params + uncrop_param = { + 'center': center, + 'scale': scale, + 'ori_shape': img_ori.shape, + 'box_shape': img_np.shape, + 'crop_shape': img_for_crop.shape, + 'M': M + } + + if not (seg_path is None): + segmentations = load_segmentation(seg_path, (in_height, in_width)) + seg_coord_normalized = [] + for seg in segmentations: + coord_normalized = [] + for xy in seg['coordinates']: + xy_h = np.vstack((xy[:, 0], xy[:, 1], np.ones(len(xy)))).T + warped_indeces = M[0:2, :] @ xy_h[:, :, None] + warped_indeces = np.array(warped_indeces).astype(int) + warped_indeces.resize((warped_indeces.shape[:2])) + + # cropped_indeces = crop_segmentation(warped_indeces, center, scale, (input_res, input_res), img_np.shape) + cropped_indeces = crop_segmentation(warped_indeces, + (input_res, input_res), + cropping_parameters) + + indices = np.vstack( + (cropped_indeces[:, 0], cropped_indeces[:, 1])).T + + # Convert to NDC coordinates + seg_cropped_normalized = 2 * (indices / input_res) - 1 + # Don't know why we need to divide by 50 but it works ¯\_(ツ)_/¯ (probably some scaling factor somewhere) + # Divide only by 45 on the horizontal axis to take the curve of the human body into account + seg_cropped_normalized[:, + 0] = (1 / + 40) * seg_cropped_normalized[:, 0] + seg_cropped_normalized[:, + 1] = (1 / + 50) * seg_cropped_normalized[:, 1] + coord_normalized.append(seg_cropped_normalized) + + seg['coord_normalized'] = coord_normalized + seg_coord_normalized.append(seg) + + return img_tensor, img_hps, img_ori, img_mask, uncrop_param, seg_coord_normalized + + return img_tensor, img_hps, img_ori, img_mask, uncrop_param + + +def get_transform(center, scale, res): + """Generate transformation matrix.""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + + return t + + +def transform(pt, center, scale, res, invert=0): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + return np.around(new_pt[:2]).astype(np.int16) + + +def crop(img, center, scale, res): + """Crop image according to the supplied bounding box.""" + + # Upper left point + ul = np.array(transform([0, 0], center, scale, res, invert=1)) + + # Bottom right point + br = np.array(transform(res, center, scale, res, invert=1)) + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + if len(img.shape) == 2: + new_img = np.array(Image.fromarray(new_img).resize(res)) + else: + new_img = np.array( + Image.fromarray(new_img.astype(np.uint8)).resize(res)) + + return new_img, (old_x, new_x, old_y, new_y, new_shape) + + +def crop_segmentation(org_coord, res, cropping_parameters): + old_x, new_x, old_y, new_y, new_shape = cropping_parameters + + new_coord = np.zeros((org_coord.shape)) + new_coord[:, 0] = new_x[0] + (org_coord[:, 0] - old_x[0]) + new_coord[:, 1] = new_y[0] + (org_coord[:, 1] - old_y[0]) + + new_coord[:, 0] = res[0] * (new_coord[:, 0] / new_shape[1]) + new_coord[:, 1] = res[1] * (new_coord[:, 1] / new_shape[0]) + + return new_coord + + +def crop_for_hybrik(img, center, scale): + inp_h, inp_w = (256, 256) + trans = get_affine_transform(center, scale, 0, [inp_w, inp_h]) + new_img = cv2.warpAffine(img, + trans, (int(inp_w), int(inp_h)), + flags=cv2.INTER_LINEAR) + return new_img + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + + def get_dir(src_point, rot_rad): + """Rotate the point by `rot_rad` degree.""" + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + def get_3rd_point(a, b): + """Return vector c that perpendicular to (a - b).""" + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale]) + + scale_tmp = scale + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def corner_align(ul, br): + + if ul[1] - ul[0] != br[1] - br[0]: + ul[1] = ul[0] + br[1] - br[0] + + return ul, br + + +def uncrop(img, center, scale, orig_shape): + """'Undo' the image cropping/resizing. + This function is used when evaluating mask/part segmentation. + """ + + res = img.shape[:2] + + # Upper left point + ul = np.array(transform([0, 0], center, scale, res, invert=1)) + # Bottom right point + br = np.array(transform(res, center, scale, res, invert=1)) + + # quick fix + ul, br = corner_align(ul, br) + + # size of cropped image + crop_shape = [br[1] - ul[1], br[0] - ul[0]] + new_img = np.zeros(orig_shape, dtype=np.uint8) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] + + # Range to sample from original image + old_x = max(0, ul[0]), min(orig_shape[1], br[0]) + old_y = max(0, ul[1]), min(orig_shape[0], br[1]) + + img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape)) + + new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], + new_x[0]:new_x[1]] + + return new_img + + +def rot_aa(aa, rot): + """Rotate axis angle parameters.""" + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), + np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]]) + # find the rotation of the body in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg)) + aa = (resrot.T)[0] + return aa + + +def flip_img(img): + """Flip rgb images or masks. + channels come last, e.g. (256,256,3). + """ + img = np.fliplr(img) + return img + + +def flip_kp(kp, is_smpl=False): + """Flip keypoints.""" + if len(kp) == 24: + if is_smpl: + flipped_parts = constants.SMPL_JOINTS_FLIP_PERM + else: + flipped_parts = constants.J24_FLIP_PERM + elif len(kp) == 49: + if is_smpl: + flipped_parts = constants.SMPL_J49_FLIP_PERM + else: + flipped_parts = constants.J49_FLIP_PERM + kp = kp[flipped_parts] + kp[:, 0] = -kp[:, 0] + return kp + + +def flip_pose(pose): + """Flip pose. + The flipping is based on SMPL parameters. + """ + flipped_parts = constants.SMPL_POSE_FLIP_PERM + pose = pose[flipped_parts] + # we also negate the second and the third dimension of the axis-angle + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + return pose + + +def normalize_2d_kp(kp_2d, crop_size=224, inv=False): + # Normalize keypoints between -1, 1 + if not inv: + ratio = 1.0 / crop_size + kp_2d = 2.0 * kp_2d * ratio - 1.0 + else: + ratio = 1.0 / crop_size + kp_2d = (kp_2d + 1.0) / (2 * ratio) + + return kp_2d + + +def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None): + ''' + param joints: [num_joints, 3] + param joints_vis: [num_joints, 3] + return: target, target_weight(1: visible, 0: invisible) + ''' + num_joints = joints.shape[0] + device = joints.device + cur_device = torch.device(device.type, device.index) + if not hasattr(heatmap_size, '__len__'): + # width height + heatmap_size = [heatmap_size, heatmap_size] + assert len(heatmap_size) == 2 + target_weight = np.ones((num_joints, 1), dtype=np.float32) + if joints_vis is not None: + target_weight[:, 0] = joints_vis[:, 0] + target = torch.zeros((num_joints, heatmap_size[1], heatmap_size[0]), + dtype=torch.float32, + device=cur_device) + + tmp_size = sigma * 3 + + for joint_id in range(num_joints): + mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5) + mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5) + # Check that any part of the gaussian is in-bounds + ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] + br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] + if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \ + or br[0] < 0 or br[1] < 0: + # If not, just return the image as is + target_weight[joint_id] = 0 + continue + + # # Generate gaussian + size = 2 * tmp_size + 1 + # x = np.arange(0, size, 1, np.float32) + # y = x[:, np.newaxis] + # x0 = y0 = size // 2 + # # The gaussian is not normalized, we want the center value to equal 1 + # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + # g = torch.from_numpy(g.astype(np.float32)) + + x = torch.arange(0, size, dtype=torch.float32, device=cur_device) + y = x.unsqueeze(-1) + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], heatmap_size[0]) + img_y = max(0, ul[1]), min(br[1], heatmap_size[1]) + + v = target_weight[joint_id] + if v > 0.5: + target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \ + g[g_y[0]:g_y[1], g_x[0]:g_x[1]] + + return target, target_weight diff --git a/lib/pymaf/utils/streamer.py b/lib/pymaf/utils/streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..4a310651b26ab496ed3a2e4dab19691c29cf7bae --- /dev/null +++ b/lib/pymaf/utils/streamer.py @@ -0,0 +1,144 @@ +import cv2 +import torch +import numpy as np +import imageio + + +def aug_matrix(w1, h1, w2, h2): + dx = (w2 - w1) / 2.0 + dy = (h2 - h1) / 2.0 + + matrix_trans = np.array([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0]]) + + scale = np.min([float(w2) / w1, float(h2) / h1]) + + M = get_affine_matrix(center=(w2 / 2.0, h2 / 2.0), + translate=(0, 0), + scale=scale) + + M = np.array(M + [0., 0., 1.]).reshape(3, 3) + M = M.dot(matrix_trans) + + return M + + +def get_affine_matrix(center, translate, scale): + cx, cy = center + tx, ty = translate + + M = [1, 0, 0, 0, 1, 0] + M = [x * scale for x in M] + + # Apply translation and of center translation: RSS * C^-1 + M[2] += M[0] * (-cx) + M[1] * (-cy) + M[5] += M[3] * (-cx) + M[4] * (-cy) + + # Apply center translation: T * C * RSS * C^-1 + M[2] += cx + tx + M[5] += cy + ty + return M + + +class BaseStreamer(): + """This streamer will return images at 512x512 size. + """ + + def __init__(self, + width=512, + height=512, + pad=True, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + **kwargs): + self.width = width + self.height = height + self.pad = pad + self.mean = np.array(mean) + self.std = np.array(std) + + self.loader = self.create_loader() + + def create_loader(self): + raise NotImplementedError + yield np.zeros((600, 400, 3)) # in RGB (0, 255) + + def __getitem__(self, index): + image = next(self.loader) + in_height, in_width, _ = image.shape + M = aug_matrix(in_width, in_height, self.width, self.height, self.pad) + image = cv2.warpAffine(image, + M[0:2, :], (self.width, self.height), + flags=cv2.INTER_CUBIC) + + input = np.float32(image) + input = (input / 255.0 - self.mean) / self.std # TO [-1.0, 1.0] + input = input.transpose(2, 0, 1) # TO [3 x H x W] + return torch.from_numpy(input).float() + + def __len__(self): + raise NotImplementedError + + +class CaptureStreamer(BaseStreamer): + """This streamer takes webcam as input. + """ + + def __init__(self, id=0, width=512, height=512, pad=True, **kwargs): + super().__init__(width, height, pad, **kwargs) + self.capture = cv2.VideoCapture(id) + + def create_loader(self): + while True: + _, image = self.capture.read() + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB + yield image + + def __len__(self): + return 100_000_000 + + def __del__(self): + self.capture.release() + + +class VideoListStreamer(BaseStreamer): + """This streamer takes a list of video files as input. + """ + + def __init__(self, files, width=512, height=512, pad=True, **kwargs): + super().__init__(width, height, pad, **kwargs) + self.files = files + self.captures = [imageio.get_reader(f) for f in files] + self.nframes = sum([ + int(cap._meta["fps"] * cap._meta["duration"]) + for cap in self.captures + ]) + + def create_loader(self): + for capture in self.captures: + for image in capture: # RGB + yield image + + def __len__(self): + return self.nframes + + def __del__(self): + for capture in self.captures: + capture.close() + + +class ImageListStreamer(BaseStreamer): + """This streamer takes a list of image files as input. + """ + + def __init__(self, files, width=512, height=512, pad=True, **kwargs): + super().__init__(width, height, pad, **kwargs) + self.files = files + + def create_loader(self): + for f in self.files: + image = cv2.imread(f, cv2.IMREAD_UNCHANGED)[:, :, 0:3] + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB + yield image + + def __len__(self): + return len(self.files) diff --git a/lib/pymaf/utils/transforms.py b/lib/pymaf/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..996742aa6defb02afcd59f7f2ae0eb47576bf314 --- /dev/null +++ b/lib/pymaf/utils/transforms.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import numpy as np + + +def transform_preds(coords, center, scale, output_size): + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale, 0, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + # print(scale) + scale = np.array([scale, scale]) + + scale_tmp = scale * 200.0 + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result diff --git a/lib/pymafx/configs/pymafx_config.yaml b/lib/pymafx/configs/pymafx_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf294a77dd8e77ad3c7448b97f59151c57efcc14 --- /dev/null +++ b/lib/pymafx/configs/pymafx_config.yaml @@ -0,0 +1,175 @@ +SOLVER: + MAX_ITER: 500000 + TYPE: Adam + BASE_LR: 0.00005 + GAMMA: 0.1 + STEPS: [0] + EPOCHS: [0] +# DEBUG: False +LOGDIR: '' +DEVICE: cuda +# NUM_WORKERS: 8 +SEED_VALUE: -1 +LOSS: + KP_2D_W: 300.0 + KP_3D_W: 300.0 + HF_KP_2D_W: 1000.0 + HF_KP_3D_W: 1000.0 + GL_HF_KP_2D_W: 30. + FEET_KP_2D_W: 0. + SHAPE_W: 0.06 + POSE_W: 60.0 + VERT_W: 0.0 + VERT_REG_W: 300.0 + INDEX_WEIGHTS: 2.0 + # Loss weights for surface parts. (24 Parts) + PART_WEIGHTS: 0.3 + # Loss weights for UV regression. + POINT_REGRESSION_WEIGHTS: 0.5 +TRAIN: + NUM_WORKERS: 8 + BATCH_SIZE: 64 + LOG_FERQ: 100 + SHUFFLE: True + PIN_MEMORY: True + BHF_MODE: 'full_body' +TEST: + BATCH_SIZE: 32 +MODEL: + # IWP, Identity rotation and Weak Perspective Camera + USE_IWP_CAM: True + USE_GT_FL: False + PRED_PITCH: False + MESH_MODEL: 'smplx' + ALL_GENDER: False + EVAL_MODE: True + PyMAF: + BACKBONE: 'hr48' + HF_BACKBONE: 'res50' + MAF_ON: True + MLP_DIM: [256, 128, 64, 5] + HF_MLP_DIM: [256, 128, 64, 5] + MLP_VT_DIM: [256, 128, 64, 3] + N_ITER: 3 + SUPV_LAST: False + AUX_SUPV_ON: True + HF_AUX_SUPV_ON: False + HF_BOX_CENTER: True + DP_HEATMAP_SIZE: 56 + GRID_FEAT: False + USE_CAM_FEAT: True + HF_IMG_SIZE: 224 + HF_DET: 'pifpaf' + OPT_WRIST: True + ADAPT_INTEGR: True + PRED_VIS_H: True + HAND_VIS_TH: 0.1 + GRID_ALIGN: + USE_ATT: True + USE_FC: False + ATT_FEAT_IDX: 2 + ATT_HEAD: 1 + ATT_STARTS: 1 +RES_MODEL: + DECONV_WITH_BIAS: False + NUM_DECONV_LAYERS: 3 + NUM_DECONV_FILTERS: + - 256 + - 256 + - 256 + NUM_DECONV_KERNELS: + - 4 + - 4 + - 4 +POSE_RES_MODEL: + INIT_WEIGHTS: True + NAME: 'pose_resnet' + PRETR_SET: 'imagenet' # 'none' 'imagenet' 'coco' + # PRETRAINED: 'data/pretrained_model/resnet50-19c8e357.pth' + PRETRAINED_IM: 'data/pretrained_model/resnet50-19c8e357.pth' + PRETRAINED_COCO: 'data/pretrained_model/pose_resnet_50_256x192.pth.tar' + EXTRA: + TARGET_TYPE: 'gaussian' + HEATMAP_SIZE: + - 48 + - 64 + SIGMA: 2 + FINAL_CONV_KERNEL: 1 + DECONV_WITH_BIAS: False + NUM_DECONV_LAYERS: 3 + NUM_DECONV_FILTERS: + - 256 + - 256 + - 256 + NUM_DECONV_KERNELS: + - 4 + - 4 + - 4 + NUM_LAYERS: 50 +HR_MODEL: + INIT_WEIGHTS: True + NAME: pose_hrnet + PRETR_SET: 'coco' # 'none' 'imagenet' 'coco' + PRETRAINED_IM: 'data/pretrained_model/hrnet_w48-imgnet-8ef0771d.pth' + PRETRAINED_COCO: 'data/pretrained_model/pose_hrnet_w48_256x192.pth' + TARGET_TYPE: gaussian + IMAGE_SIZE: + - 256 + - 256 + HEATMAP_SIZE: + - 64 + - 64 + SIGMA: 2 + EXTRA: + PRETRAINED_LAYERS: + - 'conv1' + - 'bn1' + - 'conv2' + - 'bn2' + - 'layer1' + - 'transition1' + - 'stage2' + - 'transition2' + - 'stage3' + - 'transition3' + - 'stage4' + FINAL_CONV_KERNEL: 1 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + - 384 + FUSE_METHOD: SUM diff --git a/lib/pymafx/core/__init__.py b/lib/pymafx/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pymafx/core/cfgs.py b/lib/pymafx/core/cfgs.py new file mode 100644 index 0000000000000000000000000000000000000000..c970c6c0caafe7a4c2f3abbb311adcd0cef42b94 --- /dev/null +++ b/lib/pymafx/core/cfgs.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import os +import json +import random +import string +import argparse +from datetime import datetime +from yacs.config import CfgNode as CN + +# Configuration variables +cfg = CN(new_allowed=True) + +cfg.OUTPUT_DIR = 'results' +cfg.DEVICE = 'cuda' +cfg.DEBUG = False +cfg.LOGDIR = '' +cfg.VAL_VIS_BATCH_FREQ = 200 +cfg.TRAIN_VIS_ITER_FERQ = 1000 +cfg.SEED_VALUE = -1 + +cfg.TRAIN = CN(new_allowed=True) + +cfg.LOSS = CN(new_allowed=True) +cfg.LOSS.KP_2D_W = 300.0 +cfg.LOSS.KP_3D_W = 300.0 +cfg.LOSS.SHAPE_W = 0.06 +cfg.LOSS.POSE_W = 60.0 +cfg.LOSS.VERT_W = 0.0 + +# Loss weights for dense correspondences +cfg.LOSS.INDEX_WEIGHTS = 2.0 +# Loss weights for surface parts. (24 Parts) +cfg.LOSS.PART_WEIGHTS = 0.3 +# Loss weights for UV regression. +cfg.LOSS.POINT_REGRESSION_WEIGHTS = 0.5 + +cfg.MODEL = CN(new_allowed=True) + +cfg.MODEL.PyMAF = CN(new_allowed=True) + +## switch +cfg.TRAIN.BATCH_SIZE = 64 +cfg.TRAIN.VAL_LOOP = True + +cfg.TEST = CN(new_allowed=True) + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + # return cfg.clone() + return cfg + + +def update_cfg(cfg_file): + # cfg = get_cfg_defaults() + cfg.merge_from_file(cfg_file) + # return cfg.clone() + return cfg + + +def parse_args(args): + cfg_file = args.cfg_file + if args.cfg_file is not None: + cfg = update_cfg(args.cfg_file) + else: + cfg = get_cfg_defaults() + + if args.misc is not None: + cfg.merge_from_list(args.misc) + + return cfg + + +def parse_args_extend(args): + if args.resume: + if not os.path.exists(args.log_dir): + raise ValueError('Experiment are set to resume mode, but log directory does not exist.') + + if args.cfg_file is not None: + cfg = update_cfg(args.cfg_file) + else: + cfg = get_cfg_defaults() + # load log's cfg + cfg_file = os.path.join(args.log_dir, 'cfg.yaml') + cfg = update_cfg(cfg_file) + + if args.misc is not None: + cfg.merge_from_list(args.misc) + else: + parse_args(args) diff --git a/lib/pymafx/core/constants.py b/lib/pymafx/core/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5354a289f892a764a16221b469fc49794ff54127 --- /dev/null +++ b/lib/pymafx/core/constants.py @@ -0,0 +1,351 @@ +# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/constants.py +FOCAL_LENGTH = 5000.0 +IMG_RES = 224 + +# Mean and standard deviation for normalizing input image +IMG_NORM_MEAN = [0.485, 0.456, 0.406] +IMG_NORM_STD = [0.229, 0.224, 0.225] +""" +We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides. +We keep a superset of 24 joints such that we include all joints from every dataset. +If a dataset doesn't provide annotations for a specific joint, we simply ignore it. +The joints used here are the following: +""" +OP_JOINT_NAMES = [ + # 25 OpenPose joints (in the order provided by OpenPose) + 'OP Nose', + 'OP Neck', + 'OP RShoulder', + 'OP RElbow', + 'OP RWrist', + 'OP LShoulder', + 'OP LElbow', + 'OP LWrist', + 'OP MidHip', + 'OP RHip', + 'OP RKnee', + 'OP RAnkle', + 'OP LHip', + 'OP LKnee', + 'OP LAnkle', + 'OP REye', + 'OP LEye', + 'OP REar', + 'OP LEar', + 'OP LBigToe', + 'OP LSmallToe', + 'OP LHeel', + 'OP RBigToe', + 'OP RSmallToe', + 'OP RHeel', +] +SPIN_JOINT_NAMES = [ + # 24 Ground Truth joints (superset of joints from different datasets) + 'Right Ankle', + 'Right Knee', + 'Right Hip', # 2 + 'Left Hip', + 'Left Knee', # 4 + 'Left Ankle', + 'Right Wrist', # 6 + 'Right Elbow', + 'Right Shoulder', # 8 + 'Left Shoulder', + 'Left Elbow', # 10 + 'Left Wrist', + 'Neck (LSP)', # 12 + 'Top of Head (LSP)', + 'Pelvis (MPII)', # 14 + 'Thorax (MPII)', + 'Spine (H36M)', # 16 + 'Jaw (H36M)', + 'Head (H36M)', # 18 + 'Nose', + 'Left Eye', + 'Right Eye', + 'Left Ear', + 'Right Ear' +] +JOINT_NAMES = OP_JOINT_NAMES + SPIN_JOINT_NAMES + +COCO_KEYPOINTS = [ + 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', + 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', + 'right_knee', 'left_ankle', 'right_ankle' +] + +# Dict containing the joints in numerical order +JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))} + +# Map joints to SMPL joints +JOINT_MAP = { + 'OP Nose': 24, + 'OP Neck': 12, + 'OP RShoulder': 17, + 'OP RElbow': 19, + 'OP RWrist': 21, + 'OP LShoulder': 16, + 'OP LElbow': 18, + 'OP LWrist': 20, + 'OP MidHip': 0, + 'OP RHip': 2, + 'OP RKnee': 5, + 'OP RAnkle': 8, + 'OP LHip': 1, + 'OP LKnee': 4, + 'OP LAnkle': 7, + 'OP REye': 25, + 'OP LEye': 26, + 'OP REar': 27, + 'OP LEar': 28, + 'OP LBigToe': 29, + 'OP LSmallToe': 30, + 'OP LHeel': 31, + 'OP RBigToe': 32, + 'OP RSmallToe': 33, + 'OP RHeel': 34, + 'Right Ankle': 8, + 'Right Knee': 5, + 'Right Hip': 45, + 'Left Hip': 46, + 'Left Knee': 4, + 'Left Ankle': 7, + 'Right Wrist': 21, + 'Right Elbow': 19, + 'Right Shoulder': 17, + 'Left Shoulder': 16, + 'Left Elbow': 18, + 'Left Wrist': 20, + 'Neck (LSP)': 47, + 'Top of Head (LSP)': 48, + 'Pelvis (MPII)': 49, + 'Thorax (MPII)': 50, + 'Spine (H36M)': 51, + 'Jaw (H36M)': 52, + 'Head (H36M)': 53, + 'Nose': 24, + 'Left Eye': 26, + 'Right Eye': 25, + 'Left Ear': 28, + 'Right Ear': 27 +} + +# Joint selectors +# Indices to get the 14 LSP joints from the 17 H36M joints +H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] +H36M_TO_J14 = H36M_TO_J17[:14] +# Indices to get the 14 LSP joints from the ground truth joints +J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17] +J24_TO_J14 = J24_TO_J17[:14] +J24_TO_J19 = J24_TO_J17[:14] + [19, 20, 21, 22, 23] +# COCO with also 17 joints +J24_TO_JCOCO = [19, 20, 21, 22, 23, 9, 8, 10, 7, 11, 6, 3, 2, 4, 1, 5, 0] + +# Permutation of SMPL pose parameters when flipping the shape +SMPL_JOINTS_FLIP_PERM = [ + 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22 +] +SMPL_POSE_FLIP_PERM = [] +for i in SMPL_JOINTS_FLIP_PERM: + SMPL_POSE_FLIP_PERM.append(3 * i) + SMPL_POSE_FLIP_PERM.append(3 * i + 1) + SMPL_POSE_FLIP_PERM.append(3 * i + 2) +# Permutation indices for the 24 ground truth joints +J24_FLIP_PERM = [ + 5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22 +] +# Permutation indices for the full set of 49 joints +J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\ + + [25+i for i in J24_FLIP_PERM] +SMPL_J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\ + + [25+i for i in SMPL_JOINTS_FLIP_PERM] + +SMPLX2SMPL_J45 = [i for i in range(22)] + [30, 45] + [i for i in range(55, 55 + 21)] + +SMPL_PART_ID = { + 'rightHand': 1, + 'rightUpLeg': 2, + 'leftArm': 3, + 'leftLeg': 4, + 'leftToeBase': 5, + 'leftFoot': 6, + 'spine1': 7, + 'spine2': 8, + 'leftShoulder': 9, + 'rightShoulder': 10, + 'rightFoot': 11, + 'head': 12, + 'rightArm': 13, + 'leftHandIndex1': 14, + 'rightLeg': 15, + 'rightHandIndex1': 16, + 'leftForeArm': 17, + 'rightForeArm': 18, + 'neck': 19, + 'rightToeBase': 20, + 'spine': 21, + 'leftUpLeg': 22, + 'leftHand': 23, + 'hips': 24 +} + +# MANO_NAMES = [ +# 'wrist', +# 'index1', +# 'index2', +# 'index3', +# 'middle1', +# 'middle2', +# 'middle3', +# 'pinky1', +# 'pinky2', +# 'pinky3', +# 'ring1', +# 'ring2', +# 'ring3', +# 'thumb1', +# 'thumb2', +# 'thumb3', +# ] + +HAND_NAMES = [ + 'wrist', + 'thumb1', + 'thumb2', + 'thumb3', + 'thumb', + 'index1', + 'index2', + 'index3', + 'index', + 'middle1', + 'middle2', + 'middle3', + 'middle', + 'ring1', + 'ring2', + 'ring3', + 'ring', + 'pinky1', + 'pinky2', + 'pinky3', + 'pinky', +] + +import lib.smplx.joint_names as smplx_joint_name + +SMPLX_JOINT_NAMES = smplx_joint_name.JOINT_NAMES +SMPLX_JOINT_IDS = {SMPLX_JOINT_NAMES[i]: i for i in range(len(SMPLX_JOINT_NAMES))} + +FOOT_NAMES = ['big_toe', 'small_toe', 'heel'] + +FACIAL_LANDMARKS = [ + 'right_eye_brow1', + 'right_eye_brow2', + 'right_eye_brow3', + 'right_eye_brow4', + 'right_eye_brow5', + 'left_eye_brow5', + 'left_eye_brow4', + 'left_eye_brow3', + 'left_eye_brow2', + 'left_eye_brow1', + 'nose1', + 'nose2', + 'nose3', + 'nose4', + 'right_nose_2', + 'right_nose_1', + 'nose_middle', + 'left_nose_1', + 'left_nose_2', + 'right_eye1', + 'right_eye2', + 'right_eye3', + 'right_eye4', + 'right_eye5', + 'right_eye6', + 'left_eye4', + 'left_eye3', + 'left_eye2', + 'left_eye1', + 'left_eye6', + 'left_eye5', + 'right_mouth_1', + 'right_mouth_2', + 'right_mouth_3', + 'mouth_top', + 'left_mouth_3', + 'left_mouth_2', + 'left_mouth_1', + 'left_mouth_5', # 59 in OpenPose output + 'left_mouth_4', # 58 in OpenPose output + 'mouth_bottom', + 'right_mouth_4', + 'right_mouth_5', + 'right_lip_1', + 'right_lip_2', + 'lip_top', + 'left_lip_2', + 'left_lip_1', + 'left_lip_3', + 'lip_bottom', + 'right_lip_3', + 'right_contour_1', + 'right_contour_2', + 'right_contour_3', + 'right_contour_4', + 'right_contour_5', + 'right_contour_6', + 'right_contour_7', + 'right_contour_8', + 'contour_middle', + 'left_contour_8', + 'left_contour_7', + 'left_contour_6', + 'left_contour_5', + 'left_contour_4', + 'left_contour_3', + 'left_contour_2', + 'left_contour_1', +] + +# LRHAND_FLIP_PERM = [i for i in range(16, 32)] + [i for i in range(16)] +LRHAND_FLIP_PERM = [i for i in range(len(HAND_NAMES), + len(HAND_NAMES) * 2)] + [i for i in range(len(HAND_NAMES))] + +SINGLE_HAND_FLIP_PERM = [i for i in range(len(HAND_NAMES))] + +FEEF_FLIP_PERM = [i for i in range(len(FOOT_NAMES), + len(FOOT_NAMES) * 2)] + [i for i in range(len(FOOT_NAMES))] + +# matchedParts = ( +# [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], +# [21],[20],[19],[18],[17], +# [27], [28], [29], [30], +# [31, 35], [32, 34], [33], +# [32],[31], +# [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46], +# [39],[38], [37],[36],[41],[40], +# [48, 54], [49, 53], [50, 52], [51], +# [50],[49],[48], +# [55, 59], [56, 58], [57], +# [56],[55], +# [60, 64], [61, 63], [62], +# [61],[60], +# [65, 67], [66], +# [65], +# ) + +# matchedParts = ( +# [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],[8], +# ) + +FACE_FLIP_PERM = [ + 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 11, 12, 13, 18, 17, 16, 15, 14, 28, 27, 26, 25, 30, 29, 22, + 21, 20, 19, 24, 23, 37, 36, 35, 34, 33, 32, 31, 42, 41, 40, 39, 38, 47, 46, 45, 44, 43, 50, 49, + 48 +] +FACE_FLIP_PERM = FACE_FLIP_PERM + [ + 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51 +] diff --git a/lib/pymafx/core/path_config.py b/lib/pymafx/core/path_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7190076a46258757e14c6d6acf010208fab56f04 --- /dev/null +++ b/lib/pymafx/core/path_config.py @@ -0,0 +1,13 @@ +import os.path as osp + +pymafx_data_dir = osp.join(osp.dirname(__file__), "../../../data/HPS/pymafx_data") + +JOINT_REGRESSOR_TRAIN_EXTRA = osp.join(pymafx_data_dir, 'J_regressor_extra.npy') +JOINT_REGRESSOR_H36M = osp.join(pymafx_data_dir, 'J_regressor_h36m.npy') +SMPL_MEAN_PARAMS = osp.join(pymafx_data_dir, 'smpl_mean_params.npz') +SMPL_MODEL_DIR = osp.join(pymafx_data_dir, 'smpl') +CHECKPOINT_FILE = osp.join(pymafx_data_dir, 'PyMAF-X_model_checkpoint.pt') +PARTIAL_MESH_DIR = osp.join(pymafx_data_dir, "partial_mesh") + +MANO_DOWNSAMPLING = osp.join(pymafx_data_dir, 'mano_downsampling.npz') +SMPL_DOWNSAMPLING = osp.join(pymafx_data_dir, 'smpl_downsampling.npz') diff --git a/lib/pymafx/models/__init__.py b/lib/pymafx/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85ca9042485c0af8b1f29a4d4eaa1547935a40f --- /dev/null +++ b/lib/pymafx/models/__init__.py @@ -0,0 +1,3 @@ +from .hmr import hmr +from .pymaf_net import pymaf_net +from .smpl import SMPL diff --git a/lib/pymafx/models/attention.py b/lib/pymafx/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f7d3c5c63ba1471ff15ee1a3cf0d8c94a17699 --- /dev/null +++ b/lib/pymafx/models/attention.py @@ -0,0 +1,384 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import math +import os +import code +import torch +from torch import nn +from .transformers.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput +# import src.modeling.data.config as cfg +# from src.modeling._gcnn import GraphConvolution, GraphResBlock +from .transformers.bert.modeling_utils import prune_linear_layer + +LayerNormClass = torch.nn.LayerNorm +BertLayerNorm = torch.nn.LayerNorm +from .transformers.bert import BertConfig + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): + if history_state is not None: + raise + x_states = torch.cat([history_state, hidden_states], dim=1) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(x_states) + mixed_value_layer = self.value(x_states) + else: + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + # print('mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape, mixed_value_layer.shape) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + # print('query_layer', query_layer.shape, key_layer.shape, value_layer.shape) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + raise + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) + return outputs + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + # Update hyper params + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + + def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class AttLayer(nn.Module): + def __init__(self, config): + super(AttLayer, self).__init__() + self.attention = BertAttention(config) + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def MHA(self, hidden_states, attention_mask, head_mask=None, history_state=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state) + attention_output = attention_outputs[0] + + # print('attention_output', hidden_states.shape, attention_output.shape) + + intermediate_output = self.intermediate(attention_output) + # print('intermediate_output', intermediate_output.shape) + layer_output = self.output(intermediate_output, attention_output) + # print('layer_output', layer_output.shape) + outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them + return outputs + + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): + return self.MHA(hidden_states, attention_mask, head_mask, history_state) + + +class AttEncoder(nn.Module): + def __init__(self, config): + super(AttEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([AttLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + history_state = None if encoder_history_states is None else encoder_history_states[i] + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + + return outputs # outputs, (hidden states), (attentions) + + +class EncoderBlock(BertPreTrainedModel): + def __init__(self, config): + super(EncoderBlock, self).__init__(config) + self.config = config + # self.embeddings = BertEmbeddings(config) + self.encoder = AttEncoder(config) + # self.pooler = BertPooler(config) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.img_dim = config.img_feature_dim + + try: + self.use_img_layernorm = config.use_img_layernorm + except: + self.use_img_layernorm = None + + self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True) + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + if self.use_img_layernorm: + self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps) + + self.apply(self.init_weights) + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None + ): + + batch_size = len(img_feats) + seq_length = len(img_feats[0]) + input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).to(img_feats.device) + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # print('-------------------') + # print('position_ids', seq_length, position_ids.shape) + # 494 torch.Size([2, 494]) + + position_embeddings = self.position_embeddings(position_ids) + # print('position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, self.config.hidden_size) + # torch.Size([2, 494, 1024]) 512 1024 + # torch.Size([2, 494, 256]) 512 256 + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + else: + raise + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + else: + raise + + if attention_mask.dim() == 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask.unsqueeze(1) + else: + raise NotImplementedError + + # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to( + dtype=img_feats.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + if head_mask is not None: + raise + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1 + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + # Project input token features to have spcified hidden size + # print('img_feats', img_feats.shape) # torch.Size([2, 494, 2051]) + img_embedding_output = self.img_embedding(img_feats) + # print('img_embedding_output', img_embedding_output.shape) # torch.Size([2, 494, 1024]) + + # We empirically observe that adding an additional learnable position embedding leads to more stable training + embeddings = position_embeddings + img_embedding_output + + if self.use_img_layernorm: + embeddings = self.LayerNorm(embeddings) + # embeddings = self.dropout(embeddings) + + # print('extended_attention_mask', extended_attention_mask.shape) # torch.Size([2, 1, 1, 494]) + encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + + outputs = (sequence_output, ) + if self.config.output_hidden_states: + all_hidden_states = encoder_outputs[1] + outputs = outputs + (all_hidden_states, ) + if self.config.output_attentions: + all_attentions = encoder_outputs[-1] + outputs = outputs + (all_attentions, ) + + return outputs + + +def get_att_block( + img_feature_dim=2048, + output_feat_dim=512, + hidden_feat_dim=1024, + num_attention_heads=4, + num_hidden_layers=1 +): + + config_class = BertConfig + config = config_class.from_pretrained('lib/pymafx/models/transformers/bert/bert-base-uncased/') + + interm_size_scale = 2 + + config.output_attentions = False + # config.hidden_dropout_prob = args.drop_out + config.img_feature_dim = img_feature_dim + # config.output_feature_dim = output_feat_dim + config.hidden_size = hidden_feat_dim + config.intermediate_size = int(config.hidden_size * interm_size_scale) + config.num_hidden_layers = num_hidden_layers + config.num_attention_heads = num_attention_heads + config.max_position_embeddings = 900 + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + + att_model = EncoderBlock(config=config) + + return att_model + + +class Graphormer(BertPreTrainedModel): + ''' + The archtecture of a transformer encoder block we used in Graphormer + ''' + def __init__(self, config): + super(Graphormer, self).__init__(config) + self.config = config + self.bert = EncoderBlock(config) + self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim) + self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) + self.apply(self.init_weights) + + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): + ''' + # self.bert has three outputs + # predictions[0]: output tokens + # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" + # predictions[2]: attentions, if enable "self.config.output_attentions" + ''' + predictions = self.bert( + img_feats=img_feats, + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + + # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. + pred_score = self.cls_head(predictions[0]) + res_img_feats = self.residual(img_feats) + pred_score = pred_score + res_img_feats + # print('pred_score', pred_score.shape) + + if self.config.output_attentions and self.config.output_hidden_states: + return pred_score, predictions[1], predictions[-1] + else: + return pred_score diff --git a/lib/pymafx/models/hmr.py b/lib/pymafx/models/hmr.py new file mode 100644 index 0000000000000000000000000000000000000000..da5459d355d3a3f00c53638a376ab3143b23c01e --- /dev/null +++ b/lib/pymafx/models/hmr.py @@ -0,0 +1,286 @@ +# This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/hmr.py + +import torch +import torch.nn as nn +import torchvision.models.resnet as resnet +import numpy as np +import math +from lib.net.geometry import rot6d_to_rotmat + +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +class Bottleneck(nn.Module): + """ Redefinition of Bottleneck residual block + Adapted from the official PyTorch implementation + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet_Backbone(nn.Module): + """ Feature Extrator with ResNet backbone + """ + def __init__(self, model='res50', pretrained=True): + if model == 'res50': + block, layers = Bottleneck, [3, 4, 6, 3] + else: + pass # TODO + + self.inplanes = 64 + super().__init__() + npose = 24 * 6 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + + if pretrained: + resnet_imagenet = resnet.resnet50(pretrained=True) + self.load_state_dict(resnet_imagenet.state_dict(), strict=False) + logger.info('loaded resnet50 imagenet pretrained model') + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + def _get_deconv_cfg(deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = _get_deconv_cfg(num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias + ) + ) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + + batch_size = x.shape[0] + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + xf = self.avgpool(x4) + xf = xf.view(xf.size(0), -1) + + x_featmap = x4 + + return x_featmap, xf + + +class HMR(nn.Module): + """ SMPL Iterative Regressor with ResNet50 backbone + """ + def __init__(self, block, layers, smpl_mean_params): + self.inplanes = 64 + super().__init__() + npose = 24 * 6 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc1 = nn.Linear(512 * block.expansion + npose + 13, 1024) + self.drop1 = nn.Dropout() + self.fc2 = nn.Linear(1024, 1024) + self.drop2 = nn.Dropout() + self.decpose = nn.Linear(1024, npose) + self.decshape = nn.Linear(1024, 10) + self.deccam = nn.Linear(1024, 3) + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + mean_params = np.load(smpl_mean_params) + init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) + init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) + init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.register_buffer('init_cam', init_cam) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3): + + batch_size = x.shape[0] + + if init_pose is None: + init_pose = self.init_pose.expand(batch_size, -1) + if init_shape is None: + init_shape = self.init_shape.expand(batch_size, -1) + if init_cam is None: + init_cam = self.init_cam.expand(batch_size, -1) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + xf = self.avgpool(x4) + xf = xf.view(xf.size(0), -1) + + pred_pose = init_pose + pred_shape = init_shape + pred_cam = init_cam + for i in range(n_iter): + xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1) + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + pred_pose = self.decpose(xc) + pred_pose + pred_shape = self.decshape(xc) + pred_shape + pred_cam = self.deccam(xc) + pred_cam + + pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) + + return pred_rotmat, pred_shape, pred_cam + + +def hmr(smpl_mean_params, pretrained=True, **kwargs): + """ Constructs an HMR model with ResNet50 backbone. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs) + if pretrained: + resnet_imagenet = resnet.resnet50(pretrained=True) + model.load_state_dict(resnet_imagenet.state_dict(), strict=False) + return model diff --git a/lib/pymafx/models/hr_module.py b/lib/pymafx/models/hr_module.py new file mode 100644 index 0000000000000000000000000000000000000000..7396f1ea59860235db8fdd24434114381c4a7083 --- /dev/null +++ b/lib/pymafx/models/hr_module.py @@ -0,0 +1,463 @@ +import os +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +# from core.cfgs import cfg +from .res_module import BasicBlock, Bottleneck + +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +class HighResolutionModule(nn.Module): + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True + ): + super().__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block( + self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample + ) + ) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample(scale_factor=2**(j - i), mode='nearest') + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False + ), nn.BatchNorm2d(num_outchannels_conv3x3) + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False + ), nn.BatchNorm2d(num_outchannels_conv3x3), nn.ReLU(True) + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + +class PoseHighResolutionNet(nn.Module): + def __init__(self, cfg, pretrained=True, global_mode=False): + self.inplanes = 64 + extra = cfg.HR_MODEL.EXTRA + super().__init__() + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(Bottleneck, self.inplanes, 64, 4) + + self.stage2_cfg = cfg['HR_MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['HR_MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['HR_MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True + ) + + # Classification Head + self.global_mode = global_mode + if self.global_mode: + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(pre_stage_channels) + + self.pretrained_layers = cfg['HR_MODEL']['EXTRA']['PRETRAINED_LAYERS'] + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer(head_block, channels, head_channels[i], 1, stride=1) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i + 1] * head_block.expansion + + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1 + ), nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) + ) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0 + ), nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False + ), nn.BatchNorm2d(num_channels_cur_layer[i]), nn.ReLU(inplace=True) + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True) + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule( + num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, + reset_multi_scale_output + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + s_feat_s2 = y_list[0] + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + s_feat_s3 = y_list[0] + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + s_feat = [y_list[-2], y_list[-3], y_list[-4]] + + # s_feat_s4 = y_list[0] + + # if cfg.MODEL.PyMAF.HR_FEAT_STAGE == 2: + # s_feat = s_feat_s2 + # elif cfg.MODEL.PyMAF.HR_FEAT_STAGE == 3: + # s_feat = s_feat_s3 + # elif cfg.MODEL.PyMAF.HR_FEAT_STAGE == 4: + # s_feat = s_feat_s4 + # else: + # raise ValueError('HR_FEAT_STAGE should be 2, 3, or 4.') + + # Classification Head + if self.global_mode: + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i + 1](y_list[i + 1]) + \ + self.downsamp_modules[i](y) + + y = self.final_layer(y) + + if torch._C._get_tracing_state(): + xf = y.flatten(start_dim=2).mean(dim=2) + else: + xf = F.avg_pool2d(y, kernel_size=y.size()[2:]).view(y.size(0), -1) + else: + xf = None + + return s_feat, xf + + def init_weights(self, pretrained=''): + # logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + + if os.path.isfile(pretrained): + pretrained_state_dict = torch.load(pretrained) + # logger.info('=> loading pretrained HRnet model {}'.format(pretrained)) + + need_init_state_dict = {} + for name, m in pretrained_state_dict.items(): + if name.split('.')[0] in self.pretrained_layers \ + or self.pretrained_layers[0] is '*': + need_init_state_dict[name] = m + self.load_state_dict(need_init_state_dict, strict=False) + elif pretrained: + logger.error('=> please download pre-trained models first!') + raise ValueError('{} is not exist!'.format(pretrained)) + + +def get_hrnet_encoder(cfg, init_weight=True, global_mode=False, **kwargs): + model = PoseHighResolutionNet(cfg, global_mode=global_mode) + + if init_weight: + if cfg.HR_MODEL.PRETR_SET in ['imagenet']: + model.init_weights(cfg.HR_MODEL.PRETRAINED_IM) + logger.info('loaded HRNet imagenet pretrained model') + elif cfg.HR_MODEL.PRETR_SET in ['coco']: + model.init_weights(cfg.HR_MODEL.PRETRAINED_COCO) + logger.info('loaded HRNet coco pretrained model') + else: + model.init_weights() + + return model diff --git a/lib/pymafx/models/maf_extractor.py b/lib/pymafx/models/maf_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..34237bc55663dcbcbd67beb4c5d0b6e693aae266 --- /dev/null +++ b/lib/pymafx/models/maf_extractor.py @@ -0,0 +1,272 @@ +# This script is borrowed and extended from https://github.com/shunsukesaito/PIFu/blob/master/lib/model/SurfaceClassifier.py + +import torch +import scipy +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from lib.pymafx.core import path_config +from lib.pymafx.utils.geometry import projection + +import logging + +logger = logging.getLogger(__name__) + +from .transformers.net_utils import PosEnSine +from .transformers.transformer_basics import OurMultiheadAttention + +from lib.pymafx.utils.imutils import j2d_processing + + +class TransformerDecoderUnit(nn.Module): + def __init__( + self, feat_dim, attri_dim=0, n_head=8, pos_en_flag=True, attn_type='softmax', P=None + ): + super(TransformerDecoderUnit, self).__init__() + self.feat_dim = feat_dim + self.attn_type = attn_type + self.pos_en_flag = pos_en_flag + self.P = P + + assert attri_dim == 0 + if self.pos_en_flag: + pe_dim = 10 + self.pos_en = PosEnSine(pe_dim) + else: + pe_dim = 0 + self.attn = OurMultiheadAttention( + feat_dim + attri_dim + pe_dim * 3, feat_dim + pe_dim * 3, feat_dim, n_head + ) # cross-attention + + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.activation = nn.ReLU(inplace=True) + + self.norm = nn.BatchNorm2d(self.feat_dim) + + def forward(self, q, k, v, pos=None): + if self.pos_en_flag: + q_pos_embed = self.pos_en(q, pos) + k_pos_embed = self.pos_en(k) + + q = torch.cat([q, q_pos_embed], dim=1) + k = torch.cat([k, k_pos_embed], dim=1) + # else: + # q_pos_embed = 0 + # k_pos_embed = 0 + + # cross-multi-head attention + out = self.attn(q=q, k=k, v=v, attn_type=self.attn_type, P=self.P)[0] + + # feed forward + out2 = self.linear2(self.activation(self.linear1(out))) + out = out + out2 + out = self.norm(out) + + return out + + +class Mesh_Sampler(nn.Module): + ''' Mesh Up/Down-sampling + ''' + def __init__(self, type='smpl', level=2, device=torch.device('cuda'), option=None): + super().__init__() + + # downsample SMPL mesh and assign part labels + if type == 'smpl': + # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz + smpl_mesh_graph = np.load( + path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1' + ) + + A = smpl_mesh_graph['A'] + U = smpl_mesh_graph['U'] + D = smpl_mesh_graph['D'] # shape: (2,) + elif type == 'mano': + # from https://github.com/microsoft/MeshGraphormer/blob/main/src/modeling/data/mano_downsampling.npz + mano_mesh_graph = np.load( + path_config.MANO_DOWNSAMPLING, allow_pickle=True, encoding='latin1' + ) + + A = mano_mesh_graph['A'] + U = mano_mesh_graph['U'] + D = mano_mesh_graph['D'] # shape: (2,) + + # downsampling + ptD = [] + for lv in range(len(D)): + d = scipy.sparse.coo_matrix(D[lv]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) + + # downsampling mapping from 6890 points to 431 points + # ptD[0].to_dense() - Size: [1723, 6890] , [195, 778] + # ptD[1].to_dense() - Size: [431, 1723] , [49, 195] + if level == 2: + Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431 + elif level == 1: + Dmap = ptD[0].to_dense() # + self.register_buffer('Dmap', Dmap) + + # upsampling + ptU = [] + for lv in range(len(U)): + d = scipy.sparse.coo_matrix(U[lv]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptU.append(torch.sparse.FloatTensor(i, v, d.shape)) + + # upsampling mapping from 431 points to 6890 points + # ptU[0].to_dense() - Size: [6890, 1723] + # ptU[1].to_dense() - Size: [1723, 431] + if level == 2: + Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890 + elif level == 1: + Umap = ptU[0].to_dense() # + self.register_buffer('Umap', Umap) + + def downsample(self, x): + return torch.matmul(self.Dmap.unsqueeze(0), x) # [B, 431, 3] + + def upsample(self, x): + return torch.matmul(self.Umap.unsqueeze(0), x) # [B, 6890, 3] + + def forward(self, x, mode='downsample'): + if mode == 'downsample': + return self.downsample(x) + elif mode == 'upsample': + return self.upsample(x) + + +class MAF_Extractor(nn.Module): + ''' Mesh-aligned Feature Extrator + As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices. + The features extrated from spatial feature maps will go through a MLP for dimension reduction. + ''' + def __init__( + self, filter_channels, device=torch.device('cuda'), iwp_cam_mode=True, option=None + ): + super().__init__() + + self.device = device + self.filters = [] + self.num_views = 1 + self.last_op = nn.ReLU(True) + + self.iwp_cam_mode = iwp_cam_mode + + for l in range(0, len(filter_channels) - 1): + if 0 != l: + self.filters.append( + nn.Conv1d(filter_channels[l] + filter_channels[0], filter_channels[l + 1], 1) + ) + else: + self.filters.append(nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) + + self.add_module("conv%d" % l, self.filters[l]) + + # downsample SMPL mesh and assign part labels + # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz + smpl_mesh_graph = np.load( + path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1' + ) + + A = smpl_mesh_graph['A'] + U = smpl_mesh_graph['U'] + D = smpl_mesh_graph['D'] # shape: (2,) + + # downsampling + ptD = [] + for level in range(len(D)): + d = scipy.sparse.coo_matrix(D[level]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) + + # downsampling mapping from 6890 points to 431 points + # ptD[0].to_dense() - Size: [1723, 6890] + # ptD[1].to_dense() - Size: [431. 1723] + Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431 + self.register_buffer('Dmap', Dmap) + + # upsampling + ptU = [] + for level in range(len(U)): + d = scipy.sparse.coo_matrix(U[level]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptU.append(torch.sparse.FloatTensor(i, v, d.shape)) + + # upsampling mapping from 431 points to 6890 points + # ptU[0].to_dense() - Size: [6890, 1723] + # ptU[1].to_dense() - Size: [1723, 431] + Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890 + self.register_buffer('Umap', Umap) + + def reduce_dim(self, feature): + ''' + Dimension reduction by multi-layer perceptrons + :param feature: list of [B, C_s, N] point-wise features before dimension reduction + :return: [B, C_p x N] concatantion of point-wise features after dimension reduction + ''' + y = feature + tmpy = feature + for i, f in enumerate(self.filters): + y = self._modules['conv' + str(i)](y if i == 0 else torch.cat([y, tmpy], 1)) + if i != len(self.filters) - 1: + y = F.leaky_relu(y) + if self.num_views > 1 and i == len(self.filters) // 2: + y = y.view(-1, self.num_views, y.shape[1], y.shape[2]).mean(dim=1) + tmpy = feature.view(-1, self.num_views, feature.shape[1], + feature.shape[2]).mean(dim=1) + + y = self.last_op(y) + + # y = y.view(y.shape[0], -1) + + return y + + def sampling(self, points, im_feat=None, z_feat=None, add_att=False, reduce_dim=True): + ''' + Given 2D points, sample the point-wise features for each point, + the dimension of point-wise features will be reduced from C_s to C_p by MLP. + Image features should be pre-computed before this call. + :param points: [B, N, 2] image coordinates of points + :im_feat: [B, C_s, H_s, W_s] spatial feature maps + :return: [B, C_p x N] concatantion of point-wise features after dimension reduction + ''' + # if im_feat is None: + # im_feat = self.im_feat + + batch_size = im_feat.shape[0] + point_feat = torch.nn.functional.grid_sample( + im_feat, points.unsqueeze(2), align_corners=False + )[..., 0] + + if reduce_dim: + mesh_align_feat = self.reduce_dim(point_feat) + return mesh_align_feat + else: + return point_feat + + def forward(self, p, im_feat, cam=None, add_att=False, reduce_dim=True, **kwargs): + ''' Returns mesh-aligned features for the 3D mesh points. + Args: + p (tensor): [B, N_m, 3] mesh vertices + im_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps + cam (tensor): [B, 3] camera + Return: + mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features + ''' + # if cam is None: + # cam = self.cam + p_proj_2d = projection(p, cam, retain_z=False, iwp_mode=self.iwp_cam_mode) + if self.iwp_cam_mode: + # Normalize keypoints to [-1,1] + p_proj_2d = p_proj_2d / (224. / 2.) + else: + p_proj_2d = j2d_processing(p_proj_2d, cam['kps_transf']) + mesh_align_feat = self.sampling(p_proj_2d, im_feat, add_att=add_att, reduce_dim=reduce_dim) + return mesh_align_feat diff --git a/lib/pymafx/models/pose_resnet.py b/lib/pymafx/models/pose_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d97b6609cf02fd2a94d2951f82f71de2be2356c0 --- /dev/null +++ b/lib/pymafx/models/pose_resnet.py @@ -0,0 +1,306 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging + +import torch +import torch.nn as nn + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class PoseResNet(nn.Module): + def __init__(self, block, layers, cfg, global_mode, **kwargs): + self.inplanes = 64 + extra = cfg.POSE_RES_MODEL.EXTRA + self.extra = extra + self.deconv_with_bias = extra.DECONV_WITH_BIAS + + super(PoseResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + self.global_mode = global_mode + if self.global_mode: + self.avgpool = nn.AvgPool2d(7, stride=1) + self.deconv_layers = None + else: + # used for deconv layers + self.deconv_layers = self._make_deconv_layer( + extra.NUM_DECONV_LAYERS, + extra.NUM_DECONV_FILTERS, + extra.NUM_DECONV_KERNELS, + ) + + # self.final_layer = nn.Conv2d( + # in_channels=extra.NUM_DECONV_FILTERS[-1], + # out_channels=17, + # kernel_size=extra.FINAL_CONV_KERNEL, + # stride=1, + # padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0 + # ) + self.final_layer = None + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _get_deconv_cfg(self, deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias + ) + ) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + # x = self.deconv_layers(x) + # x = self.final_layer(x) + + if self.global_mode: + g_feat = self.avgpool(x) + g_feat = g_feat.view(g_feat.size(0), -1) + s_feat_list = [g_feat] + else: + g_feat = None + if self.extra.NUM_DECONV_LAYERS == 3: + deconv_blocks = [ + self.deconv_layers[0:3], self.deconv_layers[3:6], self.deconv_layers[6:9] + ] + + s_feat_list = [] + s_feat = x + for i in range(self.extra.NUM_DECONV_LAYERS): + s_feat = deconv_blocks[i](s_feat) + s_feat_list.append(s_feat) + + return s_feat_list, g_feat + + def init_weights(self, pretrained=''): + if os.path.isfile(pretrained): + # logger.info('=> init deconv weights from normal distribution') + if self.deconv_layers is not None: + for name, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + # logger.info('=> init {}.weight as normal(0, 0.001)'.format(name)) + # logger.info('=> init {}.bias as 0'.format(name)) + nn.init.normal_(m.weight, std=0.001) + if self.deconv_with_bias: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + # logger.info('=> init {}.weight as 1'.format(name)) + # logger.info('=> init {}.bias as 0'.format(name)) + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if self.final_layer is not None: + logger.info('=> init final conv weights from normal distribution') + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + logger.info('=> init {}.weight as normal(0, 0.001)'.format(name)) + logger.info('=> init {}.bias as 0'.format(name)) + nn.init.normal_(m.weight, std=0.001) + nn.init.constant_(m.bias, 0) + + pretrained_state_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + self.load_state_dict(pretrained_state_dict, strict=False) + elif pretrained: + logger.error('=> please download pre-trained models first!') + raise ValueError('{} is not exist!'.format(pretrained)) + else: + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.normal_(m.weight, std=0.001) + # nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + if self.deconv_with_bias: + nn.init.constant_(m.bias, 0) + + +resnet_spec = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) +} + + +def get_resnet_encoder(cfg, init_weight=True, global_mode=False, **kwargs): + num_layers = cfg.POSE_RES_MODEL.EXTRA.NUM_LAYERS + + block_class, layers = resnet_spec[num_layers] + + model = PoseResNet(block_class, layers, cfg, global_mode, **kwargs) + + if init_weight: + if num_layers == 50: + if cfg.POSE_RES_MODEL.PRETR_SET in ['imagenet']: + model.init_weights(cfg.POSE_RES_MODEL.PRETRAINED_IM) + logger.info('loaded ResNet imagenet pretrained model') + elif cfg.POSE_RES_MODEL.PRETR_SET in ['coco']: + model.init_weights(cfg.POSE_RES_MODEL.PRETRAINED_COCO) + logger.info('loaded ResNet coco pretrained model') + else: + raise NotImplementedError + + return model diff --git a/lib/pymafx/models/pymaf_net.py b/lib/pymafx/models/pymaf_net.py new file mode 100644 index 0000000000000000000000000000000000000000..ca57e4b1c8ce971d76ce53d02827f441016a19ab --- /dev/null +++ b/lib/pymafx/models/pymaf_net.py @@ -0,0 +1,1680 @@ +import torch +import torch.nn as nn +import numpy as np +from lib.pymafx.core import constants + +from lib.common.config import cfg +from lib.pymafx.utils.geometry import rot6d_to_rotmat, rotmat_to_rot6d, projection, rotation_matrix_to_angle_axis, compute_twist_rotation +from .maf_extractor import MAF_Extractor, Mesh_Sampler +from .smpl import SMPL, SMPL_MODEL_DIR, SMPL_MEAN_PARAMS, get_partial_smpl, SMPL_Family +from lib.smplx.lbs import batch_rodrigues +from .res_module import IUV_predict_layer +from .hr_module import get_hrnet_encoder +from .pose_resnet import get_resnet_encoder +from lib.pymafx.utils.imutils import j2d_processing +from lib.pymafx.utils.cam_params import homo_vector +from .attention import get_att_block + +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +class Regressor(nn.Module): + def __init__( + self, + feat_dim, + smpl_mean_params, + use_cam_feats=False, + feat_dim_hand=0, + feat_dim_face=0, + bhf_names=['body'], + smpl_models={} + ): + super().__init__() + + npose = 24 * 6 + shape_dim = 10 + cam_dim = 3 + hand_dim = 15 * 6 + face_dim = 3 * 6 + 10 + + self.body_feat_dim = feat_dim + + self.smpl_mode = (cfg.MODEL.MESH_MODEL == 'smpl') + self.smplx_mode = (cfg.MODEL.MESH_MODEL == 'smplx') + self.use_cam_feats = use_cam_feats + + cam_feat_len = 4 if self.use_cam_feats else 0 + + self.bhf_names = bhf_names + self.hand_only_mode = (cfg.TRAIN.BHF_MODE == 'hand_only') + self.face_only_mode = (cfg.TRAIN.BHF_MODE == 'face_only') + self.body_hand_mode = (cfg.TRAIN.BHF_MODE == 'body_hand') + self.full_body_mode = (cfg.TRAIN.BHF_MODE == 'full_body') + + # if self.use_cam_feats: + # assert cfg.MODEL.USE_IWP_CAM is False + if 'body' in self.bhf_names: + self.fc1 = nn.Linear(feat_dim + npose + cam_feat_len + shape_dim + cam_dim, 1024) + self.drop1 = nn.Dropout() + self.fc2 = nn.Linear(1024, 1024) + self.drop2 = nn.Dropout() + self.decpose = nn.Linear(1024, npose) + self.decshape = nn.Linear(1024, 10) + self.deccam = nn.Linear(1024, 3) + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + if not self.smpl_mode: + if self.hand_only_mode: + self.part_names = ['rhand'] + elif self.face_only_mode: + self.part_names = ['face'] + elif self.body_hand_mode: + self.part_names = ['lhand', 'rhand'] + elif self.full_body_mode: + self.part_names = ['lhand', 'rhand', 'face'] + else: + self.part_names = [] + + if 'rhand' in self.part_names: + # self.fc1_hand = nn.Linear(feat_dim_hand + hand_dim + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024) + self.fc1_hand = nn.Linear(feat_dim_hand + hand_dim, 1024) + self.drop1_hand = nn.Dropout() + self.fc2_hand = nn.Linear(1024, 1024) + self.drop2_hand = nn.Dropout() + + # self.declhand = nn.Linear(1024, 15*6) + self.decrhand = nn.Linear(1024, 15 * 6) + # nn.init.xavier_uniform_(self.declhand.weight, gain=0.01) + nn.init.xavier_uniform_(self.decrhand.weight, gain=0.01) + + if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST: + rh_cam_dim = 3 + rh_orient_dim = 6 + rh_shape_dim = 10 + self.fc3_hand = nn.Linear( + 1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024 + ) + self.drop3_hand = nn.Dropout() + + self.decshape_rhand = nn.Linear(1024, 10) + self.decorient_rhand = nn.Linear(1024, 6) + self.deccam_rhand = nn.Linear(1024, 3) + nn.init.xavier_uniform_(self.decshape_rhand.weight, gain=0.01) + nn.init.xavier_uniform_(self.decorient_rhand.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam_rhand.weight, gain=0.01) + + if 'face' in self.part_names: + self.fc1_face = nn.Linear(feat_dim_face + face_dim, 1024) + self.drop1_face = nn.Dropout() + self.fc2_face = nn.Linear(1024, 1024) + self.drop2_face = nn.Dropout() + + self.dechead = nn.Linear(1024, 3 * 6) + self.decexp = nn.Linear(1024, 10) + nn.init.xavier_uniform_(self.dechead.weight, gain=0.01) + nn.init.xavier_uniform_(self.decexp.weight, gain=0.01) + + if cfg.MODEL.MESH_MODEL == 'flame': + rh_cam_dim = 3 + rh_orient_dim = 6 + rh_shape_dim = 10 + self.fc3_face = nn.Linear( + 1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024 + ) + self.drop3_face = nn.Dropout() + + self.decshape_face = nn.Linear(1024, 10) + self.decorient_face = nn.Linear(1024, 6) + self.deccam_face = nn.Linear(1024, 3) + nn.init.xavier_uniform_(self.decshape_face.weight, gain=0.01) + nn.init.xavier_uniform_(self.decorient_face.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam_face.weight, gain=0.01) + + if self.smplx_mode and cfg.MODEL.PyMAF.PRED_VIS_H: + self.fc1_vis = nn.Linear(1024 + 1024 + 1024, 1024) + self.drop1_vis = nn.Dropout() + self.fc2_vis = nn.Linear(1024, 1024) + self.drop2_vis = nn.Dropout() + self.decvis = nn.Linear(1024, 2) + nn.init.xavier_uniform_(self.decvis.weight, gain=0.01) + + if 'body' in smpl_models: + self.smpl = smpl_models['body'] + if 'hand' in smpl_models: + self.mano = smpl_models['hand'] + if 'face' in smpl_models: + self.flame = smpl_models['face'] + + if cfg.MODEL.PyMAF.OPT_WRIST: + self.body_model = SMPL(model_path=SMPL_MODEL_DIR, batch_size=64, create_transl=False) + + mean_params = np.load(smpl_mean_params) + init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) + init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) + init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.register_buffer('init_cam', init_cam) + self.register_buffer('init_orient', init_pose[:, :6]) + + self.flip_vector = torch.ones((1, 9), dtype=torch.float32) + self.flip_vector[:, [1, 2, 3, 6]] *= -1 + self.flip_vector = self.flip_vector.reshape(1, 3, 3) + + if not self.smpl_mode: + lhand_mean_rot6d = rotmat_to_rot6d( + batch_rodrigues(self.smpl.model.model_neutral.left_hand_mean.view(-1, 3)).view( + [-1, 3, 3] + ) + ) + rhand_mean_rot6d = rotmat_to_rot6d( + batch_rodrigues(self.smpl.model.model_neutral.right_hand_mean.view(-1, 3)).view( + [-1, 3, 3] + ) + ) + init_lhand = lhand_mean_rot6d.reshape(-1).unsqueeze(0) + init_rhand = rhand_mean_rot6d.reshape(-1).unsqueeze(0) + # init_hand = torch.cat([init_lhand, init_rhand]).unsqueeze(0) + init_face = rotmat_to_rot6d(torch.stack([torch.eye(3)] * 3)).reshape(-1).unsqueeze(0) + init_exp = torch.zeros(10).unsqueeze(0) + + if self.smplx_mode or 'hand' in bhf_names: + # init_hand = torch.cat([init_lhand, init_rhand]).unsqueeze(0) + self.register_buffer('init_lhand', init_lhand) + self.register_buffer('init_rhand', init_rhand) + if self.smplx_mode or 'face' in bhf_names: + self.register_buffer('init_face', init_face) + self.register_buffer('init_exp', init_exp) + + def forward( + self, + x=None, + n_iter=1, + J_regressor=None, + rw_cam={}, + init_mode=False, + global_iter=-1, + **kwargs + ): + if x is not None: + batch_size = x.shape[0] + else: + if 'xc_rhand' in kwargs: + batch_size = kwargs['xc_rhand'].shape[0] + elif 'xc_face' in kwargs: + batch_size = kwargs['xc_face'].shape[0] + + if 'body' in self.bhf_names: + if 'init_pose' not in kwargs: + kwargs['init_pose'] = self.init_pose.expand(batch_size, -1) + if 'init_shape' not in kwargs: + kwargs['init_shape'] = self.init_shape.expand(batch_size, -1) + if 'init_cam' not in kwargs: + kwargs['init_cam'] = self.init_cam.expand(batch_size, -1) + + pred_cam = kwargs['init_cam'] + pred_pose = kwargs['init_pose'] + pred_shape = kwargs['init_shape'] + + if self.full_body_mode or self.body_hand_mode: + if cfg.MODEL.PyMAF.OPT_WRIST: + pred_rotmat_body = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) + if cfg.MODEL.PyMAF.PRED_VIS_H: + pred_vis_hands = None + + # if self.full_body_mode or 'hand' in self.bhf_names: + if self.smplx_mode or 'hand' in self.bhf_names: + if 'init_lhand' not in kwargs: + # kwargs['init_lhand'] = self.init_lhand.expand(batch_size, -1) + # init with **right** hand pose + kwargs['init_lhand'] = self.init_rhand.expand(batch_size, -1) + if 'init_rhand' not in kwargs: + kwargs['init_rhand'] = self.init_rhand.expand(batch_size, -1) + + pred_lhand, pred_rhand = kwargs['init_lhand'], kwargs['init_rhand'] + + if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST: + if 'init_orient_rh' not in kwargs: + kwargs['init_orient_rh'] = self.init_orient.expand(batch_size, -1) + if 'init_shape_rh' not in kwargs: + kwargs['init_shape_rh'] = self.init_shape.expand(batch_size, -1) + if 'init_cam_rh' not in kwargs: + kwargs['init_cam_rh'] = self.init_cam.expand(batch_size, -1) + pred_orient_rh = kwargs['init_orient_rh'] + pred_shape_rh = kwargs['init_shape_rh'] + pred_cam_rh = kwargs['init_cam_rh'] + if cfg.MODEL.PyMAF.OPT_WRIST: + if 'init_orient_lh' not in kwargs: + kwargs['init_orient_lh'] = self.init_orient.expand(batch_size, -1) + if 'init_shape_lh' not in kwargs: + kwargs['init_shape_lh'] = self.init_shape.expand(batch_size, -1) + if 'init_cam_lh' not in kwargs: + kwargs['init_cam_lh'] = self.init_cam.expand(batch_size, -1) + pred_orient_lh = kwargs['init_orient_lh'] + pred_shape_lh = kwargs['init_shape_lh'] + pred_cam_lh = kwargs['init_cam_lh'] + if cfg.MODEL.MESH_MODEL == 'mano': + pred_cam = torch.cat([pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:]], dim=1) + + # if self.full_body_mode or 'face' in self.bhf_names: + if self.smplx_mode or 'face' in self.bhf_names: + if 'init_face' not in kwargs: + kwargs['init_face'] = self.init_face.expand(batch_size, -1) + if 'init_hand' not in kwargs: + kwargs['init_exp'] = self.init_exp.expand(batch_size, -1) + + pred_face = kwargs['init_face'] + pred_exp = kwargs['init_exp'] + + if cfg.MODEL.MESH_MODEL == 'flame' or cfg.MODEL.PyMAF.OPT_WRIST: + if 'init_orient_fa' not in kwargs: + kwargs['init_orient_fa'] = self.init_orient.expand(batch_size, -1) + pred_orient_fa = kwargs['init_orient_fa'] + if 'init_shape_fa' not in kwargs: + kwargs['init_shape_fa'] = self.init_shape.expand(batch_size, -1) + if 'init_cam_fa' not in kwargs: + kwargs['init_cam_fa'] = self.init_cam.expand(batch_size, -1) + pred_shape_fa = kwargs['init_shape_fa'] + pred_cam_fa = kwargs['init_cam_fa'] + if cfg.MODEL.MESH_MODEL == 'flame': + pred_cam = torch.cat([pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:]], dim=1) + + if not init_mode: + for i in range(n_iter): + if 'body' in self.bhf_names: + xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1) + if self.use_cam_feats: + if cfg.MODEL.USE_IWP_CAM: + # for IWP camera, simply use pre-defined values + vfov = torch.ones((batch_size, 1)).to(xc) * 0.8 + crop_ratio = torch.ones((batch_size, 1)).to(xc) * 0.3 + crop_center = torch.ones((batch_size, 2)).to(xc) * 0.5 + else: + vfov = rw_cam['vfov'][:, None] + crop_ratio = rw_cam['crop_ratio'][:, None] + crop_center = rw_cam['bbox_center'] / torch.cat( + [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1 + ) + xc = torch.cat([xc, vfov, crop_ratio, crop_center], 1) + + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + + pred_cam = self.deccam(xc) + pred_cam + pred_pose = self.decpose(xc) + pred_pose + pred_shape = self.decshape(xc) + pred_shape + + if not self.smpl_mode: + if self.hand_only_mode: + xc_rhand = kwargs['xc_rhand'] + xc_rhand = torch.cat([xc_rhand, pred_rhand], 1) + elif self.face_only_mode: + xc_face = kwargs['xc_face'] + xc_face = torch.cat([xc_face, pred_face, pred_exp], 1) + elif self.body_hand_mode: + xc_lhand, xc_rhand = kwargs['xc_lhand'], kwargs['xc_rhand'] + xc_lhand = torch.cat([xc_lhand, pred_lhand], 1) + xc_rhand = torch.cat([xc_rhand, pred_rhand], 1) + elif self.full_body_mode: + xc_lhand, xc_rhand, xc_face = kwargs['xc_lhand'], kwargs['xc_rhand' + ], kwargs['xc_face'] + xc_lhand = torch.cat([xc_lhand, pred_lhand], 1) + xc_rhand = torch.cat([xc_rhand, pred_rhand], 1) + xc_face = torch.cat([xc_face, pred_face, pred_exp], 1) + + if 'lhand' in self.part_names: + xc_lhand = self.drop1_hand(self.fc1_hand(xc_lhand)) + xc_lhand = self.drop2_hand(self.fc2_hand(xc_lhand)) + pred_lhand = self.decrhand(xc_lhand) + pred_lhand + + if cfg.MODEL.PyMAF.OPT_WRIST: + xc_lhand = torch.cat( + [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1 + ) + xc_lhand = self.drop3_hand(self.fc3_hand(xc_lhand)) + + pred_shape_lh = self.decshape_rhand(xc_lhand) + pred_shape_lh + pred_orient_lh = self.decorient_rhand(xc_lhand) + pred_orient_lh + pred_cam_lh = self.deccam_rhand(xc_lhand) + pred_cam_lh + + if 'rhand' in self.part_names: + xc_rhand = self.drop1_hand(self.fc1_hand(xc_rhand)) + xc_rhand = self.drop2_hand(self.fc2_hand(xc_rhand)) + pred_rhand = self.decrhand(xc_rhand) + pred_rhand + + if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST: + xc_rhand = torch.cat( + [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1 + ) + xc_rhand = self.drop3_hand(self.fc3_hand(xc_rhand)) + + pred_shape_rh = self.decshape_rhand(xc_rhand) + pred_shape_rh + pred_orient_rh = self.decorient_rhand(xc_rhand) + pred_orient_rh + pred_cam_rh = self.deccam_rhand(xc_rhand) + pred_cam_rh + + if cfg.MODEL.MESH_MODEL == 'mano': + pred_cam = torch.cat( + [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1 + ) + + if 'face' in self.part_names: + xc_face = self.drop1_face(self.fc1_face(xc_face)) + xc_face = self.drop2_face(self.fc2_face(xc_face)) + pred_face = self.dechead(xc_face) + pred_face + pred_exp = self.decexp(xc_face) + pred_exp + + if cfg.MODEL.MESH_MODEL == 'flame': + xc_face = torch.cat( + [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1 + ) + xc_face = self.drop3_face(self.fc3_face(xc_face)) + + pred_shape_fa = self.decshape_face(xc_face) + pred_shape_fa + pred_orient_fa = self.decorient_face(xc_face) + pred_orient_fa + pred_cam_fa = self.deccam_face(xc_face) + pred_cam_fa + + if cfg.MODEL.MESH_MODEL == 'flame': + pred_cam = torch.cat( + [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1 + ) + + if self.full_body_mode or self.body_hand_mode: + if cfg.MODEL.PyMAF.PRED_VIS_H: + xc_vis = torch.cat([xc, xc_lhand, xc_rhand], 1) + + xc_vis = self.drop1_vis(self.fc1_vis(xc_vis)) + xc_vis = self.drop2_vis(self.fc2_vis(xc_vis)) + pred_vis_hands = self.decvis(xc_vis) + + pred_vis_lhand = pred_vis_hands[:, 0] > cfg.MODEL.PyMAF.HAND_VIS_TH + pred_vis_rhand = pred_vis_hands[:, 1] > cfg.MODEL.PyMAF.HAND_VIS_TH + + if cfg.MODEL.PyMAF.OPT_WRIST: + + pred_rotmat_body = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) + pred_lwrist = pred_rotmat_body[:, 20] + pred_rwrist = pred_rotmat_body[:, 21] + + pred_gl_body, body_joints = self.body_model.get_global_rotation( + global_orient=pred_rotmat_body[:, 0:1], + body_pose=pred_rotmat_body[:, 1:] + ) + pred_gl_lelbow = pred_gl_body[:, 18] + pred_gl_relbow = pred_gl_body[:, 19] + + target_gl_lwrist = rot6d_to_rotmat( + pred_orient_lh.reshape(batch_size, -1, 6) + ) + target_gl_lwrist *= self.flip_vector.to(target_gl_lwrist.device) + target_gl_rwrist = rot6d_to_rotmat( + pred_orient_rh.reshape(batch_size, -1, 6) + ) + + opt_lwrist = torch.bmm(pred_gl_lelbow.transpose(1, 2), target_gl_lwrist) + opt_rwrist = torch.bmm(pred_gl_relbow.transpose(1, 2), target_gl_rwrist) + + if cfg.MODEL.PyMAF.ADAPT_INTEGR: + # if cfg.MODEL.PyMAF.ADAPT_INTEGR and global_iter == (cfg.MODEL.PyMAF.N_ITER - 1): + tpose_joints = self.smpl.get_tpose(betas=pred_shape) + lelbow_twist_axis = nn.functional.normalize( + tpose_joints[:, 20] - tpose_joints[:, 18], dim=1 + ) + relbow_twist_axis = nn.functional.normalize( + tpose_joints[:, 21] - tpose_joints[:, 19], dim=1 + ) + + lelbow_twist, lelbow_twist_angle = compute_twist_rotation( + opt_lwrist, lelbow_twist_axis + ) + relbow_twist, relbow_twist_angle = compute_twist_rotation( + opt_rwrist, relbow_twist_axis + ) + + min_angle = -0.4 * float(np.pi) + max_angle = 0.4 * float(np.pi) + + lelbow_twist_angle[lelbow_twist_angle == torch. + clamp(lelbow_twist_angle, min_angle, max_angle) + ] = 0 + relbow_twist_angle[relbow_twist_angle == torch. + clamp(relbow_twist_angle, min_angle, max_angle) + ] = 0 + lelbow_twist_angle[lelbow_twist_angle > max_angle] -= max_angle + lelbow_twist_angle[lelbow_twist_angle < min_angle] -= min_angle + relbow_twist_angle[relbow_twist_angle > max_angle] -= max_angle + relbow_twist_angle[relbow_twist_angle < min_angle] -= min_angle + + lelbow_twist = batch_rodrigues( + lelbow_twist_axis * lelbow_twist_angle + ) + relbow_twist = batch_rodrigues( + relbow_twist_axis * relbow_twist_angle + ) + + opt_lwrist = torch.bmm(lelbow_twist.transpose(1, 2), opt_lwrist) + opt_rwrist = torch.bmm(relbow_twist.transpose(1, 2), opt_rwrist) + + # left elbow: 18 + opt_lelbow = torch.bmm(pred_rotmat_body[:, 18], lelbow_twist) + # right elbow: 19 + opt_relbow = torch.bmm(pred_rotmat_body[:, 19], relbow_twist) + + if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == ( + cfg.MODEL.PyMAF.N_ITER - 1 + ): + opt_lwrist_filtered = [ + opt_lwrist[_i] + if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20] + for _i in range(batch_size) + ] + opt_rwrist_filtered = [ + opt_rwrist[_i] + if pred_vis_rhand[_i] else pred_rotmat_body[_i, 21] + for _i in range(batch_size) + ] + opt_lelbow_filtered = [ + opt_lelbow[_i] + if pred_vis_lhand[_i] else pred_rotmat_body[_i, 18] + for _i in range(batch_size) + ] + opt_relbow_filtered = [ + opt_relbow[_i] + if pred_vis_rhand[_i] else pred_rotmat_body[_i, 19] + for _i in range(batch_size) + ] + + opt_lwrist = torch.stack(opt_lwrist_filtered) + opt_rwrist = torch.stack(opt_rwrist_filtered) + opt_lelbow = torch.stack(opt_lelbow_filtered) + opt_relbow = torch.stack(opt_relbow_filtered) + + pred_rotmat_body = torch.cat( + [ + pred_rotmat_body[:, :18], + opt_lelbow.unsqueeze(1), + opt_relbow.unsqueeze(1), + opt_lwrist.unsqueeze(1), + opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:] + ], 1 + ) + else: + if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == ( + cfg.MODEL.PyMAF.N_ITER - 1 + ): + opt_lwrist_filtered = [ + opt_lwrist[_i] + if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20] + for _i in range(batch_size) + ] + opt_rwrist_filtered = [ + opt_rwrist[_i] + if pred_vis_rhand[_i] else pred_rotmat_body[_i, 21] + for _i in range(batch_size) + ] + + opt_lwrist = torch.stack(opt_lwrist_filtered) + opt_rwrist = torch.stack(opt_rwrist_filtered) + + pred_rotmat_body = torch.cat( + [ + pred_rotmat_body[:, :20], + opt_lwrist.unsqueeze(1), + opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:] + ], 1 + ) + + if self.hand_only_mode: + pred_rotmat_rh = rot6d_to_rotmat( + torch.cat([pred_orient_rh, pred_rhand], dim=1).reshape(batch_size, -1, 6) + ) # .view(batch_size, 16, 3, 3) + assert pred_rotmat_rh.shape[1] == 1 + 15 + elif self.face_only_mode: + pred_rotmat_fa = rot6d_to_rotmat( + torch.cat([pred_orient_fa, pred_face], dim=1).reshape(batch_size, -1, 6) + ) # .view(batch_size, 16, 3, 3) + assert pred_rotmat_fa.shape[1] == 1 + 3 + elif self.full_body_mode or self.body_hand_mode: + if cfg.MODEL.PyMAF.OPT_WRIST: + pred_rotmat = pred_rotmat_body + else: + pred_rotmat = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) + assert pred_rotmat.shape[1] == 24 + else: + pred_rotmat = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) + assert pred_rotmat.shape[1] == 24 + + # if self.full_body_mode: + if self.smplx_mode: + if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (cfg.MODEL.PyMAF.N_ITER - 1): + pred_lhand_filtered = [ + pred_lhand[_i] if pred_vis_lhand[_i] else self.init_rhand[0] + for _i in range(batch_size) + ] + pred_rhand_filtered = [ + pred_rhand[_i] if pred_vis_rhand[_i] else self.init_rhand[0] + for _i in range(batch_size) + ] + pred_lhand_filtered = torch.stack(pred_lhand_filtered) + pred_rhand_filtered = torch.stack(pred_rhand_filtered) + pred_hf6d = torch.cat([pred_lhand_filtered, pred_rhand_filtered, pred_face], + dim=1).reshape(batch_size, -1, 6) + else: + pred_hf6d = torch.cat([pred_lhand, pred_rhand, pred_face], + dim=1).reshape(batch_size, -1, 6) + pred_hfrotmat = rot6d_to_rotmat(pred_hf6d) + assert pred_hfrotmat.shape[1] == (15 * 2 + 3) + + # flip left hand pose + pred_lhand_rotmat = pred_hfrotmat[:, :15] * self.flip_vector.to(pred_hfrotmat.device + ).unsqueeze(0) + pred_rhand_rotmat = pred_hfrotmat[:, 15:30] + pred_face_rotmat = pred_hfrotmat[:, 30:] + + if self.hand_only_mode: + pred_output = self.mano( + betas=pred_shape_rh, + right_hand_pose=pred_rotmat_rh[:, 1:], + global_orient=pred_rotmat_rh[:, 0].unsqueeze(1), + pose2rot=False, + ) + elif self.face_only_mode: + pred_output = self.flame( + betas=pred_shape_fa, + global_orient=pred_rotmat_fa[:, 0].unsqueeze(1), + jaw_pose=pred_rotmat_fa[:, 1:2], + leye_pose=pred_rotmat_fa[:, 2:3], + reye_pose=pred_rotmat_fa[:, 3:4], + expression=pred_exp, + pose2rot=False, + ) + else: + smplx_kwargs = {} + # if self.full_body_mode: + if self.smplx_mode: + smplx_kwargs['left_hand_pose'] = pred_lhand_rotmat + smplx_kwargs['right_hand_pose'] = pred_rhand_rotmat + smplx_kwargs['jaw_pose'] = pred_face_rotmat[:, 0:1] + smplx_kwargs['leye_pose'] = pred_face_rotmat[:, 1:2] + smplx_kwargs['reye_pose'] = pred_face_rotmat[:, 2:3] + smplx_kwargs['expression'] = pred_exp + + pred_output = self.smpl( + betas=pred_shape, + body_pose=pred_rotmat[:, 1:], + global_orient=pred_rotmat[:, 0].unsqueeze(1), + pose2rot=False, + **smplx_kwargs, + ) + + pred_vertices = pred_output.vertices + pred_joints = pred_output.joints + + if self.hand_only_mode: + pred_joints_full = pred_output.rhand_joints + elif self.face_only_mode: + pred_joints_full = pred_output.face_joints + elif self.smplx_mode: + pred_joints_full = torch.cat( + [ + pred_joints, pred_output.lhand_joints, pred_output.rhand_joints, + pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints + ], + dim=1 + ) + else: + pred_joints_full = pred_joints + pred_keypoints_2d = projection( + pred_joints_full, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) + if cfg.MODEL.USE_IWP_CAM: + # Normalize keypoints to [-1,1] + pred_keypoints_2d = pred_keypoints_2d / (224. / 2.) + else: + pred_keypoints_2d = j2d_processing(pred_keypoints_2d, rw_cam['kps_transf']) + + len_b_kp = len(constants.JOINT_NAMES) + output = {} + if self.smpl_mode or self.smplx_mode: + if J_regressor is not None: + kp_3d = torch.matmul(J_regressor, pred_vertices) + pred_pelvis = kp_3d[:, [0], :].clone() + kp_3d = kp_3d[:, constants.H36M_TO_J14, :] + kp_3d = kp_3d - pred_pelvis + else: + kp_3d = pred_joints + pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72) + output.update( + { + 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1), + 'verts': pred_vertices, + 'kp_2d': pred_keypoints_2d[:, :len_b_kp], + 'kp_3d': kp_3d, + 'pred_joints': pred_joints, + 'smpl_kp_3d': pred_output.smpl_joints, + 'rotmat': pred_rotmat, + 'pred_cam': pred_cam, + 'pred_shape': pred_shape, + 'pred_pose': pred_pose, + } + ) + # if self.full_body_mode: + if self.smplx_mode: + # assert pred_keypoints_2d.shape[1] == 144 + len_h_kp = len(constants.HAND_NAMES) + len_f_kp = len(constants.FACIAL_LANDMARKS) + len_feet_kp = 2 * len(constants.FOOT_NAMES) + output.update( + { + 'smplx_verts': + pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None, + 'pred_lhand': + pred_lhand, + 'pred_rhand': + pred_rhand, + 'pred_face': + pred_face, + 'pred_exp': + pred_exp, + 'verts_lh': + pred_output.lhand_vertices, + 'verts_rh': + pred_output.rhand_vertices, + # 'pred_arm_rotmat': pred_arm_rotmat, + # 'pred_hfrotmat': pred_hfrotmat, + 'pred_lhand_rotmat': + pred_lhand_rotmat, + 'pred_rhand_rotmat': + pred_rhand_rotmat, + 'pred_face_rotmat': + pred_face_rotmat, + 'pred_lhand_kp3d': + pred_output.lhand_joints, + 'pred_rhand_kp3d': + pred_output.rhand_joints, + 'pred_face_kp3d': + pred_output.face_joints, + 'pred_lhand_kp2d': + pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp], + 'pred_rhand_kp2d': + pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2], + 'pred_face_kp2d': + pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 + + len_f_kp], + 'pred_feet_kp2d': + pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp + + len_h_kp * 2 + len_f_kp + len_feet_kp], + } + ) + if cfg.MODEL.PyMAF.OPT_WRIST: + output.update( + { + 'pred_orient_lh': pred_orient_lh, + 'pred_shape_lh': pred_shape_lh, + 'pred_orient_rh': pred_orient_rh, + 'pred_shape_rh': pred_shape_rh, + 'pred_cam_fa': pred_cam_fa, + 'pred_cam_lh': pred_cam_lh, + 'pred_cam_rh': pred_cam_rh, + } + ) + if cfg.MODEL.PyMAF.PRED_VIS_H: + output.update({'pred_vis_hands': pred_vis_hands}) + elif self.hand_only_mode: + # hand mesh out + assert pred_keypoints_2d.shape[1] == 21 + output.update( + { + 'theta': pred_cam, + 'pred_cam': pred_cam, + 'pred_rhand': pred_rhand, + 'pred_rhand_rotmat': pred_rotmat_rh[:, 1:], + 'pred_orient_rh': pred_orient_rh, + 'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0], + 'verts_rh': pred_output.rhand_vertices, + 'pred_cam_rh': pred_cam_rh, + 'pred_shape_rh': pred_shape_rh, + 'pred_rhand_kp3d': pred_output.rhand_joints, + 'pred_rhand_kp2d': pred_keypoints_2d, + } + ) + elif self.face_only_mode: + # face mesh out + assert pred_keypoints_2d.shape[1] == 68 + output.update( + { + 'theta': pred_cam, + 'pred_cam': pred_cam, + 'pred_face': pred_face, + 'pred_exp': pred_exp, + 'pred_face_rotmat': pred_rotmat_fa[:, 1:], + 'pred_orient_fa': pred_orient_fa, + 'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0], + 'verts_fa': pred_output.flame_vertices, + 'pred_cam_fa': pred_cam_fa, + 'pred_shape_fa': pred_shape_fa, + 'pred_face_kp3d': pred_output.face_joints, + 'pred_face_kp2d': pred_keypoints_2d, + } + ) + return output + + +def get_attention_modules( + module_keys, img_feature_dim_list, hidden_feat_dim, n_iter, num_attention_heads=1 +): + + align_attention = nn.ModuleDict() + for k in module_keys: + align_attention[k] = nn.ModuleList() + for i in range(n_iter): + align_attention[k].append( + get_att_block( + img_feature_dim=img_feature_dim_list[k][i], + hidden_feat_dim=hidden_feat_dim, + num_attention_heads=num_attention_heads + ) + ) + + return align_attention + + +def get_fusion_modules(module_keys, ma_feat_dim, grid_feat_dim, n_iter, out_feat_len): + + feat_fusion = nn.ModuleDict() + for k in module_keys: + feat_fusion[k] = nn.ModuleList() + for i in range(n_iter): + feat_fusion[k].append(nn.Linear(grid_feat_dim + ma_feat_dim[k], out_feat_len[k])) + + return feat_fusion + + +class PyMAF(nn.Module): + """ PyMAF based Regression Network for Human Mesh Recovery / Full-body Mesh Recovery + PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021 + PyMAF-X: Towards Well-aligned Full-body Model Regression from Monocular Images, arXiv:2207.06400, 2022 + """ + def __init__( + self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True, device=torch.device('cuda') + ): + super().__init__() + + self.device = device + + self.smpl_mode = (cfg.MODEL.MESH_MODEL == 'smpl') + self.smplx_mode = (cfg.MODEL.MESH_MODEL == 'smplx') + + assert cfg.TRAIN.BHF_MODE in [ + 'body_only', 'hand_only', 'face_only', 'body_hand', 'full_body' + ] + self.hand_only_mode = (cfg.TRAIN.BHF_MODE == 'hand_only') + self.face_only_mode = (cfg.TRAIN.BHF_MODE == 'face_only') + self.body_hand_mode = (cfg.TRAIN.BHF_MODE == 'body_hand') + self.full_body_mode = (cfg.TRAIN.BHF_MODE == 'full_body') + + bhf_names = [] + if cfg.TRAIN.BHF_MODE in ['body_only', 'body_hand', 'full_body']: + bhf_names.append('body') + if cfg.TRAIN.BHF_MODE in ['hand_only', 'body_hand', 'full_body']: + bhf_names.append('hand') + if cfg.TRAIN.BHF_MODE in ['face_only', 'full_body']: + bhf_names.append('face') + self.bhf_names = bhf_names + + self.part_module_names = {'body': {}, 'hand': {}, 'face': {}, 'link': {}} + + # the limb parts need to be handled + if self.hand_only_mode: + self.part_names = ['rhand'] + elif self.face_only_mode: + self.part_names = ['face'] + elif self.body_hand_mode: + self.part_names = ['lhand', 'rhand'] + elif self.full_body_mode: + self.part_names = ['lhand', 'rhand', 'face'] + else: + self.part_names = [] + + # joint index info + if not self.smpl_mode: + h_root_idx = constants.HAND_NAMES.index('wrist') + h_idx = constants.HAND_NAMES.index('middle1') + f_idx = constants.FACIAL_LANDMARKS.index('nose_middle') + self.hf_center_idx = {'lhand': h_idx, 'rhand': h_idx, 'face': f_idx} + self.hf_root_idx = {'lhand': h_root_idx, 'rhand': h_root_idx, 'face': f_idx} + + lh_idx_coco = constants.COCO_KEYPOINTS.index('left_wrist') + rh_idx_coco = constants.COCO_KEYPOINTS.index('right_wrist') + f_idx_coco = constants.COCO_KEYPOINTS.index('nose') + self.hf_root_idx_coco = {'lhand': lh_idx_coco, 'rhand': rh_idx_coco, 'face': f_idx_coco} + + # create parametric mesh models + self.smpl_family = {} + if self.hand_only_mode and cfg.MODEL.MESH_MODEL == 'mano': + self.smpl_family['hand'] = SMPL_Family(model_type='mano') + self.smpl_family['body'] = SMPL_Family(model_type='smplx') + elif self.face_only_mode and cfg.MODEL.MESH_MODEL == 'flame': + self.smpl_family['face'] = SMPL_Family(model_type='flame') + self.smpl_family['body'] = SMPL_Family(model_type='smplx') + else: + self.smpl_family['body'] = SMPL_Family( + model_type=cfg.MODEL.MESH_MODEL, all_gender=cfg.MODEL.ALL_GENDER + ) + + self.init_mesh_output = None + self.batch_size = 1 + + self.encoders = nn.ModuleDict() + self.global_mode = not cfg.MODEL.PyMAF.MAF_ON + + # build encoders + global_feat_dim = 2048 + bhf_ma_feat_dim = {} + # encoder for the body part + if 'body' in bhf_names: + # if self.smplx_mode or 'hr' in cfg.MODEL.PyMAF.BACKBONE: + if cfg.MODEL.PyMAF.BACKBONE == 'res50': + body_encoder = get_resnet_encoder( + cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode + ) + body_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS) + elif cfg.MODEL.PyMAF.BACKBONE == 'hr48': + body_encoder = get_hrnet_encoder( + cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode + ) + body_sfeat_dim = list(cfg.HR_MODEL.EXTRA.STAGE4.NUM_CHANNELS) + body_sfeat_dim.reverse() + body_sfeat_dim = body_sfeat_dim[1:] + else: + raise NotImplementedError + self.encoders['body'] = body_encoder + self.part_module_names['body'].update({'encoders.body': self.encoders['body']}) + + self.mesh_sampler = Mesh_Sampler(type='smpl') + self.part_module_names['body'].update({'mesh_sampler': self.mesh_sampler}) + + if not cfg.MODEL.PyMAF.GRID_FEAT: + ma_feat_dim = self.mesh_sampler.Dmap.shape[0] * cfg.MODEL.PyMAF.MLP_DIM[-1] + else: + ma_feat_dim = 0 + bhf_ma_feat_dim['body'] = ma_feat_dim + + dp_feat_dim = body_sfeat_dim[-1] + self.with_uv = cfg.LOSS.POINT_REGRESSION_WEIGHTS > 0 + if cfg.MODEL.PyMAF.AUX_SUPV_ON: + assert cfg.MODEL.PyMAF.MAF_ON + self.dp_head = IUV_predict_layer(feat_dim=dp_feat_dim) + self.part_module_names['body'].update({'dp_head': self.dp_head}) + + # encoders for the hand / face parts + if 'hand' in self.bhf_names or 'face' in self.bhf_names: + for hf in ['hand', 'face']: + if hf in bhf_names: + if cfg.MODEL.PyMAF.HF_BACKBONE == 'res50': + self.encoders[hf] = get_resnet_encoder( + cfg, + init_weight=(not cfg.MODEL.EVAL_MODE), + global_mode=self.global_mode + ) + self.part_module_names[hf].update({f'encoders.{hf}': self.encoders[hf]}) + hf_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS) + else: + raise NotImplementedError + + if cfg.MODEL.PyMAF.HF_AUX_SUPV_ON: + assert cfg.MODEL.PyMAF.MAF_ON + self.dp_head_hf = nn.ModuleDict() + if 'hand' in bhf_names: + self.dp_head_hf['hand'] = IUV_predict_layer( + feat_dim=hf_sfeat_dim[-1], mode='pncc' + ) + self.part_module_names['hand'].update( + {'dp_head_hf.hand': self.dp_head_hf['hand']} + ) + if 'face' in bhf_names: + self.dp_head_hf['face'] = IUV_predict_layer( + feat_dim=hf_sfeat_dim[-1], mode='pncc' + ) + self.part_module_names['face'].update( + {'dp_head_hf.face': self.dp_head_hf['face']} + ) + + smpl2limb_vert_faces = get_partial_smpl() + + self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() + self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() + + # grid points for grid feature extraction + grid_size = 21 + xv, yv = torch.meshgrid( + [torch.linspace(-1, 1, grid_size), + torch.linspace(-1, 1, grid_size)] + ) + grid_points = torch.stack([xv.reshape(-1), yv.reshape(-1)]).unsqueeze(0) + self.register_buffer('grid_points', grid_points) + grid_feat_dim = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1] + + # the fusion of grid and mesh-aligned features + self.fuse_grid_align = cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT or cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC + assert not (cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT and cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC) + + if self.fuse_grid_align: + self.att_starts = cfg.MODEL.PyMAF.GRID_ALIGN.ATT_STARTS + n_iter_att = cfg.MODEL.PyMAF.N_ITER - self.att_starts + att_feat_dim_idx = -cfg.MODEL.PyMAF.GRID_ALIGN.ATT_FEAT_IDX + num_att_heads = cfg.MODEL.PyMAF.GRID_ALIGN.ATT_HEAD + hidden_feat_dim = cfg.MODEL.PyMAF.MLP_DIM[att_feat_dim_idx] + bhf_att_feat_dim = {'body': 2048} + + if 'hand' in self.bhf_names: + self.mano_sampler = Mesh_Sampler(type='mano', level=1) + self.mano_ds_len = self.mano_sampler.Dmap.shape[0] + self.part_module_names['hand'].update({'mano_sampler': self.mano_sampler}) + + bhf_ma_feat_dim.update({'hand': self.mano_ds_len * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]}) + + if self.fuse_grid_align: + bhf_att_feat_dim.update({'hand': 1024}) + + if 'face' in self.bhf_names: + bhf_ma_feat_dim.update( + {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]} + ) + if self.fuse_grid_align: + bhf_att_feat_dim.update({'face': 1024}) + + # spatial alignment attention + if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT: + hfimg_feat_dim_list = {} + if 'body' in bhf_names: + hfimg_feat_dim_list['body'] = body_sfeat_dim[-n_iter_att:] + + if 'hand' in self.bhf_names or 'face' in self.bhf_names: + if 'hand' in bhf_names: + hfimg_feat_dim_list['hand'] = hf_sfeat_dim[-n_iter_att:] + if 'face' in bhf_names: + hfimg_feat_dim_list['face'] = hf_sfeat_dim[-n_iter_att:] + + self.align_attention = get_attention_modules( + bhf_names, + hfimg_feat_dim_list, + hidden_feat_dim, + n_iter=n_iter_att, + num_attention_heads=num_att_heads + ) + + for part in bhf_names: + self.part_module_names[part].update( + {f'align_attention.{part}': self.align_attention[part]} + ) + + if self.fuse_grid_align: + self.att_feat_reduce = get_fusion_modules( + bhf_names, + bhf_ma_feat_dim, + grid_feat_dim, + n_iter=n_iter_att, + out_feat_len=bhf_att_feat_dim + ) + for part in bhf_names: + self.part_module_names[part].update( + {f'att_feat_reduce.{part}': self.att_feat_reduce[part]} + ) + + # build regressor for parameter prediction + self.regressor = nn.ModuleList() + for i in range(cfg.MODEL.PyMAF.N_ITER): + ref_infeat_dim = 0 + if 'body' in self.bhf_names: + if cfg.MODEL.PyMAF.MAF_ON: + if self.fuse_grid_align: + if i >= self.att_starts: + ref_infeat_dim = bhf_att_feat_dim['body'] + elif i == 0 or cfg.MODEL.PyMAF.GRID_FEAT: + ref_infeat_dim = grid_feat_dim + else: + ref_infeat_dim = ma_feat_dim + else: + if i == 0 or cfg.MODEL.PyMAF.GRID_FEAT: + ref_infeat_dim = grid_feat_dim + else: + ref_infeat_dim = ma_feat_dim + else: + ref_infeat_dim = global_feat_dim + + if self.smpl_mode: + self.regressor.append( + Regressor( + feat_dim=ref_infeat_dim, + smpl_mean_params=smpl_mean_params, + use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT, + smpl_models=self.smpl_family + ) + ) + else: + if cfg.MODEL.PyMAF.MAF_ON: + if 'hand' in self.bhf_names or 'face' in self.bhf_names: + if i == 0: + feat_dim_hand = grid_feat_dim if 'hand' in self.bhf_names else None + feat_dim_face = grid_feat_dim if 'face' in self.bhf_names else None + else: + if self.fuse_grid_align: + feat_dim_hand = bhf_att_feat_dim[ + 'hand'] if 'hand' in self.bhf_names else None + feat_dim_face = bhf_att_feat_dim[ + 'face'] if 'face' in self.bhf_names else None + else: + feat_dim_hand = bhf_ma_feat_dim[ + 'hand'] if 'hand' in self.bhf_names else None + feat_dim_face = bhf_ma_feat_dim[ + 'face'] if 'face' in self.bhf_names else None + else: + feat_dim_hand = ref_infeat_dim + feat_dim_face = ref_infeat_dim + else: + ref_infeat_dim = global_feat_dim + feat_dim_hand = global_feat_dim + feat_dim_face = global_feat_dim + + self.regressor.append( + Regressor( + feat_dim=ref_infeat_dim, + smpl_mean_params=smpl_mean_params, + use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT, + feat_dim_hand=feat_dim_hand, + feat_dim_face=feat_dim_face, + bhf_names=bhf_names, + smpl_models=self.smpl_family + ) + ) + + # assign sub-regressor to each part + for dec_name, dec_module in self.regressor[-1].named_children(): + if 'hand' in dec_name: + self.part_module_names['hand'].update( + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) + elif 'face' in dec_name or 'head' in dec_name or 'exp' in dec_name: + self.part_module_names['face'].update( + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) + elif 'res' in dec_name or 'vis' in dec_name: + self.part_module_names['link'].update( + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) + elif 'body' in self.part_module_names: + self.part_module_names['body'].update( + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) + + # mesh-aligned feature extractor + self.maf_extractor = nn.ModuleDict() + for part in bhf_names: + self.maf_extractor[part] = nn.ModuleList() + filter_channels_default = cfg.MODEL.PyMAF.MLP_DIM if part == 'body' else cfg.MODEL.PyMAF.HF_MLP_DIM + sfeat_dim = body_sfeat_dim if part == 'body' else hf_sfeat_dim + for i in range(cfg.MODEL.PyMAF.N_ITER): + for f_i, f_dim in enumerate(filter_channels_default): + if sfeat_dim[i] > f_dim: + filter_start = f_i + break + filter_channels = [sfeat_dim[i]] + filter_channels_default[filter_start:] + + if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT and i >= self.att_starts: + self.maf_extractor[part].append( + MAF_Extractor( + filter_channels=filter_channels_default[att_feat_dim_idx:], + iwp_cam_mode=cfg.MODEL.USE_IWP_CAM + ) + ) + else: + self.maf_extractor[part].append( + MAF_Extractor( + filter_channels=filter_channels, iwp_cam_mode=cfg.MODEL.USE_IWP_CAM + ) + ) + self.part_module_names[part].update({f'maf_extractor.{part}': self.maf_extractor[part]}) + + # check all modules have been added to part_module_names + model_dict_all = dict.fromkeys(self.state_dict().keys()) + for key in self.part_module_names.keys(): + for name in list(model_dict_all.keys()): + for k in self.part_module_names[key].keys(): + if name.startswith(k): + del model_dict_all[name] + # if name.startswith('regressor.') and '.smpl.' in name: + # del model_dict_all[name] + # if name.startswith('regressor.') and '.mano.' in name: + # del model_dict_all[name] + if name.startswith('regressor.') and '.init_' in name: + del model_dict_all[name] + if name == 'grid_points': + del model_dict_all[name] + assert (len(model_dict_all.keys()) == 0) + + def init_mesh(self, batch_size, J_regressor=None, rw_cam={}): + """ initialize the mesh model with default poses and shapes + """ + if self.init_mesh_output is None or self.batch_size != batch_size: + self.init_mesh_output = self.regressor[0]( + torch.zeros(batch_size), J_regressor=J_regressor, rw_cam=rw_cam, init_mode=True + ) + self.batch_size = batch_size + return self.init_mesh_output + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """ + Deconv_layer used in Simple Baselines: + Xiao et al. Simple Baselines for Human Pose Estimation and Tracking + https://github.com/microsoft/human-pose-estimation.pytorch + """ + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + def _get_deconv_cfg(deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = _get_deconv_cfg(num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias + ) + ) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + for m in ['body', 'hand', 'face']: + if m in self.smpl_family: + self.smpl_family[m].model.to(*args, **kwargs) + return self + + def cuda(self, *args, **kwargs): + super().cuda(*args, **kwargs) + for m in ['body', 'hand', 'face']: + if m in self.smpl_family: + self.smpl_family[m].model.cuda(*args, **kwargs) + return self + + def forward(self, batch={}, J_regressor=None, rw_cam={}): + ''' + Args: + batch: input dictionary, including + images: 'img_{part}', for part in body, hand, and face if applicable + inversed affine transformation for the cropping of hand/face images: '{part}_theta_inv' for part in lhand, rhand, and face if applicable + J_regressor: joint regression matrix + rw_cam: real-world camera information, applied when cfg.MODEL.USE_IWP_CAM is False + Returns: + out_dict: the list containing the predicted parameters + vis_feat_list: the list containing features for visualization + ''' + + # batch keys: ['img_body', 'orig_height', 'orig_width', 'person_id', 'img_lhand', + # 'lhand_theta_inv', 'img_rhand', 'rhand_theta_inv', 'img_face', 'face_theta_inv'] + + # extract spatial features or global features + # run encoder for body + if 'body' in self.bhf_names: + img_body = batch['img_body'] + batch_size = img_body.shape[0] + s_feat_body, g_feat = self.encoders['body'](batch['img_body']) + if cfg.MODEL.PyMAF.MAF_ON: + assert len(s_feat_body) == cfg.MODEL.PyMAF.N_ITER + + # run encoders for hand / face + if 'hand' in self.bhf_names or 'face' in self.bhf_names: + limb_feat_dict = {} + limb_gfeat_dict = {} + if 'face' in self.bhf_names: + img_face = batch['img_face'] + batch_size = img_face.shape[0] + limb_feat_dict['face'], limb_gfeat_dict['face'] = self.encoders['face'](img_face) + + if 'hand' in self.bhf_names: + if 'lhand' in self.part_names: + img_rhand = batch['img_rhand'] + batch_size = img_rhand.shape[0] + # flip left hand images + img_lhand = torch.flip(batch['img_lhand'], [3]) + img_hands = torch.cat([img_rhand, img_lhand]) + s_feat_hands, g_feat_hands = self.encoders['hand'](img_hands) + limb_feat_dict['rhand'] = [feat[:batch_size] for feat in s_feat_hands] + limb_feat_dict['lhand'] = [feat[batch_size:] for feat in s_feat_hands] + if g_feat_hands is not None: + limb_gfeat_dict['rhand'] = g_feat_hands[:batch_size] + limb_gfeat_dict['lhand'] = g_feat_hands[batch_size:] + else: + img_rhand = batch['img_rhand'] + batch_size = img_rhand.shape[0] + limb_feat_dict['rhand'], limb_gfeat_dict['rhand'] = self.encoders['hand']( + img_rhand + ) + + if cfg.MODEL.PyMAF.MAF_ON: + for k in limb_feat_dict.keys(): + assert len(limb_feat_dict[k]) == cfg.MODEL.PyMAF.N_ITER + + out_dict = {} + + # grid-pattern points + grid_points = torch.transpose(self.grid_points.expand(batch_size, -1, -1), 1, 2) + + # initial parameters + mesh_output = self.init_mesh(batch_size, J_regressor, rw_cam) + + out_dict['mesh_out'] = [mesh_output] + out_dict['dp_out'] = [] + + # for visulization + vis_feat_list = [] + + # dense prediction during training + if not cfg.MODEL.EVAL_MODE: + if 'body' in self.bhf_names: + if cfg.MODEL.PyMAF.AUX_SUPV_ON: + iuv_out_dict = self.dp_head(s_feat_body[-1]) + out_dict['dp_out'].append(iuv_out_dict) + elif self.hand_only_mode: + if cfg.MODEL.PyMAF.HF_AUX_SUPV_ON: + out_dict['rhand_dpout'] = [] + dphand_out_dict = self.dp_head_hf['hand'](limb_feat_dict['rhand'][-1]) + out_dict['rhand_dpout'].append(dphand_out_dict) + elif self.face_only_mode: + if cfg.MODEL.PyMAF.HF_AUX_SUPV_ON: + out_dict['face_dpout'] = [] + dpface_out_dict = self.dp_head_hf['face'](limb_feat_dict['face'][-1]) + out_dict['face_dpout'].append(dpface_out_dict) + + # parameter predictions + for rf_i in range(cfg.MODEL.PyMAF.N_ITER): + current_states = {} + if 'body' in self.bhf_names: + pred_cam = mesh_output['pred_cam'].detach() + pred_shape = mesh_output['pred_shape'].detach() + pred_pose = mesh_output['pred_pose'].detach() + + current_states['init_cam'] = pred_cam + current_states['init_shape'] = pred_shape + current_states['init_pose'] = pred_pose + + pred_smpl_verts = mesh_output['verts'].detach() + + if cfg.MODEL.PyMAF.MAF_ON: + s_feat_i = s_feat_body[rf_i] + + # re-project mesh on the image plane + if self.hand_only_mode: + pred_cam = mesh_output['pred_cam'].detach() + pred_rhand_v = self.mano_sampler(mesh_output['verts_rh']) + pred_rhand_proj = projection( + pred_rhand_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) + if cfg.MODEL.USE_IWP_CAM: + pred_rhand_proj = pred_rhand_proj / (224. / 2.) + else: + pred_rhand_proj = j2d_processing(pred_rhand_proj, rw_cam['kps_transf']) + proj_hf_center = { + 'rhand': + mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1) + } + proj_hf_pts = { + 'rhand': torch.cat([proj_hf_center['rhand'], pred_rhand_proj], dim=1) + } + elif self.face_only_mode: + pred_cam = mesh_output['pred_cam'].detach() + pred_face_v = mesh_output['pred_face_kp3d'] + pred_face_proj = projection( + pred_face_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) + if cfg.MODEL.USE_IWP_CAM: + pred_face_proj = pred_face_proj / (224. / 2.) + else: + pred_face_proj = j2d_processing(pred_face_proj, rw_cam['kps_transf']) + proj_hf_center = { + 'face': mesh_output['pred_face_kp2d'][:, self.hf_root_idx['face']].unsqueeze(1) + } + proj_hf_pts = {'face': torch.cat([proj_hf_center['face'], pred_face_proj], dim=1)} + elif self.body_hand_mode: + pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand]) + pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand]) + pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1) + pred_hand_proj = projection( + pred_hand_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) + if cfg.MODEL.USE_IWP_CAM: + pred_hand_proj = pred_hand_proj / (224. / 2.) + else: + pred_hand_proj = j2d_processing(pred_hand_proj, rw_cam['kps_transf']) + + proj_hf_center = { + 'lhand': + mesh_output['pred_lhand_kp2d'][:, self.hf_root_idx['lhand']].unsqueeze(1), + 'rhand': + mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1), + } + proj_hf_pts = { + 'lhand': + torch.cat( + [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1 + ), + 'rhand': + torch.cat( + [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1 + ), + } + elif self.full_body_mode: + pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand]) + pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand]) + pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1) + pred_hand_proj = projection( + pred_hand_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) + if cfg.MODEL.USE_IWP_CAM: + pred_hand_proj = pred_hand_proj / (224. / 2.) + else: + pred_hand_proj = j2d_processing(pred_hand_proj, rw_cam['kps_transf']) + + proj_hf_center = { + 'lhand': + mesh_output['pred_lhand_kp2d'][:, self.hf_root_idx['lhand']].unsqueeze(1), + 'rhand': + mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1), + 'face': + mesh_output['pred_face_kp2d'][:, self.hf_root_idx['face']].unsqueeze(1) + } + proj_hf_pts = { + 'lhand': + torch.cat( + [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1 + ), + 'rhand': + torch.cat( + [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1 + ), + 'face': + torch.cat([proj_hf_center['face'], mesh_output['pred_face_kp2d']], dim=1) + } + + # extract mesh-aligned features for the hand / face part + if 'hand' in self.bhf_names or 'face' in self.bhf_names: + limb_rf_i = rf_i + hand_face_feat = {} + + for hf_i, part_name in enumerate(self.part_names): + if 'hand' in part_name: + hf_key = 'hand' + elif 'face' in part_name: + hf_key = 'face' + + if cfg.MODEL.PyMAF.MAF_ON: + if cfg.MODEL.PyMAF.HF_BACKBONE == 'res50': + limb_feat_i = limb_feat_dict[part_name][limb_rf_i] + else: + raise NotImplementedError + + limb_reduce_dim = (not self.fuse_grid_align) or (rf_i < self.att_starts) + + if limb_rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT: + limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling( + grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim + ) + else: + if self.hand_only_mode or self.face_only_mode: + proj_hf_pts_crop = proj_hf_pts[part_name][:, :, :2] + + proj_hf_v_center = proj_hf_pts_crop[:, 0].unsqueeze(1) + + if cfg.MODEL.PyMAF.HF_BOX_CENTER: + part_box_ul = torch.min(proj_hf_pts_crop, dim=1)[0].unsqueeze(1) + part_box_br = torch.max(proj_hf_pts_crop, dim=1)[0].unsqueeze(1) + part_box_center = (part_box_ul + part_box_br) / 2. + proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:] - part_box_center + else: + proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:] + + elif self.full_body_mode or self.body_hand_mode: + # convert projection points to the space of cropped hand/face images + theta_i_inv = batch[f'{part_name}_theta_inv'] + proj_hf_pts_crop = torch.bmm( + theta_i_inv, + homo_vector(proj_hf_pts[part_name][:, :, :2]).permute(0, 2, 1) + ).permute(0, 2, 1) + + if part_name == 'lhand': + flip_x = torch.tensor([-1, 1])[None, + None, :].to(proj_hf_pts_crop) + proj_hf_pts_crop *= flip_x + + if cfg.MODEL.PyMAF.HF_BOX_CENTER: + # align projection points with the cropped image center + part_box_ul = torch.min(proj_hf_pts_crop, dim=1)[0].unsqueeze(1) + part_box_br = torch.max(proj_hf_pts_crop, dim=1)[0].unsqueeze(1) + part_box_center = (part_box_ul + part_box_br) / 2. + proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:] - part_box_center + else: + proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:] + + # 0 is the root point + proj_hf_v_center = proj_hf_pts_crop[:, 0].unsqueeze(1) + + limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling( + proj_hf_pts_crop_ctd.detach(), + im_feat=limb_feat_i, + reduce_dim=limb_reduce_dim + ) + + if self.fuse_grid_align and limb_rf_i >= self.att_starts: + + limb_grid_feature_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling( + grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim + ) + limb_grid_ref_feat_ctd = torch.cat( + [limb_grid_feature_ctd, limb_ref_feat_ctd], dim=-1 + ).permute(0, 2, 1) + + if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT: + att_ref_feat_ctd = self.align_attention[hf_key][ + limb_rf_i - self.att_starts](limb_grid_ref_feat_ctd)[0] + elif cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC: + att_ref_feat_ctd = limb_grid_ref_feat_ctd + + att_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].reduce_dim( + att_ref_feat_ctd.permute(0, 2, 1) + ).view(batch_size, -1) + limb_ref_feat_ctd = self.att_feat_reduce[hf_key][ + limb_rf_i - self.att_starts](att_ref_feat_ctd) + + else: + # limb_ref_feat = limb_ref_feat.view(batch_size, -1) + limb_ref_feat_ctd = limb_ref_feat_ctd.view(batch_size, -1) + hand_face_feat[part_name] = limb_ref_feat_ctd + else: + hand_face_feat[part_name] = limb_gfeat_dict[part_name] + + # extract mesh-aligned features for the body part + if 'body' in self.bhf_names: + if cfg.MODEL.PyMAF.MAF_ON: + reduce_dim = (not self.fuse_grid_align) or (rf_i < self.att_starts) + if rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT: + ref_feature = self.maf_extractor['body'][rf_i].sampling( + grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim + ) + else: + # TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration + pred_smpl_verts_ds = self.mesh_sampler.downsample( + pred_smpl_verts + ) # [B, 431, 3] + ref_feature = self.maf_extractor['body'][rf_i]( + pred_smpl_verts_ds, + im_feat=s_feat_i, + cam={ + **rw_cam, 'cam_sxy': pred_cam + }, + add_att=True, + reduce_dim=reduce_dim + ) # [B, 431 * n_feat] + + if self.fuse_grid_align and rf_i >= self.att_starts: + if rf_i > 0 and not cfg.MODEL.PyMAF.GRID_FEAT: + grid_feature = self.maf_extractor['body'][rf_i].sampling( + grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim + ) + grid_ref_feat = torch.cat([grid_feature, ref_feature], dim=-1) + else: + grid_ref_feat = ref_feature + grid_ref_feat = grid_ref_feat.permute(0, 2, 1) + + if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT: + att_ref_feat = self.align_attention['body'][ + rf_i - self.att_starts](grid_ref_feat)[0] + elif cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC: + att_ref_feat = grid_ref_feat + + att_ref_feat = self.maf_extractor['body'][rf_i].reduce_dim( + att_ref_feat.permute(0, 2, 1) + ) + att_ref_feat = att_ref_feat.view(batch_size, -1) + + ref_feature = self.att_feat_reduce['body'][rf_i - + self.att_starts](att_ref_feat) + else: + ref_feature = ref_feature.view(batch_size, -1) + else: + ref_feature = g_feat + else: + ref_feature = None + + if not self.smpl_mode: + if self.hand_only_mode: + current_states['xc_rhand'] = hand_face_feat['rhand'] + elif self.face_only_mode: + current_states['xc_face'] = hand_face_feat['face'] + elif self.body_hand_mode: + current_states['xc_lhand'] = hand_face_feat['lhand'] + current_states['xc_rhand'] = hand_face_feat['rhand'] + elif self.full_body_mode: + current_states['xc_lhand'] = hand_face_feat['lhand'] + current_states['xc_rhand'] = hand_face_feat['rhand'] + current_states['xc_face'] = hand_face_feat['face'] + + if rf_i > 0: + for part in self.part_names: + current_states[f'init_{part}'] = mesh_output[f'pred_{part}'].detach() + if part == 'face': + current_states['init_exp'] = mesh_output['pred_exp'].detach() + if self.hand_only_mode: + current_states['init_shape_rh'] = mesh_output['pred_shape_rh'].detach() + current_states['init_orient_rh'] = mesh_output['pred_orient_rh'].detach() + current_states['init_cam_rh'] = mesh_output['pred_cam_rh'].detach() + elif self.face_only_mode: + current_states['init_shape_fa'] = mesh_output['pred_shape_fa'].detach() + current_states['init_orient_fa'] = mesh_output['pred_orient_fa'].detach() + current_states['init_cam_fa'] = mesh_output['pred_cam_fa'].detach() + elif self.full_body_mode or self.body_hand_mode: + if cfg.MODEL.PyMAF.OPT_WRIST: + current_states['init_shape_lh'] = mesh_output['pred_shape_lh'].detach() + current_states['init_orient_lh'] = mesh_output['pred_orient_lh'].detach( + ) + current_states['init_cam_lh'] = mesh_output['pred_cam_lh'].detach() + + current_states['init_shape_rh'] = mesh_output['pred_shape_rh'].detach() + current_states['init_orient_rh'] = mesh_output['pred_orient_rh'].detach( + ) + current_states['init_cam_rh'] = mesh_output['pred_cam_rh'].detach() + + # update mesh parameters + mesh_output = self.regressor[rf_i]( + ref_feature, + n_iter=1, + J_regressor=J_regressor, + rw_cam=rw_cam, + global_iter=rf_i, + **current_states + ) + + out_dict['mesh_out'].append(mesh_output) + + return out_dict, vis_feat_list + + +def pymaf_net(smpl_mean_params, pretrained=True, device=torch.device('cuda')): + """ Constructs an PyMAF model with ResNet50 backbone. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = PyMAF(smpl_mean_params, pretrained, device) + return model diff --git a/lib/pymafx/models/res_module.py b/lib/pymafx/models/res_module.py new file mode 100644 index 0000000000000000000000000000000000000000..94de7ecaa2ba3ead51c5f960e0ae08b806d9cd80 --- /dev/null +++ b/lib/pymafx/models/res_module.py @@ -0,0 +1,480 @@ +# code brought in part from https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from lib.pymafx.core.cfgs import cfg +# from .transformers.tokenlearner import TokenLearner + +import logging + +logger = logging.getLogger(__name__) + +BN_MOMENTUM = 0.1 + + +def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes * groups, + out_planes * groups, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + groups=groups + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1): + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride, groups=groups) + self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, groups=groups) + self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1): + super().__init__() + self.conv1 = nn.Conv2d( + inplanes * groups, planes * groups, kernel_size=1, bias=False, groups=groups + ) + self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes * groups, + planes * groups, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=groups + ) + self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes * groups, + planes * self.expansion * groups, + kernel_size=1, + bias=False, + groups=groups + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +resnet_spec = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) +} + + +class IUV_predict_layer(nn.Module): + def __init__(self, feat_dim=256, final_cov_k=3, out_channels=25, with_uv=True, mode='iuv'): + super().__init__() + + assert mode in ['iuv', 'seg', 'pncc'] + self.mode = mode + + if mode == 'seg': + self.predict_ann_index = nn.Conv2d( + in_channels=feat_dim, + out_channels=15, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.predict_uv_index = nn.Conv2d( + in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + elif mode == 'iuv': + self.predict_u = nn.Conv2d( + in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.predict_v = nn.Conv2d( + in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.predict_ann_index = nn.Conv2d( + in_channels=feat_dim, + out_channels=15, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.predict_uv_index = nn.Conv2d( + in_channels=feat_dim, + out_channels=25, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + elif mode in ['pncc']: + self.predict_pncc = nn.Conv2d( + in_channels=feat_dim, + out_channels=3, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.inplanes = feat_dim + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + return_dict = {} + + if self.mode in ['iuv', 'seg']: + predict_uv_index = self.predict_uv_index(x) + predict_ann_index = self.predict_ann_index(x) + + return_dict['predict_uv_index'] = predict_uv_index + return_dict['predict_ann_index'] = predict_ann_index + + if self.mode == 'iuv': + predict_u = self.predict_u(x) + predict_v = self.predict_v(x) + return_dict['predict_u'] = predict_u + return_dict['predict_v'] = predict_v + else: + return_dict['predict_u'] = None + return_dict['predict_v'] = None + # return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device) + # return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device) + + if self.mode == 'pncc': + predict_pncc = self.predict_pncc(x) + return_dict['predict_pncc'] = predict_pncc + + return return_dict + + +class Seg_predict_layer(nn.Module): + def __init__(self, feat_dim=256, final_cov_k=3, out_channels=25): + super().__init__() + + self.predict_seg_index = nn.Conv2d( + in_channels=feat_dim, + out_channels=out_channels, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.inplanes = feat_dim + + def forward(self, x): + return_dict = {} + + predict_seg_index = self.predict_seg_index(x) + return_dict['predict_seg_index'] = predict_seg_index + + return return_dict + + +class Kps_predict_layer(nn.Module): + def __init__(self, feat_dim=256, final_cov_k=3, out_channels=3, add_module=None): + super().__init__() + + if add_module is not None: + conv = nn.Conv2d( + in_channels=feat_dim, + out_channels=out_channels, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + self.predict_kps = nn.Sequential( + add_module, + # nn.BatchNorm2d(feat_dim, momentum=BN_MOMENTUM), + # conv, + ) + else: + self.predict_kps = nn.Conv2d( + in_channels=feat_dim, + out_channels=out_channels, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) + + self.inplanes = feat_dim + + def forward(self, x): + return_dict = {} + + predict_kps = self.predict_kps(x) + return_dict['predict_kps'] = predict_kps + + return return_dict + + +class SmplResNet(nn.Module): + def __init__( + self, + resnet_nums, + in_channels=3, + num_classes=229, + last_stride=2, + n_extra_feat=0, + truncate=0, + **kwargs + ): + super().__init__() + + self.inplanes = 64 + self.truncate = truncate + # extra = cfg.MODEL.EXTRA + # self.deconv_with_bias = extra.DECONV_WITH_BIAS + block, layers = resnet_spec[resnet_nums] + + self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) if truncate < 2 else None + self.layer4 = self._make_layer( + block, 512, layers[3], stride=last_stride + ) if truncate < 1 else None + + self.avg_pooling = nn.AdaptiveAvgPool2d(1) + + self.num_classes = num_classes + if num_classes > 0: + self.final_layer = nn.Linear(512 * block.expansion, num_classes) + nn.init.xavier_uniform_(self.final_layer.weight, gain=0.01) + + self.n_extra_feat = n_extra_feat + if n_extra_feat > 0: + self.trans_conv = nn.Sequential( + nn.Conv2d( + n_extra_feat + 512 * block.expansion, + 512 * block.expansion, + kernel_size=1, + bias=False + ), nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM), nn.ReLU(True) + ) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, infeat=None): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) if self.truncate < 2 else x2 + x4 = self.layer4(x3) if self.truncate < 1 else x3 + + if infeat is not None: + x4 = self.trans_conv(torch.cat([infeat, x4], 1)) + + if self.num_classes > 0: + xp = self.avg_pooling(x4) + cls = self.final_layer(xp.view(xp.size(0), -1)) + if not cfg.DANET.USE_MEAN_PARA: + # for non-negative scale + scale = F.relu(cls[:, 0]).unsqueeze(1) + cls = torch.cat((scale, cls[:, 1:]), dim=1) + else: + cls = None + + return cls, {'x4': x4} + + def init_weights(self, pretrained=''): + if os.path.isfile(pretrained): + logger.info('=> loading pretrained model {}'.format(pretrained)) + # self.load_state_dict(pretrained_state_dict, strict=False) + checkpoint = torch.load(pretrained) + if isinstance(checkpoint, OrderedDict): + # state_dict = checkpoint + state_dict_old = self.state_dict() + for key in state_dict_old.keys(): + if key in checkpoint.keys(): + if state_dict_old[key].shape != checkpoint[key].shape: + del checkpoint[key] + state_dict = checkpoint + elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict_old = checkpoint['state_dict'] + state_dict = OrderedDict() + # delete 'module.' because it is saved from DataParallel module + for key in state_dict_old.keys(): + if key.startswith('module.'): + # state_dict[key[7:]] = state_dict[key] + # state_dict.pop(key) + state_dict[key[7:]] = state_dict_old[key] + else: + state_dict[key] = state_dict_old[key] + else: + raise RuntimeError('No state_dict found in checkpoint file {}'.format(pretrained)) + self.load_state_dict(state_dict, strict=False) + else: + logger.error('=> imagenet pretrained model dose not exist') + logger.error('=> please download it first') + raise ValueError('imagenet pretrained model does not exist') + + +class LimbResLayers(nn.Module): + def __init__(self, resnet_nums, inplanes, outplanes=None, groups=1, **kwargs): + super().__init__() + + self.inplanes = inplanes + block, layers = resnet_spec[resnet_nums] + self.outplanes = 256 if outplanes == None else outplanes + self.layer3 = self._make_layer(block, self.outplanes, layers[2], stride=2, groups=groups) + # self.outplanes = 512 if outplanes == None else outplanes + # self.layer4 = self._make_layer(block, self.outplanes, layers[3], stride=2, groups=groups) + + self.avg_pooling = nn.AdaptiveAvgPool2d(1) + + # self.tklr = TokenLearner(S=n_token) + + def _make_layer(self, block, planes, blocks, stride=1, groups=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes * groups, + planes * block.expansion * groups, + kernel_size=1, + stride=stride, + bias=False, + groups=groups + ), + nn.BatchNorm2d(planes * block.expansion * groups, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, groups=groups)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer3(x) + # x = self.layer4(x) + # x = self.avg_pooling(x) + # x_g = self.tklr(x.permute(0, 2, 3, 1).contiguous()) + # x_g = x_g.reshape(x.shape[0], -1) + + return x, None diff --git a/lib/pymafx/models/smpl.py b/lib/pymafx/models/smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3a7871fc41038df1fa9c3c9589bb8d6a8de25b --- /dev/null +++ b/lib/pymafx/models/smpl.py @@ -0,0 +1,927 @@ +# This script is extended based on https://github.com/nkolot/SPIN/blob/master/models/smpl.py + +from typing import Optional +from dataclasses import dataclass + +import os +import torch +import torch.nn as nn +import numpy as np +import pickle +from lib.smplx import SMPL as _SMPL +from lib.smplx import SMPLXLayer, MANOLayer, FLAMELayer +from lib.smplx.lbs import batch_rodrigues, transform_mat, vertices2joints, blend_shapes +from lib.smplx.body_models import SMPLXOutput +import json + +from lib.pymafx.core import path_config, constants + +SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS +SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR + + +@dataclass +class ModelOutput(SMPLXOutput): + smpl_joints: Optional[torch.Tensor] = None + joints_J19: Optional[torch.Tensor] = None + smplx_vertices: Optional[torch.Tensor] = None + flame_vertices: Optional[torch.Tensor] = None + lhand_vertices: Optional[torch.Tensor] = None + rhand_vertices: Optional[torch.Tensor] = None + lhand_joints: Optional[torch.Tensor] = None + rhand_joints: Optional[torch.Tensor] = None + face_joints: Optional[torch.Tensor] = None + lfoot_joints: Optional[torch.Tensor] = None + rfoot_joints: Optional[torch.Tensor] = None + + +class SMPL(_SMPL): + """ Extension of the official SMPL implementation to support more joints """ + def __init__( + self, + create_betas=False, + create_global_orient=False, + create_body_pose=False, + create_transl=False, + *args, + **kwargs + ): + super().__init__( + create_betas=create_betas, + create_global_orient=create_global_orient, + create_body_pose=create_body_pose, + create_transl=create_transl, + *args, + **kwargs + ) + joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] + J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) + self.register_buffer( + 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) + ) + self.joint_map = torch.tensor(joints, dtype=torch.long) + # self.ModelOutput = namedtuple('ModelOutput_', ModelOutput._fields + ('smpl_joints', 'joints_J19',)) + # self.ModelOutput.__new__.__defaults__ = (None,) * len(self.ModelOutput._fields) + + tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0)) + self.register_buffer('tpose_joints', tpose_joints) + + def forward(self, *args, **kwargs): + kwargs['get_skin'] = True + smpl_output = super().forward(*args, **kwargs) + extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) + # smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3] + vertices = smpl_output.vertices + joints = torch.cat([smpl_output.joints, extra_joints], dim=1) + smpl_joints = smpl_output.joints[:, :24] + joints = joints[:, self.joint_map, :] # [B, 49, 3] + joints_J24 = joints[:, -24:, :] + joints_J19 = joints_J24[:, constants.J24_TO_J19, :] + output = ModelOutput( + vertices=vertices, + global_orient=smpl_output.global_orient, + body_pose=smpl_output.body_pose, + joints=joints, + joints_J19=joints_J19, + smpl_joints=smpl_joints, + betas=smpl_output.betas, + full_pose=smpl_output.full_pose + ) + return output + + def get_global_rotation( + self, + global_orient: Optional[torch.Tensor] = None, + body_pose: Optional[torch.Tensor] = None, + **kwargs + ): + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. It is expected to be in rotation matrix + format. (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + Returns + ------- + output: Global rotation matrix + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + + model_vars = [global_orient, body_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1 + ).contiguous() + + # Concatenate all pose vectors + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], + dim=1 + ) + + rot_mats = full_pose.view(batch_size, -1, 3, 3) + + # Get the joints + # NxJx3 array + # joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0).expand(batch_size, -1, -1)) + # joints = torch.unsqueeze(joints, dim=-1) + + joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, self.parents[1:]] + + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, + 1)).reshape(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, self.parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + global_rotmat = transforms[:, :, :3, :3] + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + return global_rotmat, posed_joints + + +class SMPLX(SMPLXLayer): + """ Extension of the official SMPLX implementation to support more functions """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_global_rotation( + self, + global_orient: Optional[torch.Tensor] = None, + body_pose: Optional[torch.Tensor] = None, + left_hand_pose: Optional[torch.Tensor] = None, + right_hand_pose: Optional[torch.Tensor] = None, + jaw_pose: Optional[torch.Tensor] = None, + leye_pose: Optional[torch.Tensor] = None, + reye_pose: Optional[torch.Tensor] = None, + **kwargs + ): + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. It is expected to be in rotation matrix + format. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + Expression coefficients. + For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full pose vector (default=False) + Returns + ------- + output: ModelOutput + A data class that contains the posed vertices and joints + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + + model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1 + ).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, + -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, + 3).expand(batch_size, 15, -1, + -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() + + # Concatenate all pose vectors + full_pose = torch.cat( + [ + global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + jaw_pose.reshape(-1, 1, 3, 3), + leye_pose.reshape(-1, 1, 3, 3), + reye_pose.reshape(-1, 1, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) + ], + dim=1 + ) + + rot_mats = full_pose.view(batch_size, -1, 3, 3) + + # Get the joints + # NxJx3 array + joints = vertices2joints( + self.J_regressor, + self.v_template.unsqueeze(0).expand(batch_size, -1, -1) + ) + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, self.parents[1:]] + + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, + 1)).reshape(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, self.parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + global_rotmat = transforms[:, :, :3, :3] + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + return global_rotmat, posed_joints + + +class SMPLX_ALL(nn.Module): + """ Extension of the official SMPLX implementation to support more joints """ + def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs): + super().__init__() + numBetas = 10 + self.use_face_contour = use_face_contour + if all_gender: + self.genders = ['male', 'female', 'neutral'] + else: + self.genders = ['neutral'] + for gender in self.genders: + assert gender in ['male', 'female', 'neutral'] + self.model_dict = nn.ModuleDict( + { + gender: SMPLX( + path_config.SMPL_MODEL_DIR, + gender=gender, + ext='npz', + num_betas=numBetas, + use_pca=False, + batch_size=batch_size, + use_face_contour=use_face_contour, + num_pca_comps=45, + **kwargs + ) + for gender in self.genders + } + ) + self.model_neutral = self.model_dict['neutral'] + joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] + J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) + self.register_buffer( + 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) + ) + self.joint_map = torch.tensor(joints, dtype=torch.long) + # smplx_to_smpl.pkl, file source: https://smpl-x.is.tue.mpg.de + smplx_to_smpl = pickle.load( + open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') + ) + self.register_buffer( + 'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32) + ) + + smpl2limb_vert_faces = get_partial_smpl('smpl') + self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() + self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() + + # left and right hand joint mapping + smplx2lhand_joints = [ + constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES + ] + smplx2rhand_joints = [ + constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES + ] + self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long) + self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long) + + # left and right foot joint mapping + smplx2lfoot_joints = [ + constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES + ] + smplx2rfoot_joints = [ + constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES + ] + self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long) + self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long) + + for g in self.genders: + J_template = torch.einsum( + 'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template] + ) + J_dirs = torch.einsum( + 'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs] + ) + + self.register_buffer(f'{g}_J_template', J_template) + self.register_buffer(f'{g}_J_dirs', J_dirs) + + def forward(self, *args, **kwargs): + batch_size = kwargs['body_pose'].shape[0] + kwargs['get_skin'] = True + if 'pose2rot' not in kwargs: + kwargs['pose2rot'] = True + if 'gender' not in kwargs: + kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device) + + # pose for 55 joints: 1, 21, 15, 15, 1, 1, 1 + pose_keys = [ + 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', + 'leye_pose', 'reye_pose' + ] + param_keys = ['betas'] + pose_keys + if kwargs['pose2rot']: + for key in pose_keys: + if key in kwargs: + # if key == 'left_hand_pose': + # kwargs[key] += self.model_neutral.left_hand_mean + # elif key == 'right_hand_pose': + # kwargs[key] += self.model_neutral.right_hand_mean + kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( + [batch_size, -1, 3, 3] + ) + if kwargs['body_pose'].shape[1] == 23: + # remove hand pose in the body_pose + kwargs['body_pose'] = kwargs['body_pose'][:, :21] + gender_idx_list = [] + smplx_vertices, smplx_joints = [], [] + for gi, g in enumerate(['male', 'female', 'neutral']): + gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) + if len(gender_idx) == 0: + continue + gender_idx_list.extend([int(idx) for idx in gender_idx]) + gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']} + gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) + gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs) + smplx_vertices.append(gender_smplx_output.vertices) + smplx_joints.append(gender_smplx_output.joints) + + idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] + idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device) + + smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange] + smplx_joints = torch.cat(smplx_joints)[idx_rearrange] + + # constants.HAND_NAMES + lhand_joints = smplx_joints[:, self.smplx2lh_joint_map] + rhand_joints = smplx_joints[:, self.smplx2rh_joint_map] + # constants.FACIAL_LANDMARKS + face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:] + # constants.FOOT_NAMES + lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map] + rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map] + + smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices) + lhand_vertices = smpl_vertices[:, self.smpl2lhand] + rhand_vertices = smpl_vertices[:, self.smpl2rhand] + extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices) + # smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3] + smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45] + joints = torch.cat([smplx_j45, extra_joints], dim=1) + smpl_joints = smplx_j45[:, :24] + joints = joints[:, self.joint_map, :] # [B, 49, 3] + joints_J24 = joints[:, -24:, :] + joints_J19 = joints_J24[:, constants.J24_TO_J19, :] + output = ModelOutput( + vertices=smpl_vertices, + smplx_vertices=smplx_vertices, + lhand_vertices=lhand_vertices, + rhand_vertices=rhand_vertices, + # global_orient=smplx_output.global_orient, + # body_pose=smplx_output.body_pose, + joints=joints, + joints_J19=joints_J19, + smpl_joints=smpl_joints, + # betas=smplx_output.betas, + # full_pose=smplx_output.full_pose, + lhand_joints=lhand_joints, + rhand_joints=rhand_joints, + lfoot_joints=lfoot_joints, + rfoot_joints=rfoot_joints, + face_joints=face_joints, + ) + return output + + # def make_hand_regressor(self): + # # borrowed from https://github.com/mks0601/Hand4Whole_RELEASE/blob/main/common/utils/human_models.py + # regressor = self.model_neutral.J_regressor.numpy() + # vertex_num = self.model_neutral.J_regressor.shape[-1] + # lhand_regressor = np.concatenate((regressor[[20,37,38,39],:], + # np.eye(vertex_num)[5361,None], + # regressor[[25,26,27],:], + # np.eye(vertex_num)[4933,None], + # regressor[[28,29,30],:], + # np.eye(vertex_num)[5058,None], + # regressor[[34,35,36],:], + # np.eye(vertex_num)[5169,None], + # regressor[[31,32,33],:], + # np.eye(vertex_num)[5286,None])) + # rhand_regressor = np.concatenate((regressor[[21,52,53,54],:], + # np.eye(vertex_num)[8079,None], + # regressor[[40,41,42],:], + # np.eye(vertex_num)[7669,None], + # regressor[[43,44,45],:], + # np.eye(vertex_num)[7794,None], + # regressor[[49,50,51],:], + # np.eye(vertex_num)[7905,None], + # regressor[[46,47,48],:], + # np.eye(vertex_num)[8022,None])) + # return torch.from_numpy(lhand_regressor).float(), torch.from_numpy(rhand_regressor).float() + + def get_tpose(self, betas=None, gender=None): + kwargs = {} + if betas is None: + betas = torch.zeros(1, 10).to(self.J_regressor_extra.device) + kwargs['betas'] = betas + + batch_size = kwargs['betas'].shape[0] + device = kwargs['betas'].device + + if gender is None: + kwargs['gender'] = 2 * torch.ones(batch_size).to(device) + else: + kwargs['gender'] = gender + + param_keys = ['betas'] + + gender_idx_list = [] + smplx_joints = [] + for gi, g in enumerate(['male', 'female', 'neutral']): + gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) + if len(gender_idx) == 0: + continue + gender_idx_list.extend([int(idx) for idx in gender_idx]) + gender_kwargs = {} + gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) + + J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes( + gender_kwargs['betas'], getattr(self, f'{g}_J_dirs') + ) + + smplx_joints.append(J) + + idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] + idx_rearrange = torch.tensor(idx_rearrange).long().to(device) + + smplx_joints = torch.cat(smplx_joints)[idx_rearrange] + + return smplx_joints + + +class MANO(MANOLayer): + """ Extension of the official MANO implementation to support more joints """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): + if 'pose2rot' not in kwargs: + kwargs['pose2rot'] = True + pose_keys = ['global_orient', 'right_hand_pose'] + batch_size = kwargs['global_orient'].shape[0] + if kwargs['pose2rot']: + for key in pose_keys: + if key in kwargs: + kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( + [batch_size, -1, 3, 3] + ) + kwargs['hand_pose'] = kwargs.pop('right_hand_pose') + mano_output = super().forward(*args, **kwargs) + th_verts = mano_output.vertices + th_jtr = mano_output.joints + # https://github.com/hassony2/manopth/blob/master/manopth/manolayer.py#L248-L260 + # In addition to MANO reference joints we sample vertices on each finger + # to serve as finger tips + tips = th_verts[:, [745, 317, 445, 556, 673]] + th_jtr = torch.cat([th_jtr, tips], 1) + # Reorder joints to match visualization utilities + th_jtr = th_jtr[:, + [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] + output = ModelOutput( + rhand_vertices=th_verts, + rhand_joints=th_jtr, + ) + return output + + +class FLAME(FLAMELayer): + """ Extension of the official FLAME implementation to support more joints """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): + if 'pose2rot' not in kwargs: + kwargs['pose2rot'] = True + pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose'] + batch_size = kwargs['global_orient'].shape[0] + if kwargs['pose2rot']: + for key in pose_keys: + if key in kwargs: + kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( + [batch_size, -1, 3, 3] + ) + flame_output = super().forward(*args, **kwargs) + output = ModelOutput( + flame_vertices=flame_output.vertices, + face_joints=flame_output.joints[:, 5:], + ) + return output + + +class SMPL_Family(): + def __init__(self, model_type='smpl', *args, **kwargs): + if model_type == 'smpl': + self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs) + elif model_type == 'smplx': + self.model = SMPLX_ALL(*args, **kwargs) + elif model_type == 'mano': + self.model = MANO( + model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs + ) + elif model_type == 'flame': + self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def get_tpose(self, *args, **kwargs): + return self.model.get_tpose(*args, **kwargs) + + # def to(self, device): + # self.model.to(device) + + # def cuda(self, device=None): + # if device is None: + # self.model.cuda() + # else: + # self.model.cuda(device) + + +def get_smpl_faces(): + smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1) + return smpl.faces + + +def get_smplx_faces(): + smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1) + return smplx.faces + + +def get_mano_faces(hand_type='right'): + assert hand_type in ['right', 'left'] + is_rhand = True if hand_type == 'right' else False + mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand) + + return mano.faces + + +def get_flame_faces(): + flame = FLAME(SMPL_MODEL_DIR, batch_size=1) + + return flame.faces + + +def get_model_faces(type='smpl'): + if type == 'smpl': + return get_smpl_faces() + elif type == 'smplx': + return get_smplx_faces() + elif type == 'mano': + return get_mano_faces() + elif type == 'flame': + return get_flame_faces() + + +def get_model_tpose(type='smpl'): + if type == 'smpl': + return get_smpl_tpose() + elif type == 'smplx': + return get_smplx_tpose() + elif type == 'mano': + return get_mano_tpose() + elif type == 'flame': + return get_flame_tpose() + + +def get_smpl_tpose(): + smpl = SMPL( + create_betas=True, + create_global_orient=True, + create_body_pose=True, + model_path=SMPL_MODEL_DIR, + batch_size=1 + ) + vertices = smpl().vertices[0] + return vertices.detach() + + +def get_smpl_tpose_joint(): + smpl = SMPL( + create_betas=True, + create_global_orient=True, + create_body_pose=True, + model_path=SMPL_MODEL_DIR, + batch_size=1 + ) + tpose_joint = smpl().smpl_joints[0] + return tpose_joint.detach() + + +def get_smplx_tpose(): + smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) + vertices = smplx().vertices[0] + return vertices + + +def get_smplx_tpose_joint(): + smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) + tpose_joint = smplx().joints[0] + return tpose_joint + + +def get_mano_tpose(): + mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True) + vertices = mano(global_orient=torch.zeros(1, 3), + right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0] + return vertices + + +def get_flame_tpose(): + flame = FLAME(SMPL_MODEL_DIR, batch_size=1) + vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0] + return vertices + + +def get_part_joints(smpl_joints): + batch_size = smpl_joints.shape[0] + + # part_joints = torch.zeros().to(smpl_joints.device) + + one_seg_pairs = [ + (0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17) + ] + two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)] + + one_seg_pairs.extend(two_seg_pairs) + + single_joints = [(10), (11), (15), (22), (23)] + + part_joints = [] + + for j_p in one_seg_pairs: + new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True) + part_joints.append(new_joint) + + for j_p in single_joints: + part_joints.append(smpl_joints[:, j_p:j_p + 1]) + + part_joints = torch.cat(part_joints, dim=1) + + return part_joints + + +def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): + + body_model_faces = get_model_faces(body_model) + body_model_num_verts = len(get_model_tpose(body_model)) + + part_vert_faces = {} + + for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']: + part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part) + if os.path.exists(part_vid_fname): + part_vids = np.load(part_vid_fname) + part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']} + else: + if part in ['lhand', 'rhand']: + with open( + os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb' + ) as json_file: + smplx_mano_id = pickle.load(json_file) + with open( + os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb' + ) as json_file: + smplx_smpl_id = pickle.load(json_file) + + smplx_tpose = get_smplx_tpose() + smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose) + + if part == 'lhand': + mano_vert = smplx_tpose[smplx_mano_id['left_hand']] + elif part == 'rhand': + mano_vert = smplx_tpose[smplx_mano_id['right_hand']] + + smpl2mano_id = [] + for vert in mano_vert: + v_diff = smpl_tpose - vert + v_diff = torch.sum(v_diff * v_diff, dim=1) + v_closest = torch.argmin(v_diff) + smpl2mano_id.append(int(v_closest)) + + smpl2mano_vids = np.array(smpl2mano_id).astype(np.longlong) + mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left' + ).astype(np.longlong) + + np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces) + part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces} + + elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']: + with open( + os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), + 'rb' + ) as json_file: + smplx_part_id = json.load(json_file) + + # main_body_part = list(smplx_part_id.keys()) + # print('main_body_part', main_body_part) + + if part == 'face': + selected_body_part = ['head'] + elif part == 'arm': + selected_body_part = [ + 'rightHand', + 'leftArm', + 'leftShoulder', + 'rightShoulder', + 'rightArm', + 'leftHandIndex1', + 'rightHandIndex1', + 'leftForeArm', + 'rightForeArm', + 'leftHand', + ] + # selected_body_part = ['rightHand', 'leftArm', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',] + elif part == 'forearm': + selected_body_part = [ + 'rightHand', + 'leftHandIndex1', + 'rightHandIndex1', + 'leftForeArm', + 'rightForeArm', + 'leftHand', + ] + elif part == 'arm_eval': + selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] + elif part == 'larm': + # selected_body_part = ['leftArm', 'leftForeArm'] + selected_body_part = ['leftForeArm'] + elif part == 'rarm': + # selected_body_part = ['rightArm', 'rightForeArm'] + selected_body_part = ['rightForeArm'] + + part_body_idx = [] + for k in selected_body_part: + part_body_idx.extend(smplx_part_id[k]) + + part_body_fid = [] + for f_id, face in enumerate(body_model_faces): + if any(f in part_body_idx for f in face): + part_body_fid.append(f_id) + + smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) + + mesh_vid_raw = np.arange(body_model_num_verts) + head_vid_new = np.arange(len(smpl2head_vids)) + mesh_vid_raw[smpl2head_vids] = head_vid_new + + head_faces = body_model_faces[part_body_fid] + head_faces = mesh_vid_raw[head_faces].astype(np.longlong) + + np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces) + part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces} + + elif part in ['lwrist', 'rwrist']: + + if body_model == 'smplx': + body_model_verts = get_smplx_tpose() + tpose_joint = get_smplx_tpose_joint() + elif body_model == 'smpl': + body_model_verts = get_smpl_tpose() + tpose_joint = get_smpl_tpose_joint() + + wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21] + + dist = 0.005 + wrist_vids = [] + for vid, vt in enumerate(body_model_verts): + + v_j_dist = torch.sum((vt - wrist_joint)**2) + + if v_j_dist < dist: + wrist_vids.append(vid) + + wrist_vids = np.array(wrist_vids) + + part_body_fid = [] + for f_id, face in enumerate(body_model_faces): + if any(f in wrist_vids for f in face): + part_body_fid.append(f_id) + + smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) + + mesh_vid_raw = np.arange(body_model_num_verts) + part_vid_new = np.arange(len(smpl2part_vids)) + mesh_vid_raw[smpl2part_vids] = part_vid_new + + part_faces = body_model_faces[part_body_fid] + part_faces = mesh_vid_raw[part_faces].astype(np.longlong) + + np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces) + part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces} + + # import trimesh + # mesh = trimesh.Trimesh(vertices=body_model_verts[smpl2part_vids], faces=part_faces, process=False) + # mesh.export(f'results/smplx_{part}.obj') + + # mesh = trimesh.Trimesh(vertices=body_model_verts, faces=body_model_faces, process=False) + # mesh.export(f'results/smplx_model.obj') + + return part_vert_faces diff --git a/lib/pymafx/models/transformers/__init__.py b/lib/pymafx/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pymafx/models/transformers/bert/__init__.py b/lib/pymafx/models/transformers/bert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0432a1e92856c438e5fd2f550dc5029a78fa354c --- /dev/null +++ b/lib/pymafx/models/transformers/bert/__init__.py @@ -0,0 +1,19 @@ +__version__ = "1.0.0" + +from .modeling_bert import ( + BertConfig, BertModel, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BERT_PRETRAINED_CONFIG_ARCHIVE_MAP +) + +from .modeling_graphormer import Graphormer + +# from .e2e_body_network import Graphormer_Body_Network + +# from .e2e_hand_network import Graphormer_Hand_Network + +from .modeling_utils import ( + WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer, + Conv1D +) + +from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) diff --git a/lib/pymafx/models/transformers/bert/bert-base-uncased/config.json b/lib/pymafx/models/transformers/bert/bert-base-uncased/config.json new file mode 100644 index 0000000000000000000000000000000000000000..79276673252f15cea400800731e0d4e3d3cba64f --- /dev/null +++ b/lib/pymafx/models/transformers/bert/bert-base-uncased/config.json @@ -0,0 +1,16 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 2, + "vocab_size": 30522 +} diff --git a/lib/pymafx/models/transformers/bert/e2e_body_network.py b/lib/pymafx/models/transformers/bert/e2e_body_network.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1c75e276aa18fa1e8f2d865cbef7a275f71b8c --- /dev/null +++ b/lib/pymafx/models/transformers/bert/e2e_body_network.py @@ -0,0 +1,115 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import torch +import src.modeling.data.config as cfg + + +class Graphormer_Body_Network(torch.nn.Module): + ''' + End-to-end Graphormer network for human pose and mesh reconstruction from a single image. + ''' + def __init__(self, args, config, backbone, trans_encoder, mesh_sampler): + super(Graphormer_Body_Network, self).__init__() + self.config = config + self.config.device = args.device + self.backbone = backbone + self.trans_encoder = trans_encoder + self.upsampling = torch.nn.Linear(431, 1723) + self.upsampling2 = torch.nn.Linear(1723, 6890) + self.cam_param_fc = torch.nn.Linear(3, 1) + self.cam_param_fc2 = torch.nn.Linear(431, 250) + self.cam_param_fc3 = torch.nn.Linear(250, 3) + self.grid_feat_dim = torch.nn.Linear(1024, 2051) + + def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False): + batch_size = images.size(0) + # Generate T-pose template mesh + template_pose = torch.zeros((1, 72)) + template_pose[:, 0] = 3.1416 # Rectify "upside down" reference mesh in global coord + template_pose = template_pose.cuda(self.config.device) + template_betas = torch.zeros((1, 10)).cuda(self.config.device) + template_vertices = smpl(template_pose, template_betas) + + # template mesh simplification + template_vertices_sub = mesh_sampler.downsample(template_vertices) + template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2) + print( + 'template_vertices', template_vertices.shape, template_vertices_sub.shape, + template_vertices_sub2.shape + ) + + # template mesh-to-joint regression + template_3d_joints = smpl.get_h36m_joints(template_vertices) + template_pelvis = template_3d_joints[:, cfg.H36M_J17_NAME.index('Pelvis'), :] + template_3d_joints = template_3d_joints[:, cfg.H36M_J17_TO_J14, :] + num_joints = template_3d_joints.shape[1] + + # normalize + template_3d_joints = template_3d_joints - template_pelvis[:, None, :] + template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :] + + # concatinate template joints and template vertices, and then duplicate to batch size + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2], dim=1) + ref_vertices = ref_vertices.expand(batch_size, -1, -1) + print('ref_vertices', ref_vertices.shape) + + # extract grid features and global image features using a CNN backbone + image_feat, grid_feat = self.backbone(images) + print('image_feat, grid_feat', image_feat.shape, grid_feat.shape) + # concatinate image feat and 3d mesh template + image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) + print('image_feat', image_feat.shape) + # process grid features + grid_feat = torch.flatten(grid_feat, start_dim=2) + grid_feat = grid_feat.transpose(1, 2) + print('grid_feat bf', grid_feat.shape) + grid_feat = self.grid_feat_dim(grid_feat) + print('grid_feat', grid_feat.shape) + # concatinate image feat and template mesh to form the joint/vertex queries + features = torch.cat([ref_vertices, image_feat], dim=2) + print('features', features.shape, ref_vertices.shape, image_feat.shape) + # prepare input tokens including joint/vertex queries and grid features + features = torch.cat([features, grid_feat], dim=1) + print('features', features.shape) + + if is_train == True: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:, :-49, :]).cuda() * 0.01 + print('special_token', special_token.shape, meta_masks.shape) + print('meta_masks', torch.unique(meta_masks)) + features[:, :-49, : + ] = features[:, :-49, :] * meta_masks + special_token * (1 - meta_masks) + + # forward pass + if self.config.output_attentions == True: + features, hidden_states, att = self.trans_encoder(features) + else: + features = self.trans_encoder(features) + + pred_3d_joints = features[:, :num_joints, :] + pred_vertices_sub2 = features[:, num_joints:-49, :] + + # learn camera parameters + x = self.cam_param_fc(pred_vertices_sub2) + x = x.transpose(1, 2) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(1, 2) + cam_param = cam_param.squeeze() + + temp_transpose = pred_vertices_sub2.transpose(1, 2) + pred_vertices_sub = self.upsampling(temp_transpose) + pred_vertices_full = self.upsampling2(pred_vertices_sub) + pred_vertices_sub = pred_vertices_sub.transpose(1, 2) + pred_vertices_full = pred_vertices_full.transpose(1, 2) + + if self.config.output_attentions == True: + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att + else: + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full diff --git a/lib/pymafx/models/transformers/bert/e2e_hand_network.py b/lib/pymafx/models/transformers/bert/e2e_hand_network.py new file mode 100644 index 0000000000000000000000000000000000000000..410968c4abc63e1ae8281b2e0297c8eef4e7bbcf --- /dev/null +++ b/lib/pymafx/models/transformers/bert/e2e_hand_network.py @@ -0,0 +1,94 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import torch +import src.modeling.data.config as cfg + + +class Graphormer_Hand_Network(torch.nn.Module): + ''' + End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. + ''' + def __init__(self, args, config, backbone, trans_encoder): + super(Graphormer_Hand_Network, self).__init__() + self.config = config + self.backbone = backbone + self.trans_encoder = trans_encoder + self.upsampling = torch.nn.Linear(195, 778) + self.cam_param_fc = torch.nn.Linear(3, 1) + self.cam_param_fc2 = torch.nn.Linear(195 + 21, 150) + self.cam_param_fc3 = torch.nn.Linear(150, 3) + self.grid_feat_dim = torch.nn.Linear(1024, 2051) + + def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False): + batch_size = images.size(0) + # Generate T-pose template mesh + template_pose = torch.zeros((1, 48)) + template_pose = template_pose.cuda() + template_betas = torch.zeros((1, 10)).cuda() + template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas) + template_vertices = template_vertices / 1000.0 + template_3d_joints = template_3d_joints / 1000.0 + + template_vertices_sub = mesh_sampler.downsample(template_vertices) + + # normalize + template_root = template_3d_joints[:, cfg.J_NAME.index('Wrist'), :] + template_3d_joints = template_3d_joints - template_root[:, None, :] + template_vertices = template_vertices - template_root[:, None, :] + template_vertices_sub = template_vertices_sub - template_root[:, None, :] + num_joints = template_3d_joints.shape[1] + + # concatinate template joints and template vertices, and then duplicate to batch size + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub], dim=1) + ref_vertices = ref_vertices.expand(batch_size, -1, -1) + + # extract grid features and global image features using a CNN backbone + image_feat, grid_feat = self.backbone(images) + # concatinate image feat and mesh template + image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) + # process grid features + grid_feat = torch.flatten(grid_feat, start_dim=2) + grid_feat = grid_feat.transpose(1, 2) + grid_feat = self.grid_feat_dim(grid_feat) + # concatinate image feat and template mesh to form the joint/vertex queries + features = torch.cat([ref_vertices, image_feat], dim=2) + # prepare input tokens including joint/vertex queries and grid features + features = torch.cat([features, grid_feat], dim=1) + + if is_train == True: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:, :-49, :]).cuda() * 0.01 + features[:, :-49, : + ] = features[:, :-49, :] * meta_masks + special_token * (1 - meta_masks) + + # forward pass + if self.config.output_attentions == True: + features, hidden_states, att = self.trans_encoder(features) + else: + features = self.trans_encoder(features) + + pred_3d_joints = features[:, :num_joints, :] + pred_vertices_sub = features[:, num_joints:-49, :] + + # learn camera parameters + x = self.cam_param_fc(features[:, :-49, :]) + x = x.transpose(1, 2) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(1, 2) + cam_param = cam_param.squeeze() + + temp_transpose = pred_vertices_sub.transpose(1, 2) + pred_vertices = self.upsampling(temp_transpose) + pred_vertices = pred_vertices.transpose(1, 2) + + if self.config.output_attentions == True: + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att + else: + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices diff --git a/lib/pymafx/models/transformers/bert/file_utils.py b/lib/pymafx/models/transformers/bert/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee58bed427f90be254caee9a0733d81ae92c8711 --- /dev/null +++ b/lib/pymafx/models/transformers/bert/file_utils.py @@ -0,0 +1,258 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import sys +import json +import logging +import os +import shutil +import tempfile +import fnmatch +from functools import wraps +from hashlib import sha256 +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from torch.hub import _get_torch_home + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')) + ) +default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path( + os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path) + ) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + if sys.version_info[0] == 2 and not isinstance(cache_dir, str): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + try: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except EnvironmentError: + etag = None + + if sys.version_info[0] == 2 and etag is not None: + etag = etag.decode('utf-8') + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') + matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w') as meta_file: + output_string = json.dumps(meta) + if sys.version_info[0] == 2 and isinstance(output_string, str): + output_string = unicode(output_string, 'utf-8') # The beauty of python 2 + meta_file.write(output_string) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path diff --git a/lib/pymafx/models/transformers/bert/modeling_bert.py b/lib/pymafx/models/transformers/bert/modeling_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a7f27f1bc0e69d87ac3747b8d8acfafb03b4b8 --- /dev/null +++ b/lib/pymafx/models/transformers/bert/modeling_bert.py @@ -0,0 +1,1436 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import math +import os +import sys +from io import open + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from .modeling_utils import ( + WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer, + add_start_docstrings +) + +logger = logging.getLogger(__name__) + +BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", + 'bert-large-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", + 'bert-base-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", + 'bert-large-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", + 'bert-base-multilingual-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", + 'bert-base-multilingual-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", + 'bert-base-chinese': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", + 'bert-base-german-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", + 'bert-large-cased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-large-cased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-base-cased-finetuned-mrpc': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", +} + +BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'bert-base-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", + 'bert-large-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", + 'bert-base-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", + 'bert-large-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", + 'bert-base-multilingual-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", + 'bert-base-multilingual-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", + 'bert-base-chinese': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", + 'bert-base-german-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", + 'bert-large-uncased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", + 'bert-large-cased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", + 'bert-large-uncased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", + 'bert-large-cased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", + 'bert-base-cased-finetuned-mrpc': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", +} + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model. + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m", "global_step"] for n in name): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + elif l[0] == 'squad': + pointer = getattr(pointer, 'classifier') + else: + try: + pointer = getattr(pointer, l[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertConfig(PretrainedConfig): + r""" + :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a + `BertModel`. + + + Arguments: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + """ + pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__( + self, + vocab_size_or_config_json_file=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + **kwargs + ): + super(BertConfig, self).__init__(**kwargs) + if isinstance( + vocab_size_or_config_json_file, str + ) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + else: + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) + + +# try: +# pass +# # from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm +# except ImportError: +# logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + # Update hyper params + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + + def forward(self, input_tensor, attention_mask, head_mask=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str + ) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them + return outputs + + +class BertEncoder(nn.Module): + def __init__(self, config): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + return outputs # outputs, (hidden states), (attentions) + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str + ) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + config_class = BertConfig + pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + + def __init__(self, *inputs, **kwargs): + super(BertPreTrainedModel, self).__init__(*inputs, **kwargs) + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +BERT_START_DOCSTRING = r""" The BERT model was proposed in + `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ + by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer + pre-trained using a combination of masked language modeling objective and next sentence prediction + on a large corpus comprising the Toronto Book Corpus and Wikipedia. + + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`: + https://arxiv.org/abs/1810.04805 + + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module + + Parameters: + config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model. +""" + +BERT_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0`` + + Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1[``. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) +class BertModel(BertPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertModel(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids) + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + self.apply(self.init_weights) + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.embeddings.word_embeddings + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + self.embeddings.word_embeddings = new_embeddings + return self.embeddings.word_embeddings + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1 + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids, position_ids=position_ids, token_type_ids=token_type_ids + ) + encoder_outputs = self.encoder( + embedding_output, extended_attention_mask, head_mask=head_mask + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with two heads on top as done during the pre-training: + a `masked language modeling` head and a `next sentence prediction (classification)` head. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) +class BertForPreTraining(BertPreTrainedModel): + r""" + **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the masked language modeling loss. + Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)`` + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForPreTraining(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids) + >>> prediction_scores, seq_relationship_scores = outputs[:2] + + """ + def __init__(self, config): + super(BertForPreTraining, self).__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + self.apply(self.init_weights) + self.tie_weights() + + def tie_weights(self): + """ Make sure we are sharing the input and output embeddings. + Export to TorchScript can't handle parameter sharing so we are cloning them instead. + """ + self._tie_or_clone_weights( + self.cls.predictions.decoder, self.bert.embeddings.word_embeddings + ) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + outputs = ( + prediction_scores, + seq_relationship_score, + ) + outputs[2:] # add hidden states and attention if they are here + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) + ) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) + total_loss = masked_lm_loss + next_sentence_loss + outputs = (total_loss, ) + outputs + + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, + BERT_INPUTS_DOCSTRING +) +class BertForMaskedLM(BertPreTrainedModel): + r""" + **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the masked language modeling loss. + Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForMaskedLM(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids, masked_lm_labels=input_ids) + >>> loss, prediction_scores = outputs[:2] + + """ + def __init__(self, config): + super(BertForMaskedLM, self).__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config) + + self.apply(self.init_weights) + self.tie_weights() + + def tie_weights(self): + """ Make sure we are sharing the input and output embeddings. + Export to TorchScript can't handle parameter sharing so we are cloning them instead. + """ + self._tie_or_clone_weights( + self.cls.predictions.decoder, self.bert.embeddings.word_embeddings + ) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + outputs = (prediction_scores, + ) + outputs[2:] # Add hidden states and attention is they are here + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) + ) + outputs = (masked_lm_loss, ) + outputs + + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) +class BertForNextSentencePrediction(BertPreTrainedModel): + r""" + **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Next sequence prediction (classification) loss. + **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)`` + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForNextSentencePrediction(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids) + >>> seq_relationship_scores = outputs[0] + + """ + def __init__(self, config): + super(BertForNextSentencePrediction, self).__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + self.apply(self.init_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + pooled_output = outputs[1] + + seq_relationship_score = self.cls(pooled_output) + + outputs = (seq_relationship_score, + ) + outputs[2:] # add hidden states and attention if they are here + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) + outputs = (next_sentence_loss, ) + outputs + + return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) +class BertForSequenceClassification(BertPreTrainedModel): + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification (or regression if config.num_labels==1) loss. + **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` + Classification (or regression if config.num_labels==1) scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForSequenceClassification(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids, labels=labels) + >>> loss, logits = outputs[:2] + + """ + def __init__(self, config): + super(BertForSequenceClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) + + self.apply(self.init_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + outputs = (logits, ) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss, ) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, BERT_START_DOCSTRING +) +class BertForMultipleChoice(BertPreTrainedModel): + r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0`` + + Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Mask to avoid performing attention on padding token indices. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above). + Classification scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForMultipleChoice(config) + >>> choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] + >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + >>> labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids, labels=labels) + >>> loss, classification_scores = outputs[:2] + + """ + def __init__(self, config): + super(BertForMultipleChoice, self).__init__(config) + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.apply(self.init_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None + ): + num_choices = input_ids.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_position_ids = position_ids.view( + -1, position_ids.size(-1) + ) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view( + -1, token_type_ids.size(-1) + ) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view( + -1, attention_mask.size(-1) + ) if attention_mask is not None else None + outputs = self.bert( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits, + ) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss, ) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) +class BertForTokenClassification(BertPreTrainedModel): + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels]``. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` + Classification scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForTokenClassification(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids, labels=labels) + >>> loss, scores = outputs[:2] + + """ + def __init__(self, config): + super(BertForTokenClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.apply(self.init_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + outputs = (logits, ) + outputs[2:] # add hidden states and attention if they are here + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss, ) + outputs + + return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) +class BertForQuestionAnswering(BertPreTrainedModel): + r""" + **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-start scores (before SoftMax). + **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-end scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> + >>> model = BertForQuestionAnswering(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) + >>> loss, start_scores, end_scores = outputs[:2] + + """ + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.apply(self.init_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + start_positions=None, + end_positions=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = ( + start_logits, + end_logits, + ) + outputs[2:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss, ) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/lib/pymafx/models/transformers/bert/modeling_graphormer.py b/lib/pymafx/models/transformers/bert/modeling_graphormer.py new file mode 100644 index 0000000000000000000000000000000000000000..e318af8a45d34148e0db68f42181f692afbf8754 --- /dev/null +++ b/lib/pymafx/models/transformers/bert/modeling_graphormer.py @@ -0,0 +1,381 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import math +import os +import code +import torch +from torch import nn +from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput +# import src.modeling.data.config as cfg +# from src.modeling._gcnn import GraphConvolution, GraphResBlock +from .modeling_utils import prune_linear_layer + +LayerNormClass = torch.nn.LayerNorm +BertLayerNorm = torch.nn.LayerNorm + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): + if history_state is not None: + raise + x_states = torch.cat([history_state, hidden_states], dim=1) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(x_states) + mixed_value_layer = self.value(x_states) + else: + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + print( + 'mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape, + mixed_value_layer.shape + ) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + print('query_layer', query_layer.shape, key_layer.shape, value_layer.shape) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + raise + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) + return outputs + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + # Update hyper params + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + + def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class GraphormerLayer(nn.Module): + def __init__(self, config): + super(GraphormerLayer, self).__init__() + self.attention = BertAttention(config) + self.has_graph_conv = config.graph_conv + self.mesh_type = config.mesh_type + + if self.has_graph_conv == True: + if self.mesh_type == 'hand': + self.graph_conv = GraphResBlock( + config.hidden_size, config.hidden_size, mesh_type=self.mesh_type + ) + elif self.mesh_type == 'body': + self.graph_conv = GraphResBlock( + config.hidden_size, config.hidden_size, mesh_type=self.mesh_type + ) + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, history_state=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state) + attention_output = attention_outputs[0] + + if self.has_graph_conv == True: + if self.mesh_type == 'body': + joints = attention_output[:, 0:14, :] + vertices = attention_output[:, 14:-49, :] + img_tokens = attention_output[:, -49:, :] + + elif self.mesh_type == 'hand': + joints = attention_output[:, 0:21, :] + vertices = attention_output[:, 21:-49, :] + img_tokens = attention_output[:, -49:, :] + + vertices = self.graph_conv(vertices) + joints_vertices = torch.cat([joints, vertices, img_tokens], dim=1) + else: + joints_vertices = attention_output + + intermediate_output = self.intermediate(joints_vertices) + layer_output = self.output(intermediate_output, joints_vertices) + print('layer_output', layer_output.shape) + outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them + return outputs + + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): + return self.MHA_GCN(hidden_states, attention_mask, head_mask, history_state) + + +class GraphormerEncoder(nn.Module): + def __init__(self, config): + super(GraphormerEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList( + [GraphormerLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + history_state = None if encoder_history_states is None else encoder_history_states[i] + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + + return outputs # outputs, (hidden states), (attentions) + + +class EncoderBlock(BertPreTrainedModel): + def __init__(self, config): + super(EncoderBlock, self).__init__(config) + self.config = config + # self.embeddings = BertEmbeddings(config) + self.encoder = GraphormerEncoder(config) + # self.pooler = BertPooler(config) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.img_dim = config.img_feature_dim + + try: + self.use_img_layernorm = config.use_img_layernorm + except: + self.use_img_layernorm = None + + self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if self.use_img_layernorm: + self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps) + + self.apply(self.init_weights) + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None + ): + + batch_size = len(img_feats) + seq_length = len(img_feats[0]) + input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).cuda() + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + print('-------------------') + print('position_ids', seq_length, position_ids.shape) + + position_embeddings = self.position_embeddings(position_ids) + print( + 'position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, + self.config.hidden_size + ) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + else: + raise + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + else: + raise + + if attention_mask.dim() == 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask.unsqueeze(1) + else: + raise NotImplementedError + + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + if head_mask is not None: + raise + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1 + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + # Project input token features to have spcified hidden size + print('img_feats', img_feats.shape) + img_embedding_output = self.img_embedding(img_feats) + print('img_embedding_output', img_embedding_output.shape) + + # We empirically observe that adding an additional learnable position embedding leads to more stable training + embeddings = position_embeddings + img_embedding_output + + if self.use_img_layernorm: + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + print('extended_attention_mask', extended_attention_mask.shape) + encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + + outputs = (sequence_output, ) + if self.config.output_hidden_states: + all_hidden_states = encoder_outputs[1] + outputs = outputs + (all_hidden_states, ) + if self.config.output_attentions: + all_attentions = encoder_outputs[-1] + outputs = outputs + (all_attentions, ) + + return outputs + + +class Graphormer(BertPreTrainedModel): + ''' + The archtecture of a transformer encoder block we used in Graphormer + ''' + def __init__(self, config): + super(Graphormer, self).__init__(config) + self.config = config + self.bert = EncoderBlock(config) + self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim) + self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) + self.apply(self.init_weights) + + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): + ''' + # self.bert has three outputs + # predictions[0]: output tokens + # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" + # predictions[2]: attentions, if enable "self.config.output_attentions" + ''' + predictions = self.bert( + img_feats=img_feats, + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) + + # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. + pred_score = self.cls_head(predictions[0]) + res_img_feats = self.residual(img_feats) + pred_score = pred_score + res_img_feats + print('pred_score', pred_score.shape) + + if self.config.output_attentions and self.config.output_hidden_states: + return pred_score, predictions[1], predictions[-1] + else: + return pred_score diff --git a/lib/pymafx/models/transformers/bert/modeling_utils.py b/lib/pymafx/models/transformers/bert/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40a0915822c8e736de8ac2466c075e6cc5ef7e83 --- /dev/null +++ b/lib/pymafx/models/transformers/bert/modeling_utils.py @@ -0,0 +1,973 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import copy +import json +import logging +import os +from io import open + +import six +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" +TF_WEIGHTS_NAME = 'model.ckpt' + +try: + from torch.nn import Identity +except ImportError: + # Older PyTorch compatibility + class Identity(nn.Module): + r"""A placeholder identity operator that is argument-insensitive. + """ + def __init__(self, *args, **kwargs): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +if not six.PY2: + + def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = ''.join(docstr) + fn.__doc__ + return fn + + return docstring_decorator +else: + # Not possible to update class docstrings on python2 + def add_start_docstrings(*docstr): + def docstring_decorator(fn): + return fn + + return docstring_decorator + + +class PretrainedConfig(object): + """ Base class for all configuration classes. + Handle a few common parameters and methods for loading/downloading/saving configurations. + """ + pretrained_config_archive_map = {} + + def __init__(self, **kwargs): + self.finetuning_task = kwargs.pop('finetuning_task', None) + self.num_labels = kwargs.pop('num_labels', 2) + self.output_attentions = kwargs.pop('output_attentions', False) + self.output_hidden_states = kwargs.pop('output_hidden_states', False) + self.torchscript = kwargs.pop('torchscript', False) + + def save_pretrained(self, save_directory): + """ Save a configuration object to a directory, so that it + can be re-loaded using the `from_pretrained(save_directory)` class method. + """ + assert os.path.isdir( + save_directory + ), "Saving path should be a directory where the model and configuration can be saved" + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" Instantiate a PretrainedConfig from a pre-trained model configuration. + + Params: + **pretrained_model_name_or_path**: either: + - a string with the `shortcut name` of a pre-trained model configuration to load from cache + or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). + - a path to a `directory` containing a configuration file saved + using the `save_pretrained(save_directory)` method. + - a path or url to a saved configuration `file`. + **cache_dir**: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + **return_unused_kwargs**: (`optional`) bool: + - If False, then this function returns just the final configuration object. + - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` + is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: + ie the part of kwargs which has not been used to update `config` and is otherwise ignored. + **kwargs**: (`optional`) dict: + Dictionary of key/value pairs with which to update the configuration object after loading. + - The values in kwargs of any keys which are configuration attributes will be used + to override the loaded values. + - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled + by the `return_unused_kwargs` keyword parameter. + + Examples:: + + >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. + >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` + >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') + >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) + >>> assert config.output_attention == True + >>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, + >>> foo=False, return_unused_kwargs=True) + >>> assert config.output_attention == True + >>> assert unused_kwargs == {'foo': False} + + """ + cache_dir = kwargs.pop('cache_dir', None) + return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) + + if pretrained_model_name_or_path in cls.pretrained_config_archive_map: + config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] + elif os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + else: + config_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in cls.pretrained_config_archive_map: + logger.error( + "Couldn't reach server at '{}' to download pretrained model configuration file." + .format(config_file) + ) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(cls.pretrained_config_archive_map.keys()), config_file + ) + ) + return None + if resolved_config_file == config_file: + pass + # logger.info("loading configuration file {}".format(config_file)) + else: + logger.info( + "loading configuration file {} from cache at {}".format( + config_file, resolved_config_file + ) + ) + + # Load config + config = cls.from_json_file(resolved_config_file) + + # Update config with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + # logger.info("Model config %s", config) + if return_unused_kwargs: + return config, kwargs + else: + return config + + @classmethod + def from_dict(cls, json_object): + """Constructs a `Config` from a Python dictionary of parameters.""" + config = cls(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + + +class PreTrainedModel(nn.Module): + """ Base class for all models. Handle loading/storing model config and + a simple interface for dowloading and loading pretrained models. + """ + config_class = PretrainedConfig + pretrained_model_archive_map = {} + load_tf_weights = lambda model, config, path: None + base_model_prefix = "" + input_embeddings = None + + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedModel, self).__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " + "To create a model from a pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + # Save config in model + self.config = config + + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): + """ Build a resized Embedding Module from a provided token Embedding Module. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + + Args: + new_num_tokens: (`optional`) int + New number of tokens in the embedding matrix. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + If not provided or None: return the provided token Embedding Module. + Return: ``torch.nn.Embeddings`` + Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None + """ + if new_num_tokens is None: + return old_embeddings + + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + if old_num_tokens == new_num_tokens: + return old_embeddings + + # Build new embeddings + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) + new_embeddings.to(old_embeddings.weight.device) + + # initialize all new embeddings (in particular added tokens) + self.init_weights(new_embeddings) + + # Copy word embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + new_embeddings.weight.data[:num_tokens_to_copy, : + ] = old_embeddings.weight.data[:num_tokens_to_copy, :] + + return new_embeddings + + def _tie_or_clone_weights(self, first_module, second_module): + """ Tie or clone module weights depending of weither we are using TorchScript or not + """ + if self.config.torchscript: + first_module.weight = nn.Parameter(second_module.weight.clone()) + else: + first_module.weight = second_module.weight + + def resize_token_embeddings(self, new_num_tokens=None): + """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. + Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Args: + new_num_tokens: (`optional`) int + New number of tokens in the embedding matrix. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + If not provided or None: does nothing and just returns a pointer to the input tokens Embedding Module of the model. + + Return: ``torch.nn.Embeddings`` + Pointer to the input tokens Embedding Module of the model + """ + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + model_embeds = base_model._resize_token_embeddings(new_num_tokens) + if new_num_tokens is None: + return model_embeds + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + base_model.vocab_size = new_num_tokens + + # Tie weights again if needed + if hasattr(self, 'tie_weights'): + self.tie_weights() + + return model_embeds + + def prune_heads(self, heads_to_prune): + """ Prunes heads of the base model. + Args: + heads_to_prune: dict of {layer_num (int): list of heads to prune in this layer (list of int)} + """ + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + base_model._prune_heads(heads_to_prune) + + def save_pretrained(self, save_directory): + """ Save a model with its configuration file to a directory, so that it + can be re-loaded using the `from_pretrained(save_directory)` class method. + """ + assert os.path.isdir( + save_directory + ), "Saving path should be a directory where the model and configuration can be saved" + + # Only save the model it-self if we are using distributed training + model_to_save = self.module if hasattr(self, 'module') else self + + # Save configuration file + model_to_save.config.save_pretrained(save_directory) + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(save_directory, WEIGHTS_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Params: + **pretrained_model_name_or_path**: either: + - a string with the `shortcut name` of a pre-trained model to load from cache + or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). + - a path to a `directory` containing a configuration file saved + using the `save_pretrained(save_directory)` method. + - a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`). + In this case, ``from_tf`` should be set to True and a configuration object should be + provided as `config` argument. This loading option is slower than converting the TensorFlow + checkpoint in a PyTorch model using the provided conversion scripts and loading + the PyTorch model afterwards. + **model_args**: (`optional`) Sequence: + All remaning positional arguments will be passed to the underlying model's __init__ function + **config**: an optional configuration for the model to use instead of an automatically loaded configuation. + Configuration can be automatically loaded when: + - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or + - the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory). + **state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded + from saved weights file. + This option can be used if you want to create a model from a pretrained configuraton but load your own weights. + In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not + a simpler option. + **cache_dir**: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + **output_loading_info**: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + **kwargs**: (`optional`) dict: + Dictionary of key, values to update the configuration object after loading. + Can be used to override selected configuration parameters. E.g. ``output_attention=True``. + + - If a configuration is provided with `config`, **kwargs will be directly passed + to the underlying model's __init__ method. + - If a configuration is not provided, **kwargs will be first passed to the pretrained + model configuration class loading function (`PretrainedConfig.from_pretrained`). + Each key of **kwargs that corresponds to a configuration attribute + will be used to override said attribute with the supplied **kwargs value. + Remaining keys that do not correspond to any configuration attribute will + be passed to the underlying model's __init__ function. + + Examples:: + + >>> model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + >>> model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + >>> model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + >>> assert model.config.output_attention == True + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') + >>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + config = kwargs.pop('config', None) + state_dict = kwargs.pop('state_dict', None) + cache_dir = kwargs.pop('cache_dir', None) + from_tf = kwargs.pop('from_tf', False) + output_loading_info = kwargs.pop('output_loading_info', False) + + # Load config + if config is None: + config, model_kwargs = cls.config_class.from_pretrained( + pretrained_model_name_or_path, + *model_args, + cache_dir=cache_dir, + return_unused_kwargs=True, + **kwargs + ) + else: + model_kwargs = kwargs + + # Load model + if pretrained_model_name_or_path in cls.pretrained_model_archive_map: + archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] + elif os.path.isdir(pretrained_model_name_or_path): + if from_tf: + # Directly load from a TensorFlow checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index" + ) + else: + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + if from_tf: + # Directly load from a TensorFlow checkpoint + archive_file = pretrained_model_name_or_path + ".index" + else: + archive_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in cls.pretrained_model_archive_map: + logger.error( + "Couldn't reach server at '{}' to download pretrained weights.". + format(archive_file) + ) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(cls.pretrained_model_archive_map.keys()), archive_file + ) + ) + return None + if resolved_archive_file == archive_file: + logger.info("loading weights file {}".format(archive_file)) + else: + logger.info( + "loading weights file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) + + # Instantiate model. + model = cls(config, *model_args, **model_kwargs) + + if state_dict is None and not from_tf: + state_dict = torch.load(resolved_archive_file, map_location='cpu') + if from_tf: + # Directly load from a TensorFlow checkpoint + return cls.load_tf_weights( + model, config, resolved_archive_file[:-6] + ) # Remove the '.index' + + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # Load from a PyTorch state_dict + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = '' + model_to_load = model + if not hasattr(model, cls.base_model_prefix) and any( + s.startswith(cls.base_model_prefix) for s in state_dict.keys() + ): + start_prefix = cls.base_model_prefix + '.' + if hasattr(model, cls.base_model_prefix + ) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): + model_to_load = getattr(model, cls.base_model_prefix) + + load(model_to_load, prefix=start_prefix) + if len(missing_keys) > 0: + logger.info( + "Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys + ) + ) + if len(unexpected_keys) > 0: + logger.info( + "Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys + ) + ) + if len(error_msgs) > 0: + raise RuntimeError( + 'Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + + if hasattr(model, 'tie_weights'): + model.tie_weights() # make sure word embedding weights are still tied + + # Set model in evaluation mode to desactivate DropOut modules by default + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "error_msgs": error_msgs + } + return model, loading_info + + return model + + +class Conv1D(nn.Module): + def __init__(self, nf, nx): + """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) + Basically works like a Linear layer but the weights are transposed + """ + super(Conv1D, self).__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf, ) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + return x + + +class PoolerStartLogits(nn.Module): + """ Compute SQuAD start_logits from sequence hidden states. """ + def __init__(self, config): + super(PoolerStartLogits, self).__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward(self, hidden_states, p_mask=None): + """ Args: + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` + invalid position mask such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerEndLogits(nn.Module): + """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. + """ + def __init__(self, config): + super(PoolerEndLogits, self).__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): + """ Args: + One of ``start_states``, ``start_positions`` should be not None. + If both are set, ``start_positions`` overrides ``start_states``. + + **start_states**: ``torch.LongTensor`` of shape identical to hidden_states + hidden states of the first tokens for the labeled span. + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span: + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` + Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + """ + assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, + None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerAnswerClass(nn.Module): + """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ + def __init__(self, config): + super(PoolerAnswerClass, self).__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): + """ + Args: + One of ``start_states``, ``start_positions`` should be not None. + If both are set, ``start_positions`` overrides ``start_states``. + + **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. + hidden states of the first tokens for the labeled span. + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span. + **cls_index**: torch.LongTensor of shape ``(batch_size,)`` + position of the CLS token. If None, take the last token. + + note(Original repo): + no dependency on end_feature so that we can obtain one single `cls_logits` + for each sample + """ + hsz = hidden_states.shape[-1] + assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" + if start_positions is not None: + start_positions = start_positions[:, None, + None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, + start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +class SQuADHead(nn.Module): + r""" A SQuAD head inspired by XLNet. + + Parameters: + config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. + + Inputs: + **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` + hidden states of sequence tokens + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span. + **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the last token for the labeled span. + **cls_index**: torch.LongTensor of shape ``(batch_size,)`` + position of the CLS token. If None, take the last token. + **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` + Whether the question has a possible answer in the paragraph or not. + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` + Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. + **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` + Indices for the top config.start_n_top start token possibilities (beam-search). + **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size,)`` + Log probabilities for the ``is_impossible`` label of the answers. + """ + def __init__(self, config): + super(SQuADHead, self).__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = PoolerStartLogits(config) + self.end_logits = PoolerEndLogits(config) + self.answer_class = PoolerAnswerClass(config) + + def forward( + self, + hidden_states, + start_positions=None, + end_positions=None, + cls_index=None, + is_impossible=None, + p_mask=None + ): + outputs = () + + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits( + hidden_states, start_positions=start_positions, p_mask=p_mask + ) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class( + hidden_states, start_positions=start_positions, cls_index=cls_index + ) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + outputs = (total_loss, ) + outputs + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand( + -1, -1, hsz + ) # shape (bsz, start_n_top, hsz) + start_states = torch.gather( + hidden_states, -2, start_top_index_exp + ) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand( + -1, slen, -1, -1 + ) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits( + hidden_states_expanded, start_states=start_states, p_mask=p_mask + ) + end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class( + hidden_states, start_states=start_states, cls_index=cls_index + ) + + outputs = ( + start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits + ) + outputs + + # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits + # or (if labels are provided) (total_loss,) + return outputs + + +class SequenceSummary(nn.Module): + r""" Compute a single vector summary of a sequence hidden states according to various possibilities: + Args of the config class: + summary_type: + - 'last' => [default] take the last token hidden state (like XLNet) + - 'first' => take the first token hidden state (like Bert) + - 'mean' => take the mean of all tokens hidden states + - 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2) + - 'attn' => Not implemented now, use multi-head attention + summary_use_proj: Add a projection after the vector extraction + summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. + summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default + summary_first_dropout: Add a dropout before the projection and activation + summary_last_dropout: Add a dropout after the projection and activation + """ + def __init__(self, config): + super(SequenceSummary, self).__init__() + + self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' + if config.summary_type == 'attn': + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = Identity() + if hasattr(config, 'summary_use_proj') and config.summary_use_proj: + if hasattr( + config, 'summary_proj_to_labels' + ) and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + self.activation = Identity() + if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': + self.activation = nn.Tanh() + + self.first_dropout = Identity() + if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward(self, hidden_states, token_ids=None): + """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. + token_ids: [optional] index of the classification token if summary_type == 'token_ids', + shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. + if summary_type == 'token_ids' and token_ids is None: + we take the last token of the sequence as classification token + """ + if self.summary_type == 'last': + output = hidden_states[:, -1] + elif self.summary_type == 'first': + output = hidden_states[:, 0] + elif self.summary_type == 'mean': + output = hidden_states.mean(dim=1) + elif self.summary_type == 'token_ids': + if token_ids is None: + token_ids = torch.full_like( + hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long + ) + else: + token_ids = token_ids.unsqueeze(-1).unsqueeze(-1) + token_ids = token_ids.expand( + (-1, ) * (token_ids.dim() - 1) + (hidden_states.size(-1), ) + ) + # shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, + token_ids).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == 'attn': + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +def prune_linear_layer(layer, index, dim=0): + """ Prune a linear layer (a model parameters) to keep only entries in index. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if layer.bias is not None: + if dim == 1: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias + is not None).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_conv1d_layer(layer, index, dim=1): + """ Prune a Conv1D layer (a model parameters) to keep only entries in index. + A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if dim == 0: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_layer(layer, index, dim=None): + """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + if isinstance(layer, nn.Linear): + return prune_linear_layer(layer, index, dim=0 if dim is None else dim) + elif isinstance(layer, Conv1D): + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) + else: + raise ValueError("Can't prune layer of class {}".format(layer.__class__)) diff --git a/lib/pymafx/models/transformers/net_utils.py b/lib/pymafx/models/transformers/net_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52782911e276705ec0dd908ce9676430c0a58d72 --- /dev/null +++ b/lib/pymafx/models/transformers/net_utils.py @@ -0,0 +1,216 @@ +import torch.nn as nn +import torch +import math +import torch.nn.functional as F + + +class single_conv(nn.Module): + def __init__(self, in_ch, out_ch): + super(single_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.conv(x) + + +class double_conv(nn.Module): + def __init__(self, in_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + + +class double_conv_down(nn.Module): + def __init__(self, in_ch, out_ch): + super(double_conv_down, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + + +class double_conv_up(nn.Module): + def __init__(self, in_ch, out_ch): + super(double_conv_up, self).__init__() + self.conv = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + + +class PosEnSine(nn.Module): + """ + Code borrowed from DETR: models/positional_encoding.py + output size: b*(2.num_pos_feats)*h*w + """ + def __init__(self, num_pos_feats): + super(PosEnSine, self).__init__() + self.num_pos_feats = num_pos_feats + self.normalize = True + self.scale = 2 * math.pi + self.temperature = 10000 + + def forward(self, x, pt_coord=None): + b, c, h, w = x.shape + if pt_coord is not None: + z_embed = pt_coord[:, :, 2].unsqueeze(-1) + 1. + y_embed = pt_coord[:, :, 1].unsqueeze(-1) + 1. + x_embed = pt_coord[:, :, 0].unsqueeze(-1) + 1. + else: + not_mask = torch.ones(1, h, w, device=x.device) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + z_embed = torch.ones_like(x_embed) + if self.normalize: + eps = 1e-6 + z_embed = z_embed / (torch.max(z_embed) + eps) * self.scale + y_embed = y_embed / (torch.max(y_embed) + eps) * self.scale + x_embed = x_embed / (torch.max(x_embed) + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_z = z_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_z = torch.stack((pos_z[:, :, :, 0::2].sin(), pos_z[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_x, pos_y, pos_z), dim=3).permute(0, 3, 1, 2) + # if pt_coord is None: + pos = pos.repeat(b, 1, 1, 1) + return pos + + +def softmax_attention(q, k, v): + # b x n x d x h x w + h, w = q.shape[-2], q.shape[-1] + + q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d + k = k.flatten(-2) # b x n x d x hw + v = v.flatten(-2).transpose(-2, -1) + + print('softmax', q.shape, k.shape, v.shape) + + N = k.shape[-1] # ?????? maybe change to k.shape[-2]???? + attn = torch.matmul(q / N**0.5, k) + attn = F.softmax(attn, dim=-1) + output = torch.matmul(attn, v) + + output = output.transpose(-2, -1) + output = output.view(*output.shape[:-1], h, w) + + return output, attn + + +def dotproduct_attention(q, k, v): + # b x n x d x h x w + h, w = q.shape[-2], q.shape[-1] + + q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d + k = k.flatten(-2) # b x n x d x hw + v = v.flatten(-2).transpose(-2, -1) + + N = k.shape[-1] + attn = None + tmp = torch.matmul(k, v) / N + output = torch.matmul(q, tmp) + + output = output.transpose(-2, -1) + output = output.view(*output.shape[:-1], h, w) + + return output, attn + + +def long_range_attention(q, k, v, P_h, P_w): # fixed patch size + B, N, C, qH, qW = q.size() + _, _, _, kH, kW = k.size() + + qQ_h, qQ_w = qH // P_h, qW // P_w + kQ_h, kQ_w = kH // P_h, kW // P_w + + q = q.reshape(B, N, C, qQ_h, P_h, qQ_w, P_w) + k = k.reshape(B, N, C, kQ_h, P_h, kQ_w, P_w) + v = v.reshape(B, N, -1, kQ_h, P_h, kQ_w, P_w) + + q = q.permute(0, 1, 4, 6, 2, 3, 5) # [b, n, Ph, Pw, d, Qh, Qw] + k = k.permute(0, 1, 4, 6, 2, 3, 5) + v = v.permute(0, 1, 4, 6, 2, 3, 5) + + output, attn = softmax_attention(q, k, v) # attn: [b, n, Ph, Pw, qQh*qQw, kQ_h*kQ_w] + output = output.permute(0, 1, 4, 5, 2, 6, 3) + output = output.reshape(B, N, -1, qH, qW) + return output, attn + + +def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number + B, N, C, qH, qW = q.size() + _, _, _, kH, kW = k.size() + + qP_h, qP_w = qH // Q_h, qW // Q_w + kP_h, kP_w = kH // Q_h, kW // Q_w + + q = q.reshape(B, N, C, Q_h, qP_h, Q_w, qP_w) + k = k.reshape(B, N, C, Q_h, kP_h, Q_w, kP_w) + v = v.reshape(B, N, -1, Q_h, kP_h, Q_w, kP_w) + + q = q.permute(0, 1, 3, 5, 2, 4, 6) # [b, n, Qh, Qw, d, Ph, Pw] + k = k.permute(0, 1, 3, 5, 2, 4, 6) + v = v.permute(0, 1, 3, 5, 2, 4, 6) + + output, attn = softmax_attention(q, k, v) # attn: [b, n, Qh, Qw, qPh*qPw, kPh*kPw] + output = output.permute(0, 1, 4, 2, 5, 3, 6) + output = output.reshape(B, N, -1, qH, qW) + return output, attn + + +def space_to_depth(x, block_size): + x_shape = x.shape + c, h, w = x_shape[-3:] + if len(x.shape) >= 5: + x = x.view(-1, c, h, w) + unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size) + return unfolded_x.view(*x_shape[0:-3], c * block_size**2, h // block_size, w // block_size) + + +def depth_to_space(x, block_size): + x_shape = x.shape + c, h, w = x_shape[-3:] + x = x.view(-1, c, h, w) + y = torch.nn.functional.pixel_shuffle(x, block_size) + return y.view(*x_shape[0:-3], -1, h * block_size, w * block_size) + + +def patch_attention(q, k, v, P): + # q: [b, nhead, c, h, w] + q_patch = space_to_depth(q, P) # [b, nhead, cP^2, h/P, w/P] + k_patch = space_to_depth(k, P) + v_patch = space_to_depth(v, P) + + # output: [b, nhead, cP^2, h/P, w/P] + # attn: [b, nhead, h/P*w/P, h/P*w/P] + output, attn = softmax_attention(q_patch, k_patch, v_patch) + output = depth_to_space(output, P) # output: [b, nhead, c, h, w] + return output, attn diff --git a/lib/pymafx/models/transformers/texformer.py b/lib/pymafx/models/transformers/texformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6ee6262571c84d236e1738a984c741baf2efdf --- /dev/null +++ b/lib/pymafx/models/transformers/texformer.py @@ -0,0 +1,148 @@ +import torch.nn as nn +from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine +from .transformer_basics import OurMultiheadAttention + + +class TransformerDecoderUnit(nn.Module): + def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): + super(TransformerDecoderUnit, self).__init__() + self.feat_dim = feat_dim + self.attn_type = attn_type + self.pos_en_flag = pos_en_flag + self.P = P + + self.pos_en = PosEnSine(self.feat_dim // 2) + self.attn = OurMultiheadAttention(feat_dim, n_head) # cross-attention + + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.activation = nn.ReLU(inplace=True) + + self.norm = nn.BatchNorm2d(self.feat_dim) + + def forward(self, q, k, v): + if self.pos_en_flag: + q_pos_embed = self.pos_en(q) + k_pos_embed = self.pos_en(k) + else: + q_pos_embed = 0 + k_pos_embed = 0 + + # cross-multi-head attention + out = self.attn( + q=q + q_pos_embed, k=k + k_pos_embed, v=v, attn_type=self.attn_type, P=self.P + )[0] + + # feed forward + out2 = self.linear2(self.activation(self.linear1(out))) + out = out + out2 + out = self.norm(out) + + return out + + +class Unet(nn.Module): + def __init__(self, in_ch, feat_ch, out_ch): + super().__init__() + self.conv_in = single_conv(in_ch, feat_ch) + + self.conv1 = double_conv_down(feat_ch, feat_ch) + self.conv2 = double_conv_down(feat_ch, feat_ch) + self.conv3 = double_conv(feat_ch, feat_ch) + self.conv4 = double_conv_up(feat_ch, feat_ch) + self.conv5 = double_conv_up(feat_ch, feat_ch) + self.conv6 = double_conv(feat_ch, out_ch) + + def forward(self, x): + feat0 = self.conv_in(x) # H + feat1 = self.conv1(feat0) # H/2 + feat2 = self.conv2(feat1) # H/4 + feat3 = self.conv3(feat2) # H/4 + feat3 = feat3 + feat2 # H/4 + feat4 = self.conv4(feat3) # H/2 + feat4 = feat4 + feat1 # H/2 + feat5 = self.conv5(feat4) # H + feat5 = feat5 + feat0 # H + feat6 = self.conv6(feat5) + + return feat0, feat1, feat2, feat3, feat4, feat6 + + +class Texformer(nn.Module): + def __init__(self, opts): + super().__init__() + self.feat_dim = opts.feat_dim + src_ch = opts.src_ch + tgt_ch = opts.tgt_ch + out_ch = opts.out_ch + self.mask_fusion = opts.mask_fusion + + if not self.mask_fusion: + v_ch = out_ch + else: + v_ch = 2 + 3 + + self.unet_q = Unet(tgt_ch, self.feat_dim, self.feat_dim) + self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim) + self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim) + + self.trans_dec = nn.ModuleList( + [ + None, None, None, + TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'), + TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'), + TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct') + ] + ) + + self.conv0 = double_conv(self.feat_dim, self.feat_dim) + self.conv1 = double_conv_down(self.feat_dim, self.feat_dim) + self.conv2 = double_conv_down(self.feat_dim, self.feat_dim) + self.conv3 = double_conv(self.feat_dim, self.feat_dim) + self.conv4 = double_conv_up(self.feat_dim, self.feat_dim) + self.conv5 = double_conv_up(self.feat_dim, self.feat_dim) + + if not self.mask_fusion: + self.conv6 = nn.Sequential( + single_conv(self.feat_dim, self.feat_dim), + nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1) + ) + else: + self.conv6 = nn.Sequential( + single_conv(self.feat_dim, self.feat_dim), + nn.Conv2d(self.feat_dim, 2 + 3 + 1, 3, 1, 1) + ) # mask*flow-sampling + (1-mask)*rgb + self.sigmoid = nn.Sigmoid() + + self.tanh = nn.Tanh() + + def forward(self, q, k, v): + print('qkv', q.shape, k.shape, v.shape) + q_feat = self.unet_q(q) + k_feat = self.unet_k(k) + v_feat = self.unet_v(v) + + print('q_feat', len(q_feat)) + outputs = [] + for i in range(3, len(q_feat)): + print(i, q_feat[i].shape, k_feat[i].shape, v_feat[i].shape) + outputs.append(self.trans_dec[i](q_feat[i], k_feat[i], v_feat[i])) + print('outputs', outputs[-1].shape) + + f0 = self.conv0(outputs[2]) # H + f1 = self.conv1(f0) # H/2 + f1 = f1 + outputs[1] + f2 = self.conv2(f1) # H/4 + f2 = f2 + outputs[0] + f3 = self.conv3(f2) # H/4 + f3 = f3 + outputs[0] + f2 + f4 = self.conv4(f3) # H/2 + f4 = f4 + outputs[1] + f1 + f5 = self.conv5(f4) # H + f5 = f5 + outputs[2] + f0 + if not self.mask_fusion: + out = self.tanh(self.conv6(f5)) + else: + out_ = self.conv6(f5) + out = [self.tanh(out_[:, :2]), self.tanh(out_[:, 2:5]), self.sigmoid(out_[:, 5:])] + return out diff --git a/lib/pymafx/models/transformers/tokenlearner.py b/lib/pymafx/models/transformers/tokenlearner.py new file mode 100644 index 0000000000000000000000000000000000000000..441b361a721f685f481e764c19b624b593124c1b --- /dev/null +++ b/lib/pymafx/models/transformers/tokenlearner.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SpatialAttention(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(2, 1, kernel_size=(1, 1), stride=1), nn.BatchNorm2d(1), nn.ReLU() + ) + + self.sgap = nn.AvgPool2d(2) + + def forward(self, x): + B, H, W, C = x.shape + x = x.reshape(B, C, H, W) + + mx = torch.max(x, 1)[0].unsqueeze(1) + avg = torch.mean(x, 1).unsqueeze(1) + combined = torch.cat([mx, avg], dim=1) + fmap = self.conv(combined) + weight_map = torch.sigmoid(fmap) + out = (x * weight_map).mean(dim=(-2, -1)) + + return out, x * weight_map + + +class TokenLearner(nn.Module): + def __init__(self, S) -> None: + super().__init__() + self.S = S + self.tokenizers = nn.ModuleList([SpatialAttention() for _ in range(S)]) + + def forward(self, x): + B, _, _, C = x.shape + Z = torch.Tensor(B, self.S, C).to(x) + for i in range(self.S): + Ai, _ = self.tokenizers[i](x) # [B, C] + Z[:, i, :] = Ai + return Z + + +class TokenFuser(nn.Module): + def __init__(self, H, W, C, S) -> None: + super().__init__() + self.projection = nn.Linear(S, S, bias=False) + self.Bi = nn.Linear(C, S) + self.spatial_attn = SpatialAttention() + self.S = S + + def forward(self, y, x): + B, S, C = y.shape + B, H, W, C = x.shape + + Y = self.projection(y.reshape(B, C, S)).reshape(B, S, C) + Bw = torch.sigmoid(self.Bi(x)).reshape(B, H * W, S) # [B, HW, S] + BwY = torch.matmul(Bw, Y) + + _, xj = self.spatial_attn(x) + xj = xj.reshape(B, H * W, C) + + out = (BwY + xj).reshape(B, H, W, C) + + return out diff --git a/lib/pymafx/models/transformers/transformer_basics.py b/lib/pymafx/models/transformers/transformer_basics.py new file mode 100644 index 0000000000000000000000000000000000000000..144ccd76b7e2f73189634ab551691c4262781b9d --- /dev/null +++ b/lib/pymafx/models/transformers/transformer_basics.py @@ -0,0 +1,283 @@ +import torch.nn as nn +from .net_utils import PosEnSine, softmax_attention, dotproduct_attention, long_range_attention, \ + short_range_attention, patch_attention + + +class OurMultiheadAttention(nn.Module): + def __init__(self, q_feat_dim, k_feat_dim, out_feat_dim, n_head, d_k=None, d_v=None): + super(OurMultiheadAttention, self).__init__() + if d_k is None: + d_k = out_feat_dim // n_head + if d_v is None: + d_v = out_feat_dim // n_head + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + # pre-attention projection + self.w_qs = nn.Conv2d(q_feat_dim, n_head * d_k, 1, bias=False) + self.w_ks = nn.Conv2d(k_feat_dim, n_head * d_k, 1, bias=False) + self.w_vs = nn.Conv2d(out_feat_dim, n_head * d_v, 1, bias=False) + + # after-attention combine heads + self.fc = nn.Conv2d(n_head * d_v, out_feat_dim, 1, bias=False) + + def forward(self, q, k, v, attn_type='softmax', **kwargs): + # input: b x d x h x w + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + + # Pass through the pre-attention projection: b x (nhead*dk) x h x w + # Separate different heads: b x nhead x dk x h x w + q = self.w_qs(q).view(q.shape[0], n_head, d_k, q.shape[2], q.shape[3]) + k = self.w_ks(k).view(k.shape[0], n_head, d_k, k.shape[2], k.shape[3]) + v = self.w_vs(v).view(v.shape[0], n_head, d_v, v.shape[2], v.shape[3]) + + # -------------- Attention ----------------- + if attn_type == 'softmax': + q, attn = softmax_attention(q, k, v) # b x n x dk x h x w --> b x n x dv x h x w + elif attn_type == 'dotproduct': + q, attn = dotproduct_attention(q, k, v) + elif attn_type == 'patch': + q, attn = patch_attention(q, k, v, P=kwargs['P']) + elif attn_type == 'sparse_long': + q, attn = long_range_attention(q, k, v, P_h=kwargs['ah'], P_w=kwargs['aw']) + elif attn_type == 'sparse_short': + q, attn = short_range_attention(q, k, v, Q_h=kwargs['ah'], Q_w=kwargs['aw']) + else: + raise NotImplementedError(f'Unknown attention type {attn_type}') + # ------------ end Attention --------------- + + # Concatenate all the heads together: b x (n*dv) x h x w + q = q.reshape(q.shape[0], -1, q.shape[3], q.shape[4]) + q = self.fc(q) # b x d x h x w + + return q, attn + + +class TransformerEncoderUnit(nn.Module): + def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): + super(TransformerEncoderUnit, self).__init__() + self.feat_dim = feat_dim + self.attn_type = attn_type + self.pos_en_flag = pos_en_flag + self.P = P + + self.pos_en = PosEnSine(self.feat_dim // 2) + self.attn = OurMultiheadAttention(feat_dim, n_head) + + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.activation = nn.ReLU(inplace=True) + + self.norm1 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) + + def forward(self, src): + if self.pos_en_flag: + pos_embed = self.pos_en(src) + else: + pos_embed = 0 + + # multi-head attention + src2 = self.attn( + q=src + pos_embed, k=src + pos_embed, v=src, attn_type=self.attn_type, P=self.P + )[0] + src = src + src2 + src = self.norm1(src) + + # feed forward + src2 = self.linear2(self.activation(self.linear1(src))) + src = src + src2 + src = self.norm2(src) + + return src + + +class TransformerEncoderUnitSparse(nn.Module): + def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None): + super(TransformerEncoderUnitSparse, self).__init__() + self.feat_dim = feat_dim + self.pos_en_flag = pos_en_flag + self.ahw = ahw # [Ph, Pw, Qh, Qw] + + self.pos_en = PosEnSine(self.feat_dim // 2) + self.attn1 = OurMultiheadAttention(feat_dim, n_head) # long range + self.attn2 = OurMultiheadAttention(feat_dim, n_head) # short range + + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.activation = nn.ReLU(inplace=True) + + self.norm1 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) + + def forward(self, src): + if self.pos_en_flag: + pos_embed = self.pos_en(src) + else: + pos_embed = 0 + + # multi-head long-range attention + src2 = self.attn1( + q=src + pos_embed, + k=src + pos_embed, + v=src, + attn_type='sparse_long', + ah=self.ahw[0], + aw=self.ahw[1] + )[0] + src = src + src2 # ? this might be ok to remove + + # multi-head short-range attention + src2 = self.attn2( + q=src + pos_embed, + k=src + pos_embed, + v=src, + attn_type='sparse_short', + ah=self.ahw[2], + aw=self.ahw[3] + )[0] + src = src + src2 + src = self.norm1(src) + + # feed forward + src2 = self.linear2(self.activation(self.linear1(src))) + src = src + src2 + src = self.norm2(src) + + return src + + +class TransformerDecoderUnit(nn.Module): + def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): + super(TransformerDecoderUnit, self).__init__() + self.feat_dim = feat_dim + self.attn_type = attn_type + self.pos_en_flag = pos_en_flag + self.P = P + + self.pos_en = PosEnSine(self.feat_dim // 2) + self.attn1 = OurMultiheadAttention(feat_dim, n_head) # self-attention + self.attn2 = OurMultiheadAttention(feat_dim, n_head) # cross-attention + + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.activation = nn.ReLU(inplace=True) + + self.norm1 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) + self.norm3 = nn.BatchNorm2d(self.feat_dim) + + def forward(self, tgt, src): + if self.pos_en_flag: + src_pos_embed = self.pos_en(src) + tgt_pos_embed = self.pos_en(tgt) + else: + src_pos_embed = 0 + tgt_pos_embed = 0 + + # self-multi-head attention + tgt2 = self.attn1( + q=tgt + tgt_pos_embed, k=tgt + tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P + )[0] + tgt = tgt + tgt2 + tgt = self.norm1(tgt) + + # cross-multi-head attention + tgt2 = self.attn2( + q=tgt + tgt_pos_embed, k=src + src_pos_embed, v=src, attn_type=self.attn_type, P=self.P + )[0] + tgt = tgt + tgt2 + tgt = self.norm2(tgt) + + # feed forward + tgt2 = self.linear2(self.activation(self.linear1(tgt))) + tgt = tgt + tgt2 + tgt = self.norm3(tgt) + + return tgt + + +class TransformerDecoderUnitSparse(nn.Module): + def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None): + super(TransformerDecoderUnitSparse, self).__init__() + self.feat_dim = feat_dim + self.ahw = ahw # [Ph_tgt, Pw_tgt, Qh_tgt, Qw_tgt, Ph_src, Pw_src, Qh_tgt, Qw_tgt] + self.pos_en_flag = pos_en_flag + + self.pos_en = PosEnSine(self.feat_dim // 2) + self.attn1_1 = OurMultiheadAttention(feat_dim, n_head) # self-attention: long + self.attn1_2 = OurMultiheadAttention(feat_dim, n_head) # self-attention: short + + self.attn2_1 = OurMultiheadAttention( + feat_dim, n_head + ) # cross-attention: self-attention-long + cross-attention-short + self.attn2_2 = OurMultiheadAttention(feat_dim, n_head) + + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) + self.activation = nn.ReLU(inplace=True) + + self.norm1 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) + self.norm3 = nn.BatchNorm2d(self.feat_dim) + + def forward(self, tgt, src): + if self.pos_en_flag: + src_pos_embed = self.pos_en(src) + tgt_pos_embed = self.pos_en(tgt) + else: + src_pos_embed = 0 + tgt_pos_embed = 0 + + # self-multi-head attention: sparse long + tgt2 = self.attn1_1( + q=tgt + tgt_pos_embed, + k=tgt + tgt_pos_embed, + v=tgt, + attn_type='sparse_long', + ah=self.ahw[0], + aw=self.ahw[1] + )[0] + tgt = tgt + tgt2 + # self-multi-head attention: sparse short + tgt2 = self.attn1_2( + q=tgt + tgt_pos_embed, + k=tgt + tgt_pos_embed, + v=tgt, + attn_type='sparse_short', + ah=self.ahw[2], + aw=self.ahw[3] + )[0] + tgt = tgt + tgt2 + tgt = self.norm1(tgt) + + # self-multi-head attention: sparse long + src2 = self.attn2_1( + q=src + src_pos_embed, + k=src + src_pos_embed, + v=src, + attn_type='sparse_long', + ah=self.ahw[4], + aw=self.ahw[5] + )[0] + src = src + src2 + # cross-multi-head attention: sparse short + tgt2 = self.attn2_2( + q=tgt + tgt_pos_embed, + k=src + src_pos_embed, + v=src, + attn_type='sparse_short', + ah=self.ahw[6], + aw=self.ahw[7] + )[0] + tgt = tgt + tgt2 + tgt = self.norm2(tgt) + + # feed forward + tgt2 = self.linear2(self.activation(self.linear1(tgt))) + tgt = tgt + tgt2 + tgt = self.norm3(tgt) + + return tgt diff --git a/lib/pymafx/utils/__init__.py b/lib/pymafx/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..539b22b33222e7617b267138f30fc5f1a7f82b6a --- /dev/null +++ b/lib/pymafx/utils/__init__.py @@ -0,0 +1,2 @@ +from .data_loader import CheckpointDataLoader +from .saver import CheckpointSaver \ No newline at end of file diff --git a/lib/pymafx/utils/binvox_rw.py b/lib/pymafx/utils/binvox_rw.py new file mode 100644 index 0000000000000000000000000000000000000000..947c3258691da908954f765bde07e0978cfb9f97 --- /dev/null +++ b/lib/pymafx/utils/binvox_rw.py @@ -0,0 +1,293 @@ +# Copyright (C) 2012 Daniel Maturana +# This file is part of binvox-rw-py. +# +# binvox-rw-py is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# binvox-rw-py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with binvox-rw-py. If not, see . +# +# Modified by Christopher B. Choy +# for python 3 support +""" +Binvox to Numpy and back. + + +>>> import numpy as np +>>> import binvox_rw +>>> with open('chair.binvox', 'rb') as f: +... m1 = binvox_rw.read_as_3d_array(f) +... +>>> m1.dims +[32, 32, 32] +>>> m1.scale +41.133000000000003 +>>> m1.translate +[0.0, 0.0, 0.0] +>>> with open('chair_out.binvox', 'wb') as f: +... m1.write(f) +... +>>> with open('chair_out.binvox', 'rb') as f: +... m2 = binvox_rw.read_as_3d_array(f) +... +>>> m1.dims==m2.dims +True +>>> m1.scale==m2.scale +True +>>> m1.translate==m2.translate +True +>>> np.all(m1.data==m2.data) +True + +>>> with open('chair.binvox', 'rb') as f: +... md = binvox_rw.read_as_3d_array(f) +... +>>> with open('chair.binvox', 'rb') as f: +... ms = binvox_rw.read_as_coord_array(f) +... +>>> data_ds = binvox_rw.dense_to_sparse(md.data) +>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) +>>> np.all(data_sd==md.data) +True +>>> # the ordering of elements returned by numpy.nonzero changes with axis +>>> # ordering, so to compare for equality we first lexically sort the voxels. +>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) +True +""" + +import numpy as np + + +class Voxels(object): + """ Holds a binvox model. + data is either a three-dimensional numpy boolean array (dense representation) + or a two-dimensional numpy float array (coordinate representation). + + dims, translate and scale are the model metadata. + + dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. + + scale and translate relate the voxels to the original model coordinates. + + To translate voxel coordinates i, j, k to original coordinates x, y, z: + + x_n = (i+.5)/dims[0] + y_n = (j+.5)/dims[1] + z_n = (k+.5)/dims[2] + x = scale*x_n + translate[0] + y = scale*y_n + translate[1] + z = scale*z_n + translate[2] + + """ + def __init__(self, data, dims, translate, scale, axis_order): + self.data = data + self.dims = dims + self.translate = translate + self.scale = scale + assert (axis_order in ('xzy', 'xyz')) + self.axis_order = axis_order + + def clone(self): + data = self.data.copy() + dims = self.dims[:] + translate = self.translate[:] + return Voxels(data, dims, translate, self.scale, self.axis_order) + + def write(self, fp): + write(self, fp) + + +def read_header(fp): + """ Read binvox header. Mostly meant for internal use. + """ + line = fp.readline().strip() + if not line.startswith(b'#binvox'): + raise IOError('Not a binvox file') + dims = [int(i) for i in fp.readline().strip().split(b' ')[1:]] + translate = [float(i) for i in fp.readline().strip().split(b' ')[1:]] + scale = [float(i) for i in fp.readline().strip().split(b' ')[1:]][0] + line = fp.readline() + return dims, translate, scale + + +def read_as_3d_array(fp, fix_coords=True): + """ Read binary binvox format as array. + + Returns the model with accompanying metadata. + + Voxels are stored in a three-dimensional numpy array, which is simple and + direct, but may use a lot of memory for large models. (Storage requirements + are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy + boolean arrays use a byte per element). + + Doesn't do any checks on input except for the '#binvox' line. + """ + dims, translate, scale = read_header(fp) + raw_data = np.frombuffer(fp.read(), dtype=np.uint8) + # if just using reshape() on the raw data: + # indexing the array as array[i,j,k], the indices map into the + # coords as: + # i -> x + # j -> z + # k -> y + # if fix_coords is true, then data is rearranged so that + # mapping is + # i -> x + # j -> y + # k -> z + values, counts = raw_data[::2], raw_data[1::2] + data = np.repeat(values, counts).astype(np.bool) + data = data.reshape(dims) + if fix_coords: + # xzy to xyz TODO the right thing + data = np.transpose(data, (0, 2, 1)) + axis_order = 'xyz' + else: + axis_order = 'xzy' + return Voxels(data, dims, translate, scale, axis_order) + + +def read_as_coord_array(fp, fix_coords=True): + """ Read binary binvox format as coordinates. + + Returns binvox model with voxels in a "coordinate" representation, i.e. an + 3 x N array where N is the number of nonzero voxels. Each column + corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates + of the voxel. (The odd ordering is due to the way binvox format lays out + data). Note that coordinates refer to the binvox voxels, without any + scaling or translation. + + Use this to save memory if your model is very sparse (mostly empty). + + Doesn't do any checks on input except for the '#binvox' line. + """ + dims, translate, scale = read_header(fp) + raw_data = np.frombuffer(fp.read(), dtype=np.uint8) + + values, counts = raw_data[::2], raw_data[1::2] + + sz = np.prod(dims) + index, end_index = 0, 0 + end_indices = np.cumsum(counts) + indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) + + values = values.astype(np.bool) + indices = indices[values] + end_indices = end_indices[values] + + nz_voxels = [] + for index, end_index in zip(indices, end_indices): + nz_voxels.extend(range(index, end_index)) + nz_voxels = np.array(nz_voxels) + # TODO are these dims correct? + # according to docs, + # index = x * wxh + z * width + y; // wxh = width * height = d * d + + x = nz_voxels / (dims[0] * dims[1]) + zwpy = nz_voxels % (dims[0] * dims[1]) # z*w + y + z = zwpy / dims[0] + y = zwpy % dims[0] + if fix_coords: + data = np.vstack((x, y, z)) + axis_order = 'xyz' + else: + data = np.vstack((x, z, y)) + axis_order = 'xzy' + + #return Voxels(data, dims, translate, scale, axis_order) + return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) + + +def dense_to_sparse(voxel_data, dtype=np.int): + """ From dense representation to sparse (coordinate) representation. + No coordinate reordering. + """ + if voxel_data.ndim != 3: + raise ValueError('voxel_data is wrong shape; should be 3D array.') + return np.asarray(np.nonzero(voxel_data), dtype) + + +def sparse_to_dense(voxel_data, dims, dtype=np.bool): + if voxel_data.ndim != 2 or voxel_data.shape[0] != 3: + raise ValueError('voxel_data is wrong shape; should be 3xN array.') + if np.isscalar(dims): + dims = [dims] * 3 + dims = np.atleast_2d(dims).T + # truncate to integers + xyz = voxel_data.astype(np.int) + # discard voxels that fall outside dims + valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) + xyz = xyz[:, valid_ix] + out = np.zeros(dims.flatten(), dtype=dtype) + out[tuple(xyz)] = True + return out + + +#def get_linear_index(x, y, z, dims): +#""" Assuming xzy order. (y increasing fastest. +#TODO ensure this is right when dims are not all same +#""" +#return x*(dims[1]*dims[2]) + z*dims[1] + y + + +def write(voxel_model, fp): + """ Write binary binvox format. + + Note that when saving a model in sparse (coordinate) format, it is first + converted to dense format. + + Doesn't check if the model is 'sane'. + + """ + if voxel_model.data.ndim == 2: + # TODO avoid conversion to dense + dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) + else: + dense_voxel_data = voxel_model.data + + fp.write('#binvox 1\n') + fp.write('dim ' + ' '.join(map(str, voxel_model.dims)) + '\n') + fp.write('translate ' + ' '.join(map(str, voxel_model.translate)) + '\n') + fp.write('scale ' + str(voxel_model.scale) + '\n') + fp.write('data\n') + if not voxel_model.axis_order in ('xzy', 'xyz'): + raise ValueError('Unsupported voxel model axis order') + + if voxel_model.axis_order == 'xzy': + voxels_flat = dense_voxel_data.flatten() + elif voxel_model.axis_order == 'xyz': + voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() + + # keep a sort of state machine for writing run length encoding + state = voxels_flat[0] + ctr = 0 + for c in voxels_flat: + if c == state: + ctr += 1 + # if ctr hits max, dump + if ctr == 255: + fp.write(chr(state)) + fp.write(chr(ctr)) + ctr = 0 + else: + # if switch state, dump + fp.write(chr(state)) + fp.write(chr(ctr)) + state = c + ctr = 1 + # flush out remainders + if ctr > 0: + fp.write(chr(state)) + fp.write(chr(ctr)) + + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/lib/pymafx/utils/blob.py b/lib/pymafx/utils/blob.py new file mode 100644 index 0000000000000000000000000000000000000000..00123338e18a3fa74a6c3cb730cac9fb41b59ac5 --- /dev/null +++ b/lib/pymafx/utils/blob.py @@ -0,0 +1,174 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +# +# Based on: +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- +"""blob helper functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from six.moves import cPickle as pickle +import numpy as np +import cv2 + +from models.core.config import cfg + + +def get_image_blob(im, target_scale, target_max_size): + """Convert an image into a network input. + + Arguments: + im (ndarray): a color image in BGR order + + Returns: + blob (ndarray): a data blob holding an image pyramid + im_scale (float): image scale (target size) / (original size) + im_info (ndarray) + """ + processed_im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, [target_scale], target_max_size) + blob = im_list_to_blob(processed_im) + # NOTE: this height and width may be larger than actual scaled input image + # due to the FPN.COARSEST_STRIDE related padding in im_list_to_blob. We are + # maintaining this behavior for now to make existing results exactly + # reproducible (in practice using the true input image height and width + # yields nearly the same results, but they are sometimes slightly different + # because predictions near the edge of the image will be pruned more + # aggressively). + height, width = blob.shape[2], blob.shape[3] + im_info = np.hstack((height, width, im_scale))[np.newaxis, :] + return blob, im_scale, im_info.astype(np.float32) + + +def im_list_to_blob(ims): + """Convert a list of images into a network input. Assumes images were + prepared using prep_im_for_blob or equivalent: i.e. + - BGR channel order + - pixel means subtracted + - resized to the desired input size + - float32 numpy ndarray format + Output is a 4D HCHW tensor of the images concatenated along axis 0 with + shape. + """ + if not isinstance(ims, list): + ims = [ims] + max_shape = get_max_shape([im.shape[:2] for im in ims]) + + num_images = len(ims) + blob = np.zeros((num_images, max_shape[0], max_shape[1], 3), dtype=np.float32) + for i in range(num_images): + im = ims[i] + blob[i, 0:im.shape[0], 0:im.shape[1], :] = im + # Move channels (axis 3) to axis 1 + # Axis order will become: (batch elem, channel, height, width) + channel_swap = (0, 3, 1, 2) + blob = blob.transpose(channel_swap) + return blob + + +def get_max_shape(im_shapes): + """Calculate max spatial size (h, w) for batching given a list of image shapes + """ + max_shape = np.array(im_shapes).max(axis=0) + assert max_shape.size == 2 + # Pad the image so they can be divisible by a stride + if cfg.FPN.FPN_ON: + stride = float(cfg.FPN.COARSEST_STRIDE) + max_shape[0] = int(np.ceil(max_shape[0] / stride) * stride) + max_shape[1] = int(np.ceil(max_shape[1] / stride) * stride) + return max_shape + + +def prep_im_for_blob(im, pixel_means, target_sizes, max_size): + """Prepare an image for use as a network input blob. Specially: + - Subtract per-channel pixel mean + - Convert to float32 + - Rescale to each of the specified target size (capped at max_size) + Returns a list of transformed images, one for each target size. Also returns + the scale factors that were used to compute each returned image. + """ + im = im.astype(np.float32, copy=False) + im -= pixel_means + im_shape = im.shape + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + + ims = [] + im_scales = [] + for target_size in target_sizes: + im_scale = get_target_scale(im_size_min, im_size_max, target_size, max_size) + im_resized = cv2.resize( + im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR + ) + ims.append(im_resized) + im_scales.append(im_scale) + return ims, im_scales + + +def get_im_blob_sizes(im_shape, target_sizes, max_size): + """Calculate im blob size for multiple target_sizes given original im shape + """ + im_size_min = np.min(im_shape) + im_size_max = np.max(im_shape) + im_sizes = [] + for target_size in target_sizes: + im_scale = get_target_scale(im_size_min, im_size_max, target_size, max_size) + im_sizes.append(np.round(im_shape * im_scale)) + return np.array(im_sizes) + + +def get_target_scale(im_size_min, im_size_max, target_size, max_size): + """Calculate target resize scale + """ + im_scale = float(target_size) / float(im_size_min) + # Prevent the biggest axis from being more than max_size + if np.round(im_scale * im_size_max) > max_size: + im_scale = float(max_size) / float(im_size_max) + return im_scale + + +def zeros(shape, int32=False): + """Return a blob of all zeros of the given shape with the correct float or + int data type. + """ + return np.zeros(shape, dtype=np.int32 if int32 else np.float32) + + +def ones(shape, int32=False): + """Return a blob of all ones of the given shape with the correct float or + int data type. + """ + return np.ones(shape, dtype=np.int32 if int32 else np.float32) + + +def serialize(obj): + """Serialize a Python object using pickle and encode it as an array of + float32 values so that it can be feed into the workspace. See deserialize(). + """ + return np.fromstring(pickle.dumps(obj), dtype=np.uint8).astype(np.float32) + + +def deserialize(arr): + """Unserialize a Python object from an array of float32 values fetched from + a workspace. See serialize(). + """ + return pickle.loads(arr.astype(np.uint8).tobytes()) diff --git a/lib/pymafx/utils/cam_params.py b/lib/pymafx/utils/cam_params.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6c1a8d89b2c80d72c90c841d02425df77aa4a5 --- /dev/null +++ b/lib/pymafx/utils/cam_params.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import os +from numpy.testing._private.utils import print_assert_equal +import torch +import numpy as np +import joblib + +from .geometry import batch_euler2matrix + + +def f_pix2vfov(f_pix, img_h): + + if torch.is_tensor(f_pix): + fov = 2. * torch.arctan(img_h / (2. * f_pix)) + else: + fov = 2. * np.arctan(img_h / (2. * f_pix)) + + return fov + + +def vfov2f_pix(fov, img_h): + + if torch.is_tensor(fov): + f_pix = img_h / 2. / torch.tan(fov / 2.) + else: + f_pix = img_h / 2. / np.tan(fov / 2.) + + return f_pix + + +def read_cam_params(cam_params, orig_shape=None): + # These are predicted camera parameters + # cam_param_folder = CAM_PARAM_FOLDERS[dataset_name][cam_param_type] + + cam_pitch = cam_params['pitch'].item() + cam_roll = cam_params['roll'].item() if 'roll' in cam_params else None + + cam_vfov = cam_params['vfov'].item() if 'vfov' in cam_params else None + + cam_focal_length = cam_params['f_pix'] + + orig_shape = cam_params['orig_resolution'] + + # cam_rotmat = batch_euler2matrix(torch.tensor([[cam_pitch, 0., cam_roll]]).float())[0] + cam_rotmat = batch_euler2matrix(torch.tensor([[cam_pitch, 0., 0.]]).float())[0] + + pred_cam_int = torch.zeros(3, 3) + + cx, cy = orig_shape[1] / 2, orig_shape[0] / 2 + + pred_cam_int[0, 0] = cam_focal_length + pred_cam_int[1, 1] = cam_focal_length + + pred_cam_int[:-1, -1] = torch.tensor([cx, cy]) + + cam_int = pred_cam_int.float() + + return cam_rotmat, cam_int, cam_vfov, cam_pitch, cam_roll, cam_focal_length + + +def homo_vector(vector): + """ + vector: B x N x C + h_vector: B x N x (C + 1) + """ + + batch_size, n_pts = vector.shape[:2] + + h_vector = torch.cat([vector, torch.ones((batch_size, n_pts, 1)).to(vector)], dim=-1) + return h_vector diff --git a/lib/pymafx/utils/collections.py b/lib/pymafx/utils/collections.py new file mode 100644 index 0000000000000000000000000000000000000000..edd20a8c89d5d2221dc9d35948eda12c6304ba29 --- /dev/null +++ b/lib/pymafx/utils/collections.py @@ -0,0 +1,64 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""A simple attribute dictionary used for representing configuration options.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +class AttrDict(dict): + + IMMUTABLE = '__immutable__' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__[AttrDict.IMMUTABLE] = False + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + elif name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if not self.__dict__[AttrDict.IMMUTABLE]: + if name in self.__dict__: + self.__dict__[name] = value + else: + self[name] = value + else: + raise AttributeError( + 'Attempted to set "{}" to "{}", but AttrDict is immutable'.format(name, value) + ) + + def immutable(self, is_immutable): + """Set immutability to is_immutable and recursively apply the setting + to all nested AttrDicts. + """ + self.__dict__[AttrDict.IMMUTABLE] = is_immutable + # Recursively set immutable state + for v in self.__dict__.values(): + if isinstance(v, AttrDict): + v.immutable(is_immutable) + for v in self.values(): + if isinstance(v, AttrDict): + v.immutable(is_immutable) + + def is_immutable(self): + return self.__dict__[AttrDict.IMMUTABLE] diff --git a/lib/pymafx/utils/colormap.py b/lib/pymafx/utils/colormap.py new file mode 100644 index 0000000000000000000000000000000000000000..44ef28c050021a6f03d088e9437de0c4adeb5ee5 --- /dev/null +++ b/lib/pymafx/utils/colormap.py @@ -0,0 +1,53 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""An awesome colormap for really neat visualizations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + + +def colormap(rgb=False): + color_list = np.array( + [ + 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, 1.000, 0.500, 0.000, 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, 0.000, 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, 0.500, 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, 0.333, 1.000, 0.500, 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, 0.667, 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, 0.286, 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list diff --git a/lib/pymafx/utils/common.py b/lib/pymafx/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f3330ea18c4783ccacb21657808b8b8ce2301f86 --- /dev/null +++ b/lib/pymafx/utils/common.py @@ -0,0 +1,844 @@ +import torch +import numpy as np +import logging +from copy import deepcopy +from .utils.libkdtree import KDTree + +logger_py = logging.getLogger(__name__) + + +def compute_iou(occ1, occ2): + ''' Computes the Intersection over Union (IoU) value for two sets of + occupancy values. + Args: + occ1 (tensor): first set of occupancy values + occ2 (tensor): second set of occupancy values + ''' + occ1 = np.asarray(occ1) + occ2 = np.asarray(occ2) + + # Put all data in second dimension + # Also works for 1-dimensional data + if occ1.ndim >= 2: + occ1 = occ1.reshape(occ1.shape[0], -1) + if occ2.ndim >= 2: + occ2 = occ2.reshape(occ2.shape[0], -1) + + # Convert to boolean values + occ1 = (occ1 >= 0.5) + occ2 = (occ2 >= 0.5) + + # Compute IOU + area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) + area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) + + iou = (area_intersect / area_union) + + return iou + + +def rgb2gray(rgb): + ''' rgb of size B x h x w x 3 + ''' + r, g, b = rgb[:, :, :, 0], rgb[:, :, :, 1], rgb[:, :, :, 2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + + return gray + + +def sample_patch_points( + batch_size, n_points, patch_size=1, image_resolution=(128, 128), continuous=True +): + ''' Returns sampled points in the range [-1, 1]. + + Args: + batch_size (int): required batch size + n_points (int): number of points to sample + patch_size (int): size of patch; if > 1, patches of size patch_size + are sampled instead of individual points + image_resolution (tuple): image resolution (required for calculating + the pixel distances) + continuous (bool): whether to sample continuously or only on pixel + locations + ''' + assert (patch_size > 0) + # Calculate step size for [-1, 1] that is equivalent to a pixel in + # original resolution + h_step = 1. / image_resolution[0] + w_step = 1. / image_resolution[1] + # Get number of patches + patch_size_squared = patch_size**2 + n_patches = int(n_points / patch_size_squared) + if continuous: + p = torch.rand(batch_size, n_patches, 2) # [0, 1] + else: + px = torch.randint(0, image_resolution[1], + size=(batch_size, n_patches, 1)).float() / (image_resolution[1] - 1) + py = torch.randint(0, image_resolution[0], + size=(batch_size, n_patches, 1)).float() / (image_resolution[0] - 1) + p = torch.cat([px, py], dim=-1) + # Scale p to [0, (1 - (patch_size - 1) * step) ] + p[:, :, 0] *= 1 - (patch_size - 1) * w_step + p[:, :, 1] *= 1 - (patch_size - 1) * h_step + + # Add points + patch_arange = torch.arange(patch_size) + x_offset, y_offset = torch.meshgrid(patch_arange, patch_arange) + patch_offsets = torch.stack([x_offset.reshape(-1), y_offset.reshape(-1)], + dim=1).view(1, 1, -1, 2).repeat(batch_size, n_patches, 1, 1).float() + + patch_offsets[:, :, :, 0] *= w_step + patch_offsets[:, :, :, 1] *= h_step + + # Add patch_offsets to points + p = p.view(batch_size, n_patches, 1, 2) + patch_offsets + + # Scale to [-1, x] + p = p * 2 - 1 + + p = p.view(batch_size, -1, 2) + + amax, amin = p.max(), p.min() + assert (amax <= 1. and amin >= -1.) + + return p + + +def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, n_steps=40): + ''' Returns n_steps equally spaced points inside the unit cube on the rays + cast from ray0 with direction ray_direction. + + This function is used to get the ray marching points {p^ray_j} for a given + camera position ray0 and + a given ray direction ray_direction which goes from the camera_position to + the pixel location. + + NOTE: The returned values d_proposal are the lengths of the ray: + p^ray_j = ray0 + d_proposal_j * ray_direction + + Args: + ray0 (tensor): Start positions of the rays + ray_direction (tensor): Directions of rays + padding (float): Padding which is applied to the unit cube + eps (float): The epsilon value for numerical stability + n_steps (int): number of steps + ''' + batch_size, n_pts, _ = ray0.shape + device = ray0.device + + p_intervals, d_intervals, mask_inside_cube = \ + check_ray_intersection_with_unit_cube(ray0, ray_direction, padding, + eps) + d_proposal = d_intervals[:, :, 0].unsqueeze(-1) + \ + torch.linspace(0, 1, steps=n_steps).to(device).view(1, 1, -1) * \ + (d_intervals[:, :, 1] - d_intervals[:, :, 0]).unsqueeze(-1) + d_proposal = d_proposal.unsqueeze(-1) + + return d_proposal, mask_inside_cube + + +def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, scale=2.0): + ''' Checks if rays ray0 + d * ray_direction intersect with unit cube with + padding padding. + + It returns the two intersection points as well as the sorted ray lengths d. + + Args: + ray0 (tensor): Start positions of the rays + ray_direction (tensor): Directions of rays + padding (float): Padding which is applied to the unit cube + eps (float): The epsilon value for numerical stability + scale (float): cube size + ''' + batch_size, n_pts, _ = ray0.shape + device = ray0.device + + # calculate intersections with unit cube (< . , . > is the dot product) + # = = 0 + # d = - / + + # Get points on plane p_e + p_distance = (scale * 0.5) + padding / 2 + p_e = torch.ones(batch_size, n_pts, 6).to(device) * p_distance + p_e[:, :, 3:] *= -1. + + # Calculate the intersection points with given formula + nominator = p_e - ray0.repeat(1, 1, 2) + denominator = ray_direction.repeat(1, 1, 2) + d_intersect = nominator / denominator + p_intersect = ray0.unsqueeze(-2) + d_intersect.unsqueeze(-1) * \ + ray_direction.unsqueeze(-2) + + # Calculate mask where points intersect unit cube + p_mask_inside_cube = ( + (p_intersect[:, :, :, 0] <= p_distance + eps) & + (p_intersect[:, :, :, 1] <= p_distance + eps) & + (p_intersect[:, :, :, 2] <= p_distance + eps) & + (p_intersect[:, :, :, 0] >= -(p_distance + eps)) & + (p_intersect[:, :, :, 1] >= -(p_distance + eps)) & + (p_intersect[:, :, :, 2] >= -(p_distance + eps)) + ).cpu() + + # Correct rays are these which intersect exactly 2 times + mask_inside_cube = p_mask_inside_cube.sum(-1) == 2 + + # Get interval values for p's which are valid + p_intervals = p_intersect[mask_inside_cube][p_mask_inside_cube[mask_inside_cube]].view(-1, 2, 3) + p_intervals_batch = torch.zeros(batch_size, n_pts, 2, 3).to(device) + p_intervals_batch[mask_inside_cube] = p_intervals + + # Calculate ray lengths for the interval points + d_intervals_batch = torch.zeros(batch_size, n_pts, 2).to(device) + norm_ray = torch.norm(ray_direction[mask_inside_cube], dim=-1) + d_intervals_batch[mask_inside_cube] = torch.stack( + [ + torch.norm(p_intervals[:, 0] - ray0[mask_inside_cube], dim=-1) / norm_ray, + torch.norm(p_intervals[:, 1] - ray0[mask_inside_cube], dim=-1) / norm_ray, + ], + dim=-1 + ) + + # Sort the ray lengths + d_intervals_batch, indices_sort = d_intervals_batch.sort() + p_intervals_batch = p_intervals_batch[torch.arange(batch_size).view(-1, 1, 1), + torch.arange(n_pts).view(1, -1, 1), indices_sort] + + return p_intervals_batch, d_intervals_batch, mask_inside_cube + + +def intersect_camera_rays_with_unit_cube( + pixels, camera_mat, world_mat, scale_mat, padding=0.1, eps=1e-6, use_ray_length_as_depth=True +): + ''' Returns the intersection points of ray cast from camera origin to + pixel points p on the image plane. + + The function returns the intersection points as well the depth values and + a mask specifying which ray intersects the unit cube. + + Args: + pixels (tensor): Pixel points on image plane (range [-1, 1]) + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + padding (float): Padding which is applied to the unit cube + eps (float): The epsilon value for numerical stability + + ''' + batch_size, n_points, _ = pixels.shape + + pixel_world = image_points_to_world(pixels, camera_mat, world_mat, scale_mat) + camera_world = origin_to_world(n_points, camera_mat, world_mat, scale_mat) + ray_vector = (pixel_world - camera_world) + + p_cube, d_cube, mask_cube = check_ray_intersection_with_unit_cube( + camera_world, ray_vector, padding=padding, eps=eps + ) + if not use_ray_length_as_depth: + p_cam = transform_to_camera_space( + p_cube.view(batch_size, -1, 3), camera_mat, world_mat, scale_mat + ).view(batch_size, n_points, -1, 3) + d_cube = p_cam[:, :, :, -1] + return p_cube, d_cube, mask_cube + + +def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), subsample_to=None): + ''' Arranges pixels for given resolution in range image_range. + + The function returns the unscaled pixel locations as integers and the + scaled float values. + + Args: + resolution (tuple): image resolution + batch_size (int): batch size + image_range (tuple): range of output points (default [-1, 1]) + subsample_to (int): if integer and > 0, the points are randomly + subsampled to this value + ''' + h, w = resolution + n_points = resolution[0] * resolution[1] + + # Arrange pixel location in scale resolution + pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h)) + pixel_locations = torch.stack([pixel_locations[0], pixel_locations[1]], + dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1) + pixel_scaled = pixel_locations.clone().float() + + # Shift and scale points to match image_range + scale = (image_range[1] - image_range[0]) + loc = scale / 2 + pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc + pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc + + # Subsample points if subsample_to is not None and > 0 + if (subsample_to is not None and subsample_to > 0 and subsample_to < n_points): + idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to, ), replace=False) + pixel_scaled = pixel_scaled[:, idx] + pixel_locations = pixel_locations[:, idx] + + return pixel_locations, pixel_scaled + + +def to_pytorch(tensor, return_type=False): + ''' Converts input tensor to pytorch. + + Args: + tensor (tensor): Numpy or Pytorch tensor + return_type (bool): whether to return input type + ''' + is_numpy = False + if type(tensor) == np.ndarray: + tensor = torch.from_numpy(tensor) + is_numpy = True + tensor = tensor.clone() + if return_type: + return tensor, is_numpy + return tensor + + +def get_mask(tensor): + ''' Returns mask of non-illegal values for tensor. + + Args: + tensor (tensor): Numpy or Pytorch tensor + ''' + tensor, is_numpy = to_pytorch(tensor, True) + mask = ((abs(tensor) != np.inf) & (torch.isnan(tensor) == False)) + mask = mask.to(torch.bool) + if is_numpy: + mask = mask.numpy() + + return mask + + +def transform_mesh(mesh, transform): + ''' Transforms a mesh with given transformation. + + Args: + mesh (trimesh mesh): mesh + transform (tensor): transformation matrix of size 4 x 4 + ''' + mesh = deepcopy(mesh) + v = np.asarray(mesh.vertices).astype(np.float32) + v_transformed = transform_pointcloud(v, transform) + mesh.vertices = v_transformed + return mesh + + +def transform_pointcloud(pointcloud, transform): + ''' Transforms a point cloud with given transformation. + + Args: + pointcloud (tensor): tensor of size N x 3 + transform (tensor): transformation of size 4 x 4 + ''' + + assert (transform.shape == (4, 4) and pointcloud.shape[-1] == 3) + + pcl, is_numpy = to_pytorch(pointcloud, True) + transform = to_pytorch(transform) + + # Transform point cloud to homogen coordinate system + pcl_hom = torch.cat([pcl, torch.ones(pcl.shape[0], 1)], dim=-1).transpose(1, 0) + + # Apply transformation to point cloud + pcl_hom_transformed = transform @ pcl_hom + + # Transform back to 3D coordinates + pcl_out = pcl_hom_transformed[:3].transpose(1, 0) + if is_numpy: + pcl_out = pcl_out.numpy() + + return pcl_out + + +def transform_points_batch(p, transform): + ''' Transform points tensor with given transform. + + Args: + p (tensor): tensor of size B x N x 3 + transform (tensor): transformation of size B x 4 x 4 + ''' + device = p.device + assert (transform.shape[1:] == (4, 4) and p.shape[-1] == 3 and p.shape[0] == transform.shape[0]) + + # Transform points to homogen coordinates + pcl_hom = torch.cat([p, torch.ones(p.shape[0], p.shape[1], 1).to(device)], + dim=-1).transpose(2, 1) + + # Apply transformation + pcl_hom_transformed = transform @ pcl_hom + + # Transform back to 3D coordinates + pcl_out = pcl_hom_transformed[:, :3].transpose(2, 1) + return pcl_out + + +def get_tensor_values( + tensor, p, grid_sample=True, mode='nearest', with_mask=False, squeeze_channel_dim=False +): + ''' + Returns values from tensor at given location p. + + Args: + tensor (tensor): tensor of size B x C x H x W + p (tensor): position values scaled between [-1, 1] and + of size B x N x 2 + grid_sample (boolean): whether to use grid sampling + mode (string): what mode to perform grid sampling in + with_mask (bool): whether to return the mask for invalid values + squeeze_channel_dim (bool): whether to squeeze the channel dimension + (only applicable to 1D data) + ''' + p = to_pytorch(p) + tensor, is_numpy = to_pytorch(tensor, True) + batch_size, _, h, w = tensor.shape + + if grid_sample: + p = p.unsqueeze(1) + values = torch.nn.functional.grid_sample(tensor, p, mode=mode) + values = values.squeeze(2) + values = values.permute(0, 2, 1) + else: + p[:, :, 0] = (p[:, :, 0] + 1) * (w) / 2 + p[:, :, 1] = (p[:, :, 1] + 1) * (h) / 2 + p = p.long() + values = tensor[torch.arange(batch_size).unsqueeze(-1), :, p[:, :, 1], p[:, :, 0]] + + if with_mask: + mask = get_mask(values) + if squeeze_channel_dim: + mask = mask.squeeze(-1) + if is_numpy: + mask = mask.numpy() + + if squeeze_channel_dim: + values = values.squeeze(-1) + + if is_numpy: + values = values.numpy() + + if with_mask: + return values, mask + return values + + +def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat, invert=True): + ''' Transforms pixel positions p with given depth value d to world coordinates. + + Args: + pixels (tensor): pixel tensor of size B x N x 2 + depth (tensor): depth tensor of size B x N x 1 + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + invert (bool): whether to invert matrices (default: true) + ''' + assert (pixels.shape[-1] == 2) + + # Convert to pytorch + pixels, is_numpy = to_pytorch(pixels, True) + depth = to_pytorch(depth) + camera_mat = to_pytorch(camera_mat) + world_mat = to_pytorch(world_mat) + scale_mat = to_pytorch(scale_mat) + + # Invert camera matrices + if invert: + camera_mat = torch.inverse(camera_mat) + world_mat = torch.inverse(world_mat) + scale_mat = torch.inverse(scale_mat) + + # Transform pixels to homogen coordinates + pixels = pixels.permute(0, 2, 1) + pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1) + + # Project pixels into camera space + pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1) + + # Transform pixels to world space + p_world = scale_mat @ world_mat @ camera_mat @ pixels + + # Transform p_world back to 3D coordinates + p_world = p_world[:, :3].permute(0, 2, 1) + + if is_numpy: + p_world = p_world.numpy() + return p_world + + +def transform_to_camera_space(p_world, camera_mat, world_mat, scale_mat): + ''' Transforms world points to camera space. + Args: + p_world (tensor): world points tensor of size B x N x 3 + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + ''' + batch_size, n_p, _ = p_world.shape + device = p_world.device + + # Transform world points to homogen coordinates + p_world = torch.cat([p_world, torch.ones(batch_size, n_p, 1).to(device)], + dim=-1).permute(0, 2, 1) + + # Apply matrices to transform p_world to camera space + p_cam = camera_mat @ world_mat @ scale_mat @ p_world + + # Transform points back to 3D coordinates + p_cam = p_cam[:, :3].permute(0, 2, 1) + return p_cam + + +def origin_to_world(n_points, camera_mat, world_mat, scale_mat, invert=True): + ''' Transforms origin (camera location) to world coordinates. + + Args: + n_points (int): how often the transformed origin is repeated in the + form (batch_size, n_points, 3) + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + invert (bool): whether to invert the matrices (default: true) + ''' + batch_size = camera_mat.shape[0] + device = camera_mat.device + + # Create origin in homogen coordinates + p = torch.zeros(batch_size, 4, n_points).to(device) + p[:, -1] = 1. + + # Invert matrices + if invert: + camera_mat = torch.inverse(camera_mat) + world_mat = torch.inverse(world_mat) + scale_mat = torch.inverse(scale_mat) + + # Apply transformation + p_world = scale_mat @ world_mat @ camera_mat @ p + + # Transform points back to 3D coordinates + p_world = p_world[:, :3].permute(0, 2, 1) + return p_world + + +def image_points_to_world(image_points, camera_mat, world_mat, scale_mat, invert=True): + ''' Transforms points on image plane to world coordinates. + + In contrast to transform_to_world, no depth value is needed as points on + the image plane have a fixed depth of 1. + + Args: + image_points (tensor): image points tensor of size B x N x 2 + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + invert (bool): whether to invert matrices (default: true) + ''' + batch_size, n_pts, dim = image_points.shape + assert (dim == 2) + device = image_points.device + + d_image = torch.ones(batch_size, n_pts, 1).to(device) + return transform_to_world( + image_points, d_image, camera_mat, world_mat, scale_mat, invert=invert + ) + + +def check_weights(params): + ''' Checks weights for illegal values. + + Args: + params (tensor): parameter tensor + ''' + for k, v in params.items(): + if torch.isnan(v).any(): + logger_py.warn('NaN Values detected in model weight %s.' % k) + + +def check_tensor(tensor, tensorname='', input_tensor=None): + ''' Checks tensor for illegal values. + + Args: + tensor (tensor): tensor + tensorname (string): name of tensor + input_tensor (tensor): previous input + ''' + if torch.isnan(tensor).any(): + logger_py.warn('Tensor %s contains nan values.' % tensorname) + if input_tensor is not None: + logger_py.warn(f'Input was: {input_tensor}') + + +def get_prob_from_logits(logits): + ''' Returns probabilities for logits + + Args: + logits (tensor): logits + ''' + odds = np.exp(logits) + probs = odds / (1 + odds) + return probs + + +def get_logits_from_prob(probs, eps=1e-4): + ''' Returns logits for probabilities. + + Args: + probs (tensor): probability tensor + eps (float): epsilon value for numerical stability + ''' + probs = np.clip(probs, a_min=eps, a_max=1 - eps) + logits = np.log(probs / (1 - probs)) + return logits + + +def chamfer_distance(points1, points2, use_kdtree=True, give_id=False): + ''' Returns the chamfer distance for the sets of points. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + use_kdtree (bool): whether to use a kdtree + give_id (bool): whether to return the IDs of nearest points + ''' + if use_kdtree: + return chamfer_distance_kdtree(points1, points2, give_id=give_id) + else: + return chamfer_distance_naive(points1, points2) + + +def chamfer_distance_naive(points1, points2): + ''' Naive implementation of the Chamfer distance. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + ''' + assert (points1.size() == points2.size()) + batch_size, T, _ = points1.size() + + points1 = points1.view(batch_size, T, 1, 3) + points2 = points2.view(batch_size, 1, T, 3) + + distances = (points1 - points2).pow(2).sum(-1) + + chamfer1 = distances.min(dim=1)[0].mean(dim=1) + chamfer2 = distances.min(dim=2)[0].mean(dim=1) + + chamfer = chamfer1 + chamfer2 + return chamfer + + +def chamfer_distance_kdtree(points1, points2, give_id=False): + ''' KD-tree based implementation of the Chamfer distance. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + give_id (bool): whether to return the IDs of the nearest points + ''' + # Points have size batch_size x T x 3 + batch_size = points1.size(0) + + # First convert points to numpy + points1_np = points1.detach().cpu().numpy() + points2_np = points2.detach().cpu().numpy() + + # Get list of nearest neighbors indices + idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np) + idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device) + # Expands it as batch_size x 1 x 3 + idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1) + + # Get list of nearest neighbors indices + idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np) + idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device) + # Expands it as batch_size x T x 3 + idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2) + + # Compute nearest neighbors in points2 to points in points1 + # points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k] + points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand) + + # Compute nearest neighbors in points1 to points in points2 + # points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k] + points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand) + + # Compute chamfer distance + chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1) + chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1) + + # Take sum + chamfer = chamfer1 + chamfer2 + + # If required, also return nearest neighbors + if give_id: + return chamfer1, chamfer2, idx_nn_12, idx_nn_21 + + return chamfer + + +def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1): + ''' Returns the nearest neighbors for point sets batchwise. + + Args: + points_src (numpy array): source points + points_tgt (numpy array): target points + k (int): number of nearest neighbors to return + ''' + indices = [] + distances = [] + + for (p1, p2) in zip(points_src, points_tgt): + kdtree = KDTree(p2) + dist, idx = kdtree.query(p1, k=k) + indices.append(idx) + distances.append(dist) + + return indices, distances + + +def normalize_imagenet(x): + ''' Normalize input images according to ImageNet standards. + + Args: + x (tensor): input images + ''' + x = x.clone() + x[:, 0] = (x[:, 0] - 0.485) / 0.229 + x[:, 1] = (x[:, 1] - 0.456) / 0.224 + x[:, 2] = (x[:, 2] - 0.406) / 0.225 + return x + + +def make_3d_grid(bb_min, bb_max, shape): + ''' Makes a 3D grid. + + Args: + bb_min (tuple): bounding box minimum + bb_max (tuple): bounding box maximum + shape (tuple): output shape + ''' + size = shape[0] * shape[1] * shape[2] + + pxs = torch.linspace(bb_min[0], bb_max[0], shape[0]) + pys = torch.linspace(bb_min[1], bb_max[1], shape[1]) + pzs = torch.linspace(bb_min[2], bb_max[2], shape[2]) + + pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) + pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) + pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) + p = torch.stack([pxs, pys, pzs], dim=1) + + return p + + +def get_occupancy_loss_points( + pixels, + camera_mat, + world_mat, + scale_mat, + depth_image=None, + use_cube_intersection=True, + occupancy_random_normal=False, + depth_range=[0, 2.4] +): + ''' Returns 3D points for occupancy loss. + + Args: + pixels (tensor): sampled pixels in range [-1, 1] + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + depth_image tensor): if not None, these depth values are used for + initialization (e.g. depth or visual hull depth) + use_cube_intersection (bool): whether to check unit cube intersection + occupancy_random_normal (bool): whether to sample from a Normal + distribution instead of a uniform one + depth_range (float): depth range; important when no cube + intersection is used + ''' + device = pixels.device + batch_size, n_points, _ = pixels.shape + + if use_cube_intersection: + _, d_cube_intersection, mask_cube = \ + intersect_camera_rays_with_unit_cube( + pixels, camera_mat, world_mat, scale_mat, padding=0., + use_ray_length_as_depth=False) + d_cube = d_cube_intersection[mask_cube] + + d_occupancy = torch.rand(batch_size, n_points).to(device) * depth_range[1] + + if use_cube_intersection: + d_occupancy[mask_cube] = d_cube[:, 0] + \ + torch.rand(d_cube.shape[0]).to( + device) * (d_cube[:, 1] - d_cube[:, 0]) + if occupancy_random_normal: + d_occupancy = torch.randn(batch_size, n_points).to(device) \ + * (depth_range[1] / 8) + depth_range[1] / 2 + if use_cube_intersection: + mean_cube = d_cube.sum(-1) / 2 + std_cube = (d_cube[:, 1] - d_cube[:, 0]) / 8 + d_occupancy[mask_cube] = mean_cube + \ + torch.randn(mean_cube.shape[0]).to(device) * std_cube + + if depth_image is not None: + depth_gt, mask_gt_depth = get_tensor_values( + depth_image, pixels, squeeze_channel_dim=True, with_mask=True + ) + d_occupancy[mask_gt_depth] = depth_gt[mask_gt_depth] + + p_occupancy = transform_to_world( + pixels, d_occupancy.unsqueeze(-1), camera_mat, world_mat, scale_mat + ) + return p_occupancy + + +def get_freespace_loss_points( + pixels, camera_mat, world_mat, scale_mat, use_cube_intersection=True, depth_range=[0, 2.4] +): + ''' Returns 3D points for freespace loss. + + Args: + pixels (tensor): sampled pixels in range [-1, 1] + camera_mat (tensor): camera matrix + world_mat (tensor): world matrix + scale_mat (tensor): scale matrix + use_cube_intersection (bool): whether to check unit cube intersection + depth_range (float): depth range; important when no cube + intersection is used + ''' + device = pixels.device + batch_size, n_points, _ = pixels.shape + + d_freespace = torch.rand(batch_size, n_points).to(device) * \ + depth_range[1] + + if use_cube_intersection: + _, d_cube_intersection, mask_cube = \ + intersect_camera_rays_with_unit_cube( + pixels, camera_mat, world_mat, scale_mat, + use_ray_length_as_depth=False) + d_cube = d_cube_intersection[mask_cube] + d_freespace[mask_cube] = d_cube[:, 0] + \ + torch.rand(d_cube.shape[0]).to( + device) * (d_cube[:, 1] - d_cube[:, 0]) + + p_freespace = transform_to_world( + pixels, d_freespace.unsqueeze(-1), camera_mat, world_mat, scale_mat + ) + return p_freespace + + +def normalize_tensor(tensor, min_norm=1e-5, feat_dim=-1): + ''' Normalizes the tensor. + + Args: + tensor (tensor): tensor + min_norm (float): minimum norm for numerical stability + feat_dim (int): feature dimension in tensor (default: -1) + ''' + norm_tensor = torch.clamp(torch.norm(tensor, dim=feat_dim, keepdim=True), min=min_norm) + normed_tensor = tensor / norm_tensor + return normed_tensor diff --git a/lib/pymafx/utils/data_loader.py b/lib/pymafx/utils/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..cc92ad223836e9de322bc80bbab887bb9ec3f17b --- /dev/null +++ b/lib/pymafx/utils/data_loader.py @@ -0,0 +1,77 @@ +from __future__ import division +import torch +from torch.utils.data import DataLoader +from torch.utils.data.sampler import Sampler + + +class RandomSampler(Sampler): + def __init__(self, data_source, checkpoint): + self.data_source = data_source + if checkpoint is not None and checkpoint['dataset_perm'] is not None: + self.dataset_perm = checkpoint['dataset_perm'] + self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] + else: + self.dataset_perm = torch.randperm(len(self.data_source)).tolist() + self.perm = torch.randperm(len(self.data_source)).tolist() + + def __iter__(self): + return iter(self.perm) + + def __len__(self): + return len(self.perm) + + +class SequentialSampler(Sampler): + def __init__(self, data_source, checkpoint): + self.data_source = data_source + if checkpoint is not None and checkpoint['dataset_perm'] is not None: + self.dataset_perm = checkpoint['dataset_perm'] + self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] + else: + self.dataset_perm = list(range(len(self.data_source))) + self.perm = self.dataset_perm + + def __iter__(self): + return iter(self.perm) + + def __len__(self): + return len(self.perm) + + +class CheckpointDataLoader(DataLoader): + """ + Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch. + """ + def __init__( + self, + dataset, + checkpoint=None, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=False, + drop_last=True, + timeout=0, + worker_init_fn=None + ): + + if shuffle: + sampler = RandomSampler(dataset, checkpoint) + else: + sampler = SequentialSampler(dataset, checkpoint) + if checkpoint is not None: + self.checkpoint_batch_idx = checkpoint['batch_idx'] + else: + self.checkpoint_batch_idx = 0 + + super(CheckpointDataLoader, self).__init__( + dataset, + sampler=sampler, + shuffle=False, + batch_size=batch_size, + num_workers=num_workers, + drop_last=drop_last, + pin_memory=pin_memory, + timeout=timeout, + worker_init_fn=None + ) diff --git a/lib/pymafx/utils/demo_utils.py b/lib/pymafx/utils/demo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ad8da91c7a7f6f67d4770c9866a02a78aa5275 --- /dev/null +++ b/lib/pymafx/utils/demo_utils.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import os +import cv2 +import time +import json +import torch +import subprocess +import numpy as np +import os.path as osp +# from pytube import YouTube +from collections import OrderedDict + +from utils.smooth_bbox import get_smooth_bbox_params, get_all_bbox_params +from datasets.data_utils.img_utils import get_single_image_crop_demo +from utils.geometry import rotation_matrix_to_angle_axis + + +def preprocess_video(video, joints2d, bboxes, frames, scale=1.0, crop_size=224): + """ + Read video, do normalize and crop it according to the bounding box. + If there are bounding box annotations, use them to crop the image. + If no bounding box is specified but openpose detections are available, use them to get the bounding box. + + :param video (ndarray): input video + :param joints2d (ndarray, NxJx3): openpose detections + :param bboxes (ndarray, Nx5): bbox detections + :param scale (float): bbox crop scaling factor + :param crop_size (int): crop width and height + :return: cropped video, cropped and normalized video, modified bboxes, modified joints2d + """ + + if joints2d is not None: + bboxes, time_pt1, time_pt2 = get_all_bbox_params(joints2d, vis_thresh=0.3) + bboxes[:, 2:] = 150. / bboxes[:, 2:] + bboxes = np.stack([bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 2]]).T + + video = video[time_pt1:time_pt2] + joints2d = joints2d[time_pt1:time_pt2] + frames = frames[time_pt1:time_pt2] + + shape = video.shape + + temp_video = np.zeros((shape[0], crop_size, crop_size, shape[-1])) + norm_video = torch.zeros(shape[0], shape[-1], crop_size, crop_size) + + for idx in range(video.shape[0]): + + img = video[idx] + bbox = bboxes[idx] + + j2d = joints2d[idx] if joints2d is not None else None + + norm_img, raw_img, kp_2d = get_single_image_crop_demo( + img, bbox, kp_2d=j2d, scale=scale, crop_size=crop_size + ) + + if joints2d is not None: + joints2d[idx] = kp_2d + + temp_video[idx] = raw_img + norm_video[idx] = norm_img + + temp_video = temp_video.astype(np.uint8) + + return temp_video, norm_video, bboxes, joints2d, frames + + +def download_youtube_clip(url, download_folder): + return YouTube(url).streams.first().download(output_path=download_folder) + + +def smplify_runner( + pred_rotmat, + pred_betas, + pred_cam, + j2d, + device, + batch_size, + lr=1.0, + opt_steps=1, + use_lbfgs=True, + pose2aa=True +): + smplify = TemporalSMPLify( + step_size=lr, + batch_size=batch_size, + num_iters=opt_steps, + focal_length=5000., + use_lbfgs=use_lbfgs, + device=device, + # max_iter=10, + ) + # Convert predicted rotation matrices to axis-angle + if pose2aa: + pred_pose = rotation_matrix_to_angle_axis(pred_rotmat.detach()).reshape(batch_size, -1) + else: + pred_pose = pred_rotmat + + # Calculate camera parameters for smplify + pred_cam_t = torch.stack( + [pred_cam[:, 1], pred_cam[:, 2], 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9)], dim=-1 + ) + + gt_keypoints_2d_orig = j2d + # Before running compute reprojection error of the network + opt_joint_loss = smplify.get_fitting_loss( + pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), + 0.5 * 224 * torch.ones(batch_size, 2, device=device), gt_keypoints_2d_orig + ).mean(dim=-1) + + best_prediction_id = torch.argmin(opt_joint_loss).item() + pred_betas = pred_betas[best_prediction_id].unsqueeze(0) + # pred_betas = pred_betas[best_prediction_id:best_prediction_id+2] # .unsqueeze(0) + # top5_best_idxs = torch.topk(opt_joint_loss, 5, largest=False)[1] + # breakpoint() + + start = time.time() + # Run SMPLify optimization initialized from the network prediction + # new_opt_vertices, new_opt_joints, \ + # new_opt_pose, new_opt_betas, \ + # new_opt_cam_t, \ + output, new_opt_joint_loss = smplify( + pred_pose.detach(), + pred_betas.detach(), + pred_cam_t.detach(), + 0.5 * 224 * torch.ones(batch_size, 2, device=device), + gt_keypoints_2d_orig, + ) + new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1) + # smplify_time = time.time() - start + # print(f'Smplify time: {smplify_time}') + # Will update the dictionary for the examples where the new loss is less than the current one + update = (new_opt_joint_loss < opt_joint_loss) + + new_opt_vertices = output['verts'] + new_opt_cam_t = output['theta'][:, :3] + new_opt_pose = output['theta'][:, 3:75] + new_opt_betas = output['theta'][:, 75:] + new_opt_joints3d = output['kp_3d'] + + return_val = [ + update, + new_opt_vertices.cpu(), + new_opt_cam_t.cpu(), + new_opt_pose.cpu(), + new_opt_betas.cpu(), + new_opt_joints3d.cpu(), + new_opt_joint_loss, + opt_joint_loss, + ] + + return return_val + + +def trim_videos(filename, start_time, end_time, output_filename): + command = [ + 'ffmpeg', '-i', + '"%s"' % filename, '-ss', + str(start_time), '-t', + str(end_time - start_time), '-c:v', 'libx264', '-c:a', 'copy', '-threads', '1', '-loglevel', + 'panic', + '"%s"' % output_filename + ] + # command = ' '.join(command) + subprocess.call(command) + + +def video_to_images(vid_file, img_folder=None, return_info=False): + if img_folder is None: + img_folder = osp.join(osp.expanduser('~'), 'tmp', osp.basename(vid_file).replace('.', '_')) + # img_folder = osp.join('/tmp', osp.basename(vid_file).replace('.', '_')) + + print(img_folder) + os.makedirs(img_folder, exist_ok=True) + + command = ['ffmpeg', '-i', vid_file, '-f', 'image2', '-v', 'error', f'{img_folder}/%06d.png'] + print(f'Running \"{" ".join(command)}\"') + + try: + subprocess.call(command) + except: + subprocess.call(f'{" ".join(command)}', shell=True) + + print(f'Images saved to \"{img_folder}\"') + + img_shape = cv2.imread(osp.join(img_folder, '000001.png')).shape + + if return_info: + return img_folder, len(os.listdir(img_folder)), img_shape + else: + return img_folder + + +def download_url(url, outdir): + print(f'Downloading files from {url}') + cmd = ['wget', '-c', url, '-P', outdir] + subprocess.call(cmd) + + +def download_ckpt(outdir='data/vibe_data', use_3dpw=False): + os.makedirs(outdir, exist_ok=True) + + if use_3dpw: + ckpt_file = 'data/vibe_data/vibe_model_w_3dpw.pth.tar' + url = 'https://www.dropbox.com/s/41ozgqorcp095ja/vibe_model_w_3dpw.pth.tar' + if not os.path.isfile(ckpt_file): + download_url(url=url, outdir=outdir) + else: + ckpt_file = 'data/vibe_data/vibe_model_wo_3dpw.pth.tar' + url = 'https://www.dropbox.com/s/amj2p8bmf6g56k6/vibe_model_wo_3dpw.pth.tar' + if not os.path.isfile(ckpt_file): + download_url(url=url, outdir=outdir) + + return ckpt_file + + +def images_to_video(img_folder, output_vid_file): + os.makedirs(img_folder, exist_ok=True) + + command = [ + 'ffmpeg', + '-y', + '-threads', + '16', + '-i', + f'{img_folder}/%06d.png', + '-profile:v', + 'baseline', + '-level', + '3.0', + '-c:v', + 'libx264', + '-pix_fmt', + 'yuv420p', + '-an', + '-v', + 'error', + output_vid_file, + ] + + print(f'Running \"{" ".join(command)}\"') + try: + subprocess.call(command) + except: + subprocess.call(f'{" ".join(command)}', shell=True) + + +def convert_crop_cam_to_orig_img(cam, bbox, img_width, img_height): + ''' + Convert predicted camera from cropped image coordinates + to original image coordinates + :param cam (ndarray, shape=(3,)): weak perspective camera in cropped img coordinates + :param bbox (ndarray, shape=(4,)): bbox coordinates (c_x, c_y, h) + :param img_width (int): original image width + :param img_height (int): original image height + :return: + ''' + cx, cy, h = bbox[:, 0], bbox[:, 1], bbox[:, 2] + hw, hh = img_width / 2., img_height / 2. + sx = cam[:, 0] * (1. / (img_width / h)) + sy = cam[:, 0] * (1. / (img_height / h)) + tx = ((cx - hw) / hw / sx) + cam[:, 1] + ty = ((cy - hh) / hh / sy) + cam[:, 2] + orig_cam = np.stack([sx, sy, tx, ty]).T + return orig_cam + + +def prepare_rendering_results(results_dict, nframes): + frame_results = [{} for _ in range(nframes)] + for person_id, person_data in results_dict.items(): + for idx, frame_id in enumerate(person_data['frame_ids']): + frame_results[frame_id][person_id] = { + 'verts': + person_data['verts'][idx], + 'smplx_verts': + person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None, + 'cam': + person_data['orig_cam'][idx], + 'cam_t': + person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None, + # 'cam': person_data['pred_cam'][idx], + } + + # naive depth ordering based on the scale of the weak perspective camera + for frame_id, frame_data in enumerate(frame_results): + # sort based on y-scale of the cam in original image coords + sort_idx = np.argsort([v['cam'][1] for k, v in frame_data.items()]) + frame_results[frame_id] = OrderedDict( + {list(frame_data.keys())[i]: frame_data[list(frame_data.keys())[i]] + for i in sort_idx} + ) + + return frame_results diff --git a/lib/pymafx/utils/densepose_methods.py b/lib/pymafx/utils/densepose_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..93fdf66a6651dcfe05f6e95c55379eaa00c52cb0 --- /dev/null +++ b/lib/pymafx/utils/densepose_methods.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import copy +import cv2 +from scipy.io import loadmat +import scipy.spatial.distance +import os + + +class DensePoseMethods: + def __init__(self): + # + ALP_UV = loadmat(os.path.join('./data/UV_data', 'UV_Processed.mat')) + self.FaceIndices = np.array(ALP_UV['All_FaceIndices']).squeeze() + self.FacesDensePose = ALP_UV['All_Faces'] - 1 + self.U_norm = ALP_UV['All_U_norm'].squeeze() + self.V_norm = ALP_UV['All_V_norm'].squeeze() + self.All_vertices = ALP_UV['All_vertices'][0] + ## Info to compute symmetries. + self.SemanticMaskSymmetries = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14] + self.Index_Symmetry_List = [ + 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23 + ] + UV_symmetry_filename = os.path.join('./data/UV_data', 'UV_symmetry_transforms.mat') + self.UV_symmetry_transformations = loadmat(UV_symmetry_filename) + + def get_symmetric_densepose(self, I, U, V, x, y, Mask): + ### This is a function to get the mirror symmetric UV labels. + Labels_sym = np.zeros(I.shape) + U_sym = np.zeros(U.shape) + V_sym = np.zeros(V.shape) + ### + for i in (range(24)): + if i + 1 in I: + Labels_sym[I == (i + 1)] = self.Index_Symmetry_List[i] + jj = np.where(I == (i + 1)) + ### + U_loc = (U[jj] * 255).astype(np.int64) + V_loc = (V[jj] * 255).astype(np.int64) + ### + V_sym[jj] = self.UV_symmetry_transformations['V_transforms'][0, i][V_loc, U_loc] + U_sym[jj] = self.UV_symmetry_transformations['U_transforms'][0, i][V_loc, U_loc] + ## + Mask_flip = np.fliplr(Mask) + Mask_flipped = np.zeros(Mask.shape) + # + for i in (range(14)): + Mask_flipped[Mask_flip == (i + 1)] = self.SemanticMaskSymmetries[i + 1] + # + [y_max, x_max] = Mask_flip.shape + y_sym = y + x_sym = x_max - x + # + return Labels_sym, U_sym, V_sym, x_sym, y_sym, Mask_flipped + + def barycentric_coordinates_exists(self, P0, P1, P2, P): + u = P1 - P0 + v = P2 - P0 + w = P - P0 + # + vCrossW = np.cross(v, w) + vCrossU = np.cross(v, u) + if (np.dot(vCrossW, vCrossU) < 0): + return False + # + uCrossW = np.cross(u, w) + uCrossV = np.cross(u, v) + # + if (np.dot(uCrossW, uCrossV) < 0): + return False + # + denom = np.sqrt((uCrossV**2).sum()) + r = np.sqrt((vCrossW**2).sum()) / denom + t = np.sqrt((uCrossW**2).sum()) / denom + # + return ((r <= 1) & (t <= 1) & (r + t <= 1)) + + def barycentric_coordinates(self, P0, P1, P2, P): + u = P1 - P0 + v = P2 - P0 + w = P - P0 + # + vCrossW = np.cross(v, w) + vCrossU = np.cross(v, u) + # + uCrossW = np.cross(u, w) + uCrossV = np.cross(u, v) + # + denom = np.sqrt((uCrossV**2).sum()) + r = np.sqrt((vCrossW**2).sum()) / denom + t = np.sqrt((uCrossW**2).sum()) / denom + # + return (1 - (r + t), r, t) + + def IUV2FBC(self, I_point, U_point, V_point): + P = [U_point, V_point, 0] + FaceIndicesNow = np.where(self.FaceIndices == I_point) + FacesNow = self.FacesDensePose[FaceIndicesNow] + # + P_0 = np.vstack( + ( + self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0], + np.zeros(self.U_norm[FacesNow][:, 0].shape) + ) + ).transpose() + P_1 = np.vstack( + ( + self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1], + np.zeros(self.U_norm[FacesNow][:, 1].shape) + ) + ).transpose() + P_2 = np.vstack( + ( + self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2], + np.zeros(self.U_norm[FacesNow][:, 2].shape) + ) + ).transpose() + # + + for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)): + if (self.barycentric_coordinates_exists(P0, P1, P2, P)): + [bc1, bc2, bc3] = self.barycentric_coordinates(P0, P1, P2, P) + return (FaceIndicesNow[0][i], bc1, bc2, bc3) + # + # If the found UV is not inside any faces, select the vertex that is closest! + # + D1 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], + P_0[:, 0:2]).squeeze() + D2 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], + P_1[:, 0:2]).squeeze() + D3 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], + P_2[:, 0:2]).squeeze() + # + minD1 = D1.min() + minD2 = D2.min() + minD3 = D3.min() + # + if ((minD1 < minD2) & (minD1 < minD3)): + return (FaceIndicesNow[0][np.argmin(D1)], 1., 0., 0.) + elif ((minD2 < minD1) & (minD2 < minD3)): + return (FaceIndicesNow[0][np.argmin(D2)], 0., 1., 0.) + else: + return (FaceIndicesNow[0][np.argmin(D3)], 0., 0., 1.) + + def FBC2PointOnSurface(self, FaceIndex, bc1, bc2, bc3, Vertices): + ## + Vert_indices = self.All_vertices[self.FacesDensePose[FaceIndex]] - 1 + ## + p = Vertices[Vert_indices[0], :] * bc1 + \ + Vertices[Vert_indices[1], :] * bc2 + \ + Vertices[Vert_indices[2], :] * bc3 + ## + return (p) diff --git a/lib/pymafx/utils/geometry.py b/lib/pymafx/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..608288fc4d73a4918ab95938a7bf5dbe98ce606f --- /dev/null +++ b/lib/pymafx/utils/geometry.py @@ -0,0 +1,709 @@ +import torch +from torch.nn import functional as F +import numpy as np +import numbers +from einops.einops import rearrange +""" +Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" + + +def batch_rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 + ], + dim=1 + ).view(B, 3, 3) + return rotMat + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3, 3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape( + 1, 3, 1 + ).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape) + ) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta) + ) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def quaternion_to_angle(quaternion: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion vector to angle of the rotation. + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 1)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle(quaternion) # Nx1 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape) + ) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta) + ) + + # theta: torch.Tensor = 2.0 * torch.atan2(sin_theta, cos_theta) + + # theta2 = torch.where(sin_squared_theta > 0.0, - theta, theta) + + return theta.unsqueeze(-1) + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape) + ) + # if not rotation_matrix.shape[-2:] == (3, 4): + # raise ValueError( + # "Input size must be a N x 3 x 4 tensor. Got {}".format( + # rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack( + [ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1 + ) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack( + [ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, + rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1 + ) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack( + [ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1 + ) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack( + [ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1 + ) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt( + t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3 + ) # noqa + q *= 0.5 + return q + + +def batch_euler2matrix(r): + return quaternion_to_rotation_matrix(euler_to_quaternion(r)) + + +def euler_to_quaternion(r): + x = r[..., 0] + y = r[..., 1] + z = r[..., 2] + + z = z / 2.0 + y = y / 2.0 + x = x / 2.0 + cz = torch.cos(z) + sz = torch.sin(z) + cy = torch.cos(y) + sy = torch.sin(y) + cx = torch.cos(x) + sx = torch.sin(x) + quaternion = torch.zeros_like(r.repeat(1, 2))[..., :4].to(r.device) + quaternion[..., 0] += cx * cy * cz - sx * sy * sz + quaternion[..., 1] += cx * sy * sz + cy * cz * sx + quaternion[..., 2] += cx * cz * sy - sx * cy * sz + quaternion[..., 3] += cx * cy * sz + sx * cz * sy + return quaternion + + +def quaternion_to_rotation_matrix(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 + ], + dim=1 + ).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + if x.shape[-1] == 6: + batch_size = x.shape[0] + if len(x.shape) == 3: + num = x.shape[1] + x = rearrange(x, 'b n d -> (b n) d', d=6) + else: + num = 1 + x = rearrange(x, 'b (k l) -> b k l', k=3, l=2) + # x = x.view(-1,3,2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2, dim=-1) + + mat = torch.stack((b1, b2, b3), dim=-1) + if num > 1: + mat = rearrange(mat, '(b n) h w-> b n h w', b=batch_size, n=num, h=3, w=3) + else: + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2, dim=-1) + mat = torch.stack((b1, b2, b3), dim=-1) + return mat + + +def rotmat_to_rot6d(x): + """Convert 3x3 rotation matrix to 6D rotation representation. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,3,3) Batch of corresponding rotation matrices + Output: + (B,6) Batch of 6-D rotation representations + """ + batch_size = x.shape[0] + x = x[:, :, :2] + x = x.reshape(batch_size, 6) + return x + + +def rotmat_to_angle(x): + """Convert rotation to one-D angle. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,2) Batch of corresponding rotation + Output: + (B,1) Batch of 1-D angle + """ + a = F.normalize(x) + angle = torch.atan2(a[:, 0], a[:, 1]).unsqueeze(-1) + + return angle + + +def projection(pred_joints, pred_camera, retain_z=False, iwp_mode=True): + """ Project 3D points on the image plane based on the given camera info, + Identity rotation and Weak Perspective (IWP) camera is used when iwp_mode=True, more about camera settings: + SPEC: Seeing People in the Wild with an Estimated Camera, ICCV 2021 + """ + + batch_size = pred_joints.shape[0] + if iwp_mode: + cam_sxy = pred_camera['cam_sxy'] + pred_cam_t = torch.stack( + [cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)], dim=-1 + ) + + camera_center = torch.zeros(batch_size, 2) + pred_keypoints_2d = perspective_projection( + pred_joints, + rotation=torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(pred_joints.device), + translation=pred_cam_t, + focal_length=5000., + camera_center=camera_center, + retain_z=retain_z + ) + # # Normalize keypoints to [-1,1] + # pred_keypoints_2d = pred_keypoints_2d / (224. / 2.) + else: + assert type(pred_camera) is dict + + bbox_scale, bbox_center = pred_camera['bbox_scale'], pred_camera['bbox_center'] + img_w, img_h, crop_res = pred_camera['img_w'], pred_camera['img_h'], pred_camera['crop_res'] + cam_sxy, cam_rotmat, cam_intrinsics = pred_camera['cam_sxy'], pred_camera[ + 'cam_rotmat'], pred_camera['cam_intrinsics'] + if 'cam_t' in pred_camera: + cam_t = pred_camera['cam_t'] + else: + cam_t = convert_to_full_img_cam( + pare_cam=cam_sxy, + bbox_height=bbox_scale * 200., + bbox_center=bbox_center, + img_w=img_w, + img_h=img_h, + focal_length=cam_intrinsics[:, 0, 0], + ) + + pred_keypoints_2d = perspective_projection( + pred_joints, + rotation=cam_rotmat, + translation=cam_t, + cam_intrinsics=cam_intrinsics, + ) + + return pred_keypoints_2d + + +def perspective_projection( + points, + rotation, + translation, + focal_length=None, + camera_center=None, + cam_intrinsics=None, + retain_z=False +): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + if cam_intrinsics is not None: + K = cam_intrinsics + else: + # raise + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:, 0, 0] = focal_length + K[:, 1, 1] = focal_length + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + if retain_z: + return projected_points + else: + return projected_points[:, :, :-1] + + +def convert_to_full_img_cam(pare_cam, bbox_height, bbox_center, img_w, img_h, focal_length): + # Converts weak perspective camera estimated by PARE in + # bbox coords to perspective camera in full image coordinates + # from https://arxiv.org/pdf/2009.06549.pdf + s, tx, ty = pare_cam[:, 0], pare_cam[:, 1], pare_cam[:, 2] + res = 224 + r = bbox_height / res + tz = 2 * focal_length / (r * res * s) + + cx = 2 * (bbox_center[:, 0] - (img_w / 2.)) / (s * bbox_height) + cy = 2 * (bbox_center[:, 1] - (img_h / 2.)) / (s * bbox_height) + + if torch.is_tensor(pare_cam): + cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) + else: + cam_t = np.stack([tx + cx, ty + cy, tz], axis=-1) + + return cam_t + + +def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_size=(224., 224.)): + """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length, focal_length]) + # optical center + center = np.array([img_size[1] / 2., img_size[0] / 2.]) + + # transformations + Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1) + XY = np.reshape(S[:, 0:2], -1) + O = np.tile(center, num_joints) + F = np.tile(f, num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1) + + # least squares + Q = np.array( + [ + F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints), + O - np.reshape(joints_2d, -1) + ] + ).T + c = (np.reshape(joints_2d, -1) - O) * Z - F * XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W, Q) + c = np.dot(W, c) + + # square matrix + A = np.dot(Q.T, Q) + b = np.dot(Q.T, c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_all_kps=False): + """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + if isinstance(focal_length, numbers.Number): + focal_length = [ + focal_length, + ] * S.shape[0] + # print(len(focal_length), focal_length) + + if isinstance(img_size, numbers.Number): + img_size = [ + (img_size, img_size), + ] * S.shape[0] + # print(len(img_size), img_size) + + device = S.device + if use_all_kps: + S = S.cpu().numpy() + joints_2d = joints_2d.cpu().numpy() + else: + # Use only joints 25:49 (GT joints) + S = S[:, 25:, :].cpu().numpy() + joints_2d = joints_2d[:, 25:, :].cpu().numpy() + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float32) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np( + S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i] + ) + return torch.from_numpy(trans).to(device) + + +def Rot_y(angle, category='torch', prepend_dim=True, device=None): + '''Rotate around y-axis by angle + Args: + category: 'torch' or 'numpy' + prepend_dim: prepend an extra dimension + Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) + ''' + m = np.array( + [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0., + np.cos(angle)]] + ) + if category == 'torch': + if prepend_dim: + return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) + else: + return torch.tensor(m, dtype=torch.float, device=device) + elif category == 'numpy': + if prepend_dim: + return np.expand_dims(m, 0) + else: + return m + else: + raise ValueError("category must be 'torch' or 'numpy'") + + +def Rot_x(angle, category='torch', prepend_dim=True, device=None): + '''Rotate around x-axis by angle + Args: + category: 'torch' or 'numpy' + prepend_dim: prepend an extra dimension + Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) + ''' + m = np.array( + [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle), + np.cos(angle)]] + ) + if category == 'torch': + if prepend_dim: + return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) + else: + return torch.tensor(m, dtype=torch.float, device=device) + elif category == 'numpy': + if prepend_dim: + return np.expand_dims(m, 0) + else: + return m + else: + raise ValueError("category must be 'torch' or 'numpy'") + + +def Rot_z(angle, category='torch', prepend_dim=True, device=None): + '''Rotate around z-axis by angle + Args: + category: 'torch' or 'numpy' + prepend_dim: prepend an extra dimension + Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) + ''' + m = np.array( + [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]] + ) + if category == 'torch': + if prepend_dim: + return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) + else: + return torch.tensor(m, dtype=torch.float, device=device) + elif category == 'numpy': + if prepend_dim: + return np.expand_dims(m, 0) + else: + return m + else: + raise ValueError("category must be 'torch' or 'numpy'") + + +def compute_twist_rotation(rotation_matrix, twist_axis): + ''' + Compute the twist component of given rotation and twist axis + https://stackoverflow.com/questions/3684269/component-of-a-quaternion-rotation-around-an-axis + Parameters + ---------- + rotation_matrix : Tensor (B, 3, 3,) + The rotation to convert + twist_axis : Tensor (B, 3,) + The twist axis + Returns + ------- + Tensor (B, 3, 3) + The twist rotation + ''' + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + + twist_axis = twist_axis / (torch.norm(twist_axis, dim=1, keepdim=True) + 1e-9) + + projection = torch.einsum('bi,bi->b', twist_axis, quaternion[:, 1:]).unsqueeze(-1) * twist_axis + + twist_quaternion = torch.cat([quaternion[:, 0:1], projection], dim=1) + twist_quaternion = twist_quaternion / (torch.norm(twist_quaternion, dim=1, keepdim=True) + 1e-9) + + twist_rotation = quaternion_to_rotation_matrix(twist_quaternion) + twist_aa = quaternion_to_angle_axis(twist_quaternion) + + twist_angle = torch.sum(twist_aa, dim=1, + keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True) + + return twist_rotation, twist_angle diff --git a/lib/pymafx/utils/imutils.py b/lib/pymafx/utils/imutils.py new file mode 100644 index 0000000000000000000000000000000000000000..b3522fee118cf47c5101bfd8e16991e5c30f58ad --- /dev/null +++ b/lib/pymafx/utils/imutils.py @@ -0,0 +1,294 @@ +""" +This file contains functions that are used to perform data augmentation. +""" +import torch +import numpy as np +import cv2 +import skimage.transform +from PIL import Image + +from lib.pymafx.core import constants + + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + t = np.dot(get_rot_transf(res, rot), t) + return t + + +def get_rot_transf(res, rot): + """Generate rotation transformation matrix.""" + if rot == 0: + return np.identity(3) + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + rot_transf = np.dot(t_inv, np.dot(rot_mat, t_mat)) + return rot_transf + + +def transform(pt, center, scale, res, invert=0, rot=0): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2].astype(int) + 1 + + +def transform_pts(coords, center, scale, res, invert=0, rot=0): + """Transform coordinates (N x 2) to different reference.""" + new_coords = coords.copy() + for p in range(coords.shape[0]): + new_coords[p, 0:2] = transform(coords[p, 0:2], center, scale, res, invert, rot) + return new_coords + + +def crop(img, center, scale, res, rot=0): + """Crop image according to the supplied bounding box.""" + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 + # Bottom right point + br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1 + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] + + if not rot == 0: + # Remove padding + new_img = skimage.transform.rotate(new_img, rot).astype(np.uint8) + new_img = new_img[pad:-pad, pad:-pad] + + new_img_resized = np.array(Image.fromarray(new_img.astype(np.uint8)).resize(res)) + return new_img_resized, new_img, new_shape + + +def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): + """'Undo' the image cropping/resizing. + This function is used when evaluating mask/part segmentation. + """ + res = img.shape[:2] + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 + # Bottom right point + br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1 + # size of cropped image + crop_shape = [br[1] - ul[1], br[0] - ul[0]] + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(orig_shape, dtype=np.uint8) + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(orig_shape[1], br[0]) + old_y = max(0, ul[1]), min(orig_shape[0], br[1]) + img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape)) + new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] + return new_img + + +def rot_aa(aa, rot): + """Rotate axis angle parameters.""" + # pose parameters + R = np.array( + [ + [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1] + ] + ) + # find the rotation of the body in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg)) + aa = (resrot.T)[0] + return aa + + +def flip_img(img): + """Flip rgb images or masks. + channels come last, e.g. (256,256,3). + """ + img = np.fliplr(img) + return img + + +def flip_kp(kp, is_smpl=False, type='body'): + """Flip keypoints.""" + assert type in ['body', 'hand', 'face', 'feet'] + if type == 'body': + if len(kp) == 24: + if is_smpl: + flipped_parts = constants.SMPL_JOINTS_FLIP_PERM + else: + flipped_parts = constants.J24_FLIP_PERM + elif len(kp) == 49: + if is_smpl: + flipped_parts = constants.SMPL_J49_FLIP_PERM + else: + flipped_parts = constants.J49_FLIP_PERM + elif type == 'hand': + if len(kp) == 21: + flipped_parts = constants.SINGLE_HAND_FLIP_PERM + elif len(kp) == 42: + flipped_parts = constants.LRHAND_FLIP_PERM + elif type == 'face': + flipped_parts = constants.FACE_FLIP_PERM + elif type == 'feet': + flipped_parts = constants.FEEF_FLIP_PERM + + kp = kp[flipped_parts] + kp[:, 0] = -kp[:, 0] + return kp + + +def flip_pose(pose): + """Flip pose. + The flipping is based on SMPL parameters. + """ + flipped_parts = constants.SMPL_POSE_FLIP_PERM + pose = pose[flipped_parts] + # we also negate the second and the third dimension of the axis-angle + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + return pose + + +def flip_aa(pose): + """Flip aa. + """ + # we also negate the second and the third dimension of the axis-angle + if len(pose.shape) == 1: + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + elif len(pose.shape) == 2: + pose[:, 1::3] = -pose[:, 1::3] + pose[:, 2::3] = -pose[:, 2::3] + else: + raise NotImplementedError + return pose + + +def normalize_2d_kp(kp_2d, crop_size=224, inv=False): + # Normalize keypoints between -1, 1 + if not inv: + ratio = 1.0 / crop_size + kp_2d = 2.0 * kp_2d * ratio - 1.0 + else: + ratio = 1.0 / crop_size + kp_2d = (kp_2d + 1.0) / (2 * ratio) + + return kp_2d + + +def j2d_processing(kp, transf): + """Process gt 2D keypoints and apply transforms.""" + # nparts = kp.shape[1] + bs, npart = kp.shape[:2] + kp_pad = torch.cat([kp, torch.ones((bs, npart, 1)).to(kp)], dim=-1) + kp_new = torch.bmm(transf, kp_pad.transpose(1, 2)) + kp_new = kp_new.transpose(1, 2) + kp_new[:, :, :-1] = 2. * kp_new[:, :, :-1] / constants.IMG_RES - 1. + return kp_new[:, :, :2] + + +def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None): + ''' + param joints: [num_joints, 3] + param joints_vis: [num_joints, 3] + return: target, target_weight(1: visible, 0: invisible) + ''' + num_joints = joints.shape[0] + device = joints.device + cur_device = torch.device(device.type, device.index) + if not hasattr(heatmap_size, '__len__'): + # width height + heatmap_size = [heatmap_size, heatmap_size] + assert len(heatmap_size) == 2 + target_weight = np.ones((num_joints, 1), dtype=np.float32) + if joints_vis is not None: + target_weight[:, 0] = joints_vis[:, 0] + target = torch.zeros( + (num_joints, heatmap_size[1], heatmap_size[0]), dtype=torch.float32, device=cur_device + ) + + tmp_size = sigma * 3 + + for joint_id in range(num_joints): + mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5) + mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5) + # Check that any part of the gaussian is in-bounds + ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] + br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] + if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \ + or br[0] < 0 or br[1] < 0: + # If not, just return the image as is + target_weight[joint_id] = 0 + continue + + # # Generate gaussian + size = 2 * tmp_size + 1 + # x = np.arange(0, size, 1, np.float32) + # y = x[:, np.newaxis] + # x0 = y0 = size // 2 + # # The gaussian is not normalized, we want the center value to equal 1 + # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + # g = torch.from_numpy(g.astype(np.float32)) + + x = torch.arange(0, size, dtype=torch.float32, device=cur_device) + y = x.unsqueeze(-1) + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], heatmap_size[0]) + img_y = max(0, ul[1]), min(br[1], heatmap_size[1]) + + v = target_weight[joint_id] + if v > 0.5: + target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \ + g[g_y[0]:g_y[1], g_x[0]:g_x[1]] + + return target, target_weight diff --git a/lib/pymafx/utils/io.py b/lib/pymafx/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..0926624ddeb1eccf2e9c6393595acfd34a62e84d --- /dev/null +++ b/lib/pymafx/utils/io.py @@ -0,0 +1,145 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""IO utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from six.moves import cPickle as pickle +import hashlib +import logging +import os +import re +import sys +try: + from urllib.request import urlopen +except ImportError: #python2 + from urllib2 import urlopen + +logger = logging.getLogger(__name__) + +_DETECTRON_S3_BASE_URL = 'https://s3-us-west-2.amazonaws.com/detectron' + + +def save_object(obj, file_name): + """Save a Python object by pickling it.""" + file_name = os.path.abspath(file_name) + with open(file_name, 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + + +def cache_url(url_or_file, cache_dir): + """Download the file specified by the URL to the cache_dir and return the + path to the cached file. If the argument is not a URL, simply return it as + is. + """ + is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None + + if not is_url: + return url_or_file + + url = url_or_file + # assert url.startswith(_DETECTRON_S3_BASE_URL), \ + # ('Detectron only automatically caches URLs in the Detectron S3 ' + # 'bucket: {}').format(_DETECTRON_S3_BASE_URL) + # + # cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir) + Len_filename = len(url.split('/')[-1]) + BASE_URL = url[0:-Len_filename - 1] + # + cache_file_path = url.replace(BASE_URL, cache_dir) + if os.path.exists(cache_file_path): + # assert_cache_file_is_ok(url, cache_file_path) + return cache_file_path + + cache_file_dir = os.path.dirname(cache_file_path) + if not os.path.exists(cache_file_dir): + os.makedirs(cache_file_dir) + + logger.info('Downloading remote file {} to {}'.format(url, cache_file_path)) + download_url(url, cache_file_path) + # assert_cache_file_is_ok(url, cache_file_path) + return cache_file_path + + +def assert_cache_file_is_ok(url, file_path): + """Check that cache file has the correct hash.""" + # File is already in the cache, verify that the md5sum matches and + # return local path + cache_file_md5sum = _get_file_md5sum(file_path) + ref_md5sum = _get_reference_md5sum(url) + assert cache_file_md5sum == ref_md5sum, \ + ('Target URL {} appears to be downloaded to the local cache file ' + '{}, but the md5 hash of the local file does not match the ' + 'reference (actual: {} vs. expected: {}). You may wish to delete ' + 'the cached file and try again to trigger automatic ' + 'download.').format(url, file_path, cache_file_md5sum, ref_md5sum) + + +def _progress_bar(count, total): + """Report download progress. + Credit: + https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 + """ + bar_len = 60 + filled_len = int(round(bar_len * count / float(total))) + + percents = round(100.0 * count / float(total), 1) + bar = '=' * filled_len + '-' * (bar_len - filled_len) + + sys.stdout.write(' [{}] {}% of {:.1f}MB file \r'.format(bar, percents, total / 1024 / 1024)) + sys.stdout.flush() + if count >= total: + sys.stdout.write('\n') + + +def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): + """Download url and write it to dst_file_path. + Credit: + https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook + """ + response = urlopen(url) + total_size = response.info().getheader('Content-Length').strip() + total_size = int(total_size) + bytes_so_far = 0 + + with open(dst_file_path, 'wb') as f: + while 1: + chunk = response.read(chunk_size) + bytes_so_far += len(chunk) + if not chunk: + break + if progress_hook: + progress_hook(bytes_so_far, total_size) + f.write(chunk) + + return bytes_so_far + + +def _get_file_md5sum(file_name): + """Compute the md5 hash of a file.""" + hash_obj = hashlib.md5() + with open(file_name, 'r') as f: + hash_obj.update(f.read()) + return hash_obj.hexdigest() + + +def _get_reference_md5sum(url): + """By convention the md5 hash for url is stored in url + '.md5sum'.""" + url_md5sum = url + '.md5sum' + md5sum = urlopen(url_md5sum).read().strip() + return md5sum diff --git a/lib/pymafx/utils/iuvmap.py b/lib/pymafx/utils/iuvmap.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7c25398e04e30b2b244d44badc83415d583852 --- /dev/null +++ b/lib/pymafx/utils/iuvmap.py @@ -0,0 +1,295 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def iuvmap_clean(U_uv, V_uv, Index_UV, AnnIndex=None): + + Index_UV_max = torch.argmax(Index_UV, dim=1).float() + recon_Index_UV = [] + for i in range(Index_UV.size(1)): + if i == 0: + recon_Index_UV_i = torch.min( + F.threshold(Index_UV_max + 1, 0.5, 0), -F.threshold(-Index_UV_max - 1, -1.5, 0) + ) + else: + recon_Index_UV_i = torch.min( + F.threshold(Index_UV_max, i - 0.5, 0), -F.threshold(-Index_UV_max, -i - 0.5, 0) + ) / float(i) + recon_Index_UV.append(recon_Index_UV_i) + recon_Index_UV = torch.stack(recon_Index_UV, dim=1) + + if AnnIndex is None: + recon_Ann_Index = None + else: + AnnIndex_max = torch.argmax(AnnIndex, dim=1).float() + recon_Ann_Index = [] + for i in range(AnnIndex.size(1)): + if i == 0: + recon_Ann_Index_i = torch.min( + F.threshold(AnnIndex_max + 1, 0.5, 0), -F.threshold(-AnnIndex_max - 1, -1.5, 0) + ) + else: + recon_Ann_Index_i = torch.min( + F.threshold(AnnIndex_max, i - 0.5, 0), -F.threshold(-AnnIndex_max, -i - 0.5, 0) + ) / float(i) + recon_Ann_Index.append(recon_Ann_Index_i) + recon_Ann_Index = torch.stack(recon_Ann_Index, dim=1) + + recon_U = recon_Index_UV * U_uv + recon_V = recon_Index_UV * V_uv + + return recon_U, recon_V, recon_Index_UV, recon_Ann_Index + + +def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=None, n_part=24): + device_id = U_uv.get_device() + batch_size = U_uv.size(0) + K = U_uv.size(1) + heatmap_size = U_uv.size(2) + + Index_UV_max = torch.argmax(Index_UV, dim=1) + if AnnIndex is None: + Index_UV_max = Index_UV_max.to(torch.int64) + else: + AnnIndex_max = torch.argmax(AnnIndex, dim=1) + Index_UV_max = Index_UV_max * (AnnIndex_max > 0).to(torch.int64) + + outputs = [] + + for batch_id in range(batch_size): + + output = torch.zeros([3, U_uv.size(2), U_uv.size(3)], dtype=torch.float32).cuda(device_id) + output[0] = Index_UV_max[batch_id].to(torch.float32) + if ind_mapping is None: + output[0] /= float(K - 1) + else: + for ind in range(len(ind_mapping)): + output[0][output[0] == ind] = ind_mapping[ind] * (1. / n_part) + + for part_id in range(0, K): + CurrentU = U_uv[batch_id, part_id] + CurrentV = V_uv[batch_id, part_id] + output[1, + Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] + output[2, + Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] + + if uv_rois is None: + outputs.append(output.unsqueeze(0)) + else: + roi_fg = uv_rois[batch_id][1:] + + # x1 = roi_fg[0] + # x2 = roi_fg[2] + # y1 = roi_fg[1] + # y2 = roi_fg[3] + + w = roi_fg[2] - roi_fg[0] + h = roi_fg[3] - roi_fg[1] + + aspect_ratio = float(w) / h + + if aspect_ratio < 1: + new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)] + output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') + paddingleft = int(0.5 * (heatmap_size - new_size[1])) + output = F.pad( + output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0) + ) + else: + new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size] + output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') + paddingtop = int(0.5 * (heatmap_size - new_size[0])) + output = F.pad( + output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop) + ) + + outputs.append(output) + + return torch.cat(outputs, dim=0) + + +def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24): + device_id = uvimages.get_device() + batch_size = uvimages.size(0) + uvimg_size = uvimages.size(-1) + + Index2mask = [ + [0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18], + [19, 21], [20, 22], [23, 24] + ] + + part_ind = torch.round(uvimages[:, 0, :, :] * n_part) + part_u = uvimages[:, 1, :, :] + part_v = uvimages[:, 2, :, :] + + recon_U = [] + recon_V = [] + recon_Index_UV = [] + recon_Ann_Index = [] + + for i in range(n_part + 1): + if i == 0: + recon_Index_UV_i = torch.min( + F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0) + ) + else: + recon_Index_UV_i = torch.min( + F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0) + ) / float(i) + recon_U_i = recon_Index_UV_i * part_u + recon_V_i = recon_Index_UV_i * part_v + + recon_Index_UV.append(recon_Index_UV_i) + recon_U.append(recon_U_i) + recon_V.append(recon_V_i) + + for i in range(len(Index2mask)): + if len(Index2mask[i]) == 1: + recon_Ann_Index_i = recon_Index_UV[Index2mask[i][0]] + elif len(Index2mask[i]) == 2: + p_ind0 = Index2mask[i][0] + p_ind1 = Index2mask[i][1] + # recon_Ann_Index[:, i, :, :] = torch.where(recon_Index_UV[:, p_ind0, :, :] > 0.5, recon_Index_UV[:, p_ind0, :, :], recon_Index_UV[:, p_ind1, :, :]) + # recon_Ann_Index[:, i, :, :] = torch.eq(part_ind, p_ind0) | torch.eq(part_ind, p_ind1) + recon_Ann_Index_i = recon_Index_UV[p_ind0] + recon_Index_UV[p_ind1] + + recon_Ann_Index.append(recon_Ann_Index_i) + + recon_U = torch.stack(recon_U, dim=1) + recon_V = torch.stack(recon_V, dim=1) + recon_Index_UV = torch.stack(recon_Index_UV, dim=1) + recon_Ann_Index = torch.stack(recon_Ann_Index, dim=1) + + if uv_rois is None: + return recon_U, recon_V, recon_Index_UV, recon_Ann_Index + + recon_U_roi = [] + recon_V_roi = [] + recon_Index_UV_roi = [] + recon_Ann_Index_roi = [] + + if new_size is None: + M = uvimg_size + else: + M = new_size + + for i in range(batch_size): + roi_fg = uv_rois[i][1:] + + # x1 = roi_fg[0] + # x2 = roi_fg[2] + # y1 = roi_fg[1] + # y2 = roi_fg[3] + + w = roi_fg[2] - roi_fg[0] + h = roi_fg[3] - roi_fg[1] + + aspect_ratio = float(w) / h + + if aspect_ratio < 1: + w_size = max(int(uvimg_size * aspect_ratio), 1) + w_margin = int((uvimg_size - w_size) / 2) + + recon_U_roi_i = recon_U[i, :, :, w_margin:w_margin + w_size] + recon_V_roi_i = recon_V[i, :, :, w_margin:w_margin + w_size] + recon_Index_UV_roi_i = recon_Index_UV[i, :, :, w_margin:w_margin + w_size] + recon_Ann_Index_roi_i = recon_Ann_Index[i, :, :, w_margin:w_margin + w_size] + else: + h_size = max(int(uvimg_size / aspect_ratio), 1) + h_margin = int((uvimg_size - h_size) / 2) + + recon_U_roi_i = recon_U[i, :, h_margin:h_margin + h_size, :] + recon_V_roi_i = recon_V[i, :, h_margin:h_margin + h_size, :] + recon_Index_UV_roi_i = recon_Index_UV[i, :, h_margin:h_margin + h_size, :] + recon_Ann_Index_roi_i = recon_Ann_Index[i, :, h_margin:h_margin + h_size, :] + + recon_U_roi_i = F.interpolate(recon_U_roi_i.unsqueeze(0), size=(M, M), mode='nearest') + recon_V_roi_i = F.interpolate(recon_V_roi_i.unsqueeze(0), size=(M, M), mode='nearest') + recon_Index_UV_roi_i = F.interpolate( + recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest' + ) + recon_Ann_Index_roi_i = F.interpolate( + recon_Ann_Index_roi_i.unsqueeze(0), size=(M, M), mode='nearest' + ) + + recon_U_roi.append(recon_U_roi_i) + recon_V_roi.append(recon_V_roi_i) + recon_Index_UV_roi.append(recon_Index_UV_roi_i) + recon_Ann_Index_roi.append(recon_Ann_Index_roi_i) + + recon_U_roi = torch.cat(recon_U_roi, dim=0) + recon_V_roi = torch.cat(recon_V_roi, dim=0) + recon_Index_UV_roi = torch.cat(recon_Index_UV_roi, dim=0) + recon_Ann_Index_roi = torch.cat(recon_Ann_Index_roi, dim=0) + + return recon_U_roi, recon_V_roi, recon_Index_UV_roi, recon_Ann_Index_roi + + +def seg_img2map(segimages, uv_rois=None, new_size=None, n_part=24): + device_id = segimages.get_device() + batch_size = segimages.size(0) + uvimg_size = segimages.size(-1) + + part_ind = torch.round(segimages[:, 0, :, :] * n_part) + + recon_Index_UV = [] + + for i in range(n_part + 1): + if i == 0: + recon_Index_UV_i = torch.min( + F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0) + ) + else: + recon_Index_UV_i = torch.min( + F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0) + ) / float(i) + + recon_Index_UV.append(recon_Index_UV_i) + + recon_Index_UV = torch.stack(recon_Index_UV, dim=1) + + if uv_rois is None: + return None, None, recon_Index_UV, None + + recon_Index_UV_roi = [] + + if new_size is None: + M = uvimg_size + else: + M = new_size + + for i in range(batch_size): + roi_fg = uv_rois[i][1:] + + # x1 = roi_fg[0] + # x2 = roi_fg[2] + # y1 = roi_fg[1] + # y2 = roi_fg[3] + + w = roi_fg[2] - roi_fg[0] + h = roi_fg[3] - roi_fg[1] + + aspect_ratio = float(w) / h + + if aspect_ratio < 1: + w_size = max(int(uvimg_size * aspect_ratio), 1) + w_margin = int((uvimg_size - w_size) / 2) + + recon_Index_UV_roi_i = recon_Index_UV[i, :, :, w_margin:w_margin + w_size] + else: + h_size = max(int(uvimg_size / aspect_ratio), 1) + h_margin = int((uvimg_size - h_size) / 2) + + recon_Index_UV_roi_i = recon_Index_UV[i, :, h_margin:h_margin + h_size, :] + + recon_Index_UV_roi_i = F.interpolate( + recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest' + ) + + recon_Index_UV_roi.append(recon_Index_UV_roi_i) + + recon_Index_UV_roi = torch.cat(recon_Index_UV_roi, dim=0) + + return None, None, recon_Index_UV_roi, None diff --git a/lib/pymafx/utils/keypoints.py b/lib/pymafx/utils/keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab223c2bef79518adc523da1606cfc331ef8251 --- /dev/null +++ b/lib/pymafx/utils/keypoints.py @@ -0,0 +1,358 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""Keypoint utilities (somewhat specific to COCO keypoints).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torch.cuda.comm + +# from core.config import cfg +# import utils.blob as blob_utils + + +def get_keypoints(): + """Get the COCO keypoints and their left/right flip coorespondence map.""" + # Keypoints are not available in the COCO json for the test split, so we + # provide them here. + keypoints = [ + 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', + 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', + 'left_knee', 'right_knee', 'left_ankle', 'right_ankle' + ] + keypoint_flip_map = { + 'left_eye': 'right_eye', + 'left_ear': 'right_ear', + 'left_shoulder': 'right_shoulder', + 'left_elbow': 'right_elbow', + 'left_wrist': 'right_wrist', + 'left_hip': 'right_hip', + 'left_knee': 'right_knee', + 'left_ankle': 'right_ankle' + } + return keypoints, keypoint_flip_map + + +def get_person_class_index(): + """Index of the person class in COCO.""" + return 1 + + +def flip_keypoints(keypoints, keypoint_flip_map, keypoint_coords, width): + """Left/right flip keypoint_coords. keypoints and keypoint_flip_map are + accessible from get_keypoints(). + """ + flipped_kps = keypoint_coords.copy() + for lkp, rkp in keypoint_flip_map.items(): + lid = keypoints.index(lkp) + rid = keypoints.index(rkp) + flipped_kps[:, :, lid] = keypoint_coords[:, :, rid] + flipped_kps[:, :, rid] = keypoint_coords[:, :, lid] + + # Flip x coordinates + flipped_kps[:, 0, :] = width - flipped_kps[:, 0, :] - 1 + # Maintain COCO convention that if visibility == 0, then x, y = 0 + inds = np.where(flipped_kps[:, 2, :] == 0) + flipped_kps[inds[0], 0, inds[1]] = 0 + return flipped_kps + + +def flip_heatmaps(heatmaps): + """Flip heatmaps horizontally.""" + keypoints, flip_map = get_keypoints() + heatmaps_flipped = heatmaps.copy() + for lkp, rkp in flip_map.items(): + lid = keypoints.index(lkp) + rid = keypoints.index(rkp) + heatmaps_flipped[:, rid, :, :] = heatmaps[:, lid, :, :] + heatmaps_flipped[:, lid, :, :] = heatmaps[:, rid, :, :] + heatmaps_flipped = heatmaps_flipped[:, :, :, ::-1] + return heatmaps_flipped + + +def heatmaps_to_keypoints(maps, rois): + """Extract predicted keypoint locations from heatmaps. Output has shape + (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) + for each keypoint. + """ + # This function converts a discrete image coordinate in a HEATMAP_SIZE x + # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain + # consistency with keypoints_to_heatmap_labels by using the conversion from + # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a + # continuous coordinate. + offset_x = rois[:, 0] + offset_y = rois[:, 1] + + widths = rois[:, 2] - rois[:, 0] + heights = rois[:, 3] - rois[:, 1] + widths = np.maximum(widths, 1) + heights = np.maximum(heights, 1) + widths_ceil = np.ceil(widths) + heights_ceil = np.ceil(heights) + + # NCHW to NHWC for use with OpenCV + maps = np.transpose(maps, [0, 2, 3, 1]) + min_size = cfg.KRCNN.INFERENCE_MIN_SIZE + xy_preds = np.zeros((len(rois), 4, cfg.KRCNN.NUM_KEYPOINTS), dtype=np.float32) + for i in range(len(rois)): + if min_size > 0: + roi_map_width = int(np.maximum(widths_ceil[i], min_size)) + roi_map_height = int(np.maximum(heights_ceil[i], min_size)) + else: + roi_map_width = widths_ceil[i] + roi_map_height = heights_ceil[i] + width_correction = widths[i] / roi_map_width + height_correction = heights[i] / roi_map_height + roi_map = cv2.resize( + maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC + ) + # Bring back to CHW + roi_map = np.transpose(roi_map, [2, 0, 1]) + roi_map_probs = scores_to_probs(roi_map.copy()) + w = roi_map.shape[2] + for k in range(cfg.KRCNN.NUM_KEYPOINTS): + pos = roi_map[k, :, :].argmax() + x_int = pos % w + y_int = (pos - x_int) // w + assert (roi_map_probs[k, y_int, x_int] == roi_map_probs[k, :, :].max()) + x = (x_int + 0.5) * width_correction + y = (y_int + 0.5) * height_correction + xy_preds[i, 0, k] = x + offset_x[i] + xy_preds[i, 1, k] = y + offset_y[i] + xy_preds[i, 2, k] = roi_map[k, y_int, x_int] + xy_preds[i, 3, k] = roi_map_probs[k, y_int, x_int] + + return xy_preds + + +def keypoints_to_heatmap_labels(keypoints, rois): + """Encode keypoint location in the target heatmap for use in + SoftmaxWithLoss. + """ + # Maps keypoints from the half-open interval [x1, x2) on continuous image + # coordinates to the closed interval [0, HEATMAP_SIZE - 1] on discrete image + # coordinates. We use the continuous <-> discrete conversion from Heckbert + # 1990 ("What is the coordinate of a pixel?"): d = floor(c) and c = d + 0.5, + # where d is a discrete coordinate and c is a continuous coordinate. + assert keypoints.shape[2] == cfg.KRCNN.NUM_KEYPOINTS + + shape = (len(rois), cfg.KRCNN.NUM_KEYPOINTS) + heatmaps = blob_utils.zeros(shape) + weights = blob_utils.zeros(shape) + + offset_x = rois[:, 0] + offset_y = rois[:, 1] + scale_x = cfg.KRCNN.HEATMAP_SIZE / (rois[:, 2] - rois[:, 0]) + scale_y = cfg.KRCNN.HEATMAP_SIZE / (rois[:, 3] - rois[:, 1]) + + for kp in range(keypoints.shape[2]): + vis = keypoints[:, 2, kp] > 0 + x = keypoints[:, 0, kp].astype(np.float32) + y = keypoints[:, 1, kp].astype(np.float32) + # Since we use floor below, if a keypoint is exactly on the roi's right + # or bottom boundary, we shift it in by eps (conceptually) to keep it in + # the ground truth heatmap. + x_boundary_inds = np.where(x == rois[:, 2])[0] + y_boundary_inds = np.where(y == rois[:, 3])[0] + x = (x - offset_x) * scale_x + x = np.floor(x) + if len(x_boundary_inds) > 0: + x[x_boundary_inds] = cfg.KRCNN.HEATMAP_SIZE - 1 + + y = (y - offset_y) * scale_y + y = np.floor(y) + if len(y_boundary_inds) > 0: + y[y_boundary_inds] = cfg.KRCNN.HEATMAP_SIZE - 1 + + valid_loc = np.logical_and( + np.logical_and(x >= 0, y >= 0), + np.logical_and(x < cfg.KRCNN.HEATMAP_SIZE, y < cfg.KRCNN.HEATMAP_SIZE) + ) + + valid = np.logical_and(valid_loc, vis) + valid = valid.astype(np.int32) + + lin_ind = y * cfg.KRCNN.HEATMAP_SIZE + x + heatmaps[:, kp] = lin_ind * valid + weights[:, kp] = valid + + return heatmaps, weights + + +def scores_to_probs(scores): + """Transforms CxHxW of scores to probabilities spatially.""" + channels = scores.shape[0] + for c in range(channels): + temp = scores[c, :, :] + max_score = temp.max() + temp = np.exp(temp - max_score) / np.sum(np.exp(temp - max_score)) + scores[c, :, :] = temp + return scores + + +def nms_oks(kp_predictions, rois, thresh): + """Nms based on kp predictions.""" + scores = np.mean(kp_predictions[:, 2, :], axis=1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = compute_oks(kp_predictions[i], rois[i], kp_predictions[order[1:]], rois[order[1:]]) + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def compute_oks(src_keypoints, src_roi, dst_keypoints, dst_roi): + """Compute OKS for predicted keypoints wrt gt_keypoints. + src_keypoints: 4xK + src_roi: 4x1 + dst_keypoints: Nx4xK + dst_roi: Nx4 + """ + + sigmas = np.array( + [.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89] + ) / 10.0 + vars = (sigmas * 2)**2 + + # area + src_area = (src_roi[2] - src_roi[0] + 1) * (src_roi[3] - src_roi[1] + 1) + + # measure the per-keypoint distance if keypoints visible + dx = dst_keypoints[:, 0, :] - src_keypoints[0, :] + dy = dst_keypoints[:, 1, :] - src_keypoints[1, :] + + e = (dx**2 + dy**2) / vars / (src_area + np.spacing(1)) / 2 + e = np.sum(np.exp(-e), axis=1) / e.shape[1] + + return e + + +def get_max_preds(batch_heatmaps): + ''' + get predictions from score maps + heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) + ''' + assert isinstance(batch_heatmaps, np.ndarray), \ + 'batch_heatmaps should be numpy.ndarray' + assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim' + + batch_size = batch_heatmaps.shape[0] + num_joints = batch_heatmaps.shape[1] + width = batch_heatmaps.shape[3] + heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) + idx = np.argmax(heatmaps_reshaped, 2) + maxvals = np.amax(heatmaps_reshaped, 2) + + maxvals = maxvals.reshape((batch_size, num_joints, 1)) + idx = idx.reshape((batch_size, num_joints, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + + preds[:, :, 0] = (preds[:, :, 0]) % width + preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) + + pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) + pred_mask = pred_mask.astype(np.float32) + + preds *= pred_mask + return preds, maxvals + + +def generate_3d_integral_preds_tensor(heatmaps, num_joints, x_dim, y_dim, z_dim): + assert isinstance(heatmaps, torch.Tensor) + + if z_dim is not None: + heatmaps = heatmaps.reshape((heatmaps.shape[0], num_joints, z_dim, y_dim, x_dim)) + + accu_x = heatmaps.sum(dim=2) + accu_x = accu_x.sum(dim=2) + accu_y = heatmaps.sum(dim=2) + accu_y = accu_y.sum(dim=3) + accu_z = heatmaps.sum(dim=3) + accu_z = accu_z.sum(dim=3) + + accu_x = accu_x * torch.cuda.comm.broadcast( + torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index] + )[0] + accu_y = accu_y * torch.cuda.comm.broadcast( + torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index] + )[0] + accu_z = accu_z * torch.cuda.comm.broadcast( + torch.arange(z_dim, dtype=torch.float32), devices=[accu_z.device.index] + )[0] + + accu_x = accu_x.sum(dim=2, keepdim=True) + accu_y = accu_y.sum(dim=2, keepdim=True) + accu_z = accu_z.sum(dim=2, keepdim=True) + else: + heatmaps = heatmaps.reshape((heatmaps.shape[0], num_joints, y_dim, x_dim)) + + accu_x = heatmaps.sum(dim=2) + accu_y = heatmaps.sum(dim=3) + + accu_x = accu_x * torch.cuda.comm.broadcast( + torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index] + )[0] + accu_y = accu_y * torch.cuda.comm.broadcast( + torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index] + )[0] + + accu_x = accu_x.sum(dim=2, keepdim=True) + accu_y = accu_y.sum(dim=2, keepdim=True) + accu_z = None + + return accu_x, accu_y, accu_z + + +# integral pose estimation +# https://github.com/JimmySuen/integral-human-pose/blob/99647e40ec93dfa4e3b6a1382c935cebb35440da/pytorch_projects/common_pytorch/common_loss/integral.py#L28 +def softmax_integral_tensor(preds, num_joints, hm_width, hm_height, hm_depth=None): + # global soft max + preds = preds.reshape((preds.shape[0], num_joints, -1)) + preds = F.softmax(preds, 2) + + output_3d = False if hm_depth is None else True + + # integrate heatmap into joint location + if output_3d: + x, y, z = generate_3d_integral_preds_tensor( + preds, num_joints, hm_width, hm_height, hm_depth + ) + # x = x / float(hm_width) - 0.5 + # y = y / float(hm_height) - 0.5 + # z = z / float(hm_depth) - 0.5 + preds = torch.cat((x, y, z), dim=2) + # preds = preds.reshape((preds.shape[0], num_joints * 3)) + else: + x, y, _ = generate_3d_integral_preds_tensor( + preds, num_joints, hm_width, hm_height, z_dim=None + ) + # x = x / float(hm_width) - 0.5 + # y = y / float(hm_height) - 0.5 + preds = torch.cat((x, y), dim=2) + # preds = preds.reshape((preds.shape[0], num_joints * 2)) + + return preds diff --git a/lib/pymafx/utils/mesh_generation.py b/lib/pymafx/utils/mesh_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..2876209e7678d2906a84850208f6c288103d07c5 --- /dev/null +++ b/lib/pymafx/utils/mesh_generation.py @@ -0,0 +1,409 @@ +import time +import torch +import trimesh +import numpy as np +import torch.optim as optim +from torch import autograd +from torch.utils.data import TensorDataset, DataLoader + +from .common import make_3d_grid +from .utils import libmcubes +from .utils.libmise import MISE +from .utils.libsimplify import simplify_mesh +from .common import transform_pointcloud + + +class Generator3D(object): + ''' Generator class for DVRs. + + It provides functions to generate the final mesh as well refining options. + + Args: + model (nn.Module): trained DVR model + points_batch_size (int): batch size for points evaluation + threshold (float): threshold value + refinement_step (int): number of refinement steps + device (device): pytorch device + resolution0 (int): start resolution for MISE + upsampling steps (int): number of upsampling steps + with_normals (bool): whether normals should be estimated + padding (float): how much padding should be used for MISE + simplify_nfaces (int): number of faces the mesh should be simplified to + refine_max_faces (int): max number of faces which are used as batch + size for refinement process (we added this functionality in this + work) + ''' + def __init__( + self, + model, + points_batch_size=100000, + threshold=0.5, + refinement_step=0, + device=None, + resolution0=16, + upsampling_steps=3, + with_normals=False, + padding=0.1, + simplify_nfaces=None, + with_color=False, + refine_max_faces=10000 + ): + self.model = model.to(device) + self.points_batch_size = points_batch_size + self.refinement_step = refinement_step + self.threshold = threshold + self.device = device + self.resolution0 = resolution0 + self.upsampling_steps = upsampling_steps + self.with_normals = with_normals + self.padding = padding + self.simplify_nfaces = simplify_nfaces + self.with_color = with_color + self.refine_max_faces = refine_max_faces + + def generate_mesh(self, data, return_stats=True): + ''' Generates the output mesh. + + Args: + data (tensor): data tensor + return_stats (bool): whether stats should be returned + ''' + self.model.eval() + device = self.device + stats_dict = {} + + inputs = data.get('inputs', torch.empty(1, 0)).to(device) + kwargs = {} + + c = self.model.encode_inputs(inputs) + mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs) + + return mesh, stats_dict + + def generate_meshes(self, data, return_stats=True): + ''' Generates the output meshes with data of batch size >=1 + + Args: + data (tensor): data tensor + return_stats (bool): whether stats should be returned + ''' + self.model.eval() + device = self.device + stats_dict = {} + + inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device) + + meshes = [] + for i in range(inputs.shape[0]): + input_i = inputs[i].unsqueeze(0) + c = self.model.encode_inputs(input_i) + mesh = self.generate_from_latent(c, stats_dict=stats_dict) + meshes.append(mesh) + + return meshes + + def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True): + ''' Generates a point cloud from the mesh. + + Args: + mesh (trimesh): mesh + data (dict): data dictionary + n_points (int): number of point cloud points + scale_back (bool): whether to undo scaling (requires a scale + matrix in data dictionary) + ''' + pcl = mesh.sample(n_points).astype(np.float32) + + if scale_back: + scale_mat = data.get('camera.scale_mat_0', None) + if scale_mat is not None: + pcl = transform_pointcloud(pcl, scale_mat[0]) + else: + print('Warning: No scale_mat found!') + pcl_out = trimesh.Trimesh(vertices=pcl, process=False) + return pcl_out + + def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs): + ''' Generates mesh from latent. + + Args: + c (tensor): latent conditioned code c + pl (tensor): predicted plane parameters + stats_dict (dict): stats dictionary + ''' + threshold = np.log(self.threshold) - np.log(1. - self.threshold) + + t0 = time.time() + # Compute bounding box size + box_size = 1 + self.padding + + # Shortcut + if self.upsampling_steps == 0: + nx = self.resolution0 + pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3) + values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() + value_grid = values.reshape(nx, nx, nx) + else: + mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold) + + points = mesh_extractor.query() + + while points.shape[0] != 0: + # Query points + pointsf = torch.FloatTensor(points).to(self.device) + # Normalize to bounding box + pointsf = 2 * pointsf / mesh_extractor.resolution + pointsf = box_size * (pointsf - 1.0) + # Evaluate model and update + values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() + + values = values.astype(np.float64) + mesh_extractor.update(points, values) + points = mesh_extractor.query() + + value_grid = mesh_extractor.to_dense() + + # Extract mesh + stats_dict['time (eval points)'] = time.time() - t0 + + mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict) + return mesh + + def eval_points(self, p, c=None, pl=None, **kwargs): + ''' Evaluates the occupancy values for the points. + + Args: + p (tensor): points + c (tensor): latent conditioned code c + ''' + p_split = torch.split(p, self.points_batch_size) + occ_hats = [] + + for pi in p_split: + pi = pi.unsqueeze(0).to(self.device) + with torch.no_grad(): + occ_hat = self.model.decode(pi, c, pl, **kwargs).logits + + occ_hats.append(occ_hat.squeeze(0).detach().cpu()) + + occ_hat = torch.cat(occ_hats, dim=0) + + return occ_hat + + def extract_mesh(self, occ_hat, c=None, stats_dict=dict()): + ''' Extracts the mesh from the predicted occupancy grid. + + Args: + occ_hat (tensor): value grid of occupancies + c (tensor): latent conditioned code c + stats_dict (dict): stats dictionary + ''' + # Some short hands + n_x, n_y, n_z = occ_hat.shape + box_size = 1 + self.padding + threshold = np.log(self.threshold) - np.log(1. - self.threshold) + # Make sure that mesh is watertight + t0 = time.time() + occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6) + vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold) + stats_dict['time (marching cubes)'] = time.time() - t0 + # Strange behaviour in libmcubes: vertices are shifted by 0.5 + vertices -= 0.5 + # Undo padding + vertices -= 1 + # Normalize to bounding box + vertices /= np.array([n_x - 1, n_y - 1, n_z - 1]) + vertices *= 2 + vertices = box_size * (vertices - 1) + + # mesh_pymesh = pymesh.form_mesh(vertices, triangles) + # mesh_pymesh = fix_pymesh(mesh_pymesh) + + # Estimate normals if needed + if self.with_normals and not vertices.shape[0] == 0: + t0 = time.time() + normals = self.estimate_normals(vertices, c) + stats_dict['time (normals)'] = time.time() - t0 + else: + normals = None + # Create mesh + mesh = trimesh.Trimesh( + vertices, + triangles, + vertex_normals=normals, + # vertex_colors=vertex_colors, + process=False + ) + + # Directly return if mesh is empty + if vertices.shape[0] == 0: + return mesh + + # TODO: normals are lost here + if self.simplify_nfaces is not None: + t0 = time.time() + mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.) + stats_dict['time (simplify)'] = time.time() - t0 + + # Refine mesh + if self.refinement_step > 0: + t0 = time.time() + self.refine_mesh(mesh, occ_hat, c) + stats_dict['time (refine)'] = time.time() - t0 + + # Estimate Vertex Colors + if self.with_color and not vertices.shape[0] == 0: + t0 = time.time() + vertex_colors = self.estimate_colors(np.array(mesh.vertices), c) + stats_dict['time (color)'] = time.time() - t0 + mesh = trimesh.Trimesh( + vertices=mesh.vertices, + faces=mesh.faces, + vertex_normals=mesh.vertex_normals, + vertex_colors=vertex_colors, + process=False + ) + + return mesh + + def estimate_colors(self, vertices, c=None): + ''' Estimates vertex colors by evaluating the texture field. + + Args: + vertices (numpy array): vertices of the mesh + c (tensor): latent conditioned code c + ''' + device = self.device + vertices = torch.FloatTensor(vertices) + vertices_split = torch.split(vertices, self.points_batch_size) + colors = [] + for vi in vertices_split: + vi = vi.to(device) + with torch.no_grad(): + ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu() + colors.append(ci) + + colors = np.concatenate(colors, axis=0) + colors = np.clip(colors, 0, 1) + colors = (colors * 255).astype(np.uint8) + colors = np.concatenate( + [colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1 + ) + return colors + + def estimate_normals(self, vertices, c=None): + ''' Estimates the normals by computing the gradient of the objective. + + Args: + vertices (numpy array): vertices of the mesh + z (tensor): latent code z + c (tensor): latent conditioned code c + ''' + device = self.device + vertices = torch.FloatTensor(vertices) + vertices_split = torch.split(vertices, self.points_batch_size) + + normals = [] + c = c.unsqueeze(0) + for vi in vertices_split: + vi = vi.unsqueeze(0).to(device) + vi.requires_grad_() + occ_hat = self.model.decode(vi, c).logits + out = occ_hat.sum() + out.backward() + ni = -vi.grad + ni = ni / torch.norm(ni, dim=-1, keepdim=True) + ni = ni.squeeze(0).cpu().numpy() + normals.append(ni) + + normals = np.concatenate(normals, axis=0) + return normals + + def refine_mesh(self, mesh, occ_hat, c=None): + ''' Refines the predicted mesh. + + Args: + mesh (trimesh object): predicted mesh + occ_hat (tensor): predicted occupancy grid + c (tensor): latent conditioned code c + ''' + + self.model.eval() + + # Some shorthands + n_x, n_y, n_z = occ_hat.shape + assert (n_x == n_y == n_z) + # threshold = np.log(self.threshold) - np.log(1. - self.threshold) + threshold = self.threshold + + # Vertex parameter + v0 = torch.FloatTensor(mesh.vertices).to(self.device) + v = torch.nn.Parameter(v0.clone()) + + # Faces of mesh + faces = torch.LongTensor(mesh.faces) + + # detach c; otherwise graph needs to be retained + # caused by new Pytorch version? + c = c.detach() + + # Start optimization + optimizer = optim.RMSprop([v], lr=1e-5) + + # Dataset + ds_faces = TensorDataset(faces) + dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True) + + # We updated the refinement algorithm to subsample faces; this is + # usefull when using a high extraction resolution / when working on + # small GPUs + it_r = 0 + while it_r < self.refinement_step: + for f_it in dataloader: + f_it = f_it[0].to(self.device) + optimizer.zero_grad() + + # Loss + face_vertex = v[f_it] + eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0]) + eps = torch.FloatTensor(eps).to(self.device) + face_point = (face_vertex * eps[:, :, None]).sum(dim=1) + + face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :] + face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :] + face_normal = torch.cross(face_v1, face_v2) + face_normal = face_normal / \ + (face_normal.norm(dim=1, keepdim=True) + 1e-10) + + face_value = torch.cat( + [ + torch.sigmoid(self.model.decode(p_split, c).logits) + for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1) + ], + dim=1 + ) + + normal_target = -autograd.grad([face_value.sum()], [face_point], + create_graph=True)[0] + + normal_target = \ + normal_target / \ + (normal_target.norm(dim=1, keepdim=True) + 1e-10) + loss_target = (face_value - threshold).pow(2).mean() + loss_normal = \ + (face_normal - normal_target).pow(2).sum(dim=1).mean() + + loss = loss_target + 0.01 * loss_normal + + # Update + loss.backward() + optimizer.step() + + # Update it_r + it_r += 1 + + if it_r >= self.refinement_step: + break + + mesh.vertices = v.data.cpu().numpy() + return mesh diff --git a/lib/pymafx/utils/part_utils.py b/lib/pymafx/utils/part_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12f0de443fa11e90674761816a644cf82a48a786 --- /dev/null +++ b/lib/pymafx/utils/part_utils.py @@ -0,0 +1,71 @@ +import torch +import numpy as np +import neural_renderer as nr +from core import path_config + +from models import SMPL + + +class PartRenderer(): + """Renderer used to render segmentation masks and part segmentations. + Internally it uses the Neural 3D Mesh Renderer + """ + def __init__(self, focal_length=5000., render_res=224): + # Parameters for rendering + self.focal_length = focal_length + self.render_res = render_res + # We use Neural 3D mesh renderer for rendering masks and part segmentations + self.neural_renderer = nr.Renderer( + dist_coeffs=None, + orig_size=self.render_res, + image_size=render_res, + light_intensity_ambient=1, + light_intensity_directional=0, + anti_aliasing=False + ) + self.faces = torch.from_numpy(SMPL(path_config.SMPL_MODEL_DIR).faces.astype(np.int32) + ).cuda() + textures = np.load(path_config.VERTEX_TEXTURE_FILE) + self.textures = torch.from_numpy(textures).cuda().float() + self.cube_parts = torch.cuda.FloatTensor(np.load(path_config.CUBE_PARTS_FILE)) + + def get_parts(self, parts, mask): + """Process renderer part image to get body part indices.""" + bn, c, h, w = parts.shape + mask = mask.view(-1, 1) + parts_index = torch.floor(100 * parts.permute(0, 2, 3, 1).contiguous().view(-1, 3)).long() + parts = self.cube_parts[parts_index[:, 0], parts_index[:, 1], parts_index[:, 2], None] + parts *= mask + parts = parts.view(bn, h, w).long() + return parts + + def __call__(self, vertices, camera): + """Wrapper function for rendering process.""" + # Estimate camera parameters given a fixed focal length + cam_t = torch.stack( + [ + camera[:, 1], camera[:, 2], 2 * self.focal_length / + (self.render_res * camera[:, 0] + 1e-9) + ], + dim=-1 + ) + batch_size = vertices.shape[0] + K = torch.eye(3, device=vertices.device) + K[0, 0] = self.focal_length + K[1, 1] = self.focal_length + K[2, 2] = 1 + K[0, 2] = self.render_res / 2. + K[1, 2] = self.render_res / 2. + K = K[None, :, :].expand(batch_size, -1, -1) + R = torch.eye(3, device=vertices.device)[None, :, :].expand(batch_size, -1, -1) + faces = self.faces[None, :, :].expand(batch_size, -1, -1) + parts, _, mask = self.neural_renderer( + vertices, + faces, + textures=self.textures.expand(batch_size, -1, -1, -1, -1, -1), + K=K, + R=R, + t=cam_t.unsqueeze(1) + ) + parts = self.get_parts(parts, mask) + return mask, parts diff --git a/lib/pymafx/utils/pose_tracker.py b/lib/pymafx/utils/pose_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..92c383cdb3dba6053a0595b9f03305c02e9fc277 --- /dev/null +++ b/lib/pymafx/utils/pose_tracker.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import os +import json +import shutil +import subprocess +import numpy as np +import os.path as osp + + +def run_openpose( + video_file, + output_folder, + staf_folder, + vis=False, +): + pwd = os.getcwd() + + os.chdir(staf_folder) + + render = 1 if vis else 0 + display = 2 if vis else 0 + cmd = [ + 'build/examples/openpose/openpose.bin', '--model_pose', 'BODY_21A', '--tracking', '1', + '--render_pose', + str(render), '--video', video_file, '--write_json', output_folder, '--display', + str(display) + ] + + print('Executing', ' '.join(cmd)) + subprocess.call(cmd) + os.chdir(pwd) + + +def read_posetrack_keypoints(output_folder): + + people = dict() + + for idx, result_file in enumerate(sorted(os.listdir(output_folder))): + json_file = osp.join(output_folder, result_file) + data = json.load(open(json_file)) + # print(idx, data) + for person in data['people']: + person_id = person['person_id'][0] + joints2d = person['pose_keypoints_2d'] + if person_id in people.keys(): + people[person_id]['joints2d'].append(joints2d) + people[person_id]['frames'].append(idx) + else: + people[person_id] = { + 'joints2d': [], + 'frames': [], + } + people[person_id]['joints2d'].append(joints2d) + people[person_id]['frames'].append(idx) + + for k in people.keys(): + people[k]['joints2d'] = np.array(people[k]['joints2d']).reshape( + (len(people[k]['joints2d']), -1, 3) + ) + people[k]['frames'] = np.array(people[k]['frames']) + + return people + + +def run_posetracker(video_file, staf_folder, posetrack_output_folder='/tmp', display=False): + posetrack_output_folder = os.path.join( + posetrack_output_folder, f'{os.path.basename(video_file)}_posetrack' + ) + + # run posetrack on video + run_openpose(video_file, posetrack_output_folder, vis=display, staf_folder=staf_folder) + + people_dict = read_posetrack_keypoints(posetrack_output_folder) + + shutil.rmtree(posetrack_output_folder) + + return people_dict diff --git a/lib/pymafx/utils/pose_utils.py b/lib/pymafx/utils/pose_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..55eb1d771376da71c864a715d1dd6b5d66e9894e --- /dev/null +++ b/lib/pymafx/utils/pose_utils.py @@ -0,0 +1,152 @@ +""" +Parts of the code are adapted from https://github.com/akanazawa/hmr +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import torch + + +def compute_similarity_transform(S1, S2): + """ + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.T + S2 = S2.T + transposed = True + assert (S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=1, keepdims=True) + mu2 = S2.mean(axis=1, keepdims=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = np.sum(X1**2) + + # 3. The outer product of X1 and X2. + K = X1.dot(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, Vh = np.linalg.svd(K) + V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = np.eye(U.shape[0]) + Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) + # Construct R. + R = V.dot(Z.dot(U.T)) + + # 5. Recover scale. + scale = np.trace(R.dot(K)) / var1 + + # 6. Recover translation. + t = mu2 - scale * (R.dot(mu1)) + + # 7. Error: + S1_hat = scale * R.dot(S1) + t + + if transposed: + S1_hat = S1_hat.T + + return S1_hat + + +def compute_similarity_transform_batch(S1, S2): + """Batched version of compute_similarity_transform.""" + S1_hat = np.zeros_like(S1) + for i in range(S1.shape[0]): + S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) + return S1_hat + + +def reconstruction_error(S1, S2, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + re = np.sqrt(((S1_hat - S2)**2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re, S1_hat + + +# https://math.stackexchange.com/questions/382760/composition-of-two-axis-angle-rotations +def axis_angle_add(theta, roll_axis, alpha): + """Composition of two axis-angle rotations (PyTorch version) + Args: + theta: N x 3 + roll_axis: N x 3 + alph: N x 1 + Returns: + equivalent axis-angle of the composition + """ + alpha = alpha / 2. + + l2norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l2norm, -1) + + normalized = torch.div(theta, angle) + angle = angle * 0.5 + b_cos = torch.cos(angle).cpu() + b_sin = torch.sin(angle).cpu() + + a_cos = torch.cos(alpha) + a_sin = torch.sin(alpha) + + dot_mm = torch.sum(normalized * roll_axis, dim=1, keepdim=True) + cross_mm = torch.zeros_like(normalized) + cross_mm[:, 0] = roll_axis[:, 1] * normalized[:, 2] - roll_axis[:, 2] * normalized[:, 1] + cross_mm[:, 1] = roll_axis[:, 2] * normalized[:, 0] - roll_axis[:, 0] * normalized[:, 2] + cross_mm[:, 2] = roll_axis[:, 0] * normalized[:, 1] - roll_axis[:, 1] * normalized[:, 0] + + c_cos = a_cos * b_cos - a_sin * b_sin * dot_mm + c_sin_n = a_sin * b_cos * roll_axis + a_cos * b_sin * normalized + a_sin * b_sin * cross_mm + + c_angle = 2 * torch.acos(c_cos) + c_sin = torch.sin(c_angle * 0.5) + c_n = (c_angle / c_sin) * c_sin_n + + return c_n + + +def axis_angle_add_np(theta, roll_axis, alpha): + """Composition of two axis-angle rotations (NumPy version) + Args: + theta: N x 3 + roll_axis: N x 3 + alph: N x 1 + Returns: + equivalent axis-angle of the composition + """ + alpha = alpha / 2. + + angle = np.linalg.norm(theta + 1e-8, ord=2, axis=1, keepdims=True) + normalized = np.divide(theta, angle) + angle = angle * 0.5 + + b_cos = np.cos(angle) + b_sin = np.sin(angle) + a_cos = np.cos(alpha) + a_sin = np.sin(alpha) + + dot_mm = np.sum(normalized * roll_axis, axis=1, keepdims=True) + cross_mm = np.zeros_like(normalized) + cross_mm[:, 0] = roll_axis[:, 1] * normalized[:, 2] - roll_axis[:, 2] * normalized[:, 1] + cross_mm[:, 1] = roll_axis[:, 2] * normalized[:, 0] - roll_axis[:, 0] * normalized[:, 2] + cross_mm[:, 2] = roll_axis[:, 0] * normalized[:, 1] - roll_axis[:, 1] * normalized[:, 0] + + c_cos = a_cos * b_cos - a_sin * b_sin * dot_mm + c_sin_n = a_sin * b_cos * roll_axis + a_cos * b_sin * normalized + a_sin * b_sin * cross_mm + c_angle = 2 * np.arccos(c_cos) + c_sin = np.sin(c_angle * 0.5) + c_n = (c_angle / c_sin) * c_sin_n + + return c_n diff --git a/lib/pymafx/utils/renderer.py b/lib/pymafx/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb19568680b839f93c00a5288c94a5a52025242 --- /dev/null +++ b/lib/pymafx/utils/renderer.py @@ -0,0 +1,663 @@ +import imp +import os +from pickle import NONE +# os.environ['PYOPENGL_PLATFORM'] = 'osmesa' +import torch +import trimesh +import numpy as np +# import neural_renderer as nr +from skimage.transform import resize +from torchvision.utils import make_grid +import torch.nn.functional as F + +from models.smpl import get_smpl_faces, get_model_faces, get_model_tpose +from utils.densepose_methods import DensePoseMethods +from core import constants, path_config +import json +from .geometry import convert_to_full_img_cam +from utils.imutils import crop + +try: + import math + import pyrender + from pyrender.constants import RenderFlags +except: + pass +try: + from opendr.renderer import ColoredRenderer + from opendr.lighting import LambertianPointLight, SphericalHarmonics + from opendr.camera import ProjectPoints +except: + pass + +from pytorch3d.structures.meshes import Meshes +# from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments + +from pytorch3d.renderer import ( + look_at_view_transform, FoVPerspectiveCameras, PerspectiveCameras, AmbientLights, PointLights, + RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, SoftPhongShader, + SoftSilhouetteShader, HardPhongShader, HardGouraudShader, HardFlatShader, TexturesVertex +) + +import logging + +logger = logging.getLogger(__name__) + + +class WeakPerspectiveCamera(pyrender.Camera): + def __init__( + self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None + ): + super(WeakPerspectiveCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + self.scale = scale + self.translation = translation + + def get_projection_matrix(self, width=None, height=None): + P = np.eye(4) + P[0, 0] = self.scale[0] + P[1, 1] = self.scale[1] + P[0, 3] = self.translation[0] * self.scale[0] + P[1, 3] = -self.translation[1] * self.scale[1] + P[2, 2] = -1 + return P + + +class PyRenderer: + def __init__( + self, resolution=(224, 224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1. + ): + self.resolution = (resolution[0] * scale_ratio, resolution[1] * scale_ratio) + # self.scale_ratio = scale_ratio + + self.faces = { + 'smplx': get_model_faces('smplx'), + 'smpl': get_model_faces('smpl'), + # 'mano': get_model_faces('mano'), + # 'flame': get_model_faces('flame'), + } + self.orig_img = orig_img + self.wireframe = wireframe + self.renderer = pyrender.OffscreenRenderer( + viewport_width=self.resolution[0], viewport_height=self.resolution[1], point_size=1.0 + ) + + self.vis_ratio = vis_ratio + + # set the scene + self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + + light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=1) + + yrot = np.radians(120) # angle of lights + + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [0, 1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [1, 1, 2] + self.scene.add(light, pose=light_pose) + + spot_l = pyrender.SpotLight( + color=np.ones(3), intensity=15.0, innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2 + ) + + light_pose[:3, 3] = [1, 2, 2] + self.scene.add(spot_l, pose=light_pose) + + light_pose[:3, 3] = [-1, 2, 2] + self.scene.add(spot_l, pose=light_pose) + + # light_pose[:3, 3] = [-2, 2, 0] + # self.scene.add(spot_l, pose=light_pose) + + # light_pose[:3, 3] = [-2, 2, 0] + # self.scene.add(spot_l, pose=light_pose) + + self.colors_dict = { + 'red': np.array([0.5, 0.2, 0.2]), + 'pink': np.array([0.7, 0.5, 0.5]), + 'neutral': np.array([0.7, 0.7, 0.6]), + # 'purple': np.array([0.5, 0.5, 0.7]), + 'purple': np.array([0.55, 0.4, 0.9]), + 'green': np.array([0.5, 0.55, 0.3]), + 'sky': np.array([0.3, 0.5, 0.55]), + 'white': np.array([1.0, 0.98, 0.94]), + } + + def __call__( + self, + verts, + faces=None, + img=np.zeros((224, 224, 3)), + cam=np.array([1, 0, 0]), + focal_length=[5000, 5000], + camera_rotation=np.eye(3), + crop_info=None, + angle=None, + axis=None, + mesh_filename=None, + color_type=None, + color=[1.0, 1.0, 0.9], + iwp_mode=True, + crop_img=True, + mesh_type='smpl', + scale_ratio=1., + rgba_mode=False + ): + + if faces is None: + faces = self.faces[mesh_type] + mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) + + Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0]) + mesh.apply_transform(Rx) + + if mesh_filename is not None: + mesh.export(mesh_filename) + + if angle and axis: + R = trimesh.transformations.rotation_matrix(math.radians(angle), axis) + mesh.apply_transform(R) + + cam = cam.copy() + if iwp_mode: + resolution = np.array(img.shape[:2]) * scale_ratio + if len(cam) == 4: + sx, sy, tx, ty = cam + # sy = sx + camera_translation = np.array( + [tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)] + ) + elif len(cam) == 3: + sx, tx, ty = cam + sy = sx + camera_translation = np.array( + [-tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)] + ) + render_res = resolution + self.renderer.viewport_width = render_res[1] + self.renderer.viewport_height = render_res[0] + else: + if crop_info['opt_cam_t'] is None: + camera_translation = convert_to_full_img_cam( + pare_cam=cam[None], + bbox_height=crop_info['bbox_scale'] * 200., + bbox_center=crop_info['bbox_center'], + img_w=crop_info['img_w'], + img_h=crop_info['img_h'], + focal_length=focal_length[0], + ) + else: + camera_translation = crop_info['opt_cam_t'] + if torch.is_tensor(camera_translation): + camera_translation = camera_translation[0].cpu().numpy() + camera_translation = camera_translation.copy() + camera_translation[0] *= -1 + if 'img_h' in crop_info and 'img_w' in crop_info: + render_res = (int(crop_info['img_h'][0]), int(crop_info['img_w'][0])) + else: + render_res = img.shape[:2] if type(img) is not list else img[0].shape[:2] + self.renderer.viewport_width = render_res[1] + self.renderer.viewport_height = render_res[0] + camera_rotation = camera_rotation.T + camera = pyrender.IntrinsicsCamera( + fx=focal_length[0], fy=focal_length[1], cx=render_res[1] / 2., cy=render_res[0] / 2. + ) + + if color_type != None: + color = self.colors_dict[color_type] + + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.2, + roughnessFactor=0.6, + alphaMode='OPAQUE', + baseColorFactor=(color[0], color[1], color[2], 1.0) + ) + + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + mesh_node = self.scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, :3] = camera_rotation + camera_pose[:3, 3] = camera_rotation @ camera_translation + cam_node = self.scene.add(camera, pose=camera_pose) + + if self.wireframe: + render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME | RenderFlags.SHADOWS_SPOT + else: + render_flags = RenderFlags.RGBA | RenderFlags.SHADOWS_SPOT + + rgb, _ = self.renderer.render(self.scene, flags=render_flags) + if crop_info is not None and crop_img: + crop_res = img.shape[:2] + rgb, _, _ = crop(rgb, crop_info['bbox_center'][0], crop_info['bbox_scale'][0], crop_res) + + valid_mask = (rgb[:, :, -1] > 0)[:, :, np.newaxis] + + image_list = [img] if type(img) is not list else img + + return_img = [] + for item in image_list: + if scale_ratio != 1: + orig_size = item.shape[:2] + item = resize( + item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio), + anti_aliasing=True + ) + item = (item * 255).astype(np.uint8) + output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + ( + 1 - valid_mask * self.vis_ratio + ) * item + # output_img[valid_mask < 0.5] = item[valid_mask < 0.5] + # if scale_ratio != 1: + # output_img = resize(output_img, (orig_size[0], orig_size[1]), anti_aliasing=True) + if rgba_mode: + output_img_rgba = np.zeros((output_img.shape[0], output_img.shape[1], 4)) + output_img_rgba[:, :, :3] = output_img + output_img_rgba[:, :, 3][valid_mask[:, :, 0]] = 255 + output_img = output_img_rgba.astype(np.uint8) + image = output_img.astype(np.uint8) + return_img.append(image) + return_img.append(item) + + if type(img) is not list: + # if scale_ratio == 1: + return_img = return_img[0] + + self.scene.remove_node(mesh_node) + self.scene.remove_node(cam_node) + + return return_img + + +class OpenDRenderer: + def __init__(self, resolution=(224, 224), ratio=1): + self.resolution = (resolution[0] * ratio, resolution[1] * ratio) + self.ratio = ratio + self.focal_length = 5000. + self.K = np.array( + [ + [self.focal_length, 0., self.resolution[1] / 2.], + [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.] + ] + ) + self.colors_dict = { + 'red': np.array([0.5, 0.2, 0.2]), + 'pink': np.array([0.7, 0.5, 0.5]), + 'neutral': np.array([0.7, 0.7, 0.6]), + 'purple': np.array([0.5, 0.5, 0.7]), + 'green': np.array([0.5, 0.55, 0.3]), + 'sky': np.array([0.3, 0.5, 0.55]), + 'white': np.array([1.0, 0.98, 0.94]), + } + self.renderer = ColoredRenderer() + self.faces = get_smpl_faces() + + def reset_res(self, resolution): + self.resolution = (resolution[0] * self.ratio, resolution[1] * self.ratio) + self.K = np.array( + [ + [self.focal_length, 0., self.resolution[1] / 2.], + [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.] + ] + ) + + def __call__( + self, + verts, + faces=None, + color=None, + color_type='white', + R=None, + mesh_filename=None, + img=np.zeros((224, 224, 3)), + cam=np.array([1, 0, 0]), + rgba=False, + addlight=True + ): + '''Render mesh using OpenDR + verts: shape - (V, 3) + faces: shape - (F, 3) + img: shape - (224, 224, 3), range - [0, 255] (np.uint8) + axis: rotate along with X/Y/Z axis (by angle) + R: rotation matrix (used to manipulate verts) shape - [3, 3] + Return: + rendered img: shape - (224, 224, 3), range - [0, 255] (np.uint8) + ''' + ## Create OpenDR renderer + rn = self.renderer + h, w = self.resolution + K = self.K + + f = np.array([K[0, 0], K[1, 1]]) + c = np.array([K[0, 2], K[1, 2]]) + + if faces is None: + faces = self.faces + if len(cam) == 4: + t = np.array([cam[2], cam[3], 2 * K[0, 0] / (w * cam[0] + 1e-9)]) + elif len(cam) == 3: + t = np.array([cam[1], cam[2], 2 * K[0, 0] / (w * cam[0] + 1e-9)]) + + rn.camera = ProjectPoints(rt=np.array([0, 0, 0]), t=t, f=f, c=c, k=np.zeros(5)) + rn.frustum = {'near': 1., 'far': 1000., 'width': w, 'height': h} + + albedo = np.ones_like(verts) * .9 + + if color is not None: + color0 = np.array(color) + color1 = np.array(color) + color2 = np.array(color) + elif color_type == 'white': + color0 = np.array([1., 1., 1.]) + color1 = np.array([1., 1., 1.]) + color2 = np.array([0.7, 0.7, 0.7]) + color = np.ones_like(verts) * self.colors_dict[color_type][None, :] + else: + color0 = self.colors_dict[color_type] * 1.2 + color1 = self.colors_dict[color_type] * 1.2 + color2 = self.colors_dict[color_type] * 1.2 + color = np.ones_like(verts) * self.colors_dict[color_type][None, :] + + # render_smpl = rn.r + if R is not None: + assert R.shape == (3, 3), "Shape of rotation matrix should be (3, 3)" + verts = np.dot(verts, R) + + rn.set(v=verts, f=faces, vc=color, bgcolor=np.zeros(3)) + + if addlight: + yrot = np.radians(120) # angle of lights + # # 1. 1. 0.7 + rn.vc = LambertianPointLight( + f=rn.f, + v=rn.v, + num_verts=len(rn.v), + light_pos=rotateY(np.array([-200, -100, -100]), yrot), + vc=albedo, + light_color=color0 + ) + + # Construct Left Light + rn.vc += LambertianPointLight( + f=rn.f, + v=rn.v, + num_verts=len(rn.v), + light_pos=rotateY(np.array([800, 10, 300]), yrot), + vc=albedo, + light_color=color1 + ) + + # Construct Right Light + rn.vc += LambertianPointLight( + f=rn.f, + v=rn.v, + num_verts=len(rn.v), + light_pos=rotateY(np.array([-500, 500, 1000]), yrot), + vc=albedo, + light_color=color2 + ) + + rendered_image = rn.r + visibility_image = rn.visibility_image + + image_list = [img] if type(img) is not list else img + + return_img = [] + for item in image_list: + if self.ratio != 1: + img_resized = resize( + item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio), + anti_aliasing=True + ) + else: + img_resized = item / 255. + + try: + img_resized[visibility_image != (2**32 - 1) + ] = rendered_image[visibility_image != (2**32 - 1)] + except: + logger.warning('Can not render mesh.') + + img_resized = (img_resized * 255).astype(np.uint8) + res = img_resized + + if rgba: + img_resized_rgba = np.zeros((img_resized.shape[0], img_resized.shape[1], 4)) + img_resized_rgba[:, :, :3] = img_resized + img_resized_rgba[:, :, 3][visibility_image != (2**32 - 1)] = 255 + res = img_resized_rgba.astype(np.uint8) + return_img.append(res) + + if type(img) is not list: + return_img = return_img[0] + + return return_img + + +# https://github.com/classner/up/blob/master/up_tools/camera.py +def rotateY(points, angle): + """Rotate all points in a 2D array around the y axis.""" + ry = np.array( + [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0., + np.cos(angle)]] + ) + return np.dot(points, ry) + + +def rotateX(points, angle): + """Rotate all points in a 2D array around the x axis.""" + rx = np.array( + [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle), + np.cos(angle)]] + ) + return np.dot(points, rx) + + +def rotateZ(points, angle): + """Rotate all points in a 2D array around the z axis.""" + rz = np.array( + [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]] + ) + return np.dot(points, rz) + + +class IUV_Renderer(object): + def __init__( + self, + focal_length=5000., + orig_size=224, + output_size=56, + mode='iuv', + device=torch.device('cuda'), + mesh_type='smpl' + ): + + self.focal_length = focal_length + self.orig_size = orig_size + self.output_size = output_size + + if mode in ['iuv']: + if mesh_type == 'smpl': + DP = DensePoseMethods() + + vert_mapping = DP.All_vertices.astype('int64') - 1 + self.vert_mapping = torch.from_numpy(vert_mapping) + + faces = DP.FacesDensePose + faces = faces[None, :, :] + self.faces = torch.from_numpy( + faces.astype(np.int32) + ) # [1, 13774, 3], torch.int32 + + num_part = float(np.max(DP.FaceIndices)) + self.num_part = num_part + + dp_vert_pid_fname = 'data/dp_vert_pid.npy' + if os.path.exists(dp_vert_pid_fname): + dp_vert_pid = list(np.load(dp_vert_pid_fname)) + else: + print('creating data/dp_vert_pid.npy') + dp_vert_pid = [] + for v in range(len(vert_mapping)): + for i, f in enumerate(DP.FacesDensePose): + if v in f: + dp_vert_pid.append(DP.FaceIndices[i]) + break + np.save(dp_vert_pid_fname, np.array(dp_vert_pid)) + + textures_vts = np.array( + [ + (dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i]) + for i in range(len(vert_mapping)) + ] + ) + self.textures_vts = torch.from_numpy( + textures_vts[None].astype(np.float32) + ) # (1, 7829, 3) + elif mode == 'pncc': + self.vert_mapping = None + self.faces = torch.from_numpy( + get_model_faces(mesh_type)[None].astype(np.int32) + ) # mano: torch.Size([1, 1538, 3]) + textures_vts = get_model_tpose(mesh_type).unsqueeze( + 0 + ) # mano: torch.Size([1, 778, 3]) + + texture_min = torch.min(textures_vts) - 0.001 + texture_range = torch.max(textures_vts) - texture_min + 0.001 + self.textures_vts = (textures_vts - texture_min) / texture_range + elif mode in ['seg']: + self.vert_mapping = None + body_model = 'smpl' + + self.faces = torch.from_numpy(get_smpl_faces().astype(np.int32)[None]) + + with open( + os.path.join( + path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model) + ), 'rb' + ) as json_file: + smpl_part_id = json.load(json_file) + + v_id = [] + for k in smpl_part_id.keys(): + v_id.extend(smpl_part_id[k]) + + v_id = torch.tensor(v_id) + n_verts = len(torch.unique(v_id)) + num_part = len(constants.SMPL_PART_ID.keys()) + self.num_part = num_part + + seg_vert_pid = np.zeros(n_verts) + for k in smpl_part_id.keys(): + seg_vert_pid[smpl_part_id[k]] = constants.SMPL_PART_ID[k] + + print('seg_vert_pid', seg_vert_pid.shape) + textures_vts = seg_vert_pid[:, None].repeat(3, axis=1) / num_part + print('textures_vts', textures_vts.shape) + # textures_vts = np.array( + # [(seg_vert_pid[i] / num_part,) * 3 for i in + # range(n_verts)]) + self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32)) + + K = np.array( + [ + [self.focal_length, 0., self.orig_size / 2.], + [0., self.focal_length, self.orig_size / 2.], [0., 0., 1.] + ] + ) + + R = np.array([[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]]) + + t = np.array([0, 0, 5]) + + if self.orig_size != 224: + rander_scale = self.orig_size / float(224) + K[0, 0] *= rander_scale + K[1, 1] *= rander_scale + K[0, 2] *= rander_scale + K[1, 2] *= rander_scale + + self.K = torch.FloatTensor(K[None, :, :]) + self.R = torch.FloatTensor(R[None, :, :]) + self.t = torch.FloatTensor(t[None, None, :]) + + camK = F.pad(self.K, (0, 1, 0, 1), "constant", 0) + camK[:, 2, 2] = 0 + camK[:, 3, 2] = 1 + camK[:, 2, 3] = 1 + + self.K = camK + + self.device = device + lights = AmbientLights(device=self.device) + + raster_settings = RasterizationSettings( + image_size=output_size, + blur_radius=0, + faces_per_pixel=1, + ) + self.renderer = MeshRenderer( + rasterizer=MeshRasterizer(raster_settings=raster_settings), + shader=HardFlatShader( + device=self.device, + lights=lights, + blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0) + ) + ) + + def camera_matrix(self, cam): + batch_size = cam.size(0) + + K = self.K.repeat(batch_size, 1, 1) + R = self.R.repeat(batch_size, 1, 1) + t = torch.stack( + [-cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)], + dim=-1 + ) + + if cam.is_cuda: + # device_id = cam.get_device() + K = K.to(cam.device) + R = R.to(cam.device) + t = t.to(cam.device) + + return K, R, t + + def verts2iuvimg(self, verts, cam, iwp_mode=True): + batch_size = verts.size(0) + + K, R, t = self.camera_matrix(cam) + + if self.vert_mapping is None: + vertices = verts + else: + vertices = verts[:, self.vert_mapping, :] + + mesh = Meshes(vertices, self.faces.to(verts.device).expand(batch_size, -1, -1)) + mesh.textures = TexturesVertex( + verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1) + ) + + cameras = PerspectiveCameras( + device=verts.device, + R=R, + T=t, + K=K, + in_ndc=False, + image_size=[(self.orig_size, self.orig_size)] + ) + + iuv_image = self.renderer(mesh, cameras=cameras) + iuv_image = iuv_image[..., :3].permute(0, 3, 1, 2) + + return iuv_image diff --git a/lib/pymafx/utils/sample_mesh.py b/lib/pymafx/utils/sample_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2599bee12d2577b6826ea8bfad8c937f2bcc2db2 --- /dev/null +++ b/lib/pymafx/utils/sample_mesh.py @@ -0,0 +1,64 @@ +import os +import trimesh +import numpy as np +from .utils.libmesh import check_mesh_contains + + +def get_occ_gt( + in_path=None, + vertices=None, + faces=None, + pts_num=1000, + points_sigma=0.01, + with_dp=False, + points=None, + extra_points=None +): + if in_path is not None: + mesh = trimesh.load(in_path, process=False) + print(type(mesh.vertices), mesh.vertices.shape, mesh.faces.shape) + + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + + # print('get_occ_gt', type(mesh.vertices), mesh.vertices.shape, mesh.faces.shape) + + # points_size = 100000 + points_padding = 0.1 + # points_sigma = 0.01 + points_uniform_ratio = 0.5 + n_points_uniform = int(pts_num * points_uniform_ratio) + n_points_surface = pts_num - n_points_uniform + + if points is None: + points_scale = 2.0 + boxsize = points_scale + points_padding + points_uniform = np.random.rand(n_points_uniform, 3) + points_uniform = boxsize * (points_uniform - 0.5) + points_surface, index_surface = mesh.sample(n_points_surface, return_index=True) + points_surface += points_sigma * np.random.randn(n_points_surface, 3) + points = np.concatenate([points_uniform, points_surface], axis=0) + + if extra_points is not None: + extra_points += points_sigma * np.random.randn(len(extra_points), 3) + points = np.concatenate([points, extra_points], axis=0) + + occupancies = check_mesh_contains(mesh, points) + + index_surface = None + + # points = points.astype(dtype) + + # print('occupancies', occupancies.dtype, np.sum(occupancies), occupancies.shape) + # occupancies = np.packbits(occupancies) + # print('occupancies bit', occupancies.dtype, np.sum(occupancies), occupancies.shape) + + # print('occupancies', points.shape, occupancies.shape, occupancies.dtype, np.sum(occupancies), index_surface.shape) + + return_dict = {} + return_dict['points'] = points + return_dict['points.occ'] = occupancies + return_dict['sf_sidx'] = index_surface + + # export_pointcloud(mesh, modelname, loc, scale, args) + # export_points(mesh, modelname, loc, scale, args) + return return_dict diff --git a/lib/pymafx/utils/saver.py b/lib/pymafx/utils/saver.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6bd3a184cc658dbc666ad2dcf3bc15d8cc427b --- /dev/null +++ b/lib/pymafx/utils/saver.py @@ -0,0 +1,139 @@ +from __future__ import division +import os +import torch +import datetime +import logging + +logger = logging.getLogger(__name__) + + +class CheckpointSaver(): + """Class that handles saving and loading checkpoints during training.""" + def __init__(self, save_dir, save_steps=1000, overwrite=False): + self.save_dir = os.path.abspath(save_dir) + self.save_steps = save_steps + self.overwrite = overwrite + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + self.get_latest_checkpoint() + return + + def exists_checkpoint(self, checkpoint_file=None): + """Check if a checkpoint exists in the current directory.""" + if checkpoint_file is None: + return False if self.latest_checkpoint is None else True + else: + return os.path.isfile(checkpoint_file) + + def save_checkpoint( + self, + models, + optimizers, + epoch, + batch_idx, + batch_size, + total_step_count, + is_best=False, + save_by_step=False, + interval=5, + with_optimizer=True + ): + """Save checkpoint.""" + timestamp = datetime.datetime.now() + if self.overwrite: + checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt')) + elif save_by_step: + checkpoint_filename = os.path.abspath( + os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)) + ) + else: + if epoch % interval == 0: + checkpoint_filename = os.path.abspath( + os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt') + ) + else: + checkpoint_filename = None + + checkpoint = {} + for model in models: + model_dict = models[model].state_dict() + for k in list(model_dict.keys()): + if '.smpl.' in k: + del model_dict[k] + checkpoint[model] = model_dict + if with_optimizer: + for optimizer in optimizers: + checkpoint[optimizer] = optimizers[optimizer].state_dict() + checkpoint['epoch'] = epoch + checkpoint['batch_idx'] = batch_idx + checkpoint['batch_size'] = batch_size + checkpoint['total_step_count'] = total_step_count + print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) + + if checkpoint_filename is not None: + torch.save(checkpoint, checkpoint_filename) + print('Saving checkpoint file [' + checkpoint_filename + ']') + if is_best: # save the best + checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt')) + torch.save(checkpoint, checkpoint_filename) + print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) + print('Saving checkpoint file [' + checkpoint_filename + ']') + torch.save(checkpoint, checkpoint_filename) + print('Saved checkpoint file [' + checkpoint_filename + ']') + + def load_checkpoint(self, models, optimizers, checkpoint_file=None): + """Load a checkpoint.""" + if checkpoint_file is None: + logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']') + checkpoint_file = self.latest_checkpoint + checkpoint = torch.load(checkpoint_file) + for model in models: + if model in checkpoint: + model_dict = models[model].state_dict() + pretrained_dict = { + k: v + for k, v in checkpoint[model].items() if k in model_dict.keys() + } + model_dict.update(pretrained_dict) + models[model].load_state_dict(model_dict) + + # models[model].load_state_dict(checkpoint[model]) + for optimizer in optimizers: + if optimizer in checkpoint: + optimizers[optimizer].load_state_dict(checkpoint[optimizer]) + return { + 'epoch': checkpoint['epoch'], + 'batch_idx': checkpoint['batch_idx'], + 'batch_size': checkpoint['batch_size'], + 'total_step_count': checkpoint['total_step_count'] + } + + def get_latest_checkpoint(self): + """Get filename of latest checkpoint if it exists.""" + checkpoint_list = [] + for dirpath, dirnames, filenames in os.walk(self.save_dir): + for filename in filenames: + if filename.endswith('.pt'): + checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename))) + # sort + import re + + def atof(text): + try: + retval = float(text) + except ValueError: + retval = text + return retval + + def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + float regex comes from https://stackoverflow.com/a/12643073/190597 + ''' + return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)] + + checkpoint_list.sort(key=natural_keys) + self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] + return diff --git a/lib/pymafx/utils/segms.py b/lib/pymafx/utils/segms.py new file mode 100644 index 0000000000000000000000000000000000000000..44c617529d67323a8664c3e00872e5db091b8be6 --- /dev/null +++ b/lib/pymafx/utils/segms.py @@ -0,0 +1,268 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""Functions for interacting with segmentation masks in the COCO format. + +The following terms are used in this module + mask: a binary mask encoded as a 2D numpy array + segm: a segmentation mask in one of the two COCO formats (polygon or RLE) + polygon: COCO's polygon format + RLE: COCO's run length encoding format +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + +import pycocotools.mask as mask_util + + +def GetDensePoseMask(Polys): + MaskGen = np.zeros([256, 256]) + for i in range(1, 15): + if (Polys[i - 1]): + current_mask = mask_util.decode(Polys[i - 1]) + MaskGen[current_mask > 0] = i + return MaskGen + + +def flip_segms(segms, height, width): + """Left/right flip each mask in a list of masks.""" + def _flip_poly(poly, width): + flipped_poly = np.array(poly) + flipped_poly[0::2] = width - np.array(poly[0::2]) - 1 + return flipped_poly.tolist() + + def _flip_rle(rle, height, width): + if 'counts' in rle and type(rle['counts']) == list: + # Magic RLE format handling painfully discovered by looking at the + # COCO API showAnns function. + rle = mask_util.frPyObjects([rle], height, width) + mask = mask_util.decode(rle) + mask = mask[:, ::-1, :] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + flipped_segms = [] + for segm in segms: + if type(segm) == list: + # Polygon format + flipped_segms.append([_flip_poly(poly, width) for poly in segm]) + else: + # RLE format + assert type(segm) == dict + flipped_segms.append(_flip_rle(segm, height, width)) + return flipped_segms + + +def polys_to_mask(polygons, height, width): + """Convert from the COCO polygon segmentation format to a binary mask + encoded as a 2D array of data type numpy.float32. The polygon segmentation + is understood to be enclosed inside a height x width image. The resulting + mask is therefore of shape (height, width). + """ + rle = mask_util.frPyObjects(polygons, height, width) + mask = np.array(mask_util.decode(rle), dtype=np.float32) + # Flatten in case polygons was a list + mask = np.sum(mask, axis=2) + mask = np.array(mask > 0, dtype=np.float32) + return mask + + +def mask_to_bbox(mask): + """Compute the tight bounding box of a binary mask.""" + xs = np.where(np.sum(mask, axis=0) > 0)[0] + ys = np.where(np.sum(mask, axis=1) > 0)[0] + + if len(xs) == 0 or len(ys) == 0: + return None + + x0 = xs[0] + x1 = xs[-1] + y0 = ys[0] + y1 = ys[-1] + return np.array((x0, y0, x1, y1), dtype=np.float32) + + +def polys_to_mask_wrt_box(polygons, box, M): + """Convert from the COCO polygon segmentation format to a binary mask + encoded as a 2D array of data type numpy.float32. The polygon segmentation + is understood to be enclosed in the given box and rasterized to an M x M + mask. The resulting mask is therefore of shape (M, M). + """ + w = box[2] - box[0] + h = box[3] - box[1] + + w = np.maximum(w, 1) + h = np.maximum(h, 1) + + polygons_norm = [] + for poly in polygons: + p = np.array(poly, dtype=np.float32) + p[0::2] = (p[0::2] - box[0]) * M / w + p[1::2] = (p[1::2] - box[1]) * M / h + polygons_norm.append(p) + + rle = mask_util.frPyObjects(polygons_norm, M, M) + mask = np.array(mask_util.decode(rle), dtype=np.float32) + # Flatten in case polygons was a list + mask = np.sum(mask, axis=2) + mask = np.array(mask > 0, dtype=np.float32) + return mask + + +def polys_to_boxes(polys): + """Convert a list of polygons into an array of tight bounding boxes.""" + boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32) + for i in range(len(polys)): + poly = polys[i] + x0 = min(min(p[::2]) for p in poly) + x1 = max(max(p[::2]) for p in poly) + y0 = min(min(p[1::2]) for p in poly) + y1 = max(max(p[1::2]) for p in poly) + boxes_from_polys[i, :] = [x0, y0, x1, y1] + + return boxes_from_polys + + +def rle_mask_voting(top_masks, all_masks, all_dets, iou_thresh, binarize_thresh, method='AVG'): + """Returns new masks (in correspondence with `top_masks`) by combining + multiple overlapping masks coming from the pool of `all_masks`. Two methods + for combining masks are supported: 'AVG' uses a weighted average of + overlapping mask pixels; 'UNION' takes the union of all mask pixels. + """ + if len(top_masks) == 0: + return + + all_not_crowd = [False] * len(all_masks) + top_to_all_overlaps = mask_util.iou(top_masks, all_masks, all_not_crowd) + decoded_all_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in all_masks] + decoded_top_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in top_masks] + all_boxes = all_dets[:, :4].astype(np.int32) + all_scores = all_dets[:, 4] + + # Fill box support with weights + mask_shape = decoded_all_masks[0].shape + mask_weights = np.zeros((len(all_masks), mask_shape[0], mask_shape[1])) + for k in range(len(all_masks)): + ref_box = all_boxes[k] + x_0 = max(ref_box[0], 0) + x_1 = min(ref_box[2] + 1, mask_shape[1]) + y_0 = max(ref_box[1], 0) + y_1 = min(ref_box[3] + 1, mask_shape[0]) + mask_weights[k, y_0:y_1, x_0:x_1] = all_scores[k] + mask_weights = np.maximum(mask_weights, 1e-5) + + top_segms_out = [] + for k in range(len(top_masks)): + # Corner case of empty mask + if decoded_top_masks[k].sum() == 0: + top_segms_out.append(top_masks[k]) + continue + + inds_to_vote = np.where(top_to_all_overlaps[k] >= iou_thresh)[0] + # Only matches itself + if len(inds_to_vote) == 1: + top_segms_out.append(top_masks[k]) + continue + + masks_to_vote = [decoded_all_masks[i] for i in inds_to_vote] + if method == 'AVG': + ws = mask_weights[inds_to_vote] + soft_mask = np.average(masks_to_vote, axis=0, weights=ws) + mask = np.array(soft_mask > binarize_thresh, dtype=np.uint8) + elif method == 'UNION': + # Any pixel that's on joins the mask + soft_mask = np.sum(masks_to_vote, axis=0) + mask = np.array(soft_mask > 1e-5, dtype=np.uint8) + else: + raise NotImplementedError('Method {} is unknown'.format(method)) + rle = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0] + top_segms_out.append(rle) + + return top_segms_out + + +def rle_mask_nms(masks, dets, thresh, mode='IOU'): + """Performs greedy non-maximum suppression based on an overlap measurement + between masks. The type of measurement is determined by `mode` and can be + either 'IOU' (standard intersection over union) or 'IOMA' (intersection over + mininum area). + """ + if len(masks) == 0: + return [] + if len(masks) == 1: + return [0] + + if mode == 'IOU': + # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(union(m1, m2)) + all_not_crowds = [False] * len(masks) + ious = mask_util.iou(masks, masks, all_not_crowds) + elif mode == 'IOMA': + # Computes ious[m1, m2] = area(intersect(m1, m2)) / min(area(m1), area(m2)) + all_crowds = [True] * len(masks) + # ious[m1, m2] = area(intersect(m1, m2)) / area(m2) + ious = mask_util.iou(masks, masks, all_crowds) + # ... = max(area(intersect(m1, m2)) / area(m2), + # area(intersect(m2, m1)) / area(m1)) + ious = np.maximum(ious, ious.transpose()) + elif mode == 'CONTAINMENT': + # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(m2) + # Which measures how much m2 is contained inside m1 + all_crowds = [True] * len(masks) + ious = mask_util.iou(masks, masks, all_crowds) + else: + raise NotImplementedError('Mode {} is unknown'.format(mode)) + + scores = dets[:, 4] + order = np.argsort(-scores) + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = ious[i, order[1:]] + inds_to_keep = np.where(ovr <= thresh)[0] + order = order[inds_to_keep + 1] + + return keep + + +def rle_masks_to_boxes(masks): + """Computes the bounding box of each mask in a list of RLE encoded masks.""" + if len(masks) == 0: + return [] + + decoded_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in masks] + + def get_bounds(flat_mask): + inds = np.where(flat_mask > 0)[0] + return inds.min(), inds.max() + + boxes = np.zeros((len(decoded_masks), 4)) + keep = [True] * len(decoded_masks) + for i, mask in enumerate(decoded_masks): + if mask.sum() == 0: + keep[i] = False + continue + flat_mask = mask.sum(axis=0) + x0, x1 = get_bounds(flat_mask) + flat_mask = mask.sum(axis=1) + y0, y1 = get_bounds(flat_mask) + boxes[i, :] = (x0, y0, x1, y1) + + return boxes, np.where(keep)[0] diff --git a/lib/pymafx/utils/smooth_bbox.py b/lib/pymafx/utils/smooth_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..4393320e7f50128d6838d99c76b5d0f8f45f6efc --- /dev/null +++ b/lib/pymafx/utils/smooth_bbox.py @@ -0,0 +1,123 @@ +# This script is borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/util/smooth_bbox.py +# Adhere to their licence to use this script + +import numpy as np +import scipy.signal as signal +from scipy.ndimage.filters import gaussian_filter1d + + +def get_smooth_bbox_params(kps, vis_thresh=2, kernel_size=11, sigma=3): + """ + Computes smooth bounding box parameters from keypoints: + 1. Computes bbox by rescaling the person to be around 150 px. + 2. Linearly interpolates bbox params for missing annotations. + 3. Median filtering + 4. Gaussian filtering. + + Recommended thresholds: + * detect-and-track: 0 + * 3DPW: 0.1 + + Args: + kps (list): List of kps (Nx3) or None. + vis_thresh (float): Threshold for visibility. + kernel_size (int): Kernel size for median filtering (must be odd). + sigma (float): Sigma for gaussian smoothing. + + Returns: + Smooth bbox params [cx, cy, scale], start index, end index + """ + bbox_params, start, end = get_all_bbox_params(kps, vis_thresh) + smoothed = smooth_bbox_params(bbox_params, kernel_size, sigma) + smoothed = np.vstack((np.zeros((start, 3)), smoothed)) + return smoothed, start, end + + +def kp_to_bbox_param(kp, vis_thresh): + """ + Finds the bounding box parameters from the 2D keypoints. + + Args: + kp (Kx3): 2D Keypoints. + vis_thresh (float): Threshold for visibility. + + Returns: + [center_x, center_y, scale] + """ + if kp is None: + return + vis = kp[:, 2] > vis_thresh + if not np.any(vis): + return + min_pt = np.min(kp[vis, :2], axis=0) + max_pt = np.max(kp[vis, :2], axis=0) + person_height = np.linalg.norm(max_pt - min_pt) + if person_height < 0.5: + return + center = (min_pt + max_pt) / 2. + scale = 150. / person_height + return np.append(center, scale) + + +def get_all_bbox_params(kps, vis_thresh=2): + """ + Finds bounding box parameters for all keypoints. + + Look for sequences in the middle with no predictions and linearly + interpolate the bbox params for those + + Args: + kps (list): List of kps (Kx3) or None. + vis_thresh (float): Threshold for visibility. + + Returns: + bbox_params, start_index (incl), end_index (excl) + """ + # keeps track of how many indices in a row with no prediction + num_to_interpolate = 0 + start_index = -1 + bbox_params = np.empty(shape=(0, 3), dtype=np.float32) + + for i, kp in enumerate(kps): + bbox_param = kp_to_bbox_param(kp, vis_thresh=vis_thresh) + if bbox_param is None: + num_to_interpolate += 1 + continue + + if start_index == -1: + # Found the first index with a prediction! + start_index = i + num_to_interpolate = 0 + + if num_to_interpolate > 0: + # Linearly interpolate each param. + previous = bbox_params[-1] + # This will be 3x(n+2) + interpolated = np.array( + [ + np.linspace(prev, curr, num_to_interpolate + 2) + for prev, curr in zip(previous, bbox_param) + ] + ) + bbox_params = np.vstack((bbox_params, interpolated.T[1:-1])) + num_to_interpolate = 0 + bbox_params = np.vstack((bbox_params, bbox_param)) + + return bbox_params, start_index, i - num_to_interpolate + 1 + + +def smooth_bbox_params(bbox_params, kernel_size=11, sigma=8): + """ + Applies median filtering and then gaussian filtering to bounding box + parameters. + + Args: + bbox_params (Nx3): [cx, cy, scale]. + kernel_size (int): Kernel size for median filtering (must be odd). + sigma (float): Sigma for gaussian smoothing. + + Returns: + Smoothed bounding box parameters (Nx3). + """ + smoothed = np.array([signal.medfilt(param, kernel_size) for param in bbox_params.T]).T + return np.array([gaussian_filter1d(traj, sigma) for traj in smoothed.T]).T diff --git a/lib/pymafx/utils/transforms.py b/lib/pymafx/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..25534674631d40b8b263b242d05339443b169dcb --- /dev/null +++ b/lib/pymafx/utils/transforms.py @@ -0,0 +1,119 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 + + +def flip_back(output_flipped, matched_parts): + ''' + ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width) + ''' + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped + + +def fliplr_joints(joints, joints_vis, width, matched_parts): + """ + flip coords + """ + # Flip horizontal + joints[:, 0] = width - joints[:, 0] - 1 + + # Change left-right parts + for pair in matched_parts: + joints[pair[0], :], joints[pair[1], :] = \ + joints[pair[1], :], joints[pair[0], :].copy() + joints_vis[pair[0], :], joints_vis[pair[1], :] = \ + joints_vis[pair[1], :], joints_vis[pair[0], :].copy() + + return joints * joints_vis, joints_vis + + +def transform_preds(coords, center, scale, output_size): + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale, 0, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def get_affine_transform( + center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0 +): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + # print(scale) + scale = np.array([scale, scale]) + + scale_tmp = scale * 200.0 + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + +def crop(img, center, scale, output_size, rot=0): + trans = get_affine_transform(center, scale, rot, output_size) + + dst_img = cv2.warpAffine( + img, trans, (int(output_size[0]), int(output_size[1])), flags=cv2.INTER_LINEAR + ) + + return dst_img diff --git a/lib/pymafx/utils/uv_vis.py b/lib/pymafx/utils/uv_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..86fdd33ddee774c2bbe02478b2d74f53f8522256 --- /dev/null +++ b/lib/pymafx/utils/uv_vis.py @@ -0,0 +1,152 @@ +import os +import torch +import numpy as np +import torch.nn.functional as F +from skimage.transform import resize +# Use a non-interactive backend +import matplotlib + +matplotlib.use('Agg') + +from .renderer import OpenDRenderer, PyRenderer + + +def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=None): + device_id = U_uv.get_device() + batch_size = U_uv.size(0) + K = U_uv.size(1) + heatmap_size = U_uv.size(2) + + Index_UV_max = torch.argmax(Index_UV, dim=1) + if AnnIndex is None: + Index_UV_max = Index_UV_max.to(torch.int64) + else: + AnnIndex_max = torch.argmax(AnnIndex, dim=1) + Index_UV_max = Index_UV_max * (AnnIndex_max > 0).to(torch.int64) + + outputs = [] + + for batch_id in range(batch_size): + output = torch.zeros([3, U_uv.size(2), U_uv.size(3)], dtype=torch.float32).cuda(device_id) + output[0] = Index_UV_max[batch_id].to(torch.float32) + if ind_mapping is None: + output[0] /= float(K - 1) + else: + for ind in range(len(ind_mapping)): + output[0][output[0] == ind] = ind_mapping[ind] * (1. / 24.) + + for part_id in range(1, K): + CurrentU = U_uv[batch_id, part_id] + CurrentV = V_uv[batch_id, part_id] + output[1, + Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] + output[2, + Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] + + if uv_rois is None: + outputs.append(output.unsqueeze(0)) + else: + roi_fg = uv_rois[batch_id][1:] + w = roi_fg[2] - roi_fg[0] + h = roi_fg[3] - roi_fg[1] + + aspect_ratio = float(w) / h + + if aspect_ratio < 1: + new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)] + output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') + paddingleft = int(0.5 * (heatmap_size - new_size[1])) + output = F.pad( + output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0) + ) + else: + new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size] + output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') + paddingtop = int(0.5 * (heatmap_size - new_size[0])) + output = F.pad( + output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop) + ) + + outputs.append(output) + + return torch.cat(outputs, dim=0) + + +def vis_smpl_iuv( + image, + cam_pred, + vert_pred, + face, + pred_uv, + vert_errors_batch, + image_name, + save_path, + opt, + ratio=1 +): + + # save_path = os.path.join('./notebooks/output/demo_results-wild', ids[f_id][0]) + if not os.path.exists(save_path): + os.makedirs(save_path) + # dr_render = OpenDRenderer(ratio=ratio) + dr_render = PyRenderer() + + focal_length = 5000. + orig_size = 224. + + if pred_uv is not None: + iuv_img = iuv_map2img(*pred_uv) + + for draw_i in range(len(cam_pred)): + err_val = '{:06d}_'.format(int(10 * vert_errors_batch[draw_i])) + draw_name = err_val + image_name[draw_i] + K = np.array( + [[focal_length, 0., orig_size / 2.], [0., focal_length, orig_size / 2.], [0., 0., 1.]] + ) + + # img_orig, img_resized, img_smpl, render_smpl_rgba = dr_render( + # image[draw_i], + # cam_pred[draw_i], + # vert_pred[draw_i], + # face, + # draw_name[:-4] + # ) + if opt.save_obj: + os.makedirs(os.path.join(save_path, 'mesh'), exist_ok=True) + mesh_filename = os.path.join(save_path, 'mesh', draw_name[:-4] + '.obj') + else: + mesh_filename = None + + img_orig = np.moveaxis(image[draw_i], 0, -1) + img_smpl, img_resized = dr_render( + vert_pred[draw_i], + img=img_orig, + cam=cam_pred[draw_i], + iwp_mode=True, + scale_ratio=4., + mesh_filename=mesh_filename, + ) + + ones_img = np.ones(img_smpl.shape[:2]) * 255 + ones_img = ones_img[:, :, None] + img_smpl_rgba = np.concatenate((img_smpl, ones_img), axis=2) + img_resized_rgba = np.concatenate((img_resized, ones_img), axis=2) + + # render_img = np.concatenate((img_resized_rgba, img_smpl_rgba, render_smpl_rgba * 255), axis=1) + render_img = np.concatenate((img_resized_rgba, img_smpl_rgba), axis=1) + render_img[render_img < 0] = 0 + render_img[render_img > 255] = 255 + matplotlib.image.imsave( + os.path.join(save_path, draw_name[:-4] + '.png'), render_img.astype(np.uint8) + ) + + if pred_uv is not None: + # estimated global IUV + global_iuv = iuv_img[draw_i].cpu().numpy() + global_iuv = np.transpose(global_iuv, (1, 2, 0)) + global_iuv = resize(global_iuv, img_resized.shape[:2]) + global_iuv[global_iuv > 1] = 1 + global_iuv[global_iuv < 0] = 0 + matplotlib.image.imsave( + os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'), global_iuv + ) diff --git a/lib/pymafx/utils/vis.py b/lib/pymafx/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..5273707c05f66275150e7cb2d86f44dcf4c92223 --- /dev/null +++ b/lib/pymafx/utils/vis.py @@ -0,0 +1,676 @@ +# Written by Roy Tseng +# +# Based on: +# -------------------------------------------------------- +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import cv2 +import numpy as np +import os +import pycocotools.mask as mask_util +import math +import torchvision + +from .colormap import colormap +from .keypoints import get_keypoints +from .imutils import normalize_2d_kp + +# Use a non-interactive backend +import matplotlib + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from matplotlib.patches import Polygon +from mpl_toolkits.mplot3d import Axes3D +from skimage.transform import resize + +plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator + +_GRAY = (218, 227, 218) +_GREEN = (18, 127, 15) +_WHITE = (255, 255, 255) + + +def get_colors(): + colors = { + 'pink': np.array([197, 27, 125]), # L lower leg + 'light_pink': np.array([233, 163, 201]), # L upper leg + 'light_green': np.array([161, 215, 106]), # L lower arm + 'green': np.array([77, 146, 33]), # L upper arm + 'red': np.array([215, 48, 39]), # head + 'light_red': np.array([252, 146, 114]), # head + 'light_orange': np.array([252, 141, 89]), # chest + 'purple': np.array([118, 42, 131]), # R lower leg + 'light_purple': np.array([175, 141, 195]), # R upper + 'light_blue': np.array([145, 191, 219]), # R lower arm + 'blue': np.array([69, 117, 180]), # R upper arm + 'gray': np.array([130, 130, 130]), # + 'white': np.array([255, 255, 255]), # + } + return colors + + +def kp_connections(keypoints): + kp_lines = [ + [keypoints.index('left_eye'), keypoints.index('right_eye')], + [keypoints.index('left_eye'), keypoints.index('nose')], + [keypoints.index('right_eye'), keypoints.index('nose')], + [keypoints.index('right_eye'), keypoints.index('right_ear')], + [keypoints.index('left_eye'), keypoints.index('left_ear')], + [keypoints.index('right_shoulder'), + keypoints.index('right_elbow')], + [keypoints.index('right_elbow'), + keypoints.index('right_wrist')], + [keypoints.index('left_shoulder'), + keypoints.index('left_elbow')], + [keypoints.index('left_elbow'), + keypoints.index('left_wrist')], + [keypoints.index('right_hip'), keypoints.index('right_knee')], + [keypoints.index('right_knee'), + keypoints.index('right_ankle')], + [keypoints.index('left_hip'), keypoints.index('left_knee')], + [keypoints.index('left_knee'), keypoints.index('left_ankle')], + [keypoints.index('right_shoulder'), + keypoints.index('left_shoulder')], + [keypoints.index('right_hip'), keypoints.index('left_hip')], + ] + return kp_lines + + +def convert_from_cls_format(cls_boxes, cls_segms, cls_keyps): + """Convert from the class boxes/segms/keyps format generated by the testing + code. + """ + box_list = [b for b in cls_boxes if len(b) > 0] + if len(box_list) > 0: + boxes = np.concatenate(box_list) + else: + boxes = None + if cls_segms is not None: + segms = [s for slist in cls_segms for s in slist] + else: + segms = None + if cls_keyps is not None: + keyps = [k for klist in cls_keyps for k in klist] + else: + keyps = None + classes = [] + for j in range(len(cls_boxes)): + classes += [j] * len(cls_boxes[j]) + return boxes, segms, keyps, classes + + +def vis_bbox_opencv(img, bbox, thick=1): + """Visualizes a bounding box.""" + (x0, y0, w, h) = bbox + x1, y1 = int(x0 + w), int(y0 + h) + x0, y0 = int(x0), int(y0) + cv2.rectangle(img, (x0, y0), (x1, y1), _GREEN, thickness=thick) + return img + + +def get_class_string(class_index, score, dataset): + class_text = dataset.classes[class_index] if dataset is not None else \ + 'id{:d}'.format(class_index) + return class_text + ' {:0.2f}'.format(score).lstrip('0') + + +def vis_one_image( + im, + im_name, + output_dir, + boxes, + segms=None, + keypoints=None, + body_uv=None, + thresh=0.9, + kp_thresh=2, + dpi=200, + box_alpha=0.0, + dataset=None, + show_class=False, + ext='pdf' +): + """Visual debugging of detections.""" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if isinstance(boxes, list): + boxes, segms, keypoints, classes = convert_from_cls_format(boxes, segms, keypoints) + + if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh: + return + + if segms is not None: + masks = mask_util.decode(segms) + + color_list = colormap(rgb=True) / 255 + + dataset_keypoints, _ = get_keypoints() + + kp_lines = kp_connections(dataset_keypoints) + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)] + + fig = plt.figure(frameon=False) + fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.axis('off') + fig.add_axes(ax) + ax.imshow(im) + + # Display in largest to smallest order to reduce occlusion + areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + sorted_inds = np.argsort(-areas) + + mask_color_id = 0 + for i in sorted_inds: + bbox = boxes[i, :4] + score = boxes[i, -1] + if score < thresh: + continue + + print(dataset.classes[classes[i]], score) + # show box (off by default, box_alpha=0.0) + ax.add_patch( + plt.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + fill=False, + edgecolor='g', + linewidth=0.5, + alpha=box_alpha + ) + ) + + if show_class: + ax.text( + bbox[0], + bbox[1] - 2, + get_class_string(classes[i], score, dataset), + fontsize=3, + family='serif', + bbox=dict(facecolor='g', alpha=0.4, pad=0, edgecolor='none'), + color='white' + ) + + # show mask + if segms is not None and len(segms) > i: + img = np.ones(im.shape) + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + + w_ratio = .4 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio + for c in range(3): + img[:, :, c] = color_mask[c] + e = masks[:, :, i] + + _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + + for c in contour: + polygon = Polygon( + c.reshape((-1, 2)), + fill=True, + facecolor=color_mask, + edgecolor='w', + linewidth=1.2, + alpha=0.5 + ) + ax.add_patch(polygon) + + # show keypoints + if keypoints is not None and len(keypoints) > i: + kps = keypoints[i] + plt.autoscale(False) + for l in range(len(kp_lines)): + i1 = kp_lines[l][0] + i2 = kp_lines[l][1] + if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: + x = [kps[0, i1], kps[0, i2]] + y = [kps[1, i1], kps[1, i2]] + line = ax.plot(x, y) + plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7) + if kps[2, i1] > kp_thresh: + ax.plot(kps[0, i1], kps[1, i1], '.', color=colors[l], markersize=3.0, alpha=0.7) + if kps[2, i2] > kp_thresh: + ax.plot(kps[0, i2], kps[1, i2], '.', color=colors[l], markersize=3.0, alpha=0.7) + + # add mid shoulder / mid hip for better visualization + mid_shoulder = ( + kps[:2, dataset_keypoints.index('right_shoulder')] + + kps[:2, dataset_keypoints.index('left_shoulder')] + ) / 2.0 + sc_mid_shoulder = np.minimum( + kps[2, dataset_keypoints.index('right_shoulder')], + kps[2, dataset_keypoints.index('left_shoulder')] + ) + mid_hip = ( + kps[:2, dataset_keypoints.index('right_hip')] + + kps[:2, dataset_keypoints.index('left_hip')] + ) / 2.0 + sc_mid_hip = np.minimum( + kps[2, dataset_keypoints.index('right_hip')], + kps[2, dataset_keypoints.index('left_hip')] + ) + if ( + sc_mid_shoulder > kp_thresh and kps[2, dataset_keypoints.index('nose')] > kp_thresh + ): + x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]] + y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]] + line = ax.plot(x, y) + plt.setp(line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7) + if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh: + x = [mid_shoulder[0], mid_hip[0]] + y = [mid_shoulder[1], mid_hip[1]] + line = ax.plot(x, y) + plt.setp(line, color=colors[len(kp_lines) + 1], linewidth=1.0, alpha=0.7) + + # DensePose Visualization Starts!! + ## Get full IUV image out + if body_uv is not None and len(body_uv) > 1: + IUV_fields = body_uv[1] + # + All_Coords = np.zeros(im.shape) + All_inds = np.zeros([im.shape[0], im.shape[1]]) + K = 26 + ## + inds = np.argsort(boxes[:, 4]) + ## + for i, ind in enumerate(inds): + entry = boxes[ind, :] + if entry[4] > 0.65: + entry = entry[0:4].astype(int) + #### + output = IUV_fields[ind] + #### + All_Coords_Old = All_Coords[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2], :] + All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, + 0])[All_Coords_Old == 0] + All_Coords[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2], :] = All_Coords_Old + ### + CurrentMask = (output[0, :, :] > 0).astype(np.float32) + All_inds_old = All_inds[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2]] + All_inds_old[All_inds_old == 0] = CurrentMask[All_inds_old == 0] * i + All_inds[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2]] = All_inds_old + # + All_Coords[:, :, 1:3] = 255. * All_Coords[:, :, 1:3] + All_Coords[All_Coords > 255] = 255. + All_Coords = All_Coords.astype(np.uint8) + All_inds = All_inds.astype(np.uint8) + # + IUV_SaveName = os.path.basename(im_name).split('.')[0] + '_IUV.png' + INDS_SaveName = os.path.basename(im_name).split('.')[0] + '_INDS.png' + cv2.imwrite(os.path.join(output_dir, '{}'.format(IUV_SaveName)), All_Coords) + cv2.imwrite(os.path.join(output_dir, '{}'.format(INDS_SaveName)), All_inds) + print('IUV written to: ', os.path.join(output_dir, '{}'.format(IUV_SaveName))) + ### + ### DensePose Visualization Done!! + # + output_name = os.path.basename(im_name) + '.' + ext + fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi) + plt.close('all') + + # SMPL Visualization + if body_uv is not None and len(body_uv) > 2: + smpl_fields = body_uv[2] + # + All_Coords = np.zeros(im.shape) + # All_inds = np.zeros([im.shape[0], im.shape[1]]) + K = 26 + ## + inds = np.argsort(boxes[:, 4]) + ## + for i, ind in enumerate(inds): + entry = boxes[ind, :] + if entry[4] > 0.75: + entry = entry[0:4].astype(int) + center_roi = [(entry[2] + entry[0]) / 2., (entry[3] + entry[1]) / 2.] + #### + output, center_out = smpl_fields[ind] + #### + x1_img = max(int(center_roi[0] - center_out[0]), 0) + y1_img = max(int(center_roi[1] - center_out[1]), 0) + + x2_img = min(int(center_roi[0] - center_out[0]) + output.shape[2], im.shape[1]) + y2_img = min(int(center_roi[1] - center_out[1]) + output.shape[1], im.shape[0]) + + All_Coords_Old = All_Coords[y1_img:y2_img, x1_img:x2_img, :] + + x1_out = max(int(center_out[0] - center_roi[0]), 0) + y1_out = max(int(center_out[1] - center_roi[1]), 0) + + x2_out = x1_out + (x2_img - x1_img) + y2_out = y1_out + (y2_img - y1_img) + + output = output[:, y1_out:y2_out, x1_out:x2_out] + + # All_Coords_Old = All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2], + # :] + All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, + 0])[All_Coords_Old == 0] + All_Coords[y1_img:y2_img, x1_img:x2_img, :] = All_Coords_Old + ### + # CurrentMask = (output[0, :, :] > 0).astype(np.float32) + # All_inds_old = All_inds[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2]] + # All_inds_old[All_inds_old == 0] = CurrentMask[All_inds_old == 0] * i + # All_inds[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2]] = All_inds_old + # + All_Coords = 255. * All_Coords + All_Coords[All_Coords > 255] = 255. + All_Coords = All_Coords.astype(np.uint8) + + image_stacked = im[:, :, ::-1] + image_stacked[All_Coords > 20] = All_Coords[All_Coords > 20] + # All_inds = All_inds.astype(np.uint8) + # + SMPL_SaveName = os.path.basename(im_name).split('.')[0] + '_SMPL.png' + smpl_image_SaveName = os.path.basename(im_name).split('.')[0] + '_SMPLimg.png' + # INDS_SaveName = os.path.basename(im_name).split('.')[0] + '_INDS.png' + cv2.imwrite(os.path.join(output_dir, '{}'.format(SMPL_SaveName)), All_Coords) + cv2.imwrite(os.path.join(output_dir, '{}'.format(smpl_image_SaveName)), image_stacked) + # cv2.imwrite(os.path.join(output_dir, '{}'.format(INDS_SaveName)), All_inds) + print('SMPL written to: ', os.path.join(output_dir, '{}'.format(SMPL_SaveName))) + ### + ### SMPL Visualization Done!! + # + output_name = os.path.basename(im_name) + '.' + ext + fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi) + plt.close('all') + + +def vis_batch_image_with_joints( + batch_image, + batch_joints, + batch_joints_vis, + file_name=None, + nrow=8, + padding=0, + pad_value=1, + add_text=True +): + ''' + batch_image: [batch_size, channel, height, width] + batch_joints: [batch_size, num_joints, 3], + batch_joints_vis: [batch_size, num_joints, 1], + } + ''' + grid = torchvision.utils.make_grid(batch_image, nrow, padding, True, pad_value=pad_value) + ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() + ndarr = ndarr.copy() + + nmaps = batch_image.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height = int(batch_image.size(2) + padding) + width = int(batch_image.size(3) + padding) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + + joints = batch_joints[k] + joints_vis = batch_joints_vis[k] + + flip = 1 + count = -1 + + for joint, joint_vis in zip(joints, joints_vis): + joint[0] = x * width + padding + joint[0] + joint[1] = y * height + padding + joint[1] + flip *= -1 + count += 1 + if joint_vis[0]: + try: + if flip > 0: + cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 0, [255, 0, 0], -1) + else: + cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 0, [0, 255, 0], -1) + if add_text: + cv2.putText( + ndarr, str(count), (int(joint[0]), int(joint[1])), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1 + ) + except Exception as e: + print(e) + k = k + 1 + + return ndarr + + +def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None): + n_sample = joints.shape[0] + max_show = 2 + if n_sample > max_show: + if batch_img is not None: + batch_img = batch_img[:max_show] + joints = joints[:max_show] + n_sample = max_show + + color = ['#00B0F0', '#00B050', '#DC6464', '#207070', '#BC4484'] + + # color = ['g', 'b', 'r'] + + def m_l_r(idx): + + if joint_group is None: + return 1 + + for i in range(len(joint_group)): + if idx in joint_group[i]: + return i + + for i in range(n_sample): + if batch_img is not None: + # ax_img = plt.subplot(n_sample, 2, i * 2 + 1) + ax_img = plt.subplot(2, n_sample, i + 1) + img_np = batch_img[i].cpu().numpy() + img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C + ax_img.imshow(img_np) + ax_img.set_axis_off() + ax_pred = plt.subplot(2, n_sample, n_sample + i + 1, projection='3d') + + else: + ax_pred = plt.subplot(1, n_sample, i + 1, projection='3d') + + plot_kps = joints[i] + if plot_kps.shape[1] > 2: + if joint_group is None: + ax_pred.scatter(plot_kps[:, 2], plot_kps[:, 0], plot_kps[:, 1], s=10, marker='.') + ax_pred.scatter( + plot_kps[0, 2], plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.' + ) + else: + for j in range(len(joint_group)): + ax_pred.scatter( + plot_kps[joint_group[j], 2], + plot_kps[joint_group[j], 0], + plot_kps[joint_group[j], 1], + s=30, + c=color[j], + marker='s' + ) + + if pairs is not None: + for p in pairs: + ax_pred.plot( + plot_kps[p, 2], + plot_kps[p, 0], + plot_kps[p, 1], + c=color[m_l_r(p[1])], + linewidth=2 + ) + + # ax_pred.set_axis_off() + + ax_pred.set_aspect('equal') + set_axes_equal(ax_pred) + + ax_pred.xaxis.set_ticks([]) + ax_pred.yaxis.set_ticks([]) + ax_pred.zaxis.set_ticks([]) + + +def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None): + n_sample = joints.shape[0] + max_show = 2 + if n_sample > max_show: + if batch_img is not None: + batch_img = batch_img[:max_show] + joints = joints[:max_show] + n_sample = max_show + + color = ['#00B0F0', '#00B050', '#DC6464', '#207070', '#BC4484'] + + # color = ['g', 'b', 'r'] + + def m_l_r(idx): + + if joint_group is None: + return 1 + + for i in range(len(joint_group)): + if idx in joint_group[i]: + return i + + for i in range(n_sample): + if batch_img is not None: + # ax_img = plt.subplot(n_sample, 2, i * 2 + 1) + ax_img = plt.subplot(2, n_sample, i + 1) + img_np = batch_img[i].cpu().numpy() + img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C + ax_img.imshow(img_np) + ax_img.set_axis_off() + ax_pred = plt.subplot(2, n_sample, n_sample + i + 1) + + else: + ax_pred = plt.subplot(1, n_sample, i + 1) + + plot_kps = joints[i] + if plot_kps.shape[1] > 1: + if joint_group is None: + ax_pred.scatter(plot_kps[:, 0], plot_kps[:, 1], s=300, c='#00B0F0', marker='.') + # ax_pred.scatter(plot_kps[:, 0], plot_kps[:, 1], s=10, marker='.') + # ax_pred.scatter(plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.') + else: + for j in range(len(joint_group)): + ax_pred.scatter( + plot_kps[joint_group[j], 0], + plot_kps[joint_group[j], 1], + s=100, + c=color[j], + marker='o' + ) + + if pairs is not None: + for p in pairs: + ax_pred.plot( + plot_kps[p, 0], + plot_kps[p, 1], + c=color[m_l_r(p[1])], + linestyle=':', + linewidth=3 + ) + + ax_pred.set_axis_off() + + ax_pred.set_aspect('equal') + ax_pred.axis('equal') + # set_axes_equal(ax_pred) + + ax_pred.xaxis.set_ticks([]) + ax_pred.yaxis.set_ticks([]) + # ax_pred.zaxis.set_ticks([]) + + +def draw_skeleton(image, kp_2d, dataset='common', unnormalize=True, thickness=2): + + if unnormalize: + kp_2d[:, :2] = normalize_2d_kp(kp_2d[:, :2], 224, inv=True) + + kp_2d[:, 2] = kp_2d[:, 2] > 0.3 + kp_2d = np.array(kp_2d, dtype=int) + + rcolor = get_colors()['red'].tolist() + pcolor = get_colors()['green'].tolist() + lcolor = get_colors()['blue'].tolist() + + common_lr = [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0] + for idx, pt in enumerate(kp_2d): + if pt[2] > 0: # if visible + if idx % 2 == 0: + color = rcolor + else: + color = pcolor + cv2.circle(image, (pt[0], pt[1]), 4, color, -1) + # cv2.putText(image, f'{idx}', (pt[0]+1, pt[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0)) + + if dataset == 'common' and len(kp_2d) != 15: + return image + + skeleton = eval(f'kp_utils.get_{dataset}_skeleton')() + for i, (j1, j2) in enumerate(skeleton): + if kp_2d[j1, 2] > 0 and kp_2d[j2, 2] > 0: # if visible + if dataset == 'common': + color = rcolor if common_lr[i] == 0 else lcolor + else: + color = lcolor if i % 2 == 0 else rcolor + pt1, pt2 = (kp_2d[j1, 0], kp_2d[j1, 1]), (kp_2d[j2, 0], kp_2d[j2, 1]) + cv2.line(image, pt1=pt1, pt2=pt2, color=color, thickness=thickness) + + return image + + +# https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to +def set_axes_equal(ax): + '''Make axes of 3D plot have equal scale so that spheres appear as spheres, + cubes as cubes, etc.. This is one possible solution to Matplotlib's + ax.set_aspect('equal') and ax.axis('equal') not working for 3D. + + Input + ax: a matplotlib axis, e.g., as output from plt.gca(). + ''' + + x_limits = ax.get_xlim3d() + y_limits = ax.get_ylim3d() + z_limits = ax.get_zlim3d() + + x_range = abs(x_limits[1] - x_limits[0]) + x_middle = np.mean(x_limits) + y_range = abs(y_limits[1] - y_limits[0]) + y_middle = np.mean(y_limits) + z_range = abs(z_limits[1] - z_limits[0]) + z_middle = np.mean(z_limits) + + # The plot bounding box is a sphere in the sense of the infinity + # norm, hence I call half the max range the plot radius. + plot_radius = 0.5 * max([x_range, y_range, z_range]) + + ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) + ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) + ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) diff --git a/lib/renderer/__init__.py b/lib/renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/renderer/camera.py b/lib/renderer/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4bb15651f6007b0808825a80c4569c574f8861 --- /dev/null +++ b/lib/renderer/camera.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import cv2 +import numpy as np + +from .glm import ortho + + +class Camera: + + def __init__(self, width=1600, height=1200): + # Focal Length + # equivalent 50mm + focal = np.sqrt(width * width + height * height) + self.focal_x = focal + self.focal_y = focal + # Principal Point Offset + self.principal_x = width / 2 + self.principal_y = height / 2 + # Axis Skew + self.skew = 0 + # Image Size + self.width = width + self.height = height + + self.near = 1 + self.far = 10 + + # Camera Center + self.center = np.array([0, 0, 1.6]) + self.direction = np.array([0, 0, -1]) + self.right = np.array([1, 0, 0]) + self.up = np.array([0, 1, 0]) + + self.ortho_ratio = None + + def sanity_check(self): + self.center = self.center.reshape([-1]) + self.direction = self.direction.reshape([-1]) + self.right = self.right.reshape([-1]) + self.up = self.up.reshape([-1]) + + assert len(self.center) == 3 + assert len(self.direction) == 3 + assert len(self.right) == 3 + assert len(self.up) == 3 + + @staticmethod + def normalize_vector(v): + v_norm = np.linalg.norm(v) + return v if v_norm == 0 else v / v_norm + + def get_real_z_value(self, z): + z_near = self.near + z_far = self.far + z_n = 2.0 * z - 1.0 + z_e = 2.0 * z_near * z_far / (z_far + z_near - z_n * (z_far - z_near)) + return z_e + + def get_rotation_matrix(self): + rot_mat = np.eye(3) + s = self.right + s = self.normalize_vector(s) + rot_mat[0, :] = s + u = self.up + u = self.normalize_vector(u) + rot_mat[1, :] = -u + rot_mat[2, :] = self.normalize_vector(self.direction) + + return rot_mat + + def get_translation_vector(self): + rot_mat = self.get_rotation_matrix() + trans = -np.dot(rot_mat, self.center) + return trans + + def get_intrinsic_matrix(self): + int_mat = np.eye(3) + + int_mat[0, 0] = self.focal_x + int_mat[1, 1] = self.focal_y + int_mat[0, 1] = self.skew + int_mat[0, 2] = self.principal_x + int_mat[1, 2] = self.principal_y + + return int_mat + + def get_projection_matrix(self): + ext_mat = self.get_extrinsic_matrix() + int_mat = self.get_intrinsic_matrix() + + return np.matmul(int_mat, ext_mat) + + def get_extrinsic_matrix(self): + rot_mat = self.get_rotation_matrix() + int_mat = self.get_intrinsic_matrix() + trans = self.get_translation_vector() + + extrinsic = np.eye(4) + extrinsic[:3, :3] = rot_mat + extrinsic[:3, 3] = trans + + return extrinsic[:3, :] + + def set_rotation_matrix(self, rot_mat): + self.direction = rot_mat[2, :] + self.up = -rot_mat[1, :] + self.right = rot_mat[0, :] + + def set_intrinsic_matrix(self, int_mat): + self.focal_x = int_mat[0, 0] + self.focal_y = int_mat[1, 1] + self.skew = int_mat[0, 1] + self.principal_x = int_mat[0, 2] + self.principal_y = int_mat[1, 2] + + def set_projection_matrix(self, proj_mat): + res = cv2.decomposeProjectionMatrix(proj_mat) + int_mat, rot_mat, camera_center_homo = res[0], res[1], res[2] + camera_center = camera_center_homo[0:3] / camera_center_homo[3] + camera_center = camera_center.reshape(-1) + int_mat = int_mat / int_mat[2][2] + + self.set_intrinsic_matrix(int_mat) + self.set_rotation_matrix(rot_mat) + self.center = camera_center + + self.sanity_check() + + def get_gl_matrix(self): + z_near = self.near + z_far = self.far + rot_mat = self.get_rotation_matrix() + int_mat = self.get_intrinsic_matrix() + trans = self.get_translation_vector() + + extrinsic = np.eye(4) + extrinsic[:3, :3] = rot_mat + extrinsic[:3, 3] = trans + axis_adj = np.eye(4) + axis_adj[2, 2] = -1 + axis_adj[1, 1] = -1 + model_view = np.matmul(axis_adj, extrinsic) + + projective = np.zeros([4, 4]) + projective[:2, :2] = int_mat[:2, :2] + projective[:2, 2:3] = -int_mat[:2, 2:3] + projective[3, 2] = -1 + projective[2, 2] = (z_near + z_far) + projective[2, 3] = (z_near * z_far) + + if self.ortho_ratio is None: + ndc = ortho(0, self.width, 0, self.height, z_near, z_far) + perspective = np.matmul(ndc, projective) + else: + perspective = ortho(-self.width * self.ortho_ratio / 2, + self.width * self.ortho_ratio / 2, + -self.height * self.ortho_ratio / 2, + self.height * self.ortho_ratio / 2, z_near, + z_far) + + return perspective, model_view + + +def KRT_from_P(proj_mat, normalize_K=True): + res = cv2.decomposeProjectionMatrix(proj_mat) + K, Rot, camera_center_homog = res[0], res[1], res[2] + camera_center = camera_center_homog[0:3] / camera_center_homog[3] + trans = -Rot.dot(camera_center) + if normalize_K: + K = K / K[2][2] + return K, Rot, trans + + +def MVP_from_P(proj_mat, width, height, near=0.1, far=10000): + ''' + Convert OpenCV camera calibration matrix to OpenGL projection and model view matrix + :param proj_mat: OpenCV camera projeciton matrix + :param width: Image width + :param height: Image height + :param near: Z near value + :param far: Z far value + :return: OpenGL projection matrix and model view matrix + ''' + res = cv2.decomposeProjectionMatrix(proj_mat) + K, Rot, camera_center_homog = res[0], res[1], res[2] + camera_center = camera_center_homog[0:3] / camera_center_homog[3] + trans = -Rot.dot(camera_center) + K = K / K[2][2] + + extrinsic = np.eye(4) + extrinsic[:3, :3] = Rot + extrinsic[:3, 3:4] = trans + axis_adj = np.eye(4) + axis_adj[2, 2] = -1 + axis_adj[1, 1] = -1 + model_view = np.matmul(axis_adj, extrinsic) + + zFar = far + zNear = near + projective = np.zeros([4, 4]) + projective[:2, :2] = K[:2, :2] + projective[:2, 2:3] = -K[:2, 2:3] + projective[3, 2] = -1 + projective[2, 2] = (zNear + zFar) + projective[2, 3] = (zNear * zFar) + + ndc = ortho(0, width, 0, height, zNear, zFar) + + perspective = np.matmul(ndc, projective) + + return perspective, model_view diff --git a/lib/renderer/gl/__init__.py b/lib/renderer/gl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/renderer/gl/cam_render.py b/lib/renderer/gl/cam_render.py new file mode 100644 index 0000000000000000000000000000000000000000..a4db3c1a23b4773c4e9248c43b08da2fb75798c2 --- /dev/null +++ b/lib/renderer/gl/cam_render.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from .render import Render + +GLUT = None + + +class CamRender(Render): + + def __init__(self, + width=1600, + height=1200, + name='Cam Renderer', + program_files=['simple.fs', 'simple.vs'], + color_size=1, + ms_rate=1, + egl=False): + Render.__init__(self, + width, + height, + name, + program_files, + color_size, + ms_rate=ms_rate, + egl=egl) + self.camera = None + + if not egl: + global GLUT + import OpenGL.GLUT as GLUT + GLUT.glutDisplayFunc(self.display) + GLUT.glutKeyboardFunc(self.keyboard) + + def set_camera(self, camera): + self.camera = camera + self.projection_matrix, self.model_view_matrix = camera.get_gl_matrix() + + def keyboard(self, key, x, y): + # up + eps = 1 + # print(key) + if key == b'w': + self.camera.center += eps * self.camera.direction + elif key == b's': + self.camera.center -= eps * self.camera.direction + if key == b'a': + self.camera.center -= eps * self.camera.right + elif key == b'd': + self.camera.center += eps * self.camera.right + if key == b' ': + self.camera.center += eps * self.camera.up + elif key == b'x': + self.camera.center -= eps * self.camera.up + elif key == b'i': + self.camera.near += 0.1 * eps + self.camera.far += 0.1 * eps + elif key == b'o': + self.camera.near -= 0.1 * eps + self.camera.far -= 0.1 * eps + + self.projection_matrix, self.model_view_matrix = self.camera.get_gl_matrix( + ) + + def show(self): + if GLUT is not None: + GLUT.glutMainLoop() diff --git a/lib/renderer/gl/color_render.py b/lib/renderer/gl/color_render.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d9ca10a1cffdfbaa545314ca9210f15f31d808 --- /dev/null +++ b/lib/renderer/gl/color_render.py @@ -0,0 +1,168 @@ +''' +MIT License + +Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +import numpy as np +import random + +from .framework import * +from .cam_render import CamRender + + +class ColorRender(CamRender): + + def __init__(self, + width=1600, + height=1200, + name='Color Renderer', + egl=False): + program_files = ['color.vs', 'color.fs'] + CamRender.__init__(self, + width, + height, + name, + program_files=program_files, + color_size=3, + egl=egl) + + # WARNING: this differs from vertex_buffer and vertex_data in Render + self.vert_buffer = {} + self.vert_data = {} + + # normal + self.norm_buffer = {} + self.norm_data = {} + + self.color_buffer = {} + self.color_data = {} + + self.vertex_dim = {} + self.n_vertices = {} + + self.rot_mat_unif = glGetUniformLocation(self.program, 'RotMat') + self.rot_matrix = np.eye(3) + + self.norm_mat_unif = glGetUniformLocation(self.program, 'NormMat') + self.normalize_matrix = np.eye(4) + + def set_norm_mat(self, scale, center): + N = np.eye(4) + N[:3, :3] = scale * np.eye(3) + N[:3, 3] = -scale * center + + self.normalize_matrix = N + + def set_mesh(self, vertices, faces, color, normals, mat_name='all'): + + self.vert_data[mat_name] = vertices[faces.reshape([-1])] + self.n_vertices[mat_name] = self.vert_data[mat_name].shape[0] + self.vertex_dim[mat_name] = self.vert_data[mat_name].shape[1] + self.color_data[mat_name] = color[faces.reshape([-1])] + self.norm_data[mat_name] = normals[faces.reshape([-1])] + + if mat_name not in self.vert_buffer.keys(): + self.vert_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.vert_data[mat_name], GL_STATIC_DRAW) + + if mat_name not in self.color_buffer.keys(): + self.color_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.color_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.color_data[mat_name], + GL_STATIC_DRAW) + + if mat_name not in self.norm_buffer.keys(): + self.norm_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.norm_data[mat_name], GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def cleanup(self): + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + for key in self.vert_data: + glDeleteBuffers(1, [self.vert_buffer[key]]) + glDeleteBuffers(1, [self.color_buffer[key]]) + glDeleteBuffers(1, [self.norm_buffer[key]]) + + self.norm_buffer = {} + self.norm_data = {} + + self.vert_buffer = {} + self.vert_data = {} + + self.color_buffer = {} + self.color_data = {} + + self.render_texture_mat = {} + + self.vertex_dim = {} + self.n_vertices = {} + + def draw(self): + self.draw_init() + + glEnable(GL_MULTISAMPLE) + + glUseProgram(self.program) + glUniformMatrix4fv(self.norm_mat_unif, 1, GL_FALSE, + self.normalize_matrix.transpose()) + glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, + self.model_view_matrix.transpose()) + glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, + self.projection_matrix.transpose()) + glUniformMatrix3fv(self.rot_mat_unif, 1, GL_FALSE, + self.rot_matrix.transpose()) + + for mat in self.vert_buffer: + + # Handle vertex buffer + glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat]) + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, self.vertex_dim[mat], GL_DOUBLE, GL_FALSE, + 0, None) + + # Handle color buffer + glBindBuffer(GL_ARRAY_BUFFER, self.color_buffer[mat]) + glEnableVertexAttribArray(1) + glVertexAttribPointer(1, 3, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle normal buffer + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat]) + glEnableVertexAttribArray(2) + glVertexAttribPointer(2, 3, GL_DOUBLE, GL_FALSE, 0, None) + + glDrawArrays(GL_TRIANGLES, 0, self.n_vertices[mat]) + + glDisableVertexAttribArray(2) + glDisableVertexAttribArray(1) + glDisableVertexAttribArray(0) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + glUseProgram(0) + + glDisable(GL_MULTISAMPLE) + + self.draw_end() diff --git a/lib/renderer/gl/data/color.fs b/lib/renderer/gl/data/color.fs new file mode 100644 index 0000000000000000000000000000000000000000..96e904bdca2e813e7e7ae75e1cc5dda7b1d0b5be --- /dev/null +++ b/lib/renderer/gl/data/color.fs @@ -0,0 +1,20 @@ +#version 330 core + +layout (location = 0) out vec4 FragColor; +layout (location = 1) out vec4 FragNormal; +layout (location = 2) out vec4 FragDepth; + +in vec3 Color; +in vec3 CamNormal; +in vec3 depth; + + +void main() +{ + FragColor = vec4(Color,1.0); + + vec3 cam_norm_normalized = normalize(CamNormal); + vec3 rgb = (cam_norm_normalized + 1.0) / 2.0; + FragNormal = vec4(rgb, 1.0); + FragDepth = vec4(depth.xyz, 1.0); +} diff --git a/lib/renderer/gl/data/color.vs b/lib/renderer/gl/data/color.vs new file mode 100644 index 0000000000000000000000000000000000000000..1256f7eb3f3605b8f848d452fcf601d5a18b95e2 --- /dev/null +++ b/lib/renderer/gl/data/color.vs @@ -0,0 +1,29 @@ +#version 330 core + +layout (location = 0) in vec3 a_Position; +layout (location = 1) in vec3 a_Color; +layout (location = 2) in vec3 a_Normal; + +out vec3 CamNormal; +out vec3 CamPos; +out vec3 Color; +out vec3 depth; + + +uniform mat3 RotMat; +uniform mat4 NormMat; +uniform mat4 ModelMat; +uniform mat4 PerspMat; + +void main() +{ + vec3 a_Position = (NormMat * vec4(a_Position,1.0)).xyz; + gl_Position = PerspMat * ModelMat * vec4(RotMat * a_Position, 1.0); + Color = a_Color; + + mat3 R = mat3(ModelMat) * RotMat; + CamNormal = (R * a_Normal); + + depth = vec3(gl_Position.z / gl_Position.w); + +} \ No newline at end of file diff --git a/lib/renderer/gl/data/normal.fs b/lib/renderer/gl/data/normal.fs new file mode 100644 index 0000000000000000000000000000000000000000..9e2770952e27d9265ccb100833245beed3ebebe5 --- /dev/null +++ b/lib/renderer/gl/data/normal.fs @@ -0,0 +1,12 @@ +#version 330 + +out vec4 FragColor; + +in vec3 CamNormal; + +void main() +{ + vec3 cam_norm_normalized = normalize(CamNormal); + vec3 rgb = (cam_norm_normalized + 1.0) / 2.0; + FragColor = vec4(rgb, 1.0); +} \ No newline at end of file diff --git a/lib/renderer/gl/data/normal.vs b/lib/renderer/gl/data/normal.vs new file mode 100644 index 0000000000000000000000000000000000000000..a0f7f50b1cedfd677843b2a60cf9051b2134c347 --- /dev/null +++ b/lib/renderer/gl/data/normal.vs @@ -0,0 +1,15 @@ +#version 330 + +layout (location = 0) in vec3 Position; +layout (location = 1) in vec3 Normal; + +out vec3 CamNormal; + +uniform mat4 ModelMat; +uniform mat4 PerspMat; + +void main() +{ + gl_Position = PerspMat * ModelMat * vec4(Position, 1.0); + CamNormal = (ModelMat * vec4(Normal, 0.0)).xyz; +} \ No newline at end of file diff --git a/lib/renderer/gl/data/prt.fs b/lib/renderer/gl/data/prt.fs new file mode 100644 index 0000000000000000000000000000000000000000..3737e2a5b51d5b001cdb2f3d8157793caa98fd54 --- /dev/null +++ b/lib/renderer/gl/data/prt.fs @@ -0,0 +1,157 @@ +#version 330 + +uniform vec3 SHCoeffs[9]; +uniform uint analytic; + +uniform uint hasNormalMap; +uniform uint hasAlbedoMap; + +uniform sampler2D AlbedoMap; +uniform sampler2D NormalMap; + +in VertexData { + vec3 Position; + vec3 Depth; + vec3 ModelNormal; + vec2 Texcoord; + vec3 Tangent; + vec3 Bitangent; + vec3 PRT1; + vec3 PRT2; + vec3 PRT3; + vec3 Label; +} VertexIn; + +layout (location = 0) out vec4 FragColor; +layout (location = 1) out vec4 FragNormal; +layout (location = 2) out vec4 FragPosition; +layout (location = 3) out vec4 FragAlbedo; +layout (location = 4) out vec4 FragShading; +layout (location = 5) out vec4 FragPRT1; +layout (location = 6) out vec4 FragPRT2; +// layout (location = 7) out vec4 FragPRT3; +layout (location = 7) out vec4 FragLabel; + + +vec4 gammaCorrection(vec4 vec, float g) +{ + return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w); +} + +vec3 gammaCorrection(vec3 vec, float g) +{ + return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g)); +} + +void evaluateH(vec3 n, out float H[9]) +{ + float c1 = 0.429043, c2 = 0.511664, + c3 = 0.743125, c4 = 0.886227, c5 = 0.247708; + + H[0] = c4; + H[1] = 2.0 * c2 * n[1]; + H[2] = 2.0 * c2 * n[2]; + H[3] = 2.0 * c2 * n[0]; + H[4] = 2.0 * c1 * n[0] * n[1]; + H[5] = 2.0 * c1 * n[1] * n[2]; + H[6] = c3 * n[2] * n[2] - c5; + H[7] = 2.0 * c1 * n[2] * n[0]; + H[8] = c1 * (n[0] * n[0] - n[1] * n[1]); +} + +vec3 evaluateLightingModel(vec3 normal) +{ + float H[9]; + evaluateH(normal, H); + vec3 res = vec3(0.0); + for (int i = 0; i < 9; i++) { + res += H[i] * SHCoeffs[i]; + } + return res; +} + +// nC: coarse geometry normal, nH: fine normal from normal map +vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt) +{ + float HC[9], HH[9]; + evaluateH(nC, HC); + evaluateH(nH, HH); + + vec3 res = vec3(0.0); + vec3 shadow = vec3(0.0); + vec3 unshadow = vec3(0.0); + for(int i = 0; i < 3; ++i){ + for(int j = 0; j < 3; ++j){ + int id = i*3+j; + res += HH[id]* SHCoeffs[id]; + shadow += prt[i][j] * SHCoeffs[id]; + unshadow += HC[id] * SHCoeffs[id]; + } + } + vec3 ratio = clamp(shadow/unshadow,0.0,1.0); + res = ratio * res; + + return res; +} + +vec3 evaluateLightingModelPRT(mat3 prt) +{ + vec3 res = vec3(0.0); + for(int i = 0; i < 3; ++i){ + for(int j = 0; j < 3; ++j){ + res += prt[i][j] * SHCoeffs[i*3+j]; + } + } + + return res; +} + +void main() +{ + vec2 uv = VertexIn.Texcoord; + vec3 nC = normalize(VertexIn.ModelNormal); + vec3 nml = nC; + mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3); + + if(hasAlbedoMap == uint(0)) + FragAlbedo = vec4(1.0); + else + FragAlbedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2); + + if(hasNormalMap == uint(0)) + { + if(analytic == uint(0)) + FragShading = vec4(evaluateLightingModelPRT(prt), 1.0f); + else + FragShading = vec4(evaluateLightingModel(nC), 1.0f); + } + else + { + vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0)); + + mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC); + vec3 nH = normalize(TBN * n_tan); + + if(analytic == uint(0)) + FragShading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f); + else + FragShading = vec4(evaluateLightingModel(nH), 1.0f); + + nml = nH; + } + + FragShading = gammaCorrection(FragShading, 2.2); + FragColor = clamp(FragAlbedo * FragShading, 0.0, 1.0); + FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0); + FragPosition = vec4(VertexIn.Depth.xyz, 1.0); + FragShading = vec4(clamp(0.5*FragShading.xyz, 0.0, 1.0),1.0); + // FragColor = gammaCorrection(clamp(FragAlbedo * FragShading, 0.0, 1.0),2.2); + // FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0); + // FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x); + // FragShading = vec4(gammaCorrection(clamp(0.5*FragShading.xyz, 0.0, 1.0),2.2),1.0); + // FragAlbedo = gammaCorrection(FragAlbedo,2.2); + FragPRT1 = vec4(VertexIn.PRT1,1.0); + FragPRT2 = vec4(VertexIn.PRT2,1.0); + // FragPRT3 = vec4(VertexIn.PRT3,1.0); + FragLabel = vec4(VertexIn.Label,1.0); +} \ No newline at end of file diff --git a/lib/renderer/gl/data/prt.vs b/lib/renderer/gl/data/prt.vs new file mode 100644 index 0000000000000000000000000000000000000000..4bb55cc6fa74539fa48aed054cac120ac5cf7ec2 --- /dev/null +++ b/lib/renderer/gl/data/prt.vs @@ -0,0 +1,171 @@ +#version 330 + +layout (location = 0) in vec3 a_Position; +layout (location = 1) in vec3 a_Normal; +layout (location = 2) in vec2 a_TextureCoord; +layout (location = 3) in vec3 a_Tangent; +layout (location = 4) in vec3 a_Bitangent; +layout (location = 5) in vec3 a_PRT1; +layout (location = 6) in vec3 a_PRT2; +layout (location = 7) in vec3 a_PRT3; +layout (location = 8) in vec3 a_Label; + +out VertexData { + vec3 Position; + vec3 Depth; + vec3 ModelNormal; + vec2 Texcoord; + vec3 Tangent; + vec3 Bitangent; + vec3 PRT1; + vec3 PRT2; + vec3 PRT3; + vec3 Label; +} VertexOut; + +uniform mat3 RotMat; +uniform mat4 NormMat; +uniform mat4 ModelMat; +uniform mat4 PerspMat; + +float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi)) +float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi)) +float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi)) + +float s_c_scale = 1.0/0.91529123286551084; +float s_c_scale_inv = 0.91529123286551084; + +float s_rc2 = 1.5853309190550713*s_c_scale; +float s_c4_div_c3 = s_c4/s_c3; +float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0; + +float s_scale_dst2 = s_c3 * s_c_scale_inv; +float s_scale_dst4 = s_c5 * s_c_scale_inv; + +void OptRotateBand0(float x[1], mat3 R, out float dst[1]) +{ + dst[0] = x[0]; +} + +// 9 multiplies +void OptRotateBand1(float x[3], mat3 R, out float dst[3]) +{ + // derived from SlowRotateBand1 + dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2]; + dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2]; + dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2]; +} + +// 48 multiplies +void OptRotateBand2(float x[5], mat3 R, out float dst[5]) +{ + // Sparse matrix multiply + float sh0 = x[3] + x[4] + x[4] - x[1]; + float sh1 = x[0] + s_rc2*x[2] + x[3] + x[4]; + float sh2 = x[0]; + float sh3 = -x[3]; + float sh4 = -x[1]; + + // Rotations. R0 and R1 just use the raw matrix columns + float r2x = R[0][0] + R[0][1]; + float r2y = R[1][0] + R[1][1]; + float r2z = R[2][0] + R[2][1]; + + float r3x = R[0][0] + R[0][2]; + float r3y = R[1][0] + R[1][2]; + float r3z = R[2][0] + R[2][2]; + + float r4x = R[0][1] + R[0][2]; + float r4y = R[1][1] + R[1][2]; + float r4z = R[2][1] + R[2][2]; + + // dense matrix multiplication one column at a time + + // column 0 + float sh0_x = sh0 * R[0][0]; + float sh0_y = sh0 * R[1][0]; + float d0 = sh0_x * R[1][0]; + float d1 = sh0_y * R[2][0]; + float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3); + float d3 = sh0_x * R[2][0]; + float d4 = sh0_x * R[0][0] - sh0_y * R[1][0]; + + // column 1 + float sh1_x = sh1 * R[0][2]; + float sh1_y = sh1 * R[1][2]; + d0 += sh1_x * R[1][2]; + d1 += sh1_y * R[2][2]; + d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3); + d3 += sh1_x * R[2][2]; + d4 += sh1_x * R[0][2] - sh1_y * R[1][2]; + + // column 2 + float sh2_x = sh2 * r2x; + float sh2_y = sh2 * r2y; + d0 += sh2_x * r2y; + d1 += sh2_y * r2z; + d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2); + d3 += sh2_x * r2z; + d4 += sh2_x * r2x - sh2_y * r2y; + + // column 3 + float sh3_x = sh3 * r3x; + float sh3_y = sh3 * r3y; + d0 += sh3_x * r3y; + d1 += sh3_y * r3z; + d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2); + d3 += sh3_x * r3z; + d4 += sh3_x * r3x - sh3_y * r3y; + + // column 4 + float sh4_x = sh4 * r4x; + float sh4_y = sh4 * r4y; + d0 += sh4_x * r4y; + d1 += sh4_y * r4z; + d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2); + d3 += sh4_x * r4z; + d4 += sh4_x * r4x - sh4_y * r4y; + + // extra multipliers + dst[0] = d0; + dst[1] = -d1; + dst[2] = d2 * s_scale_dst2; + dst[3] = -d3; + dst[4] = d4 * s_scale_dst4; +} + +void main() +{ + // normalization + vec3 pos = (NormMat * vec4(a_Position,1.0)).xyz; + + mat3 R = mat3(ModelMat) * RotMat; + VertexOut.ModelNormal = (R * a_Normal); + VertexOut.Position = R * pos; + VertexOut.Texcoord = a_TextureCoord; + VertexOut.Tangent = (R * a_Tangent); + VertexOut.Bitangent = (R * a_Bitangent); + VertexOut.Label = a_Label; + + float PRT0, PRT1[3], PRT2[5]; + PRT0 = a_PRT1[0]; + PRT1[0] = a_PRT1[1]; + PRT1[1] = a_PRT1[2]; + PRT1[2] = a_PRT2[0]; + PRT2[0] = a_PRT2[1]; + PRT2[1] = a_PRT2[2]; + PRT2[2] = a_PRT3[0]; + PRT2[3] = a_PRT3[1]; + PRT2[4] = a_PRT3[2]; + + OptRotateBand1(PRT1, R, PRT1); + OptRotateBand2(PRT2, R, PRT2); + + VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]); + VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]); + VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]); + + gl_Position = PerspMat * ModelMat * vec4(RotMat * pos, 1.0); + + VertexOut.Depth = vec3(gl_Position.z / gl_Position.w); +} diff --git a/lib/renderer/gl/data/prt_uv.fs b/lib/renderer/gl/data/prt_uv.fs new file mode 100644 index 0000000000000000000000000000000000000000..6e90b25c62b41c8cf61afd29333372193047d5f1 --- /dev/null +++ b/lib/renderer/gl/data/prt_uv.fs @@ -0,0 +1,141 @@ +#version 330 + +uniform vec3 SHCoeffs[9]; +uniform uint analytic; + +uniform uint hasNormalMap; +uniform uint hasAlbedoMap; + +uniform sampler2D AlbedoMap; +uniform sampler2D NormalMap; + +in VertexData { + vec3 Position; + vec3 ModelNormal; + vec3 CameraNormal; + vec2 Texcoord; + vec3 Tangent; + vec3 Bitangent; + vec3 PRT1; + vec3 PRT2; + vec3 PRT3; +} VertexIn; + +layout (location = 0) out vec4 FragColor; +layout (location = 1) out vec4 FragPosition; +layout (location = 2) out vec4 FragNormal; + +vec4 gammaCorrection(vec4 vec, float g) +{ + return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w); +} + +vec3 gammaCorrection(vec3 vec, float g) +{ + return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g)); +} + +void evaluateH(vec3 n, out float H[9]) +{ + float c1 = 0.429043, c2 = 0.511664, + c3 = 0.743125, c4 = 0.886227, c5 = 0.247708; + + H[0] = c4; + H[1] = 2.0 * c2 * n[1]; + H[2] = 2.0 * c2 * n[2]; + H[3] = 2.0 * c2 * n[0]; + H[4] = 2.0 * c1 * n[0] * n[1]; + H[5] = 2.0 * c1 * n[1] * n[2]; + H[6] = c3 * n[2] * n[2] - c5; + H[7] = 2.0 * c1 * n[2] * n[0]; + H[8] = c1 * (n[0] * n[0] - n[1] * n[1]); +} + +vec3 evaluateLightingModel(vec3 normal) +{ + float H[9]; + evaluateH(normal, H); + vec3 res = vec3(0.0); + for (int i = 0; i < 9; i++) { + res += H[i] * SHCoeffs[i]; + } + return res; +} + +// nC: coarse geometry normal, nH: fine normal from normal map +vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt) +{ + float HC[9], HH[9]; + evaluateH(nC, HC); + evaluateH(nH, HH); + + vec3 res = vec3(0.0); + vec3 shadow = vec3(0.0); + vec3 unshadow = vec3(0.0); + for(int i = 0; i < 3; ++i){ + for(int j = 0; j < 3; ++j){ + int id = i*3+j; + res += HH[id]* SHCoeffs[id]; + shadow += prt[i][j] * SHCoeffs[id]; + unshadow += HC[id] * SHCoeffs[id]; + } + } + vec3 ratio = clamp(shadow/unshadow,0.0,1.0); + res = ratio * res; + + return res; +} + +vec3 evaluateLightingModelPRT(mat3 prt) +{ + vec3 res = vec3(0.0); + for(int i = 0; i < 3; ++i){ + for(int j = 0; j < 3; ++j){ + res += prt[i][j] * SHCoeffs[i*3+j]; + } + } + + return res; +} + +void main() +{ + vec2 uv = VertexIn.Texcoord; + vec3 nM = normalize(VertexIn.ModelNormal); + vec3 nC = normalize(VertexIn.CameraNormal); + vec3 nml = nC; + mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3); + + vec4 albedo, shading; + if(hasAlbedoMap == uint(0)) + albedo = vec4(1.0); + else + albedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2); + + if(hasNormalMap == uint(0)) + { + if(analytic == uint(0)) + shading = vec4(evaluateLightingModelPRT(prt), 1.0f); + else + shading = vec4(evaluateLightingModel(nC), 1.0f); + } + else + { + vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0)); + + mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC); + vec3 nH = normalize(TBN * n_tan); + + if(analytic == uint(0)) + shading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f); + else + shading = vec4(evaluateLightingModel(nH), 1.0f); + + nml = nH; + } + + shading = gammaCorrection(shading, 2.2); + FragColor = clamp(albedo * shading, 0.0, 1.0); + FragPosition = vec4(VertexIn.Position,1.0); + FragNormal = vec4(0.5*(nM+vec3(1.0)),1.0); +} \ No newline at end of file diff --git a/lib/renderer/gl/data/prt_uv.vs b/lib/renderer/gl/data/prt_uv.vs new file mode 100644 index 0000000000000000000000000000000000000000..22a03564bd95158c3fb9edf513c0717975b93ee0 --- /dev/null +++ b/lib/renderer/gl/data/prt_uv.vs @@ -0,0 +1,168 @@ +#version 330 + +layout (location = 0) in vec3 a_Position; +layout (location = 1) in vec3 a_Normal; +layout (location = 2) in vec2 a_TextureCoord; +layout (location = 3) in vec3 a_Tangent; +layout (location = 4) in vec3 a_Bitangent; +layout (location = 5) in vec3 a_PRT1; +layout (location = 6) in vec3 a_PRT2; +layout (location = 7) in vec3 a_PRT3; + +out VertexData { + vec3 Position; + vec3 ModelNormal; + vec3 CameraNormal; + vec2 Texcoord; + vec3 Tangent; + vec3 Bitangent; + vec3 PRT1; + vec3 PRT2; + vec3 PRT3; +} VertexOut; + +uniform mat3 RotMat; +uniform mat4 NormMat; +uniform mat4 ModelMat; +uniform mat4 PerspMat; + +#define pi 3.1415926535897932384626433832795 + +float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi)) +float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi)) +float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi)) + +float s_c_scale = 1.0/0.91529123286551084; +float s_c_scale_inv = 0.91529123286551084; + +float s_rc2 = 1.5853309190550713*s_c_scale; +float s_c4_div_c3 = s_c4/s_c3; +float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0; + +float s_scale_dst2 = s_c3 * s_c_scale_inv; +float s_scale_dst4 = s_c5 * s_c_scale_inv; + +void OptRotateBand0(float x[1], mat3 R, out float dst[1]) +{ + dst[0] = x[0]; +} + +// 9 multiplies +void OptRotateBand1(float x[3], mat3 R, out float dst[3]) +{ + // derived from SlowRotateBand1 + dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2]; + dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2]; + dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2]; +} + +// 48 multiplies +void OptRotateBand2(float x[5], mat3 R, out float dst[5]) +{ + // Sparse matrix multiply + float sh0 = x[3] + x[4] + x[4] - x[1]; + float sh1 = x[0] + s_rc2*x[2] + x[3] + x[4]; + float sh2 = x[0]; + float sh3 = -x[3]; + float sh4 = -x[1]; + + // Rotations. R0 and R1 just use the raw matrix columns + float r2x = R[0][0] + R[0][1]; + float r2y = R[1][0] + R[1][1]; + float r2z = R[2][0] + R[2][1]; + + float r3x = R[0][0] + R[0][2]; + float r3y = R[1][0] + R[1][2]; + float r3z = R[2][0] + R[2][2]; + + float r4x = R[0][1] + R[0][2]; + float r4y = R[1][1] + R[1][2]; + float r4z = R[2][1] + R[2][2]; + + // dense matrix multiplication one column at a time + + // column 0 + float sh0_x = sh0 * R[0][0]; + float sh0_y = sh0 * R[1][0]; + float d0 = sh0_x * R[1][0]; + float d1 = sh0_y * R[2][0]; + float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3); + float d3 = sh0_x * R[2][0]; + float d4 = sh0_x * R[0][0] - sh0_y * R[1][0]; + + // column 1 + float sh1_x = sh1 * R[0][2]; + float sh1_y = sh1 * R[1][2]; + d0 += sh1_x * R[1][2]; + d1 += sh1_y * R[2][2]; + d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3); + d3 += sh1_x * R[2][2]; + d4 += sh1_x * R[0][2] - sh1_y * R[1][2]; + + // column 2 + float sh2_x = sh2 * r2x; + float sh2_y = sh2 * r2y; + d0 += sh2_x * r2y; + d1 += sh2_y * r2z; + d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2); + d3 += sh2_x * r2z; + d4 += sh2_x * r2x - sh2_y * r2y; + + // column 3 + float sh3_x = sh3 * r3x; + float sh3_y = sh3 * r3y; + d0 += sh3_x * r3y; + d1 += sh3_y * r3z; + d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2); + d3 += sh3_x * r3z; + d4 += sh3_x * r3x - sh3_y * r3y; + + // column 4 + float sh4_x = sh4 * r4x; + float sh4_y = sh4 * r4y; + d0 += sh4_x * r4y; + d1 += sh4_y * r4z; + d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2); + d3 += sh4_x * r4z; + d4 += sh4_x * r4x - sh4_y * r4y; + + // extra multipliers + dst[0] = d0; + dst[1] = -d1; + dst[2] = d2 * s_scale_dst2; + dst[3] = -d3; + dst[4] = d4 * s_scale_dst4; +} + +void main() +{ + // normalization + mat3 R = mat3(ModelMat) * RotMat; + VertexOut.ModelNormal = a_Normal; + VertexOut.CameraNormal = (R * a_Normal); + VertexOut.Position = a_Position; + VertexOut.Texcoord = a_TextureCoord; + VertexOut.Tangent = (R * a_Tangent); + VertexOut.Bitangent = (R * a_Bitangent); + float PRT0, PRT1[3], PRT2[5]; + PRT0 = a_PRT1[0]; + PRT1[0] = a_PRT1[1]; + PRT1[1] = a_PRT1[2]; + PRT1[2] = a_PRT2[0]; + PRT2[0] = a_PRT2[1]; + PRT2[1] = a_PRT2[2]; + PRT2[2] = a_PRT3[0]; + PRT2[3] = a_PRT3[1]; + PRT2[4] = a_PRT3[2]; + + OptRotateBand1(PRT1, R, PRT1); + OptRotateBand2(PRT2, R, PRT2); + + VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]); + VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]); + VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]); + + gl_Position = vec4(a_TextureCoord, 0.0, 1.0) - vec4(0.5, 0.5, 0, 0); + gl_Position[0] *= 2.0; + gl_Position[1] *= 2.0; +} diff --git a/lib/renderer/gl/data/quad.fs b/lib/renderer/gl/data/quad.fs new file mode 100644 index 0000000000000000000000000000000000000000..f43502f2352ca2adf19d11e809946b51498df5a5 --- /dev/null +++ b/lib/renderer/gl/data/quad.fs @@ -0,0 +1,11 @@ +#version 330 core +out vec4 FragColor; + +in vec2 TexCoord; + +uniform sampler2D screenTexture; + +void main() +{ + FragColor = texture(screenTexture, TexCoord); +} \ No newline at end of file diff --git a/lib/renderer/gl/data/quad.vs b/lib/renderer/gl/data/quad.vs new file mode 100644 index 0000000000000000000000000000000000000000..811044631a1f29f5b45c490b2d40297f3127b6ea --- /dev/null +++ b/lib/renderer/gl/data/quad.vs @@ -0,0 +1,11 @@ +#version 330 core +layout (location = 0) in vec2 aPos; +layout (location = 1) in vec2 aTexCoord; + +out vec2 TexCoord; + +void main() +{ + gl_Position = vec4(aPos.x, aPos.y, 0.0, 1.0); + TexCoord = aTexCoord; +} \ No newline at end of file diff --git a/lib/renderer/gl/framework.py b/lib/renderer/gl/framework.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae6ad0d09e475e89b38633d457c4b3b68c881c7 --- /dev/null +++ b/lib/renderer/gl/framework.py @@ -0,0 +1,95 @@ +# Mario Rosasco, 2016 +# adapted from framework.cpp, Copyright (C) 2010-2012 by Jason L. McKesson +# This file is licensed under the MIT License. +# +# NB: Unlike in the framework.cpp organization, the main loop is contained +# in the tutorial files, not in this framework file. Additionally, a copy of +# this module file must exist in the same directory as the tutorial files +# to be imported properly. + +import os +from OpenGL.GL import * + + +# Function that creates and compiles shaders according to the given type (a GL enum value) and +# shader program (a file containing a GLSL program). +def loadShader(shaderType, shaderFile): + # check if file exists, get full path name + strFilename = findFileOrThrow(shaderFile) + shaderData = None + with open(strFilename, 'r') as f: + shaderData = f.read() + + shader = glCreateShader(shaderType) + glShaderSource( + shader, + shaderData) # note that this is a simpler function call than in C + + # This shader compilation is more explicit than the one used in + # framework.cpp, which relies on a glutil wrapper function. + # This is made explicit here mainly to decrease dependence on pyOpenGL + # utilities and wrappers, which docs caution may change in future versions. + glCompileShader(shader) + + status = glGetShaderiv(shader, GL_COMPILE_STATUS) + if status == GL_FALSE: + # Note that getting the error log is much simpler in Python than in C/C++ + # and does not require explicit handling of the string buffer + strInfoLog = glGetShaderInfoLog(shader) + strShaderType = "" + if shaderType is GL_VERTEX_SHADER: + strShaderType = "vertex" + elif shaderType is GL_GEOMETRY_SHADER: + strShaderType = "geometry" + elif shaderType is GL_FRAGMENT_SHADER: + strShaderType = "fragment" + + print("Compilation failure for " + strShaderType + " shader:\n" + + str(strInfoLog)) + + return shader + + +# Function that accepts a list of shaders, compiles them, and returns a handle to the compiled program +def createProgram(shaderList): + program = glCreateProgram() + + for shader in shaderList: + glAttachShader(program, shader) + + glLinkProgram(program) + + status = glGetProgramiv(program, GL_LINK_STATUS) + if status == GL_FALSE: + # Note that getting the error log is much simpler in Python than in C/C++ + # and does not require explicit handling of the string buffer + strInfoLog = glGetProgramInfoLog(program) + print("Linker failure: \n" + str(strInfoLog)) + + for shader in shaderList: + glDetachShader(program, shader) + + return program + + +# Helper function to locate and open the target file (passed in as a string). +# Returns the full path to the file as a string. +def findFileOrThrow(strBasename): + # Keep constant names in C-style convention, for readability + # when comparing to C(/C++) code. + if os.path.isfile(strBasename): + return strBasename + + LOCAL_FILE_DIR = "data" + os.sep + GLOBAL_FILE_DIR = os.path.dirname( + os.path.abspath(__file__)) + os.sep + "data" + os.sep + + strFilename = LOCAL_FILE_DIR + strBasename + if os.path.isfile(strFilename): + return strFilename + + strFilename = GLOBAL_FILE_DIR + strBasename + if os.path.isfile(strFilename): + return strFilename + + raise IOError('Could not find target file ' + strBasename) diff --git a/lib/renderer/gl/glcontext.py b/lib/renderer/gl/glcontext.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef6834aee12f29457310f3aa528b717d0095480 --- /dev/null +++ b/lib/renderer/gl/glcontext.py @@ -0,0 +1,136 @@ +"""Headless GPU-accelerated OpenGL context creation on Google Colaboratory. + +Typical usage: + + # Optional PyOpenGL configuratiopn can be done here. + # import OpenGL + # OpenGL.ERROR_CHECKING = True + + # 'glcontext' must be imported before any OpenGL.* API. + from lucid.misc.gl.glcontext import create_opengl_context + + # Now it's safe to import OpenGL and EGL functions + import OpenGL.GL as gl + + # create_opengl_context() creates a GL context that is attached to an + # offscreen surface of the specified size. Note that rendering to buffers + # of other sizes and formats is still possible with OpenGL Framebuffers. + # + # Users are expected to directly use the EGL API in case more advanced + # context management is required. + width, height = 640, 480 + create_opengl_context((width, height)) + + # OpenGL context is available here. + +""" + +from __future__ import print_function + +# pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports + +try: + import OpenGL +except: + print('This module depends on PyOpenGL.') + print('Please run "\033[1m!pip install -q pyopengl\033[0m" ' + 'prior importing this module.') + raise + +import ctypes +from ctypes import pointer, util +import os + +os.environ['PYOPENGL_PLATFORM'] = 'egl' + +# OpenGL loading workaround. +# +# * PyOpenGL tries to load libGL, but we need libOpenGL, see [1,2]. +# This could have been solved by a symlink libGL->libOpenGL, but: +# +# * Python 2.7 can't find libGL and linEGL due to a bug (see [3]) +# in ctypes.util, that was only wixed in Python 3.6. +# +# So, the only solution I've found is to monkeypatch ctypes.util +# [1] https://devblogs.nvidia.com/egl-eye-opengl-visualization-without-x-server/ +# [2] https://devblogs.nvidia.com/linking-opengl-server-side-rendering/ +# [3] https://bugs.python.org/issue9998 +_find_library_old = ctypes.util.find_library +try: + + def _find_library_new(name): + return { + 'GL': 'libOpenGL.so', + 'EGL': 'libEGL.so', + }.get(name, _find_library_old(name)) + + util.find_library = _find_library_new + import OpenGL.GL as gl + import OpenGL.EGL as egl +except: + print('Unable to load OpenGL libraries. ' + 'Make sure you use GPU-enabled backend.') + print('Press "Runtime->Change runtime type" and set ' + '"Hardware accelerator" to GPU.') + raise +finally: + util.find_library = _find_library_old + + +def create_opengl_context(surface_size=(640, 480)): + """Create offscreen OpenGL context and make it current. + + Users are expected to directly use EGL API in case more advanced + context management is required. + + Args: + surface_size: (width, height), size of the offscreen rendering surface. + """ + egl_display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) + + major, minor = egl.EGLint(), egl.EGLint() + egl.eglInitialize(egl_display, pointer(major), pointer(minor)) + + config_attribs = [ + egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, egl.EGL_BLUE_SIZE, 8, + egl.EGL_GREEN_SIZE, 8, egl.EGL_RED_SIZE, 8, egl.EGL_DEPTH_SIZE, 24, + egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, egl.EGL_NONE + ] + config_attribs = (egl.EGLint * len(config_attribs))(*config_attribs) + + num_configs = egl.EGLint() + egl_cfg = egl.EGLConfig() + egl.eglChooseConfig(egl_display, config_attribs, pointer(egl_cfg), 1, + pointer(num_configs)) + + width, height = surface_size + pbuffer_attribs = [ + egl.EGL_WIDTH, + width, + egl.EGL_HEIGHT, + height, + egl.EGL_NONE, + ] + pbuffer_attribs = (egl.EGLint * len(pbuffer_attribs))(*pbuffer_attribs) + egl_surf = egl.eglCreatePbufferSurface(egl_display, egl_cfg, + pbuffer_attribs) + + egl.eglBindAPI(egl.EGL_OPENGL_API) + + context_attribs = None + # context_attribs = [ + # egl.EGL_CONTEXT_MAJOR_VERSION, + # 4, + # egl.EGL_CONTEXT_MINOR_VERSION, + # 1, + # egl.EGL_NONE, + # ] + + egl_context = egl.eglCreateContext(egl_display, egl_cfg, + egl.EGL_NO_CONTEXT, context_attribs) + egl.eglMakeCurrent(egl_display, egl_surf, egl_surf, egl_context) + + buffer_type = egl.EGLint() + out = egl.eglQueryContext(egl_display, egl_context, + egl.EGL_CONTEXT_CLIENT_VERSION, buffer_type) + # print(buffer_type) diff --git a/lib/renderer/gl/init_gl.py b/lib/renderer/gl/init_gl.py new file mode 100644 index 0000000000000000000000000000000000000000..613c1d034733cab395617e61c80fc3cc716e3759 --- /dev/null +++ b/lib/renderer/gl/init_gl.py @@ -0,0 +1,24 @@ +_glut_window = None +_context_inited = None + + +def initialize_GL_context(width=512, height=512, egl=False): + ''' + default context uses GLUT + ''' + if not egl: + import OpenGL.GLUT as GLUT + display_mode = GLUT.GLUT_DOUBLE | GLUT.GLUT_RGB | GLUT.GLUT_DEPTH + global _glut_window + if _glut_window is None: + GLUT.glutInit() + GLUT.glutInitDisplayMode(display_mode) + GLUT.glutInitWindowSize(width, height) + GLUT.glutInitWindowPosition(0, 0) + _glut_window = GLUT.glutCreateWindow("My Render.") + else: + from .glcontext import create_opengl_context + global _context_inited + if _context_inited is None: + create_opengl_context((width, height)) + _context_inited = True diff --git a/lib/renderer/gl/norm_render.py b/lib/renderer/gl/norm_render.py new file mode 100644 index 0000000000000000000000000000000000000000..02c9f384e3637a11d50539c8bf4aaee4b22442eb --- /dev/null +++ b/lib/renderer/gl/norm_render.py @@ -0,0 +1,80 @@ +''' +MIT License + +Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from OpenGL.GLUT import * + +from .render2 import Render + + +class NormRender(Render): + + def __init__(self, + width=1600, + height=1200, + name='Cam Renderer', + program_files=['simple.fs', 'simple.vs'], + color_size=1, + ms_rate=1): + Render.__init__(self, width, height, name, program_files, color_size, + ms_rate) + self.camera = None + + glutDisplayFunc(self.display) + glutKeyboardFunc(self.keyboard) + + def set_camera(self, camera): + self.camera = camera + self.projection_matrix, self.model_view_matrix = camera.get_gl_matrix() + + def set_matrices(self, projection, modelview): + self.projection_matrix = projection + self.model_view_matrix = modelview + + def keyboard(self, key, x, y): + # up + eps = 1 + # print(key) + if key == b'w': + self.camera.center += eps * self.camera.direction + elif key == b's': + self.camera.center -= eps * self.camera.direction + if key == b'a': + self.camera.center -= eps * self.camera.right + elif key == b'd': + self.camera.center += eps * self.camera.right + if key == b' ': + self.camera.center += eps * self.camera.up + elif key == b'x': + self.camera.center -= eps * self.camera.up + elif key == b'i': + self.camera.near += 0.1 * eps + self.camera.far += 0.1 * eps + elif key == b'o': + self.camera.near -= 0.1 * eps + self.camera.far -= 0.1 * eps + + self.projection_matrix, self.model_view_matrix = self.camera.get_gl_matrix( + ) + + def show(self): + glutMainLoop() diff --git a/lib/renderer/gl/normal_render.py b/lib/renderer/gl/normal_render.py new file mode 100644 index 0000000000000000000000000000000000000000..efb5c4f77d4bdcfed0b50c78587d643fda2da9ed --- /dev/null +++ b/lib/renderer/gl/normal_render.py @@ -0,0 +1,98 @@ +''' +MIT License + +Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +import numpy as np +import math + +from .framework import * +from .norm_render import NormRender + + +class NormalRender(NormRender): + + def __init__(self, width=1600, height=1200, name='Normal Renderer'): + NormRender.__init__(self, + width, + height, + name, + program_files=['normal.vs', 'normal.fs']) + + self.norm_buffer = glGenBuffers(1) + + self.norm_data = None + + def set_normal_mesh(self, vertices, faces, norms, face_normals): + NormRender.set_mesh(self, vertices, faces) + + self.norm_data = norms[face_normals.reshape([-1])] + + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer) + glBufferData(GL_ARRAY_BUFFER, self.norm_data, GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def euler_to_rot_mat(self, r_x, r_y, r_z): + R_x = np.array([[1, 0, 0], [0, math.cos(r_x), -math.sin(r_x)], + [0, math.sin(r_x), math.cos(r_x)]]) + + R_y = np.array([[math.cos(r_y), 0, math.sin(r_y)], [0, 1, 0], + [-math.sin(r_y), 0, math.cos(r_y)]]) + + R_z = np.array([[math.cos(r_z), -math.sin(r_z), 0], + [math.sin(r_z), math.cos(r_z), 0], [0, 0, 1]]) + + R = np.dot(R_z, np.dot(R_y, R_x)) + + return R + + def draw(self): + self.draw_init() + + glUseProgram(self.program) + glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, + self.model_view_matrix.transpose()) + glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, + self.projection_matrix.transpose()) + + # Handle vertex buffer + glBindBuffer(GL_ARRAY_BUFFER, self.vertex_buffer) + + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, self.vertex_dim, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle normal buffer + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer) + + glEnableVertexAttribArray(1) + glVertexAttribPointer(1, 3, GL_DOUBLE, GL_FALSE, 0, None) + + glDrawArrays(GL_TRIANGLES, 0, self.n_vertices) + + glDisableVertexAttribArray(1) + glDisableVertexAttribArray(0) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + glUseProgram(0) + + self.draw_end() diff --git a/lib/renderer/gl/prt_render.py b/lib/renderer/gl/prt_render.py new file mode 100644 index 0000000000000000000000000000000000000000..245353f512237e036fb152aadbafb1438b8d79d1 --- /dev/null +++ b/lib/renderer/gl/prt_render.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np +import random + +from .framework import * +from .cam_render import CamRender + + +class PRTRender(CamRender): + + def __init__(self, + width=1600, + height=1200, + name='PRT Renderer', + uv_mode=False, + ms_rate=1, + egl=False): + program_files = ['prt.vs', 'prt.fs' + ] if not uv_mode else ['prt_uv.vs', 'prt_uv.fs'] + CamRender.__init__(self, + width, + height, + name, + program_files=program_files, + color_size=8, + ms_rate=ms_rate, + egl=egl) + + # WARNING: this differs from vertex_buffer and vertex_data in Render + self.vert_buffer = {} + self.vert_data = {} + + self.vert_label_buffer = {} + self.vert_label_data = {} + + self.norm_buffer = {} + self.norm_data = {} + + self.tan_buffer = {} + self.tan_data = {} + + self.btan_buffer = {} + self.btan_data = {} + + self.prt1_buffer = {} + self.prt1_data = {} + + self.prt2_buffer = {} + self.prt2_data = {} + + self.prt3_buffer = {} + self.prt3_data = {} + + self.uv_buffer = {} + self.uv_data = {} + + self.render_texture_mat = {} + + self.vertex_dim = {} + self.n_vertices = {} + self.label_dim = {} + + self.norm_mat_unif = glGetUniformLocation(self.program, 'NormMat') + self.normalize_matrix = np.eye(4) + + self.shcoeff_unif = glGetUniformLocation(self.program, 'SHCoeffs') + self.shcoeffs = np.zeros((9, 3)) + self.shcoeffs[0, :] = 1.0 + #self.shcoeffs[1:,:] = np.random.rand(8,3) + + self.hasAlbedoUnif = glGetUniformLocation(self.program, 'hasAlbedoMap') + self.hasNormalUnif = glGetUniformLocation(self.program, 'hasNormalMap') + + self.analyticUnif = glGetUniformLocation(self.program, 'analytic') + self.analytic = False + + self.rot_mat_unif = glGetUniformLocation(self.program, 'RotMat') + self.rot_matrix = np.eye(3) + + def set_texture(self, mat_name, smplr_name, texture): + # texture_image: H x W x 3 + width = texture.shape[1] + height = texture.shape[0] + texture = np.flip(texture, 0) + img_data = np.fromstring(texture.tostring(), np.uint8) + + if mat_name not in self.render_texture_mat: + self.render_texture_mat[mat_name] = {} + if smplr_name in self.render_texture_mat[mat_name].keys(): + glDeleteTextures([self.render_texture_mat[mat_name][smplr_name]]) + del self.render_texture_mat[mat_name][smplr_name] + + self.render_texture_mat[mat_name][smplr_name] = glGenTextures(1) + glActiveTexture(GL_TEXTURE0) + + glPixelStorei(GL_UNPACK_ALIGNMENT, 1) + glBindTexture(GL_TEXTURE_2D, + self.render_texture_mat[mat_name][smplr_name]) + + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, width, height, 0, GL_RGB, + GL_UNSIGNED_BYTE, img_data) + + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAX_LEVEL, 3) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_LINEAR_MIPMAP_LINEAR) + + glGenerateMipmap(GL_TEXTURE_2D) + + def set_albedo(self, texture_image, mat_name='all'): + self.set_texture(mat_name, 'AlbedoMap', texture_image) + + def set_normal_map(self, texture_image, mat_name='all'): + self.set_texture(mat_name, 'NormalMap', texture_image) + + def set_mesh(self, + vertices, + faces, + norms, + faces_nml, + uvs, + faces_uvs, + prt, + faces_prt, + tans, + bitans, + verts_label=None, + mat_name='all'): + + self.vert_data[mat_name] = vertices[faces.reshape([-1])] + self.vert_label_data[mat_name] = verts_label[faces.reshape([-1])] + self.n_vertices[mat_name] = self.vert_data[mat_name].shape[0] + self.vertex_dim[mat_name] = self.vert_data[mat_name].shape[1] + self.label_dim[mat_name] = self.vert_label_data[mat_name].shape[1] + + if mat_name not in self.vert_buffer.keys(): + self.vert_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.vert_data[mat_name], GL_STATIC_DRAW) + + if mat_name not in self.vert_label_buffer.keys(): + self.vert_label_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.vert_label_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.vert_label_data[mat_name], + GL_STATIC_DRAW) + + self.uv_data[mat_name] = uvs[faces_uvs.reshape([-1])] + if mat_name not in self.uv_buffer.keys(): + self.uv_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.uv_data[mat_name], GL_STATIC_DRAW) + + self.norm_data[mat_name] = norms[faces_nml.reshape([-1])] + if mat_name not in self.norm_buffer.keys(): + self.norm_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.norm_data[mat_name], GL_STATIC_DRAW) + + self.tan_data[mat_name] = tans[faces_nml.reshape([-1])] + if mat_name not in self.tan_buffer.keys(): + self.tan_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.tan_data[mat_name], GL_STATIC_DRAW) + + self.btan_data[mat_name] = bitans[faces_nml.reshape([-1])] + if mat_name not in self.btan_buffer.keys(): + self.btan_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.btan_data[mat_name], GL_STATIC_DRAW) + + self.prt1_data[mat_name] = prt[faces_prt.reshape([-1])][:, :3] + self.prt2_data[mat_name] = prt[faces_prt.reshape([-1])][:, 3:6] + self.prt3_data[mat_name] = prt[faces_prt.reshape([-1])][:, 6:] + + if mat_name not in self.prt1_buffer.keys(): + self.prt1_buffer[mat_name] = glGenBuffers(1) + if mat_name not in self.prt2_buffer.keys(): + self.prt2_buffer[mat_name] = glGenBuffers(1) + if mat_name not in self.prt3_buffer.keys(): + self.prt3_buffer[mat_name] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.prt1_data[mat_name], GL_STATIC_DRAW) + glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.prt2_data[mat_name], GL_STATIC_DRAW) + glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[mat_name]) + glBufferData(GL_ARRAY_BUFFER, self.prt3_data[mat_name], GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def set_mesh_mtl(self, + vertices, + faces, + norms, + faces_nml, + uvs, + faces_uvs, + tans, + bitans, + prt, + verts_label=None): + for key in faces: + self.vert_data[key] = vertices[faces[key].reshape([-1])] + self.vert_label_data[key] = verts_label[faces[key].reshape([-1])] + self.n_vertices[key] = self.vert_data[key].shape[0] + self.vertex_dim[key] = self.vert_data[key].shape[1] + self.label_dim[key] = self.vert_label_data[key].shape[1] + + if key not in self.vert_buffer.keys(): + self.vert_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.vert_data[key], GL_STATIC_DRAW) + + if key not in self.vert_label_buffer.keys(): + self.vert_label_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.vert_label_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.vert_label_data[key], + GL_STATIC_DRAW) + + self.uv_data[key] = uvs[faces_uvs[key].reshape([-1])] + if key not in self.uv_buffer.keys(): + self.uv_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.uv_data[key], GL_STATIC_DRAW) + + self.norm_data[key] = norms[faces_nml[key].reshape([-1])] + if key not in self.norm_buffer.keys(): + self.norm_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.norm_data[key], GL_STATIC_DRAW) + + self.tan_data[key] = tans[faces_nml[key].reshape([-1])] + if key not in self.tan_buffer.keys(): + self.tan_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.tan_data[key], GL_STATIC_DRAW) + + self.btan_data[key] = bitans[faces_nml[key].reshape([-1])] + if key not in self.btan_buffer.keys(): + self.btan_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.btan_data[key], GL_STATIC_DRAW) + + self.prt1_data[key] = prt[faces[key].reshape([-1])][:, :3] + self.prt2_data[key] = prt[faces[key].reshape([-1])][:, 3:6] + self.prt3_data[key] = prt[faces[key].reshape([-1])][:, 6:] + + if key not in self.prt1_buffer.keys(): + self.prt1_buffer[key] = glGenBuffers(1) + if key not in self.prt2_buffer.keys(): + self.prt2_buffer[key] = glGenBuffers(1) + if key not in self.prt3_buffer.keys(): + self.prt3_buffer[key] = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.prt1_data[key], GL_STATIC_DRAW) + glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.prt2_data[key], GL_STATIC_DRAW) + glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[key]) + glBufferData(GL_ARRAY_BUFFER, self.prt3_data[key], GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def cleanup(self): + + glBindBuffer(GL_ARRAY_BUFFER, 0) + for key in self.vert_data: + glDeleteBuffers(1, [self.vert_buffer[key]]) + glDeleteBuffers(1, [self.norm_buffer[key]]) + glDeleteBuffers(1, [self.uv_buffer[key]]) + glDeleteBuffers(1, [self.vert_label_buffer[key]]) + + glDeleteBuffers(1, [self.tan_buffer[key]]) + glDeleteBuffers(1, [self.btan_buffer[key]]) + glDeleteBuffers(1, [self.prt1_buffer[key]]) + glDeleteBuffers(1, [self.prt2_buffer[key]]) + glDeleteBuffers(1, [self.prt3_buffer[key]]) + + glDeleteBuffers(1, []) + + for smplr in self.render_texture_mat[key]: + glDeleteTextures([self.render_texture_mat[key][smplr]]) + + self.vert_buffer = {} + self.vert_data = {} + + self.vert_label_buffer = {} + self.vert_label_data = {} + + self.norm_buffer = {} + self.norm_data = {} + + self.tan_buffer = {} + self.tan_data = {} + + self.btan_buffer = {} + self.btan_data = {} + + self.prt1_buffer = {} + self.prt1_data = {} + + self.prt2_buffer = {} + self.prt2_data = {} + + self.prt3_buffer = {} + self.prt3_data = {} + + self.uv_buffer = {} + self.uv_data = {} + + self.render_texture_mat = {} + + self.vertex_dim = {} + self.n_vertices = {} + self.label_dim = {} + + def randomize_sh(self): + self.shcoeffs[0, :] = 0.8 + self.shcoeffs[1:, :] = 1.0 * np.random.rand(8, 3) + + def set_sh(self, sh): + self.shcoeffs = sh + + def set_norm_mat(self, scale, center): + N = np.eye(4) + N[:3, :3] = scale * np.eye(3) + N[:3, 3] = -scale * center + + self.normalize_matrix = N + + def draw(self): + self.draw_init() + + glDisable(GL_BLEND) + #glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + glEnable(GL_MULTISAMPLE) + + glUseProgram(self.program) + glUniformMatrix4fv(self.norm_mat_unif, 1, GL_FALSE, + self.normalize_matrix.transpose()) + glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, + self.model_view_matrix.transpose()) + glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, + self.projection_matrix.transpose()) + + if 'AlbedoMap' in self.render_texture_mat['all']: + glUniform1ui(self.hasAlbedoUnif, GLuint(1)) + else: + glUniform1ui(self.hasAlbedoUnif, GLuint(0)) + + if 'NormalMap' in self.render_texture_mat['all']: + glUniform1ui(self.hasNormalUnif, GLuint(1)) + else: + glUniform1ui(self.hasNormalUnif, GLuint(0)) + + glUniform1ui(self.analyticUnif, + GLuint(1) if self.analytic else GLuint(0)) + + glUniform3fv(self.shcoeff_unif, 9, self.shcoeffs) + + glUniformMatrix3fv(self.rot_mat_unif, 1, GL_FALSE, + self.rot_matrix.transpose()) + + for mat in self.vert_buffer: + # Handle vertex buffer + glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat]) + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, self.vertex_dim[mat], GL_DOUBLE, GL_FALSE, + 0, None) + + # Handle normal buffer + glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat]) + glEnableVertexAttribArray(1) + glVertexAttribPointer(1, 3, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle uv buffer + glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[mat]) + glEnableVertexAttribArray(2) + glVertexAttribPointer(2, 2, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle tan buffer + glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[mat]) + glEnableVertexAttribArray(3) + glVertexAttribPointer(3, 3, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle btan buffer + glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[mat]) + glEnableVertexAttribArray(4) + glVertexAttribPointer(4, 3, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle PTR buffer + glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[mat]) + glEnableVertexAttribArray(5) + glVertexAttribPointer(5, 3, GL_DOUBLE, GL_FALSE, 0, None) + + glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[mat]) + glEnableVertexAttribArray(6) + glVertexAttribPointer(6, 3, GL_DOUBLE, GL_FALSE, 0, None) + + glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[mat]) + glEnableVertexAttribArray(7) + glVertexAttribPointer(7, 3, GL_DOUBLE, GL_FALSE, 0, None) + + # Handle vertex label buffer + glBindBuffer(GL_ARRAY_BUFFER, self.vert_label_buffer[mat]) + glEnableVertexAttribArray(8) + glVertexAttribPointer(8, self.label_dim[mat], GL_DOUBLE, GL_FALSE, + 0, None) + + for i, smplr in enumerate(self.render_texture_mat[mat]): + glActiveTexture(GL_TEXTURE0 + i) + glBindTexture(GL_TEXTURE_2D, + self.render_texture_mat[mat][smplr]) + glUniform1i(glGetUniformLocation(self.program, smplr), i) + + glDrawArrays(GL_TRIANGLES, 0, self.n_vertices[mat]) + + glDisableVertexAttribArray(8) + glDisableVertexAttribArray(7) + glDisableVertexAttribArray(6) + glDisableVertexAttribArray(5) + glDisableVertexAttribArray(4) + glDisableVertexAttribArray(3) + glDisableVertexAttribArray(2) + glDisableVertexAttribArray(1) + glDisableVertexAttribArray(0) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + glUseProgram(0) + + glDisable(GL_BLEND) + glDisable(GL_MULTISAMPLE) + + self.draw_end() diff --git a/lib/renderer/gl/render.py b/lib/renderer/gl/render.py new file mode 100644 index 0000000000000000000000000000000000000000..90b8ebd31074a9ac4344b786d896e306d108dfee --- /dev/null +++ b/lib/renderer/gl/render.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from ctypes import * + +import numpy as np +from .framework import * + +GLUT = None + + +# NOTE: Render class assumes GL context is created already. +class Render: + + def __init__(self, + width=1600, + height=1200, + name='GL Renderer', + program_files=['simple.fs', 'simple.vs'], + color_size=1, + ms_rate=1, + egl=False): + self.width = width + self.height = height + self.name = name + self.use_inverse_depth = False + self.egl = egl + + glEnable(GL_DEPTH_TEST) + + glClampColor(GL_CLAMP_READ_COLOR, GL_FALSE) + glClampColor(GL_CLAMP_FRAGMENT_COLOR, GL_FALSE) + glClampColor(GL_CLAMP_VERTEX_COLOR, GL_FALSE) + + # init program + shader_list = [] + + for program_file in program_files: + _, ext = os.path.splitext(program_file) + if ext == '.vs': + shader_list.append(loadShader(GL_VERTEX_SHADER, program_file)) + elif ext == '.fs': + shader_list.append(loadShader(GL_FRAGMENT_SHADER, + program_file)) + elif ext == '.gs': + shader_list.append(loadShader(GL_GEOMETRY_SHADER, + program_file)) + + self.program = createProgram(shader_list) + + for shader in shader_list: + glDeleteShader(shader) + + # Init uniform variables + self.model_mat_unif = glGetUniformLocation(self.program, 'ModelMat') + self.persp_mat_unif = glGetUniformLocation(self.program, 'PerspMat') + + self.vertex_buffer = glGenBuffers(1) + + # Init screen quad program and buffer + self.quad_program, self.quad_buffer = self.init_quad_program() + + # Configure frame buffer + self.frame_buffer = glGenFramebuffers(1) + glBindFramebuffer(GL_FRAMEBUFFER, self.frame_buffer) + + self.intermediate_fbo = None + if ms_rate > 1: + # Configure texture buffer to render to + self.color_buffer = [] + for i in range(color_size): + color_buffer = glGenTextures(1) + multi_sample_rate = ms_rate + glBindTexture(GL_TEXTURE_2D_MULTISAMPLE, color_buffer) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, + GL_LINEAR) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_LINEAR) + glTexImage2DMultisample(GL_TEXTURE_2D_MULTISAMPLE, + multi_sample_rate, GL_RGBA32F, + self.width, self.height, GL_TRUE) + glBindTexture(GL_TEXTURE_2D_MULTISAMPLE, 0) + glFramebufferTexture2D(GL_FRAMEBUFFER, + GL_COLOR_ATTACHMENT0 + i, + GL_TEXTURE_2D_MULTISAMPLE, color_buffer, + 0) + self.color_buffer.append(color_buffer) + + self.render_buffer = glGenRenderbuffers(1) + glBindRenderbuffer(GL_RENDERBUFFER, self.render_buffer) + glRenderbufferStorageMultisample(GL_RENDERBUFFER, + multi_sample_rate, + GL_DEPTH24_STENCIL8, self.width, + self.height) + glBindRenderbuffer(GL_RENDERBUFFER, 0) + glFramebufferRenderbuffer(GL_FRAMEBUFFER, + GL_DEPTH_STENCIL_ATTACHMENT, + GL_RENDERBUFFER, self.render_buffer) + + attachments = [] + for i in range(color_size): + attachments.append(GL_COLOR_ATTACHMENT0 + i) + glDrawBuffers(color_size, attachments) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + + self.intermediate_fbo = glGenFramebuffers(1) + glBindFramebuffer(GL_FRAMEBUFFER, self.intermediate_fbo) + + self.screen_texture = [] + for i in range(color_size): + screen_texture = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, screen_texture) + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA32F, self.width, + self.height, 0, GL_RGBA, GL_FLOAT, None) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_LINEAR) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, + GL_LINEAR) + glFramebufferTexture2D(GL_FRAMEBUFFER, + GL_COLOR_ATTACHMENT0 + i, GL_TEXTURE_2D, + screen_texture, 0) + self.screen_texture.append(screen_texture) + + glDrawBuffers(color_size, attachments) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + else: + self.color_buffer = [] + for i in range(color_size): + color_buffer = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, color_buffer) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, + GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_NEAREST) + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA32F, self.width, + self.height, 0, GL_RGBA, GL_FLOAT, None) + glFramebufferTexture2D(GL_FRAMEBUFFER, + GL_COLOR_ATTACHMENT0 + i, GL_TEXTURE_2D, + color_buffer, 0) + self.color_buffer.append(color_buffer) + + # Configure depth texture map to render to + self.depth_buffer = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, self.depth_buffer) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_DEPTH_TEXTURE_MODE, GL_INTENSITY) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_COMPARE_MODE, + GL_COMPARE_R_TO_TEXTURE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_COMPARE_FUNC, GL_LEQUAL) + glTexImage2D(GL_TEXTURE_2D, 0, GL_DEPTH_COMPONENT, self.width, + self.height, 0, GL_DEPTH_COMPONENT, GL_FLOAT, None) + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + GL_TEXTURE_2D, self.depth_buffer, 0) + + attachments = [] + for i in range(color_size): + attachments.append(GL_COLOR_ATTACHMENT0 + i) + glDrawBuffers(color_size, attachments) + self.screen_texture = self.color_buffer + + glBindFramebuffer(GL_FRAMEBUFFER, 0) + + # Configure texture buffer if needed + self.render_texture = None + + # NOTE: original render_texture only support one input + # this is tentative member of this issue + self.render_texture_v2 = {} + + # Inner storage for buffer data + self.vertex_data = None + self.vertex_dim = None + self.n_vertices = None + + self.model_view_matrix = None + self.projection_matrix = None + + if not egl: + global GLUT + import OpenGL.GLUT as GLUT + GLUT.glutDisplayFunc(self.display) + + def init_quad_program(self): + shader_list = [] + + shader_list.append(loadShader(GL_VERTEX_SHADER, "quad.vs")) + shader_list.append(loadShader(GL_FRAGMENT_SHADER, "quad.fs")) + + the_program = createProgram(shader_list) + + for shader in shader_list: + glDeleteShader(shader) + + # vertex attributes for a quad that fills the entire screen in Normalized Device Coordinates. + # positions # texCoords + quad_vertices = np.array([ + -1.0, 1.0, 0.0, 1.0, -1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 1.0, 0.0, + -1.0, 1.0, 0.0, 1.0, 1.0, -1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0 + ]) + + quad_buffer = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, quad_buffer) + glBufferData(GL_ARRAY_BUFFER, quad_vertices, GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + return the_program, quad_buffer + + def set_mesh(self, vertices, faces): + self.vertex_data = vertices[faces.reshape([-1])] + self.vertex_dim = self.vertex_data.shape[1] + self.n_vertices = self.vertex_data.shape[0] + + glBindBuffer(GL_ARRAY_BUFFER, self.vertex_buffer) + glBufferData(GL_ARRAY_BUFFER, self.vertex_data, GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def set_viewpoint(self, projection, model_view): + self.projection_matrix = projection + self.model_view_matrix = model_view + + def draw_init(self): + glBindFramebuffer(GL_FRAMEBUFFER, self.frame_buffer) + glEnable(GL_DEPTH_TEST) + + glClearColor(0.0, 0.0, 0.0, 0.0) + if self.use_inverse_depth: + glDepthFunc(GL_GREATER) + glClearDepth(0.0) + else: + glDepthFunc(GL_LESS) + glClearDepth(1.0) + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + + def draw_end(self): + if self.intermediate_fbo is not None: + for i in range(len(self.color_buffer)): + glBindFramebuffer(GL_READ_FRAMEBUFFER, self.frame_buffer) + glReadBuffer(GL_COLOR_ATTACHMENT0 + i) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self.intermediate_fbo) + glDrawBuffer(GL_COLOR_ATTACHMENT0 + i) + glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, + self.width, self.height, GL_COLOR_BUFFER_BIT, + GL_NEAREST) + + glBindFramebuffer(GL_FRAMEBUFFER, 0) + glDepthFunc(GL_LESS) + glClearDepth(1.0) + + def draw(self): + self.draw_init() + + glUseProgram(self.program) + glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, + self.model_view_matrix.transpose()) + glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, + self.projection_matrix.transpose()) + + glBindBuffer(GL_ARRAY_BUFFER, self.vertex_buffer) + + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, self.vertex_dim, GL_DOUBLE, GL_FALSE, 0, None) + + glDrawArrays(GL_TRIANGLES, 0, self.n_vertices) + + glDisableVertexAttribArray(0) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + glUseProgram(0) + + self.draw_end() + + def get_color(self, color_id=0): + glBindFramebuffer( + GL_FRAMEBUFFER, self.intermediate_fbo + if self.intermediate_fbo is not None else self.frame_buffer) + glReadBuffer(GL_COLOR_ATTACHMENT0 + color_id) + data = glReadPixels(0, + 0, + self.width, + self.height, + GL_RGBA, + GL_FLOAT, + outputType=None) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + rgb = data.reshape(self.height, self.width, -1) + rgb = np.flip(rgb, 0) + return rgb + + def get_z_value(self): + glBindFramebuffer(GL_FRAMEBUFFER, self.frame_buffer) + data = glReadPixels(0, + 0, + self.width, + self.height, + GL_DEPTH_COMPONENT, + GL_FLOAT, + outputType=None) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + z = data.reshape(self.height, self.width) + z = np.flip(z, 0) + return z + + def display(self): + self.draw() + + if not self.egl: + # First we draw a scene. + # Notice the result is stored in the texture buffer. + + # Then we return to the default frame buffer since we will display on the screen. + glBindFramebuffer(GL_FRAMEBUFFER, 0) + + # Do the clean-up. + glClearColor(0.0, 0.0, 0.0, 0.0) + glClear(GL_COLOR_BUFFER_BIT) + + # We draw a rectangle which covers the whole screen. + glUseProgram(self.quad_program) + glBindBuffer(GL_ARRAY_BUFFER, self.quad_buffer) + + size_of_double = 8 + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, 2, GL_DOUBLE, GL_FALSE, + 4 * size_of_double, None) + glEnableVertexAttribArray(1) + glVertexAttribPointer(1, 2, GL_DOUBLE, GL_FALSE, + 4 * size_of_double, + c_void_p(2 * size_of_double)) + + glDisable(GL_DEPTH_TEST) + + # The stored texture is then mapped to this rectangle. + # properly assing color buffer texture + glActiveTexture(GL_TEXTURE0) + glBindTexture(GL_TEXTURE_2D, self.screen_texture[0]) + glUniform1i( + glGetUniformLocation(self.quad_program, 'screenTexture'), 0) + + glDrawArrays(GL_TRIANGLES, 0, 6) + + glDisableVertexAttribArray(1) + glDisableVertexAttribArray(0) + + glEnable(GL_DEPTH_TEST) + glBindBuffer(GL_ARRAY_BUFFER, 0) + glUseProgram(0) + + GLUT.glutSwapBuffers() + GLUT.glutPostRedisplay() + + def show(self): + if not self.egl: + GLUT.glutMainLoop() diff --git a/lib/renderer/gl/render2.py b/lib/renderer/gl/render2.py new file mode 100644 index 0000000000000000000000000000000000000000..5d250eb9ca667e80427c0780f7887c902274cdd9 --- /dev/null +++ b/lib/renderer/gl/render2.py @@ -0,0 +1,389 @@ +''' +MIT License + +Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +import numpy as np +from OpenGL.GLUT import * +from .framework import * + +_glut_window = None + + +class Render: + + def __init__(self, + width=1600, + height=1200, + name='GL Renderer', + program_files=['simple.fs', 'simple.vs'], + color_size=1, + ms_rate=1): + self.width = width + self.height = height + self.name = name + self.display_mode = GLUT_DOUBLE | GLUT_RGB | GLUT_DEPTH + self.use_inverse_depth = False + + global _glut_window + if _glut_window is None: + glutInit() + glutInitDisplayMode(self.display_mode) + glutInitWindowSize(self.width, self.height) + glutInitWindowPosition(0, 0) + _glut_window = glutCreateWindow("My Render.") + + # glEnable(GL_DEPTH_CLAMP) + glEnable(GL_DEPTH_TEST) + + glClampColor(GL_CLAMP_READ_COLOR, GL_FALSE) + glClampColor(GL_CLAMP_FRAGMENT_COLOR, GL_FALSE) + glClampColor(GL_CLAMP_VERTEX_COLOR, GL_FALSE) + + # init program + shader_list = [] + + for program_file in program_files: + _, ext = os.path.splitext(program_file) + if ext == '.vs': + shader_list.append(loadShader(GL_VERTEX_SHADER, program_file)) + elif ext == '.fs': + shader_list.append(loadShader(GL_FRAGMENT_SHADER, + program_file)) + elif ext == '.gs': + shader_list.append(loadShader(GL_GEOMETRY_SHADER, + program_file)) + + self.program = createProgram(shader_list) + + for shader in shader_list: + glDeleteShader(shader) + + # Init uniform variables + self.model_mat_unif = glGetUniformLocation(self.program, 'ModelMat') + self.persp_mat_unif = glGetUniformLocation(self.program, 'PerspMat') + + self.vertex_buffer = glGenBuffers(1) + + # Init screen quad program and buffer + self.quad_program, self.quad_buffer = self.init_quad_program() + + # Configure frame buffer + self.frame_buffer = glGenFramebuffers(1) + glBindFramebuffer(GL_FRAMEBUFFER, self.frame_buffer) + + self.intermediate_fbo = None + if ms_rate > 1: + # Configure texture buffer to render to + self.color_buffer = [] + for i in range(color_size): + color_buffer = glGenTextures(1) + multi_sample_rate = ms_rate + glBindTexture(GL_TEXTURE_2D_MULTISAMPLE, color_buffer) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, + GL_LINEAR) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_LINEAR) + glTexImage2DMultisample(GL_TEXTURE_2D_MULTISAMPLE, + multi_sample_rate, GL_RGBA32F, + self.width, self.height, GL_TRUE) + glBindTexture(GL_TEXTURE_2D_MULTISAMPLE, 0) + glFramebufferTexture2D(GL_FRAMEBUFFER, + GL_COLOR_ATTACHMENT0 + i, + GL_TEXTURE_2D_MULTISAMPLE, color_buffer, + 0) + self.color_buffer.append(color_buffer) + + self.render_buffer = glGenRenderbuffers(1) + glBindRenderbuffer(GL_RENDERBUFFER, self.render_buffer) + glRenderbufferStorageMultisample(GL_RENDERBUFFER, + multi_sample_rate, + GL_DEPTH24_STENCIL8, self.width, + self.height) + glBindRenderbuffer(GL_RENDERBUFFER, 0) + glFramebufferRenderbuffer(GL_FRAMEBUFFER, + GL_DEPTH_STENCIL_ATTACHMENT, + GL_RENDERBUFFER, self.render_buffer) + + attachments = [] + for i in range(color_size): + attachments.append(GL_COLOR_ATTACHMENT0 + i) + glDrawBuffers(color_size, attachments) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + + self.intermediate_fbo = glGenFramebuffers(1) + glBindFramebuffer(GL_FRAMEBUFFER, self.intermediate_fbo) + + self.screen_texture = [] + for i in range(color_size): + screen_texture = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, screen_texture) + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA32F, self.width, + self.height, 0, GL_RGBA, GL_FLOAT, None) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_LINEAR) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, + GL_LINEAR) + glFramebufferTexture2D(GL_FRAMEBUFFER, + GL_COLOR_ATTACHMENT0 + i, GL_TEXTURE_2D, + screen_texture, 0) + self.screen_texture.append(screen_texture) + + glDrawBuffers(color_size, attachments) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + else: + self.color_buffer = [] + for i in range(color_size): + color_buffer = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, color_buffer) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, + GL_CLAMP_TO_EDGE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, + GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, + GL_NEAREST) + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA32F, self.width, + self.height, 0, GL_RGBA, GL_FLOAT, None) + glFramebufferTexture2D(GL_FRAMEBUFFER, + GL_COLOR_ATTACHMENT0 + i, GL_TEXTURE_2D, + color_buffer, 0) + self.color_buffer.append(color_buffer) + + # Configure depth texture map to render to + self.depth_buffer = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, self.depth_buffer) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_DEPTH_TEXTURE_MODE, GL_INTENSITY) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_COMPARE_MODE, + GL_COMPARE_R_TO_TEXTURE) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_COMPARE_FUNC, GL_LEQUAL) + glTexImage2D(GL_TEXTURE_2D, 0, GL_DEPTH_COMPONENT, self.width, + self.height, 0, GL_DEPTH_COMPONENT, GL_FLOAT, None) + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + GL_TEXTURE_2D, self.depth_buffer, 0) + + attachments = [] + for i in range(color_size): + attachments.append(GL_COLOR_ATTACHMENT0 + i) + glDrawBuffers(color_size, attachments) + self.screen_texture = self.color_buffer + + glBindFramebuffer(GL_FRAMEBUFFER, 0) + + # Configure texture buffer if needed + self.render_texture = None + + # NOTE: original render_texture only support one input + # this is tentative member of this issue + self.render_texture_v2 = {} + + # Inner storage for buffer data + self.vertex_data = None + self.vertex_dim = None + self.n_vertices = None + + self.model_view_matrix = None + self.projection_matrix = None + + glutDisplayFunc(self.display) + + def init_quad_program(self): + shader_list = [] + + shader_list.append(loadShader(GL_VERTEX_SHADER, "quad.vs")) + shader_list.append(loadShader(GL_FRAGMENT_SHADER, "quad.fs")) + + the_program = createProgram(shader_list) + + for shader in shader_list: + glDeleteShader(shader) + + # vertex attributes for a quad that fills the entire screen in Normalized Device Coordinates. + # positions # texCoords + quad_vertices = np.array([ + -1.0, 1.0, 0.0, 1.0, -1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 1.0, 0.0, + -1.0, 1.0, 0.0, 1.0, 1.0, -1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0 + ]) + + quad_buffer = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, quad_buffer) + glBufferData(GL_ARRAY_BUFFER, quad_vertices, GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + return the_program, quad_buffer + + def set_mesh(self, vertices, faces): + self.vertex_data = vertices[faces.reshape([-1])] + self.vertex_dim = self.vertex_data.shape[1] + self.n_vertices = self.vertex_data.shape[0] + + glBindBuffer(GL_ARRAY_BUFFER, self.vertex_buffer) + glBufferData(GL_ARRAY_BUFFER, self.vertex_data, GL_STATIC_DRAW) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def set_viewpoint(self, projection, model_view): + self.projection_matrix = projection + self.model_view_matrix = model_view + + def draw_init(self): + glBindFramebuffer(GL_FRAMEBUFFER, self.frame_buffer) + glEnable(GL_DEPTH_TEST) + + # glClearColor(0.0, 0.0, 0.0, 0.0) + glClearColor(1.0, 1.0, 1.0, 0.0) # Black background + + if self.use_inverse_depth: + glDepthFunc(GL_GREATER) + glClearDepth(0.0) + else: + glDepthFunc(GL_LESS) + glClearDepth(1.0) + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + + def draw_end(self): + if self.intermediate_fbo is not None: + for i in range(len(self.color_buffer)): + glBindFramebuffer(GL_READ_FRAMEBUFFER, self.frame_buffer) + glReadBuffer(GL_COLOR_ATTACHMENT0 + i) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self.intermediate_fbo) + glDrawBuffer(GL_COLOR_ATTACHMENT0 + i) + glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, + self.width, self.height, GL_COLOR_BUFFER_BIT, + GL_NEAREST) + + glBindFramebuffer(GL_FRAMEBUFFER, 0) + glDepthFunc(GL_LESS) + glClearDepth(1.0) + + def draw(self): + self.draw_init() + + glUseProgram(self.program) + glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, + self.model_view_matrix.transpose()) + glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, + self.projection_matrix.transpose()) + + glBindBuffer(GL_ARRAY_BUFFER, self.vertex_buffer) + + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, self.vertex_dim, GL_DOUBLE, GL_FALSE, 0, None) + + glDrawArrays(GL_TRIANGLES, 0, self.n_vertices) + + glDisableVertexAttribArray(0) + + glBindBuffer(GL_ARRAY_BUFFER, 0) + + glUseProgram(0) + + self.draw_end() + + def get_color(self, color_id=0): + glBindFramebuffer( + GL_FRAMEBUFFER, self.intermediate_fbo + if self.intermediate_fbo is not None else self.frame_buffer) + glReadBuffer(GL_COLOR_ATTACHMENT0 + color_id) + data = glReadPixels(0, + 0, + self.width, + self.height, + GL_RGBA, + GL_FLOAT, + outputType=None) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + rgb = data.reshape(self.height, self.width, -1) + rgb = np.flip(rgb, 0) + return rgb + + def get_z_value(self): + glBindFramebuffer(GL_FRAMEBUFFER, self.frame_buffer) + data = glReadPixels(0, + 0, + self.width, + self.height, + GL_DEPTH_COMPONENT, + GL_FLOAT, + outputType=None) + glBindFramebuffer(GL_FRAMEBUFFER, 0) + z = data.reshape(self.height, self.width) + z = np.flip(z, 0) + return z + + def display(self): + # First we draw a scene. + # Notice the result is stored in the texture buffer. + self.draw() + + # Then we return to the default frame buffer since we will display on the screen. + glBindFramebuffer(GL_FRAMEBUFFER, 0) + + # Do the clean-up. + # glClearColor(0.0, 0.0, 0.0, 0.0) #Black background + glClearColor(1.0, 1.0, 1.0, 0.0) # Black background + glClear(GL_COLOR_BUFFER_BIT) + + # We draw a rectangle which covers the whole screen. + glUseProgram(self.quad_program) + glBindBuffer(GL_ARRAY_BUFFER, self.quad_buffer) + + size_of_double = 8 + glEnableVertexAttribArray(0) + glVertexAttribPointer(0, 2, GL_DOUBLE, GL_FALSE, 4 * size_of_double, + None) + glEnableVertexAttribArray(1) + glVertexAttribPointer(1, 2, GL_DOUBLE, GL_FALSE, 4 * size_of_double, + c_void_p(2 * size_of_double)) + + glDisable(GL_DEPTH_TEST) + + # The stored texture is then mapped to this rectangle. + # properly assing color buffer texture + glActiveTexture(GL_TEXTURE0) + glBindTexture(GL_TEXTURE_2D, self.screen_texture[0]) + glUniform1i(glGetUniformLocation(self.quad_program, 'screenTexture'), + 0) + + glDrawArrays(GL_TRIANGLES, 0, 6) + + glDisableVertexAttribArray(1) + glDisableVertexAttribArray(0) + + glEnable(GL_DEPTH_TEST) + glBindBuffer(GL_ARRAY_BUFFER, 0) + glUseProgram(0) + + glutSwapBuffers() + glutPostRedisplay() + + def show(self): + glutMainLoop() diff --git a/lib/renderer/glm.py b/lib/renderer/glm.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a45b0fe7396c8d372b85bbe782ae6d6cca4d0e --- /dev/null +++ b/lib/renderer/glm.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np + + +def vec3(x, y, z): + return np.array([x, y, z], dtype=np.float32) + + +def radians(v): + return np.radians(v) + + +def identity(): + return np.identity(4, dtype=np.float32) + + +def empty(): + return np.zeros([4, 4], dtype=np.float32) + + +def magnitude(v): + return np.linalg.norm(v) + + +def normalize(v): + m = magnitude(v) + return v if m == 0 else v / m + + +def dot(u, v): + return np.sum(u * v) + + +def cross(u, v): + res = vec3(0, 0, 0) + res[0] = u[1] * v[2] - u[2] * v[1] + res[1] = u[2] * v[0] - u[0] * v[2] + res[2] = u[0] * v[1] - u[1] * v[0] + return res + + +# below functions can be optimized + + +def translate(m, v): + res = np.copy(m) + res[:, 3] = m[:, 0] * v[0] + m[:, 1] * v[1] + m[:, 2] * v[2] + m[:, 3] + return res + + +def rotate(m, angle, v): + a = angle + c = np.cos(a) + s = np.sin(a) + + axis = normalize(v) + temp = (1 - c) * axis + + rot = empty() + rot[0][0] = c + temp[0] * axis[0] + rot[0][1] = temp[0] * axis[1] + s * axis[2] + rot[0][2] = temp[0] * axis[2] - s * axis[1] + + rot[1][0] = temp[1] * axis[0] - s * axis[2] + rot[1][1] = c + temp[1] * axis[1] + rot[1][2] = temp[1] * axis[2] + s * axis[0] + + rot[2][0] = temp[2] * axis[0] + s * axis[1] + rot[2][1] = temp[2] * axis[1] - s * axis[0] + rot[2][2] = c + temp[2] * axis[2] + + res = empty() + res[:, 0] = m[:, 0] * rot[0][0] + m[:, 1] * rot[0][1] + m[:, 2] * rot[0][2] + res[:, 1] = m[:, 0] * rot[1][0] + m[:, 1] * rot[1][1] + m[:, 2] * rot[1][2] + res[:, 2] = m[:, 0] * rot[2][0] + m[:, 1] * rot[2][1] + m[:, 2] * rot[2][2] + res[:, 3] = m[:, 3] + return res + + +def perspective(fovy, aspect, zNear, zFar): + tanHalfFovy = np.tan(fovy / 2) + + res = empty() + res[0][0] = 1 / (aspect * tanHalfFovy) + res[1][1] = 1 / (tanHalfFovy) + res[2][3] = -1 + res[2][2] = -(zFar + zNear) / (zFar - zNear) + res[3][2] = -(2 * zFar * zNear) / (zFar - zNear) + + return res.T + + +def ortho(left, right, bottom, top, zNear, zFar): + # res = np.ones([4, 4], dtype=np.float32) + res = identity() + res[0][0] = 2 / (right - left) + res[1][1] = 2 / (top - bottom) + res[2][2] = -2 / (zFar - zNear) + res[3][0] = -(right + left) / (right - left) + res[3][1] = -(top + bottom) / (top - bottom) + res[3][2] = -(zFar + zNear) / (zFar - zNear) + return res.T + + +def lookat(eye, center, up): + f = normalize(center - eye) + s = normalize(cross(f, up)) + u = cross(s, f) + + res = identity() + res[0][0] = s[0] + res[1][0] = s[1] + res[2][0] = s[2] + res[0][1] = u[0] + res[1][1] = u[1] + res[2][1] = u[2] + res[0][2] = -f[0] + res[1][2] = -f[1] + res[2][2] = -f[2] + res[3][0] = -dot(s, eye) + res[3][1] = -dot(u, eye) + res[3][2] = -dot(f, eye) + return res.T + + +def transform(d, m): + return np.dot(m, d.T).T diff --git a/lib/renderer/mesh.py b/lib/renderer/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..145883f39cae84abbbd5abce6d4829bfc581a8bb --- /dev/null +++ b/lib/renderer/mesh.py @@ -0,0 +1,532 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.dataset.mesh_util import SMPLX +from lib.common.render_utils import face_vertices +import numpy as np +import lib.smplx as smplx +import trimesh +import torch +import torch.nn.functional as F + +model_init_params = dict(gender='male', + model_type='smplx', + model_path=SMPLX().model_dir, + create_global_orient=False, + create_body_pose=False, + create_betas=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_expression=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + create_transl=False, + num_pca_comps=12) + + +def get_smpl_model(model_type, gender): + return smplx.create(**model_init_params) + + +def normalization(data): + _range = np.max(data) - np.min(data) + return ((data - np.min(data)) / _range) + + +def sigmoid(x): + z = 1 / (1 + np.exp(-x)) + return z + + +def load_fit_body(fitted_path, + scale, + smpl_type='smplx', + smpl_gender='neutral', + noise_dict=None): + + param = np.load(fitted_path, allow_pickle=True) + for key in param.keys(): + param[key] = torch.as_tensor(param[key]) + + smpl_model = get_smpl_model(smpl_type, smpl_gender) + model_forward_params = dict(betas=param['betas'], + global_orient=param['global_orient'], + body_pose=param['body_pose'], + left_hand_pose=param['left_hand_pose'], + right_hand_pose=param['right_hand_pose'], + jaw_pose=param['jaw_pose'], + leye_pose=param['leye_pose'], + reye_pose=param['reye_pose'], + expression=param['expression'], + return_verts=True) + + if noise_dict is not None: + model_forward_params.update(noise_dict) + + smpl_out = smpl_model(**model_forward_params) + + smpl_verts = ( + (smpl_out.vertices[0] * param['scale'] + param['translation']) * + scale).detach() + smpl_joints = ( + (smpl_out.joints[0] * param['scale'] + param['translation']) * + scale).detach() + smpl_mesh = trimesh.Trimesh(smpl_verts, + smpl_model.faces, + process=False, + maintain_order=True) + + return smpl_mesh, smpl_joints + + +def load_ori_fit_body(fitted_path, smpl_type='smplx', smpl_gender='neutral'): + + param = np.load(fitted_path, allow_pickle=True) + for key in param.keys(): + param[key] = torch.as_tensor(param[key]) + + smpl_model = get_smpl_model(smpl_type, smpl_gender) + model_forward_params = dict(betas=param['betas'], + global_orient=param['global_orient'], + body_pose=param['body_pose'], + left_hand_pose=param['left_hand_pose'], + right_hand_pose=param['right_hand_pose'], + jaw_pose=param['jaw_pose'], + leye_pose=param['leye_pose'], + reye_pose=param['reye_pose'], + expression=param['expression'], + return_verts=True) + + smpl_out = smpl_model(**model_forward_params) + + smpl_verts = smpl_out.vertices[0].detach() + smpl_mesh = trimesh.Trimesh(smpl_verts, + smpl_model.faces, + process=False, + maintain_order=True) + + return smpl_mesh + + +def save_obj_mesh(mesh_path, verts, faces): + file = open(mesh_path, 'w') + for v in verts: + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +# https://github.com/ratcave/wavefront_reader +def read_mtlfile(fname): + materials = {} + with open(fname) as f: + lines = f.read().splitlines() + + for line in lines: + if line: + split_line = line.strip().split(' ', 1) + if len(split_line) < 2: + continue + + prefix, data = split_line[0], split_line[1] + if 'newmtl' in prefix: + material = {} + materials[data] = material + elif materials: + if data: + split_data = data.strip().split(' ') + + # assume texture maps are in the same level + # WARNING: do not include space in your filename!! + if 'map' in prefix: + material[prefix] = split_data[-1].split('\\')[-1] + elif len(split_data) > 1: + material[prefix] = tuple(float(d) for d in split_data) + else: + try: + material[prefix] = int(data) + except ValueError: + material[prefix] = float(data) + + return materials + + +def load_obj_mesh_mtl(mesh_file): + vertex_data = [] + norm_data = [] + uv_data = [] + + face_data = [] + face_norm_data = [] + face_uv_data = [] + + # face per material + face_data_mat = {} + face_norm_data_mat = {} + face_uv_data_mat = {} + + # current material name + mtl_data = None + cur_mat = None + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + elif values[0] == 'vn': + vn = list(map(float, values[1:4])) + norm_data.append(vn) + elif values[0] == 'vt': + vt = list(map(float, values[1:3])) + uv_data.append(vt) + elif values[0] == 'mtllib': + mtl_data = read_mtlfile( + mesh_file.replace(mesh_file.split('/')[-1], values[1])) + elif values[0] == 'usemtl': + cur_mat = values[1] + elif values[0] == 'f': + # local triangle data + l_face_data = [] + l_face_uv_data = [] + l_face_norm_data = [] + + # quad mesh + if len(values) > 4: + f = list( + map( + lambda x: int(x.split('/')[0]) if int(x.split('/')[0]) + < 0 else int(x.split('/')[0]) - 1, values[1:4])) + l_face_data.append(f) + f = list( + map( + lambda x: int(x.split('/')[0]) + if int(x.split('/')[0]) < 0 else int(x.split('/')[0]) - + 1, [values[3], values[4], values[1]])) + l_face_data.append(f) + # tri mesh + else: + f = list( + map( + lambda x: int(x.split('/')[0]) if int(x.split('/')[0]) + < 0 else int(x.split('/')[0]) - 1, values[1:4])) + l_face_data.append(f) + # deal with texture + if len(values[1].split('/')) >= 2: + # quad mesh + if len(values) > 4: + f = list( + map( + lambda x: int(x.split('/')[1]) + if int(x.split('/')[1]) < 0 else int( + x.split('/')[1]) - 1, values[1:4])) + l_face_uv_data.append(f) + f = list( + map( + lambda x: int(x.split('/')[1]) + if int(x.split('/')[1]) < 0 else int( + x.split('/')[1]) - 1, + [values[3], values[4], values[1]])) + l_face_uv_data.append(f) + # tri mesh + elif len(values[1].split('/')[1]) != 0: + f = list( + map( + lambda x: int(x.split('/')[1]) + if int(x.split('/')[1]) < 0 else int( + x.split('/')[1]) - 1, values[1:4])) + l_face_uv_data.append(f) + # deal with normal + if len(values[1].split('/')) == 3: + # quad mesh + if len(values) > 4: + f = list( + map( + lambda x: int(x.split('/')[2]) + if int(x.split('/')[2]) < 0 else int( + x.split('/')[2]) - 1, values[1:4])) + l_face_norm_data.append(f) + f = list( + map( + lambda x: int(x.split('/')[2]) + if int(x.split('/')[2]) < 0 else int( + x.split('/')[2]) - 1, + [values[3], values[4], values[1]])) + l_face_norm_data.append(f) + # tri mesh + elif len(values[1].split('/')[2]) != 0: + f = list( + map( + lambda x: int(x.split('/')[2]) + if int(x.split('/')[2]) < 0 else int( + x.split('/')[2]) - 1, values[1:4])) + l_face_norm_data.append(f) + + face_data += l_face_data + face_uv_data += l_face_uv_data + face_norm_data += l_face_norm_data + + if cur_mat is not None: + if cur_mat not in face_data_mat.keys(): + face_data_mat[cur_mat] = [] + if cur_mat not in face_uv_data_mat.keys(): + face_uv_data_mat[cur_mat] = [] + if cur_mat not in face_norm_data_mat.keys(): + face_norm_data_mat[cur_mat] = [] + face_data_mat[cur_mat] += l_face_data + face_uv_data_mat[cur_mat] += l_face_uv_data + face_norm_data_mat[cur_mat] += l_face_norm_data + + vertices = np.array(vertex_data) + faces = np.array(face_data) + + norms = np.array(norm_data) + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) + + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) + + out_tuple = (vertices, faces, norms, face_normals, uvs, face_uvs) + + if cur_mat is not None and mtl_data is not None: + for key in face_data_mat: + face_data_mat[key] = np.array(face_data_mat[key]) + face_uv_data_mat[key] = np.array(face_uv_data_mat[key]) + face_norm_data_mat[key] = np.array(face_norm_data_mat[key]) + + out_tuple += (face_data_mat, face_norm_data_mat, face_uv_data_mat, + mtl_data) + + return out_tuple + + +def load_scan(mesh_file, with_normal=False, with_texture=False): + vertex_data = [] + norm_data = [] + uv_data = [] + + face_data = [] + face_norm_data = [] + face_uv_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + elif values[0] == 'vn': + vn = list(map(float, values[1:4])) + norm_data.append(vn) + elif values[0] == 'vt': + vt = list(map(float, values[1:3])) + uv_data.append(vt) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + # deal with texture + if len(values[1].split('/')) >= 2: + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[1]), values[1:4])) + face_uv_data.append(f) + f = list( + map(lambda x: int(x.split('/')[1]), + [values[3], values[4], values[1]])) + face_uv_data.append(f) + # tri mesh + elif len(values[1].split('/')[1]) != 0: + f = list(map(lambda x: int(x.split('/')[1]), values[1:4])) + face_uv_data.append(f) + # deal with normal + if len(values[1].split('/')) == 3: + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[2]), values[1:4])) + face_norm_data.append(f) + f = list( + map(lambda x: int(x.split('/')[2]), + [values[3], values[4], values[1]])) + face_norm_data.append(f) + # tri mesh + elif len(values[1].split('/')[2]) != 0: + f = list(map(lambda x: int(x.split('/')[2]), values[1:4])) + face_norm_data.append(f) + + vertices = np.array(vertex_data) + faces = np.array(face_data) - 1 + + if with_texture and with_normal: + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) - 1 + norms = np.array(norm_data) + if norms.shape[0] == 0: + norms = compute_normal(vertices, faces) + face_normals = faces + else: + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) - 1 + return vertices, faces, norms, face_normals, uvs, face_uvs + + if with_texture: + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) - 1 + return vertices, faces, uvs, face_uvs + + if with_normal: + norms = np.array(norm_data) + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) - 1 + return vertices, faces, norms, face_normals + + return vertices, faces + + +def normalize_v3(arr): + ''' Normalize a numpy array of 3 component vectors shape=(n,3) ''' + lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2) + eps = 0.00000001 + lens[lens < eps] = eps + arr[:, 0] /= lens + arr[:, 1] /= lens + arr[:, 2] /= lens + return arr + + +def compute_normal(vertices, faces): + # Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal + norm = np.zeros(vertices.shape, dtype=vertices.dtype) + # Create an indexed view into the vertex array using the array of three indices for triangles + tris = vertices[faces] + # Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle + n = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0]) + # n is now an array of normals per triangle. The length of each normal is dependent the vertices, + # we need to normalize these, so that our next step weights each normal equally. + normalize_v3(n) + # now we have a normalized array of normals, one per triangle, i.e., per triangle normals. + # But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle, + # the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards. + # The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array + norm[faces[:, 0]] += n + norm[faces[:, 1]] += n + norm[faces[:, 2]] += n + normalize_v3(norm) + + return norm + + +def compute_normal_batch(vertices, faces): + + bs, nv = vertices.shape[:2] + bs, nf = faces.shape[:2] + + vert_norm = torch.zeros(bs * nv, 3).type_as(vertices) + tris = face_vertices(vertices, faces) + face_norm = F.normalize(torch.cross(tris[:, :, 1] - tris[:, :, 0], + tris[:, :, 2] - tris[:, :, 0]), + dim=-1) + + faces = (faces + + (torch.arange(bs).type_as(faces) * nv)[:, None, None]).view( + -1, 3) + + vert_norm[faces[:, 0]] += face_norm.view(-1, 3) + vert_norm[faces[:, 1]] += face_norm.view(-1, 3) + vert_norm[faces[:, 2]] += face_norm.view(-1, 3) + + vert_norm = F.normalize(vert_norm, dim=-1).view(bs, nv, 3) + + return vert_norm + + +# compute tangent and bitangent +def compute_tangent(vertices, faces, normals, uvs, faceuvs): + # NOTE: this could be numerically unstable around [0,0,1] + # but other current solutions are pretty freaky somehow + c1 = np.cross(normals, np.array([0, 1, 0.0])) + tan = c1 + normalize_v3(tan) + btan = np.cross(normals, tan) + + # NOTE: traditional version is below + + # pts_tris = vertices[faces] + # uv_tris = uvs[faceuvs] + + # W = np.stack([pts_tris[::, 1] - pts_tris[::, 0], pts_tris[::, 2] - pts_tris[::, 0]],2) + # UV = np.stack([uv_tris[::, 1] - uv_tris[::, 0], uv_tris[::, 2] - uv_tris[::, 0]], 1) + + # for i in range(W.shape[0]): + # W[i,::] = W[i,::].dot(np.linalg.inv(UV[i,::])) + + # tan = np.zeros(vertices.shape, dtype=vertices.dtype) + # tan[faces[:,0]] += W[:,:,0] + # tan[faces[:,1]] += W[:,:,0] + # tan[faces[:,2]] += W[:,:,0] + + # btan = np.zeros(vertices.shape, dtype=vertices.dtype) + # btan[faces[:,0]] += W[:,:,1] + # btan[faces[:,1]] += W[:,:,1] + # btan[faces[:,2]] += W[:,:,1] + + # normalize_v3(tan) + + # ndott = np.sum(normals*tan, 1, keepdims=True) + # tan = tan - ndott * normals + + # normalize_v3(btan) + # normalize_v3(tan) + + # tan[np.sum(np.cross(normals, tan) * btan, 1) < 0,:] *= -1.0 + + return tan, btan diff --git a/lib/renderer/opengl_util.py b/lib/renderer/opengl_util.py new file mode 100644 index 0000000000000000000000000000000000000000..74ba082393b91943ec4649885de07f994c3bcda9 --- /dev/null +++ b/lib/renderer/opengl_util.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from lib.renderer.mesh import load_scan, compute_tangent, compute_normal, load_obj_mesh_mtl +from lib.dataset.mesh_util import projection +from lib.renderer.gl.prt_render import PRTRender +from lib.renderer.camera import Camera +import os +import cv2 +import math +import random +import numpy as np + + +def render_result(rndr, shader_id, path, mask=False): + + cam_render = rndr.get_color(shader_id) + cam_render = cv2.cvtColor(cam_render, cv2.COLOR_RGBA2BGRA) + + os.makedirs(os.path.dirname(path), exist_ok=True) + if shader_id != 2: + cv2.imwrite(path, np.uint8(255.0 * cam_render)) + else: + cam_render[:, :, -1] -= 0.5 + cam_render[:, :, -1] *= 2.0 + if not mask: + cv2.imwrite(path, np.uint8(255.0 / 2.0 * (cam_render + 1.0))) + else: + cv2.imwrite(path, np.uint8(-1.0 * cam_render[:, :, [3]])) + + +def make_rotate(rx, ry, rz): + sinX = np.sin(rx) + sinY = np.sin(ry) + sinZ = np.sin(rz) + + cosX = np.cos(rx) + cosY = np.cos(ry) + cosZ = np.cos(rz) + + Rx = np.zeros((3, 3)) + Rx[0, 0] = 1.0 + Rx[1, 1] = cosX + Rx[1, 2] = -sinX + Rx[2, 1] = sinX + Rx[2, 2] = cosX + + Ry = np.zeros((3, 3)) + Ry[0, 0] = cosY + Ry[0, 2] = sinY + Ry[1, 1] = 1.0 + Ry[2, 0] = -sinY + Ry[2, 2] = cosY + + Rz = np.zeros((3, 3)) + Rz[0, 0] = cosZ + Rz[0, 1] = -sinZ + Rz[1, 0] = sinZ + Rz[1, 1] = cosZ + Rz[2, 2] = 1.0 + + R = np.matmul(np.matmul(Rz, Ry), Rx) + return R + + +def rotateSH(SH, R): + SHn = SH + + # 1st order + SHn[1] = R[1, 1] * SH[1] - R[1, 2] * SH[2] + R[1, 0] * SH[3] + SHn[2] = -R[2, 1] * SH[1] + R[2, 2] * SH[2] - R[2, 0] * SH[3] + SHn[3] = R[0, 1] * SH[1] - R[0, 2] * SH[2] + R[0, 0] * SH[3] + + # 2nd order + SHn[4:, 0] = rotateBand2(SH[4:, 0], R) + SHn[4:, 1] = rotateBand2(SH[4:, 1], R) + SHn[4:, 2] = rotateBand2(SH[4:, 2], R) + + return SHn + + +def rotateBand2(x, R): + s_c3 = 0.94617469575 + s_c4 = -0.31539156525 + s_c5 = 0.54627421529 + + s_c_scale = 1.0 / 0.91529123286551084 + s_c_scale_inv = 0.91529123286551084 + + s_rc2 = 1.5853309190550713 * s_c_scale + s_c4_div_c3 = s_c4 / s_c3 + s_c4_div_c3_x2 = (s_c4 / s_c3) * 2.0 + + s_scale_dst2 = s_c3 * s_c_scale_inv + s_scale_dst4 = s_c5 * s_c_scale_inv + + sh0 = x[3] + x[4] + x[4] - x[1] + sh1 = x[0] + s_rc2 * x[2] + x[3] + x[4] + sh2 = x[0] + sh3 = -x[3] + sh4 = -x[1] + + r2x = R[0][0] + R[0][1] + r2y = R[1][0] + R[1][1] + r2z = R[2][0] + R[2][1] + + r3x = R[0][0] + R[0][2] + r3y = R[1][0] + R[1][2] + r3z = R[2][0] + R[2][2] + + r4x = R[0][1] + R[0][2] + r4y = R[1][1] + R[1][2] + r4z = R[2][1] + R[2][2] + + sh0_x = sh0 * R[0][0] + sh0_y = sh0 * R[1][0] + d0 = sh0_x * R[1][0] + d1 = sh0_y * R[2][0] + d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3) + d3 = sh0_x * R[2][0] + d4 = sh0_x * R[0][0] - sh0_y * R[1][0] + + sh1_x = sh1 * R[0][2] + sh1_y = sh1 * R[1][2] + d0 += sh1_x * R[1][2] + d1 += sh1_y * R[2][2] + d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3) + d3 += sh1_x * R[2][2] + d4 += sh1_x * R[0][2] - sh1_y * R[1][2] + + sh2_x = sh2 * r2x + sh2_y = sh2 * r2y + d0 += sh2_x * r2y + d1 += sh2_y * r2z + d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2) + d3 += sh2_x * r2z + d4 += sh2_x * r2x - sh2_y * r2y + + sh3_x = sh3 * r3x + sh3_y = sh3 * r3y + d0 += sh3_x * r3y + d1 += sh3_y * r3z + d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2) + d3 += sh3_x * r3z + d4 += sh3_x * r3x - sh3_y * r3y + + sh4_x = sh4 * r4x + sh4_y = sh4 * r4y + d0 += sh4_x * r4y + d1 += sh4_y * r4z + d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2) + d3 += sh4_x * r4z + d4 += sh4_x * r4x - sh4_y * r4y + + dst = x + dst[0] = d0 + dst[1] = -d1 + dst[2] = d2 * s_scale_dst2 + dst[3] = -d3 + dst[4] = d4 * s_scale_dst4 + + return dst + + +def load_calib(param, render_size=512): + # pixel unit / world unit + ortho_ratio = param['ortho_ratio'] + # world unit / model unit + scale = param['scale'] + # camera center world coordinate + center = param['center'] + # model rotation + R = param['R'] + + translate = -np.matmul(R, center).reshape(3, 1) + extrinsic = np.concatenate([R, translate], axis=1) + extrinsic = np.concatenate( + [extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0) + # Match camera space to image pixel space + scale_intrinsic = np.identity(4) + scale_intrinsic[0, 0] = scale / ortho_ratio + scale_intrinsic[1, 1] = -scale / ortho_ratio + scale_intrinsic[2, 2] = scale / ortho_ratio + # Match image pixel space to image uv space + uv_intrinsic = np.identity(4) + uv_intrinsic[0, 0] = 1.0 / float(render_size // 2) + uv_intrinsic[1, 1] = 1.0 / float(render_size // 2) + uv_intrinsic[2, 2] = 1.0 / float(render_size // 2) + + intrinsic = np.matmul(uv_intrinsic, scale_intrinsic) + calib = np.concatenate([extrinsic, intrinsic], axis=0) + return calib + + +def render_prt_ortho(out_path, + folder_name, + subject_name, + shs, + rndr, + rndr_uv, + im_size, + angl_step=4, + n_light=1, + pitch=[0]): + cam = Camera(width=im_size, height=im_size) + cam.ortho_ratio = 0.4 * (512 / im_size) + cam.near = -100 + cam.far = 100 + cam.sanity_check() + + # set path for obj, prt + mesh_file = os.path.join(folder_name, subject_name + '_100k.obj') + if not os.path.exists(mesh_file): + print('ERROR: obj file does not exist!!', mesh_file) + return + prt_file = os.path.join(folder_name, 'bounce', 'bounce0.txt') + if not os.path.exists(prt_file): + print('ERROR: prt file does not exist!!!', prt_file) + return + face_prt_file = os.path.join(folder_name, 'bounce', 'face.npy') + if not os.path.exists(face_prt_file): + print('ERROR: face prt file does not exist!!!', prt_file) + return + text_file = os.path.join(folder_name, 'tex', subject_name + '_dif_2k.jpg') + if not os.path.exists(text_file): + print('ERROR: dif file does not exist!!', text_file) + return + + texture_image = cv2.imread(text_file) + texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB) + + vertices, faces, normals, faces_normals, textures, face_textures = load_scan( + mesh_file, with_normal=True, with_texture=True) + vmin = vertices.min(0) + vmax = vertices.max(0) + up_axis = 1 if (vmax - vmin).argmax() == 1 else 2 + + vmed = np.median(vertices, 0) + vmed[up_axis] = 0.5 * (vmax[up_axis] + vmin[up_axis]) + y_scale = 180 / (vmax[up_axis] - vmin[up_axis]) + + rndr.set_norm_mat(y_scale, vmed) + rndr_uv.set_norm_mat(y_scale, vmed) + + tan, bitan = compute_tangent(vertices, faces, normals, textures, + face_textures) + prt = np.loadtxt(prt_file) + face_prt = np.load(face_prt_file) + rndr.set_mesh(vertices, faces, normals, faces_normals, textures, + face_textures, prt, face_prt, tan, bitan) + rndr.set_albedo(texture_image) + + rndr_uv.set_mesh(vertices, faces, normals, faces_normals, textures, + face_textures, prt, face_prt, tan, bitan) + rndr_uv.set_albedo(texture_image) + + os.makedirs(os.path.join(out_path, 'GEO', 'OBJ', subject_name), + exist_ok=True) + os.makedirs(os.path.join(out_path, 'PARAM', subject_name), exist_ok=True) + os.makedirs(os.path.join(out_path, 'RENDER', subject_name), exist_ok=True) + os.makedirs(os.path.join(out_path, 'MASK', subject_name), exist_ok=True) + os.makedirs(os.path.join(out_path, 'UV_RENDER', subject_name), + exist_ok=True) + os.makedirs(os.path.join(out_path, 'UV_MASK', subject_name), exist_ok=True) + os.makedirs(os.path.join(out_path, 'UV_POS', subject_name), exist_ok=True) + os.makedirs(os.path.join(out_path, 'UV_NORMAL', subject_name), + exist_ok=True) + + if not os.path.exists(os.path.join(out_path, 'val.txt')): + f = open(os.path.join(out_path, 'val.txt'), 'w') + f.close() + + # copy obj file + cmd = 'cp %s %s' % (mesh_file, + os.path.join(out_path, 'GEO', 'OBJ', subject_name)) + print(cmd) + os.system(cmd) + + for p in pitch: + for y in tqdm(range(0, 360, angl_step)): + R = np.matmul(make_rotate(math.radians(p), 0, 0), + make_rotate(0, math.radians(y), 0)) + if up_axis == 2: + R = np.matmul(R, make_rotate(math.radians(90), 0, 0)) + + rndr.rot_matrix = R + rndr_uv.rot_matrix = R + rndr.set_camera(cam) + rndr_uv.set_camera(cam) + + for j in range(n_light): + sh_id = random.randint(0, shs.shape[0] - 1) + sh = shs[sh_id] + sh_angle = 0.2 * np.pi * (random.random() - 0.5) + sh = rotateSH(sh, make_rotate(0, sh_angle, 0).T) + + dic = { + 'sh': sh, + 'ortho_ratio': cam.ortho_ratio, + 'scale': y_scale, + 'center': vmed, + 'R': R + } + + rndr.set_sh(sh) + rndr.analytic = False + rndr.use_inverse_depth = False + rndr.display() + + out_all_f = rndr.get_color(0) + out_mask = out_all_f[:, :, 3] + out_all_f = cv2.cvtColor(out_all_f, cv2.COLOR_RGBA2BGR) + + np.save( + os.path.join(out_path, 'PARAM', subject_name, + '%d_%d_%02d.npy' % (y, p, j)), dic) + cv2.imwrite( + os.path.join(out_path, 'RENDER', subject_name, + '%d_%d_%02d.jpg' % (y, p, j)), + 255.0 * out_all_f) + cv2.imwrite( + os.path.join(out_path, 'MASK', subject_name, + '%d_%d_%02d.png' % (y, p, j)), + 255.0 * out_mask) + + rndr_uv.set_sh(sh) + rndr_uv.analytic = False + rndr_uv.use_inverse_depth = False + rndr_uv.display() + + uv_color = rndr_uv.get_color(0) + uv_color = cv2.cvtColor(uv_color, cv2.COLOR_RGBA2BGR) + cv2.imwrite( + os.path.join(out_path, 'UV_RENDER', subject_name, + '%d_%d_%02d.jpg' % (y, p, j)), + 255.0 * uv_color) + + if y == 0 and j == 0 and p == pitch[0]: + uv_pos = rndr_uv.get_color(1) + uv_mask = uv_pos[:, :, 3] + cv2.imwrite( + os.path.join(out_path, 'UV_MASK', subject_name, + '00.png'), 255.0 * uv_mask) + + data = { + 'default': uv_pos[:, :, :3] + } # default is a reserved name + pyexr.write( + os.path.join(out_path, 'UV_POS', subject_name, + '00.exr'), data) + + uv_nml = rndr_uv.get_color(2) + uv_nml = cv2.cvtColor(uv_nml, cv2.COLOR_RGBA2BGR) + cv2.imwrite( + os.path.join(out_path, 'UV_NORMAL', subject_name, + '00.png'), 255.0 * uv_nml) diff --git a/lib/renderer/prt_util.py b/lib/renderer/prt_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dd617c9e17f8899d5e697d287720e432c8607df4 --- /dev/null +++ b/lib/renderer/prt_util.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import os +import trimesh +import numpy as np +import math +from scipy.special import sph_harm +import argparse +from tqdm import tqdm +from trimesh.util import bounds_tree + + +def factratio(N, D): + if N >= D: + prod = 1.0 + for i in range(D + 1, N + 1): + prod *= i + return prod + else: + prod = 1.0 + for i in range(N + 1, D + 1): + prod *= i + return 1.0 / prod + + +def KVal(M, L): + return math.sqrt(((2 * L + 1) / (4 * math.pi)) * (factratio(L - M, L + M))) + + +def AssociatedLegendre(M, L, x): + if M < 0 or M > L or np.max(np.abs(x)) > 1.0: + return np.zeros_like(x) + + pmm = np.ones_like(x) + if M > 0: + somx2 = np.sqrt((1.0 + x) * (1.0 - x)) + fact = 1.0 + for i in range(1, M + 1): + pmm = -pmm * fact * somx2 + fact = fact + 2 + + if L == M: + return pmm + else: + pmmp1 = x * (2 * M + 1) * pmm + if L == M + 1: + return pmmp1 + else: + pll = np.zeros_like(x) + for i in range(M + 2, L + 1): + pll = (x * (2 * i - 1) * pmmp1 - (i + M - 1) * pmm) / (i - M) + pmm = pmmp1 + pmmp1 = pll + return pll + + +def SphericalHarmonic(M, L, theta, phi): + if M > 0: + return math.sqrt(2.0) * KVal(M, L) * np.cos( + M * phi) * AssociatedLegendre(M, L, np.cos(theta)) + elif M < 0: + return math.sqrt(2.0) * KVal(-M, L) * np.sin( + -M * phi) * AssociatedLegendre(-M, L, np.cos(theta)) + else: + return KVal(0, L) * AssociatedLegendre(0, L, np.cos(theta)) + + +def save_obj(mesh_path, verts): + file = open(mesh_path, 'w') + for v in verts: + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + file.close() + + +def sampleSphericalDirections(n): + xv = np.random.rand(n, n) + yv = np.random.rand(n, n) + theta = np.arccos(1 - 2 * xv) + phi = 2.0 * math.pi * yv + + phi = phi.reshape(-1) + theta = theta.reshape(-1) + + vx = -np.sin(theta) * np.cos(phi) + vy = -np.sin(theta) * np.sin(phi) + vz = np.cos(theta) + return np.stack([vx, vy, vz], 1), phi, theta + + +def getSHCoeffs(order, phi, theta): + shs = [] + for n in range(0, order + 1): + for m in range(-n, n + 1): + s = SphericalHarmonic(m, n, theta, phi) + shs.append(s) + + return np.stack(shs, 1) + + +def computePRT(mesh_path, scale, n, order): + + prt_dir = os.path.join(os.path.dirname(mesh_path), "prt") + bounce_path = os.path.join(prt_dir, "bounce.npy") + face_path = os.path.join(prt_dir, "face.npy") + + os.makedirs(prt_dir, exist_ok=True) + + PRT = None + F = None + + if os.path.exists(bounce_path) and os.path.exists(face_path): + + PRT = np.load(bounce_path) + F = np.load(face_path) + + else: + + mesh = trimesh.load(mesh_path, + skip_materials=True, + process=False, + maintain_order=True) + mesh.vertices *= scale + + vectors_orig, phi, theta = sampleSphericalDirections(n) + SH_orig = getSHCoeffs(order, phi, theta) + + w = 4.0 * math.pi / (n * n) + + origins = mesh.vertices + normals = mesh.vertex_normals + n_v = origins.shape[0] + + origins = np.repeat(origins[:, None], n, axis=1).reshape(-1, 3) + normals = np.repeat(normals[:, None], n, axis=1).reshape(-1, 3) + PRT_all = None + for i in range(n): + SH = np.repeat(SH_orig[None, (i * n):((i + 1) * n)], n_v, + axis=0).reshape(-1, SH_orig.shape[1]) + vectors = np.repeat(vectors_orig[None, (i * n):((i + 1) * n)], + n_v, + axis=0).reshape(-1, 3) + + dots = (vectors * normals).sum(1) + front = (dots > 0.0) + + delta = 1e-3 * min(mesh.bounding_box.extents) + + hits = mesh.ray.intersects_any(origins + delta * normals, vectors) + nohits = np.logical_and(front, np.logical_not(hits)) + + PRT = (nohits.astype(np.float32) * dots)[:, None] * SH + + if PRT_all is not None: + PRT_all += (PRT.reshape(-1, n, SH.shape[1]).sum(1)) + else: + PRT_all = (PRT.reshape(-1, n, SH.shape[1]).sum(1)) + + PRT = w * PRT_all + F = mesh.faces + + np.save(bounce_path, PRT) + np.save(face_path, F) + + # NOTE: trimesh sometimes break the original vertex order, but topology will not change. + # when loading PRT in other program, use the triangle list from trimesh. + + return PRT, F + + +def testPRT(obj_path, n=40): + + os.makedirs(os.path.join(os.path.dirname(obj_path), + f'../bounce/{os.path.basename(obj_path)[:-4]}'), + exist_ok=True) + + PRT, F = computePRT(obj_path, n, 2) + np.savetxt( + os.path.join(os.path.dirname(obj_path), + f'../bounce/{os.path.basename(obj_path)[:-4]}', + 'bounce.npy'), PRT) + np.save( + os.path.join(os.path.dirname(obj_path), + f'../bounce/{os.path.basename(obj_path)[:-4]}', + 'face.npy'), F) diff --git a/lib/smplx/LICENSE b/lib/smplx/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3034a97b164d6e006655493e950314ec58e200cd --- /dev/null +++ b/lib/smplx/LICENSE @@ -0,0 +1,58 @@ +License + +Software Copyright License for non-commercial scientific research purposes +Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License + +Ownership / Licensees +The Software and the associated materials has been developed at the + +Max Planck Institute for Intelligent Systems (hereinafter "MPI"). + +Any copyright or patent right is owned by and proprietary material of the + +Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) + +hereinafter the “Licensor”. + +License Grant +Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: + +To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization; +To use the Model & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects; +Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Model & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. + +The Model & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Model & Software to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Model & Software, you agree not to reverse engineer it. + +No Distribution +The Model & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. + +Disclaimer of Representations and Warranties +You expressly acknowledge and agree that the Model & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model & Software, (ii) that the use of the Model & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model & Software will not cause any damage of any kind to you or a third party. + +Limitation of Liability +Because this Model & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. +Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. +Patent claims generated through the usage of the Model & Software cannot be directed towards the copyright holders. +The Model & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Model & Software and is not responsible for any problems such modifications cause. + +No Maintenance Services +You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Model & Software at any time. + +Defects of the Model & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. + +Publications using the Model & Software +You acknowledge that the Model & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Model & Software. + +Citation: + + +@inproceedings{SMPL-X:2019, + title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image}, + author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.}, + booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, + year = {2019} +} +Commercial licensing opportunities +For commercial uses of the Software, please send email to ps-license@tue.mpg.de + +This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. diff --git a/lib/smplx/README.md b/lib/smplx/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e000e63af4569d8fae38346be370ba815662674d --- /dev/null +++ b/lib/smplx/README.md @@ -0,0 +1,207 @@ +## SMPL-X: A new joint 3D model of the human body, face and hands together + +[[Paper Page](https://smpl-x.is.tue.mpg.de)] [[Paper](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/497/SMPL-X.pdf)] +[[Supp. Mat.](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/498/SMPL-X-supp.pdf)] + +![SMPL-X Examples](./images/teaser_fig.png) + +## Table of Contents + * [License](#license) + * [Description](#description) + * [News](#news) + * [Installation](#installation) + * [Downloading the model](#downloading-the-model) + * [Loading SMPL-X, SMPL+H and SMPL](#loading-smpl-x-smplh-and-smpl) + * [SMPL and SMPL+H setup](#smpl-and-smplh-setup) + * [Model loading](https://github.com/vchoutas/smplx#model-loading) + * [MANO and FLAME correspondences](#mano-and-flame-correspondences) + * [Example](#example) + * [Modifying the global pose of the model](#modifying-the-global-pose-of-the-model) + * [Citation](#citation) + * [Acknowledgments](#acknowledgments) + * [Contact](#contact) + +## License + +Software Copyright License for **non-commercial scientific research purposes**. +Please read carefully the [terms and conditions](https://github.com/vchoutas/smplx/blob/master/LICENSE) and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this [License](./LICENSE). + +## Disclaimer + +The original images used for the figures 1 and 2 of the paper can be found in this link. +The images in the paper are used under license from gettyimages.com. +We have acquired the right to use them in the publication, but redistribution is not allowed. +Please follow the instructions on the given link to acquire right of usage. +Our results are obtained on the 483 × 724 pixels resolution of the original images. + +## Description + +*SMPL-X* (SMPL eXpressive) is a unified body model with shape parameters trained jointly for the +face, hands and body. *SMPL-X* uses standard vertex based linear blend skinning with learned corrective blend +shapes, has N = 10, 475 vertices and K = 54 joints, +which include joints for the neck, jaw, eyeballs and fingers. +SMPL-X is defined by a function M(θ, β, ψ), where θ is the pose parameters, β the shape parameters and +ψ the facial expression parameters. + +## News + +- 3 November 2020: We release the code to transfer between the models in the + SMPL family. For more details on the code, go to this [readme + file](./transfer_model/README.md). A detailed explanation on how the mappings + were extracted can be found [here](./transfer_model/docs/transfer.md). +- 23 September 2020: A UV map is now available for SMPL-X, please check the + Downloads section of the website. +- 20 August 2020: The full shape and expression space of SMPL-X are now available. + +## Installation + +To install the model please follow the next steps in the specified order: +1. To install from PyPi simply run: + ```Shell + pip install smplx[all] + ``` +2. Clone this repository and install it using the *setup.py* script: +```Shell +git clone https://github.com/vchoutas/smplx +python setup.py install +``` + +## Downloading the model + +To download the *SMPL-X* model go to [this project website](https://smpl-x.is.tue.mpg.de) and register to get access to the downloads section. + +To download the *SMPL+H* model go to [this project website](http://mano.is.tue.mpg.de) and register to get access to the downloads section. + +To download the *SMPL* model go to [this](http://smpl.is.tue.mpg.de) (male and female models) and [this](http://smplify.is.tue.mpg.de) (gender neutral model) project website and register to get access to the downloads section. + +## Loading SMPL-X, SMPL+H and SMPL + +### SMPL and SMPL+H setup + +The loader gives the option to use any of the SMPL-X, SMPL+H, SMPL, and MANO models. Depending on the model you want to use, please follow the respective download instructions. To switch between MANO, SMPL, SMPL+H and SMPL-X just change the *model_path* or *model_type* parameters. For more details please check the docs of the model classes. +Before using SMPL and SMPL+H you should follow the instructions in [tools/README.md](./tools/README.md) to remove the +Chumpy objects from both model pkls, as well as merge the MANO parameters with SMPL+H. + +### Model loading + +You can either use the [create](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L54) +function from [body_models](./smplx/body_models.py) or directly call the constructor for the +[SMPL](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L106), +[SMPL+H](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L395) and +[SMPL-X](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L628) model. The path to the model can either be the path to the file with the parameters or a directory with the following structure: +```bash +models +├── smpl +│   ├── SMPL_FEMALE.pkl +│   └── SMPL_MALE.pkl +│   └── SMPL_NEUTRAL.pkl +├── smplh +│   ├── SMPLH_FEMALE.pkl +│   └── SMPLH_MALE.pkl +├── mano +| ├── MANO_RIGHT.pkl +| └── MANO_LEFT.pkl +└── smplx + ├── SMPLX_FEMALE.npz + ├── SMPLX_FEMALE.pkl + ├── SMPLX_MALE.npz + ├── SMPLX_MALE.pkl + ├── SMPLX_NEUTRAL.npz + └── SMPLX_NEUTRAL.pkl +``` + + +## MANO and FLAME correspondences + +The vertex correspondences between SMPL-X and MANO, FLAME can be downloaded +from [the project website](https://smpl-x.is.tue.mpg.de). If you have extracted +the correspondence data in the folder *correspondences*, then use the following +scripts to visualize them: + +1. To view MANO correspondences run the following command: + +``` +python examples/vis_mano_vertices.py --model-folder $SMPLX_FOLDER --corr-fname correspondences/MANO_SMPLX_vertex_ids.pkl +``` + +2. To view FLAME correspondences run the following command: + +``` +python examples/vis_flame_vertices.py --model-folder $SMPLX_FOLDER --corr-fname correspondences/SMPL-X__FLAME_vertex_ids.npy +``` + +## Example + +After installing the *smplx* package and downloading the model parameters you should be able to run the *demo.py* +script to visualize the results. For this step you have to install the [pyrender](https://pyrender.readthedocs.io/en/latest/index.html) and [trimesh](https://trimsh.org/) packages. + +`python examples/demo.py --model-folder $SMPLX_FOLDER --plot-joints=True --gender="neutral"` + +![SMPL-X Examples](./images/example.png) + +## Modifying the global pose of the model + +If you want to modify the global pose of the model, i.e. the root rotation and +translation, to a new coordinate system for example, you need to take into +account that the model rotation uses the pelvis as the center of rotation. A +more detailed description can be found in the following +[link](https://www.dropbox.com/scl/fi/zkatuv5shs8d4tlwr8ecc/Change-parameters-to-new-coordinate-system.paper?dl=0&rlkey=lotq1sh6wzkmyttisc05h0in0). +If something is not clear, please let me know so that I can update the +description. + +## Citation + +Depending on which model is loaded for your project, i.e. SMPL-X or SMPL+H or SMPL, please cite the most relevant work below, listed in the same order: + +``` +@inproceedings{SMPL-X:2019, + title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image}, + author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.}, + booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, + year = {2019} +} +``` + +``` +@article{MANO:SIGGRAPHASIA:2017, + title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together}, + author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.}, + journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, + volume = {36}, + number = {6}, + series = {245:1--245:17}, + month = nov, + year = {2017}, + month_numeric = {11} + } +``` + +``` +@article{SMPL:2015, + author = {Loper, Matthew and Mahmood, Naureen and Romero, Javier and Pons-Moll, Gerard and Black, Michael J.}, + title = {{SMPL}: A Skinned Multi-Person Linear Model}, + journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, + month = oct, + number = {6}, + pages = {248:1--248:16}, + publisher = {ACM}, + volume = {34}, + year = {2015} +} +``` + +This repository was originally developed for SMPL-X / SMPLify-X (CVPR 2019), you might be interested in having a look: [https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de). + +## Acknowledgments + +### Facial Contour + +Special thanks to [Soubhik Sanyal](https://github.com/soubhiksanyal) for sharing the Tensorflow code used for the facial +landmarks. + +## Contact +The code of this repository was implemented by [Vassilis Choutas](vassilis.choutas@tuebingen.mpg.de). + +For questions, please contact [smplx@tue.mpg.de](smplx@tue.mpg.de). + +For commercial licensing (and all related questions for business applications), please contact [ps-licensing@tue.mpg.de](ps-licensing@tue.mpg.de). diff --git a/lib/smplx/__init__.py b/lib/smplx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..886949df670691d1ef5995737cafa285224826c4 --- /dev/null +++ b/lib/smplx/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from .body_models import ( + create, + SMPL, + SMPLH, + SMPLX, + MANO, + FLAME, + build_layer, + SMPLLayer, + SMPLHLayer, + SMPLXLayer, + MANOLayer, + FLAMELayer, +) diff --git a/lib/smplx/body_models.py b/lib/smplx/body_models.py new file mode 100644 index 0000000000000000000000000000000000000000..d48b03e06ad4fb0b65d7113a4408431e435be318 --- /dev/null +++ b/lib/smplx/body_models.py @@ -0,0 +1,2455 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional, Dict, Union +import os +import os.path as osp +import pickle + +import numpy as np +import torch +import torch.nn as nn +from collections import namedtuple + +import logging + +logging.getLogger("smplx").setLevel(logging.ERROR) + +from .lbs import (lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords) + +from .vertex_ids import vertex_ids as VERTEX_IDS +from .utils import (Struct, to_np, to_tensor, Tensor, Array, SMPLOutput, + SMPLHOutput, SMPLXOutput, MANOOutput, FLAMEOutput, + find_joint_kin_chain) +from .vertex_joint_selector import VertexJointSelector + +ModelOutput = namedtuple('ModelOutput', [ + 'vertices', 'joints', 'full_pose', 'betas', 'global_orient', 'body_pose', + 'expression', 'left_hand_pose', 'right_hand_pose', 'jaw_pose' +]) +ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) + + +class SMPL(nn.Module): + + NUM_JOINTS = 23 + NUM_BODY_JOINTS = 23 + SHAPE_SPACE_DIM = 300 + + def __init__(self, + model_path: str, + kid_template_path: str = '', + data_struct: Optional[Struct] = None, + create_betas: bool = True, + betas: Optional[Tensor] = None, + num_betas: int = 10, + create_global_orient: bool = True, + global_orient: Optional[Tensor] = None, + create_body_pose: bool = True, + body_pose: Optional[Tensor] = None, + create_transl: bool = True, + transl: Optional[Tensor] = None, + dtype=torch.float32, + batch_size: int = 1, + joint_mapper=None, + gender: str = 'neutral', + age: str = 'adult', + vertex_ids: Dict[str, int] = None, + v_template: Optional[Union[Tensor, Array]] = None, + v_personal: Optional[Union[Tensor, Array]] = None, + **kwargs) -> None: + ''' SMPL model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_global_orient: bool, optional + Flag for creating a member variable for the global orientation + of the body. (default = True) + global_orient: torch.tensor, optional, Bx3 + The default value for the global orientation variable. + (default = None) + create_body_pose: bool, optional + Flag for creating a member variable for the pose of the body. + (default = True) + body_pose: torch.tensor, optional, Bx(Body Joints * 3) + The default value for the body pose variable. + (default = None) + num_betas: int, optional + Number of shape components to use + (default = 10). + create_betas: bool, optional + Flag for creating a member variable for the shape space + (default = True). + betas: torch.tensor, optional, Bx10 + The default value for the shape member variable. + (default = None) + create_transl: bool, optional + Flag for creating a member variable for the translation + of the body. (default = True) + transl: torch.tensor, optional, Bx3 + The default value for the transl variable. + (default = None) + dtype: torch.dtype, optional + The data type for the created variables + batch_size: int, optional + The batch size used for creating the member variables + joint_mapper: object, optional + An object that re-maps the joints. Useful if one wants to + re-order the SMPL joints to some other convention (e.g. MSCOCO) + (default = None) + gender: str, optional + Which gender to load + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.gender = gender + self.age = age + + if data_struct is None: + if osp.isdir(model_path): + model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') + smpl_path = os.path.join(model_path, model_fn) + else: + smpl_path = model_path + assert osp.exists(smpl_path), 'Path {} does not exist!'.format( + smpl_path) + + with open(smpl_path, 'rb') as smpl_file: + data_struct = Struct( + **pickle.load(smpl_file, encoding='latin1')) + + super(SMPL, self).__init__() + self.batch_size = batch_size + shapedirs = data_struct.shapedirs + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM): + # print(f'WARNING: You are using a {self.name()} model, with only' + # ' 10 shape coefficients.') + num_betas = min(num_betas, 10) + else: + num_betas = min(num_betas, self.SHAPE_SPACE_DIM) + + if self.age == 'kid': + v_template_smil = np.load(kid_template_path) + v_template_smil -= np.mean(v_template_smil, axis=0) + v_template_diff = np.expand_dims(v_template_smil - + data_struct.v_template, + axis=2) + shapedirs = np.concatenate( + (shapedirs[:, :, :num_betas], v_template_diff), axis=2) + num_betas = num_betas + 1 + + self._num_betas = num_betas + shapedirs = shapedirs[:, :, :num_betas] + # The shape components + self.register_buffer('shapedirs', + to_tensor(to_np(shapedirs), dtype=dtype)) + + if vertex_ids is None: + # SMPL and SMPL-H share the same topology, so any extra joints can + # be drawn from the same place + vertex_ids = VERTEX_IDS['smplh'] + + self.dtype = dtype + + self.joint_mapper = joint_mapper + + self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, + **kwargs) + + self.faces = data_struct.f + self.register_buffer( + 'faces_tensor', + to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long)) + + if create_betas: + if betas is None: + default_betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype) + else: + if torch.is_tensor(betas): + default_betas = betas.clone().detach() + else: + default_betas = torch.tensor(betas, dtype=dtype) + + self.register_parameter( + 'betas', nn.Parameter(default_betas, requires_grad=True)) + + # The tensor that contains the global rotation of the model + # It is separated from the pose of the joints in case we wish to + # optimize only over one of them + if create_global_orient: + if global_orient is None: + default_global_orient = torch.zeros([batch_size, 3], + dtype=dtype) + else: + if torch.is_tensor(global_orient): + default_global_orient = global_orient.clone().detach() + else: + default_global_orient = torch.tensor(global_orient, + dtype=dtype) + + global_orient = nn.Parameter(default_global_orient, + requires_grad=True) + self.register_parameter('global_orient', global_orient) + + if create_body_pose: + if body_pose is None: + default_body_pose = torch.zeros( + [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) + else: + if torch.is_tensor(body_pose): + default_body_pose = body_pose.clone().detach() + else: + default_body_pose = torch.tensor(body_pose, dtype=dtype) + self.register_parameter( + 'body_pose', nn.Parameter(default_body_pose, + requires_grad=True)) + + if create_transl: + if transl is None: + default_transl = torch.zeros([batch_size, 3], + dtype=dtype, + requires_grad=True) + else: + default_transl = torch.tensor(transl, dtype=dtype) + self.register_parameter( + 'transl', nn.Parameter(default_transl, requires_grad=True)) + + if v_template is None: + v_template = data_struct.v_template + + if not torch.is_tensor(v_template): + v_template = to_tensor(to_np(v_template), dtype=dtype) + + if v_personal is not None: + v_personal = to_tensor(to_np(v_personal), dtype=dtype) + v_template += v_personal + + # The vertices of the template model + self.register_buffer('v_template', v_template) + + j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype) + self.register_buffer('J_regressor', j_regressor) + + # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 + num_pose_basis = data_struct.posedirs.shape[-1] + # 207 x 20670 + posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T + self.register_buffer('posedirs', to_tensor(to_np(posedirs), + dtype=dtype)) + + # indices of parents for each joints + parents = to_tensor(to_np(data_struct.kintree_table[0])).long() + parents[0] = -1 + self.register_buffer('parents', parents) + + self.register_buffer( + 'lbs_weights', to_tensor(to_np(data_struct.weights), dtype=dtype)) + + @property + def num_betas(self): + return self._num_betas + + @property + def num_expression_coeffs(self): + return 0 + + def create_mean_pose(self, data_struct) -> Tensor: + pass + + def name(self) -> str: + return 'SMPL' + + @torch.no_grad() + def reset_params(self, **params_dict) -> None: + for param_name, param in self.named_parameters(): + if param_name in params_dict: + param[:] = torch.tensor(params_dict[param_name]) + else: + param.fill_(0) + + def get_num_verts(self) -> int: + return self.v_template.shape[0] + + def get_num_faces(self) -> int: + return self.faces.shape[0] + + def extra_repr(self) -> str: + msg = [ + f'Gender: {self.gender.upper()}', + f'Number of joints: {self.J_regressor.shape[0]}', + f'Betas: {self.num_betas}', + ] + return '\n'.join(msg) + + def forward(self, + betas: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts=True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs) -> SMPLOutput: + ''' Forward pass for the SMPL model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape Bx(J*3) + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient + if global_orient is not None else self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None and hasattr(self, 'transl'): + transl = self.transl + + full_pose = torch.cat([global_orient, body_pose], dim=1) + + batch_size = max(betas.shape[0], global_orient.shape[0], + body_pose.shape[0]) + + if betas.shape[0] != batch_size: + num_repeats = int(batch_size / betas.shape[0]) + betas = betas.expand(num_repeats, -1) + + vertices, joints = lbs(betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=pose2rot) + + joints = self.vertex_joint_selector(vertices, joints) + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLOutput(vertices=vertices if return_verts else None, + global_orient=global_orient, + body_pose=body_pose, + joints=joints, + betas=betas, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLLayer(SMPL): + + def __init__(self, *args, **kwargs) -> None: + # Just create a SMPL module without any member variables + super(SMPLLayer, self).__init__( + create_body_pose=False, + create_betas=False, + create_global_orient=False, + create_transl=False, + *args, + **kwargs, + ) + + def forward(self, + betas: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts=True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs) -> SMPLOutput: + ''' Forward pass for the SMPL model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + Body pose. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + model_vars = [betas, global_orient, body_pose, transl] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, + -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, + device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + full_pose = torch.cat([ + global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3) + ], + dim=1) + + vertices, joints = lbs(betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=False) + + joints = self.vertex_joint_selector(vertices, joints) + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLOutput(vertices=vertices if return_verts else None, + global_orient=global_orient, + body_pose=body_pose, + joints=joints, + betas=betas, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLH(SMPL): + + # The hand joints are replaced by MANO + NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 + NUM_HAND_JOINTS = 15 + NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + + def __init__(self, + model_path, + kid_template_path: str = '', + data_struct: Optional[Struct] = None, + create_left_hand_pose: bool = True, + left_hand_pose: Optional[Tensor] = None, + create_right_hand_pose: bool = True, + right_hand_pose: Optional[Tensor] = None, + use_pca: bool = True, + num_pca_comps: int = 6, + flat_hand_mean: bool = False, + batch_size: int = 1, + gender: str = 'neutral', + age: str = 'adult', + dtype=torch.float32, + vertex_ids=None, + use_compressed: bool = True, + ext: str = 'pkl', + **kwargs) -> None: + ''' SMPLH model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_left_hand_pose: bool, optional + Flag for creating a member variable for the pose of the left + hand. (default = True) + left_hand_pose: torch.tensor, optional, BxP + The default value for the left hand pose member variable. + (default = None) + create_right_hand_pose: bool, optional + Flag for creating a member variable for the pose of the right + hand. (default = True) + right_hand_pose: torch.tensor, optional, BxP + The default value for the right hand pose member variable. + (default = None) + num_pca_comps: int, optional + The number of PCA components to use for each hand. + (default = 6) + flat_hand_mean: bool, optional + If False, then the pose of the hand is initialized to False. + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype, optional + The data type for the created variables + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.num_pca_comps = num_pca_comps + # If no data structure is passed, then load the data from the given + # model folder + if data_struct is None: + # Load the model + if osp.isdir(model_path): + model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext) + smplh_path = os.path.join(model_path, model_fn) + else: + smplh_path = model_path + assert osp.exists(smplh_path), 'Path {} does not exist!'.format( + smplh_path) + + if ext == 'pkl': + with open(smplh_path, 'rb') as smplh_file: + model_data = pickle.load(smplh_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(smplh_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**model_data) + + if vertex_ids is None: + vertex_ids = VERTEX_IDS['smplh'] + + super(SMPLH, self).__init__(model_path=model_path, + kid_template_path=kid_template_path, + data_struct=data_struct, + batch_size=batch_size, + vertex_ids=vertex_ids, + gender=gender, + age=age, + use_compressed=use_compressed, + dtype=dtype, + ext=ext, + **kwargs) + + self.use_pca = use_pca + self.num_pca_comps = num_pca_comps + self.flat_hand_mean = flat_hand_mean + + left_hand_components = data_struct.hands_componentsl[:num_pca_comps] + right_hand_components = data_struct.hands_componentsr[:num_pca_comps] + + self.np_left_hand_components = left_hand_components + self.np_right_hand_components = right_hand_components + if self.use_pca: + self.register_buffer( + 'left_hand_components', + torch.tensor(left_hand_components, dtype=dtype)) + self.register_buffer( + 'right_hand_components', + torch.tensor(right_hand_components, dtype=dtype)) + + if self.flat_hand_mean: + left_hand_mean = np.zeros_like(data_struct.hands_meanl) + else: + left_hand_mean = data_struct.hands_meanl + + if self.flat_hand_mean: + right_hand_mean = np.zeros_like(data_struct.hands_meanr) + else: + right_hand_mean = data_struct.hands_meanr + + self.register_buffer('left_hand_mean', + to_tensor(left_hand_mean, dtype=self.dtype)) + self.register_buffer('right_hand_mean', + to_tensor(right_hand_mean, dtype=self.dtype)) + + # Create the buffers for the pose of the left hand + hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS + if create_left_hand_pose: + if left_hand_pose is None: + default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) + + left_hand_pose_param = nn.Parameter(default_lhand_pose, + requires_grad=True) + self.register_parameter('left_hand_pose', left_hand_pose_param) + + if create_right_hand_pose: + if right_hand_pose is None: + default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) + + right_hand_pose_param = nn.Parameter(default_rhand_pose, + requires_grad=True) + self.register_parameter('right_hand_pose', right_hand_pose_param) + + # Create the buffer for the mean pose. + pose_mean_tensor = self.create_mean_pose(data_struct, + flat_hand_mean=flat_hand_mean) + if not torch.is_tensor(pose_mean_tensor): + pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) + self.register_buffer('pose_mean', pose_mean_tensor) + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], + dtype=self.dtype) + + pose_mean = torch.cat([ + global_orient_mean, body_pose_mean, self.left_hand_mean, + self.right_hand_mean + ], + dim=0) + return pose_mean + + def name(self) -> str: + return 'SMPL+H' + + def extra_repr(self): + msg = super(SMPLH, self).extra_repr() + msg = [msg] + if self.use_pca: + msg.append(f'Number of PCA components: {self.num_pca_comps}') + msg.append(f'Flat hand mean: {self.flat_hand_mean}') + return '\n'.join(msg) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs) -> SMPLHOutput: + ''' + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient + if global_orient is not None else self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + left_hand_pose = (left_hand_pose if left_hand_pose is not None else + self.left_hand_pose) + right_hand_pose = (right_hand_pose if right_hand_pose is not None else + self.right_hand_pose) + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + left_hand_pose = torch.einsum( + 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) + right_hand_pose = torch.einsum( + 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) + + full_pose = torch.cat( + [global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) + + full_pose += self.pose_mean + + vertices, joints = lbs(betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=pose2rot) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLHOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLHLayer(SMPLH): + + def __init__(self, *args, **kwargs) -> None: + ''' SMPL+H as a layer model constructor + ''' + super(SMPLHLayer, self).__init__(create_global_orient=False, + create_body_pose=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_betas=False, + create_transl=False, + *args, + **kwargs) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs) -> SMPLHOutput: + ''' Forward pass for the SMPL+H model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + model_vars = [ + betas, global_orient, body_pose, transl, left_hand_pose, + right_hand_pose + ] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, + device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + # Concatenate all pose vectors + full_pose = torch.cat([ + global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) + ], + dim=1) + + vertices, joints = lbs(betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=False) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLHOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLX(SMPLH): + ''' + SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters + trained jointly for the face, hands and body. + SMPL-X uses standard vertex based linear blend skinning with learned + corrective blend shapes, has N=10475 vertices and K=54 joints, + which includes joints for the neck, jaw, eyeballs and fingers. + ''' + + NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21 + NUM_HAND_JOINTS = 15 + NUM_FACE_JOINTS = 3 + NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS + EXPRESSION_SPACE_DIM = 100 + NECK_IDX = 12 + + def __init__(self, + model_path: str, + kid_template_path: str = '', + num_expression_coeffs: int = 10, + create_expression: bool = True, + expression: Optional[Tensor] = None, + create_jaw_pose: bool = True, + jaw_pose: Optional[Tensor] = None, + create_leye_pose: bool = True, + leye_pose: Optional[Tensor] = None, + create_reye_pose=True, + reye_pose: Optional[Tensor] = None, + use_face_contour: bool = False, + batch_size: int = 1, + gender: str = 'neutral', + age: str = 'adult', + dtype=torch.float32, + ext: str = 'npz', + **kwargs) -> None: + ''' SMPLX model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + num_expression_coeffs: int, optional + Number of expression components to use + (default = 10). + create_expression: bool, optional + Flag for creating a member variable for the expression space + (default = True). + expression: torch.tensor, optional, Bx10 + The default value for the expression member variable. + (default = None) + create_jaw_pose: bool, optional + Flag for creating a member variable for the jaw pose. + (default = False) + jaw_pose: torch.tensor, optional, Bx3 + The default value for the jaw pose variable. + (default = None) + create_leye_pose: bool, optional + Flag for creating a member variable for the left eye pose. + (default = False) + leye_pose: torch.tensor, optional, Bx10 + The default value for the left eye pose variable. + (default = None) + create_reye_pose: bool, optional + Flag for creating a member variable for the right eye pose. + (default = False) + reye_pose: torch.tensor, optional, Bx10 + The default value for the right eye pose variable. + (default = None) + use_face_contour: bool, optional + Whether to compute the keypoints that form the facial contour + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype + The data type for the created variables + ''' + + # Load the model + if osp.isdir(model_path): + model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext) + smplx_path = os.path.join(model_path, model_fn) + else: + smplx_path = model_path + assert osp.exists(smplx_path), 'Path {} does not exist!'.format( + smplx_path) + + if ext == 'pkl': + with open(smplx_path, 'rb') as smplx_file: + model_data = pickle.load(smplx_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(smplx_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + + data_struct = Struct(**model_data) + + super(SMPLX, self).__init__(model_path=model_path, + kid_template_path=kid_template_path, + data_struct=data_struct, + dtype=dtype, + batch_size=batch_size, + vertex_ids=VERTEX_IDS['smplx'], + gender=gender, + age=age, + ext=ext, + **kwargs) + + lmk_faces_idx = data_struct.lmk_faces_idx + self.register_buffer('lmk_faces_idx', + torch.tensor(lmk_faces_idx, dtype=torch.long)) + lmk_bary_coords = data_struct.lmk_bary_coords + self.register_buffer('lmk_bary_coords', + torch.tensor(lmk_bary_coords, dtype=dtype)) + + self.use_face_contour = use_face_contour + if self.use_face_contour: + dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx + dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, + dtype=torch.long) + self.register_buffer('dynamic_lmk_faces_idx', + dynamic_lmk_faces_idx) + + dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords + dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, + dtype=dtype) + self.register_buffer('dynamic_lmk_bary_coords', + dynamic_lmk_bary_coords) + + neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) + self.register_buffer( + 'neck_kin_chain', torch.tensor(neck_kin_chain, + dtype=torch.long)) + + if create_jaw_pose: + if jaw_pose is None: + default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) + jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) + self.register_parameter('jaw_pose', jaw_pose_param) + + if create_leye_pose: + if leye_pose is None: + default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_leye_pose = torch.tensor(leye_pose, dtype=dtype) + leye_pose_param = nn.Parameter(default_leye_pose, + requires_grad=True) + self.register_parameter('leye_pose', leye_pose_param) + + if create_reye_pose: + if reye_pose is None: + default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_reye_pose = torch.tensor(reye_pose, dtype=dtype) + reye_pose_param = nn.Parameter(default_reye_pose, + requires_grad=True) + self.register_parameter('reye_pose', reye_pose_param) + + shapedirs = data_struct.shapedirs + if len(shapedirs.shape) < 3: + shapedirs = shapedirs[:, :, None] + if (shapedirs.shape[-1] < + self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): + # print(f'WARNING: You are using a {self.name()} model, with only' + # ' 10 shape and 10 expression coefficients.') + expr_start_idx = 10 + expr_end_idx = 20 + num_expression_coeffs = min(num_expression_coeffs, 10) + else: + expr_start_idx = self.SHAPE_SPACE_DIM + expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs + num_expression_coeffs = min(num_expression_coeffs, + self.EXPRESSION_SPACE_DIM) + + self._num_expression_coeffs = num_expression_coeffs + + expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] + self.register_buffer('expr_dirs', + to_tensor(to_np(expr_dirs), dtype=dtype)) + + if create_expression: + if expression is None: + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype) + else: + default_expression = torch.tensor(expression, dtype=dtype) + expression_param = nn.Parameter(default_expression, + requires_grad=True) + self.register_parameter('expression', expression_param) + + def name(self) -> str: + return 'SMPL-X' + + @property + def num_expression_coeffs(self): + return self._num_expression_coeffs + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], + dtype=self.dtype) + jaw_pose_mean = torch.zeros([3], dtype=self.dtype) + leye_pose_mean = torch.zeros([3], dtype=self.dtype) + reye_pose_mean = torch.zeros([3], dtype=self.dtype) + + pose_mean = np.concatenate([ + global_orient_mean, body_pose_mean, jaw_pose_mean, leye_pose_mean, + reye_pose_mean, self.left_hand_mean, self.right_hand_mean + ], + axis=0) + + return pose_mean + + def extra_repr(self): + msg = super(SMPLX, self).extra_repr() + msg = [ + msg, + f'Number of Expression Coefficients: {self.num_expression_coeffs}' + ] + return '\n'.join(msg) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + return_joint_transformation: bool = False, + return_vertex_transformation: bool = False, + **kwargs) -> SMPLXOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape Bx(J*3) + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + left_hand_pose: torch.tensor, optional, shape BxP + If given, ignore the member variable `left_hand_pose` and + use this instead. It should either contain PCA coefficients or + joint rotations in axis-angle format. + right_hand_pose: torch.tensor, optional, shape BxP + If given, ignore the member variable `right_hand_pose` and + use this instead. It should either contain PCA coefficients or + joint rotations in axis-angle format. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient + if global_orient is not None else self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + + left_hand_pose = (left_hand_pose if left_hand_pose is not None else + self.left_hand_pose) + right_hand_pose = (right_hand_pose if right_hand_pose is not None else + self.right_hand_pose) + jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose + leye_pose = leye_pose if leye_pose is not None else self.leye_pose + reye_pose = reye_pose if reye_pose is not None else self.reye_pose + expression = expression if expression is not None else self.expression + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + left_hand_pose = torch.einsum( + 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) + right_hand_pose = torch.einsum( + 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) + + full_pose = torch.cat([ + global_orient, body_pose, jaw_pose, leye_pose, reye_pose, + left_hand_pose, right_hand_pose + ], + dim=1) + + # Add the mean pose of the model. Does not affect the body, only the + # hands when flat_hand_mean == False + full_pose += self.pose_mean + + batch_size = max(betas.shape[0], global_orient.shape[0], + body_pose.shape[0]) + # Concatenate the shape and expression coefficients + scale = int(batch_size / betas.shape[0]) + if scale > 1: + betas = betas.expand(scale, -1) + shape_components = torch.cat([betas, expression], dim=-1) + + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + if return_joint_transformation or return_vertex_transformation: + vertices, joints, joint_transformation, vertex_transformation = lbs( + shape_components, + full_pose, + self.v_template, + shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=pose2rot, + return_transformation=True) + else: + vertices, joints = lbs( + shape_components, + full_pose, + self.v_template, + shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=pose2rot, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( + batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + self.batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, + full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=True, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat([ + lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords + ], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + # Map the joints to the current dataset + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLXOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None, + joint_transformation=joint_transformation + if return_joint_transformation else None, + vertex_transformation=vertex_transformation + if return_vertex_transformation else None) + return output + + +class SMPLXLayer(SMPLX): + + def __init__(self, *args, **kwargs) -> None: + # Just create a SMPLX module without any member variables + super(SMPLXLayer, self).__init__( + create_global_orient=False, + create_body_pose=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + create_betas=False, + create_expression=False, + create_transl=False, + *args, + **kwargs, + ) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs) -> SMPLXOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. It is expected to be in rotation matrix + format. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + Expression coefficients. + For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full pose vector (default=False) + Returns + ------- + output: ModelOutput + A data class that contains the posed vertices and joints + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + + model_vars = [ + betas, global_orient, body_pose, transl, expression, + left_hand_pose, right_hand_pose, jaw_pose + ] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, + -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if expression is None: + expression = torch.zeros([batch_size, self.num_expression_coeffs], + dtype=dtype, + device=device) + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, + device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + # Concatenate all pose vectors + full_pose = torch.cat([ + global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + jaw_pose.reshape(-1, 1, 3, 3), + leye_pose.reshape(-1, 1, 3, 3), + reye_pose.reshape(-1, 1, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) + ], + dim=1) + shape_components = torch.cat([betas, expression], dim=-1) + + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs( + shape_components, + full_pose, + self.v_template, + shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=False, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( + batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, + full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=False, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat([ + lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords + ], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + # Map the joints to the current dataset + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLXOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + jaw_pose=jaw_pose, + transl=transl, + full_pose=full_pose if return_full_pose else None) + return output + + +class MANO(SMPL): + # The hand joints are replaced by MANO + NUM_BODY_JOINTS = 1 + NUM_HAND_JOINTS = 15 + NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS + + def __init__(self, + model_path: str, + is_rhand: bool = True, + data_struct: Optional[Struct] = None, + create_hand_pose: bool = True, + hand_pose: Optional[Tensor] = None, + use_pca: bool = True, + num_pca_comps: int = 6, + flat_hand_mean: bool = False, + batch_size: int = 1, + dtype=torch.float32, + vertex_ids=None, + use_compressed: bool = True, + ext: str = 'pkl', + **kwargs) -> None: + ''' MANO model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_hand_pose: bool, optional + Flag for creating a member variable for the pose of the right + hand. (default = True) + hand_pose: torch.tensor, optional, BxP + The default value for the right hand pose member variable. + (default = None) + num_pca_comps: int, optional + The number of PCA components to use for each hand. + (default = 6) + flat_hand_mean: bool, optional + If False, then the pose of the hand is initialized to False. + batch_size: int, optional + The batch size used for creating the member variables + dtype: torch.dtype, optional + The data type for the created variables + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.num_pca_comps = num_pca_comps + self.is_rhand = is_rhand + # If no data structure is passed, then load the data from the given + # model folder + if data_struct is None: + # Load the model + if osp.isdir(model_path): + model_fn = 'MANO_{}.{ext}'.format( + 'RIGHT' if is_rhand else 'LEFT', ext=ext) + mano_path = os.path.join(model_path, model_fn) + else: + mano_path = model_path + self.is_rhand = True if 'RIGHT' in os.path.basename( + model_path) else False + assert osp.exists(mano_path), 'Path {} does not exist!'.format( + mano_path) + + if ext == 'pkl': + with open(mano_path, 'rb') as mano_file: + model_data = pickle.load(mano_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(mano_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**model_data) + + if vertex_ids is None: + vertex_ids = VERTEX_IDS['smplh'] + + super(MANO, self).__init__(model_path=model_path, + data_struct=data_struct, + batch_size=batch_size, + vertex_ids=vertex_ids, + use_compressed=use_compressed, + dtype=dtype, + ext=ext, + **kwargs) + + # add only MANO tips to the extra joints + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + list(VERTEX_IDS['mano'].values()), dtype=torch.long) + + self.use_pca = use_pca + self.num_pca_comps = num_pca_comps + if self.num_pca_comps == 45: + self.use_pca = False + self.flat_hand_mean = flat_hand_mean + + hand_components = data_struct.hands_components[:num_pca_comps] + + self.np_hand_components = hand_components + + if self.use_pca: + self.register_buffer('hand_components', + torch.tensor(hand_components, dtype=dtype)) + + if self.flat_hand_mean: + hand_mean = np.zeros_like(data_struct.hands_mean) + else: + hand_mean = data_struct.hands_mean + + self.register_buffer('hand_mean', to_tensor(hand_mean, + dtype=self.dtype)) + + # Create the buffers for the pose of the left hand + hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS + if create_hand_pose: + if hand_pose is None: + default_hand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_hand_pose = torch.tensor(hand_pose, dtype=dtype) + + hand_pose_param = nn.Parameter(default_hand_pose, + requires_grad=True) + self.register_parameter('hand_pose', hand_pose_param) + + # Create the buffer for the mean pose. + pose_mean = self.create_mean_pose(data_struct, + flat_hand_mean=flat_hand_mean) + pose_mean_tensor = pose_mean.clone().to(dtype) + # pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype) + self.register_buffer('pose_mean', pose_mean_tensor) + + def name(self) -> str: + return 'MANO' + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) + return pose_mean + + def extra_repr(self): + msg = [super(MANO, self).extra_repr()] + if self.use_pca: + msg.append(f'Number of PCA components: {self.num_pca_comps}') + msg.append(f'Flat hand mean: {self.flat_hand_mean}') + return '\n'.join(msg) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs) -> MANOOutput: + ''' Forward pass for the MANO model + ''' + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient + if global_orient is not None else self.global_orient) + betas = betas if betas is not None else self.betas + hand_pose = (hand_pose if hand_pose is not None else self.hand_pose) + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + hand_pose = torch.einsum('bi,ij->bj', + [hand_pose, self.hand_components]) + + full_pose = torch.cat([global_orient, hand_pose], dim=1) + full_pose += self.pose_mean + + vertices, joints = lbs( + betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=True, + ) + + # # Add pre-selected extra joints that might be needed + # joints = self.vertex_joint_selector(vertices, joints) + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = MANOOutput(vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + hand_pose=hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class MANOLayer(MANO): + + def __init__(self, *args, **kwargs) -> None: + ''' MANO as a layer model constructor + ''' + super(MANOLayer, self).__init__(create_global_orient=False, + create_hand_pose=False, + create_betas=False, + create_transl=False, + *args, + **kwargs) + + def name(self) -> str: + return 'MANO' + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs) -> MANOOutput: + ''' Forward pass for the MANO model + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if hand_pose is None: + hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, + device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat([global_orient, hand_pose], dim=1) + vertices, joints = lbs(betas, + full_pose, + self.v_template, + self.shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=False) + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = MANOOutput(vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + hand_pose=hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class FLAME(SMPL): + NUM_JOINTS = 5 + SHAPE_SPACE_DIM = 300 + EXPRESSION_SPACE_DIM = 100 + NECK_IDX = 0 + + def __init__(self, + model_path: str, + data_struct=None, + num_expression_coeffs=10, + create_expression: bool = True, + expression: Optional[Tensor] = None, + create_neck_pose: bool = True, + neck_pose: Optional[Tensor] = None, + create_jaw_pose: bool = True, + jaw_pose: Optional[Tensor] = None, + create_leye_pose: bool = True, + leye_pose: Optional[Tensor] = None, + create_reye_pose=True, + reye_pose: Optional[Tensor] = None, + use_face_contour=False, + batch_size: int = 1, + gender: str = 'neutral', + dtype: torch.dtype = torch.float32, + ext='pkl', + **kwargs) -> None: + ''' FLAME model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + num_expression_coeffs: int, optional + Number of expression components to use + (default = 10). + create_expression: bool, optional + Flag for creating a member variable for the expression space + (default = True). + expression: torch.tensor, optional, Bx10 + The default value for the expression member variable. + (default = None) + create_neck_pose: bool, optional + Flag for creating a member variable for the neck pose. + (default = False) + neck_pose: torch.tensor, optional, Bx3 + The default value for the neck pose variable. + (default = None) + create_jaw_pose: bool, optional + Flag for creating a member variable for the jaw pose. + (default = False) + jaw_pose: torch.tensor, optional, Bx3 + The default value for the jaw pose variable. + (default = None) + create_leye_pose: bool, optional + Flag for creating a member variable for the left eye pose. + (default = False) + leye_pose: torch.tensor, optional, Bx10 + The default value for the left eye pose variable. + (default = None) + create_reye_pose: bool, optional + Flag for creating a member variable for the right eye pose. + (default = False) + reye_pose: torch.tensor, optional, Bx10 + The default value for the right eye pose variable. + (default = None) + use_face_contour: bool, optional + Whether to compute the keypoints that form the facial contour + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype + The data type for the created variables + ''' + model_fn = f'FLAME_{gender.upper()}.{ext}' + flame_path = os.path.join(model_path, model_fn) + assert osp.exists(flame_path), 'Path {} does not exist!'.format( + flame_path) + if ext == 'npz': + file_data = np.load(flame_path, allow_pickle=True) + elif ext == 'pkl': + with open(flame_path, 'rb') as smpl_file: + file_data = pickle.load(smpl_file, encoding='latin1') + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**file_data) + + super(FLAME, self).__init__(model_path=model_path, + data_struct=data_struct, + dtype=dtype, + batch_size=batch_size, + gender=gender, + ext=ext, + **kwargs) + + self.use_face_contour = use_face_contour + + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + [], dtype=torch.long) + + if create_neck_pose: + if neck_pose is None: + default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_neck_pose = torch.tensor(neck_pose, dtype=dtype) + neck_pose_param = nn.Parameter(default_neck_pose, + requires_grad=True) + self.register_parameter('neck_pose', neck_pose_param) + + if create_jaw_pose: + if jaw_pose is None: + default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) + jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) + self.register_parameter('jaw_pose', jaw_pose_param) + + if create_leye_pose: + if leye_pose is None: + default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_leye_pose = torch.tensor(leye_pose, dtype=dtype) + leye_pose_param = nn.Parameter(default_leye_pose, + requires_grad=True) + self.register_parameter('leye_pose', leye_pose_param) + + if create_reye_pose: + if reye_pose is None: + default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_reye_pose = torch.tensor(reye_pose, dtype=dtype) + reye_pose_param = nn.Parameter(default_reye_pose, + requires_grad=True) + self.register_parameter('reye_pose', reye_pose_param) + + shapedirs = data_struct.shapedirs + if len(shapedirs.shape) < 3: + shapedirs = shapedirs[:, :, None] + if (shapedirs.shape[-1] < + self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): + # print(f'WARNING: You are using a {self.name()} model, with only' + # ' 10 shape and 10 expression coefficients.') + expr_start_idx = 10 + expr_end_idx = 20 + num_expression_coeffs = min(num_expression_coeffs, 10) + else: + expr_start_idx = self.SHAPE_SPACE_DIM + expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs + num_expression_coeffs = min(num_expression_coeffs, + self.EXPRESSION_SPACE_DIM) + + self._num_expression_coeffs = num_expression_coeffs + + expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] + self.register_buffer('expr_dirs', + to_tensor(to_np(expr_dirs), dtype=dtype)) + + if create_expression: + if expression is None: + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype) + else: + default_expression = torch.tensor(expression, dtype=dtype) + expression_param = nn.Parameter(default_expression, + requires_grad=True) + self.register_parameter('expression', expression_param) + + # The pickle file that contains the barycentric coordinates for + # regressing the landmarks + landmark_bcoord_filename = osp.join(model_path, + 'flame_static_embedding.pkl') + + with open(landmark_bcoord_filename, 'rb') as fp: + landmarks_data = pickle.load(fp, encoding='latin1') + + lmk_faces_idx = landmarks_data['lmk_face_idx'].astype(np.int64) + self.register_buffer('lmk_faces_idx', + torch.tensor(lmk_faces_idx, dtype=torch.long)) + lmk_bary_coords = landmarks_data['lmk_b_coords'] + self.register_buffer('lmk_bary_coords', + torch.tensor(lmk_bary_coords, dtype=dtype)) + if self.use_face_contour: + face_contour_path = os.path.join(model_path, + 'flame_dynamic_embedding.npy') + contour_embeddings = np.load(face_contour_path, + allow_pickle=True, + encoding='latin1')[()] + + dynamic_lmk_faces_idx = np.array( + contour_embeddings['lmk_face_idx'], dtype=np.int64) + dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, + dtype=torch.long) + self.register_buffer('dynamic_lmk_faces_idx', + dynamic_lmk_faces_idx) + + dynamic_lmk_b_coords = torch.tensor( + contour_embeddings['lmk_b_coords'], dtype=dtype) + self.register_buffer('dynamic_lmk_bary_coords', + dynamic_lmk_b_coords) + + neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) + self.register_buffer( + 'neck_kin_chain', torch.tensor(neck_kin_chain, + dtype=torch.long)) + + @property + def num_expression_coeffs(self): + return self._num_expression_coeffs + + def name(self) -> str: + return 'FLAME' + + def extra_repr(self): + msg = [ + super(FLAME, self).extra_repr(), + f'Number of Expression Coefficients: {self.num_expression_coeffs}', + f'Use face contour: {self.use_face_contour}', + ] + return '\n'.join(msg) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + neck_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs) -> FLAMEOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape Bx10 + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape Bx10 + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient + if global_orient is not None else self.global_orient) + jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose + neck_pose = neck_pose if neck_pose is not None else self.neck_pose + + leye_pose = leye_pose if leye_pose is not None else self.leye_pose + reye_pose = reye_pose if reye_pose is not None else self.reye_pose + + betas = betas if betas is not None else self.betas + expression = expression if expression is not None else self.expression + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + full_pose = torch.cat( + [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) + + batch_size = max(betas.shape[0], global_orient.shape[0], + jaw_pose.shape[0]) + # Concatenate the shape and expression coefficients + scale = int(batch_size / betas.shape[0]) + if scale > 1: + betas = betas.expand(scale, -1) + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs( + shape_components, + full_pose, + self.v_template, + shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=pose2rot, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( + batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + self.batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, + full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=True, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat([ + lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords + ], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = FLAMEOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + neck_pose=neck_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None) + return output + + +class FLAMELayer(FLAME): + + def __init__(self, *args, **kwargs) -> None: + ''' FLAME as a layer model constructor ''' + super(FLAMELayer, self).__init__(create_betas=False, + create_expression=False, + create_global_orient=False, + create_neck_pose=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + *args, + **kwargs) + + def forward(self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + neck_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs) -> FLAMEOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if neck_pose is None: + neck_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, + device=device) + if expression is None: + expression = torch.zeros([batch_size, self.num_expression_coeffs], + dtype=dtype, + device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat( + [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) + + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs( + shape_components, + full_pose, + self.v_template, + shapedirs, + self.posedirs, + self.J_regressor, + self.parents, + self.lbs_weights, + pose2rot=False, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( + batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + self.batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, + full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=False, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat([ + lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords + ], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = FLAMEOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + neck_pose=neck_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None) + return output + + +def build_layer( + model_path: str, + model_type: str = 'smpl', + **kwargs +) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: + ''' Method for creating a model from a path and a model type + + Parameters + ---------- + model_path: str + Either the path to the model you wish to load or a folder, + where each subfolder contains the differents types, i.e.: + model_path: + | + |-- smpl + |-- SMPL_FEMALE + |-- SMPL_NEUTRAL + |-- SMPL_MALE + |-- smplh + |-- SMPLH_FEMALE + |-- SMPLH_MALE + |-- smplx + |-- SMPLX_FEMALE + |-- SMPLX_NEUTRAL + |-- SMPLX_MALE + |-- mano + |-- MANO RIGHT + |-- MANO LEFT + |-- flame + |-- FLAME_FEMALE + |-- FLAME_MALE + |-- FLAME_NEUTRAL + + model_type: str, optional + When model_path is a folder, then this parameter specifies the + type of model to be loaded + **kwargs: dict + Keyword arguments + + Returns + ------- + body_model: nn.Module + The PyTorch module that implements the corresponding body model + Raises + ------ + ValueError: In case the model type is not one of SMPL, SMPLH, + SMPLX, MANO or FLAME + ''' + + if osp.isdir(model_path): + model_path = os.path.join(model_path, model_type) + else: + model_type = osp.basename(model_path).split('_')[0].lower() + + if model_type.lower() == 'smpl': + return SMPLLayer(model_path, **kwargs) + elif model_type.lower() == 'smplh': + return SMPLHLayer(model_path, **kwargs) + elif model_type.lower() == 'smplx': + return SMPLXLayer(model_path, **kwargs) + elif 'mano' in model_type.lower(): + return MANOLayer(model_path, **kwargs) + elif 'flame' in model_type.lower(): + return FLAMELayer(model_path, **kwargs) + else: + raise ValueError(f'Unknown model type {model_type}, exiting!') + + +def create(model_path: str, + model_type: str = 'smpl', + **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: + ''' Method for creating a model from a path and a model type + + Parameters + ---------- + model_path: str + Either the path to the model you wish to load or a folder, + where each subfolder contains the differents types, i.e.: + model_path: + | + |-- smpl + |-- SMPL_FEMALE + |-- SMPL_NEUTRAL + |-- SMPL_MALE + |-- smplh + |-- SMPLH_FEMALE + |-- SMPLH_MALE + |-- smplx + |-- SMPLX_FEMALE + |-- SMPLX_NEUTRAL + |-- SMPLX_MALE + |-- mano + |-- MANO RIGHT + |-- MANO LEFT + + model_type: str, optional + When model_path is a folder, then this parameter specifies the + type of model to be loaded + **kwargs: dict + Keyword arguments + + Returns + ------- + body_model: nn.Module + The PyTorch module that implements the corresponding body model + Raises + ------ + ValueError: In case the model type is not one of SMPL, SMPLH, + SMPLX, MANO or FLAME + ''' + + # If it's a folder, assume + if osp.isdir(model_path): + model_path = os.path.join(model_path, model_type) + else: + model_type = osp.basename(model_path).split('_')[0].lower() + + if model_type.lower() == 'smpl': + return SMPL(model_path, **kwargs) + elif model_type.lower() == 'smplh': + return SMPLH(model_path, **kwargs) + elif model_type.lower() == 'smplx': + return SMPLX(model_path, **kwargs) + elif 'mano' in model_type.lower(): + return MANO(model_path, **kwargs) + elif 'flame' in model_type.lower(): + return FLAME(model_path, **kwargs) + else: + raise ValueError(f'Unknown model type {model_type}, exiting!') diff --git a/lib/smplx/joint_names.py b/lib/smplx/joint_names.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3a10f8cef8b50075dc9f680459fc5d596a0013 --- /dev/null +++ b/lib/smplx/joint_names.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +JOINT_NAMES = [ + 'pelvis', + 'left_hip', + 'right_hip', + 'spine1', + 'left_knee', + 'right_knee', + 'spine2', + 'left_ankle', + 'right_ankle', + 'spine3', + 'left_foot', + 'right_foot', + 'neck', + 'left_collar', + 'right_collar', + 'head', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'jaw', + 'left_eye_smplhf', + 'right_eye_smplhf', + 'left_index1', + 'left_index2', + 'left_index3', + 'left_middle1', + 'left_middle2', + 'left_middle3', + 'left_pinky1', + 'left_pinky2', + 'left_pinky3', + 'left_ring1', + 'left_ring2', + 'left_ring3', + 'left_thumb1', + 'left_thumb2', + 'left_thumb3', + 'right_index1', + 'right_index2', + 'right_index3', + 'right_middle1', + 'right_middle2', + 'right_middle3', + 'right_pinky1', + 'right_pinky2', + 'right_pinky3', + 'right_ring1', + 'right_ring2', + 'right_ring3', + 'right_thumb1', + 'right_thumb2', + 'right_thumb3', + 'nose', + 'right_eye', + 'left_eye', + 'right_ear', + 'left_ear', + 'left_big_toe', + 'left_small_toe', + 'left_heel', + 'right_big_toe', + 'right_small_toe', + 'right_heel', + 'left_thumb', + 'left_index', + 'left_middle', + 'left_ring', + 'left_pinky', + 'right_thumb', + 'right_index', + 'right_middle', + 'right_ring', + 'right_pinky', + 'right_eye_brow1', + 'right_eye_brow2', + 'right_eye_brow3', + 'right_eye_brow4', + 'right_eye_brow5', + 'left_eye_brow5', + 'left_eye_brow4', + 'left_eye_brow3', + 'left_eye_brow2', + 'left_eye_brow1', + 'nose1', + 'nose2', + 'nose3', + 'nose4', + 'right_nose_2', + 'right_nose_1', + 'nose_middle', + 'left_nose_1', + 'left_nose_2', + 'right_eye1', + 'right_eye2', + 'right_eye3', + 'right_eye4', + 'right_eye5', + 'right_eye6', + 'left_eye4', + 'left_eye3', + 'left_eye2', + 'left_eye1', + 'left_eye6', + 'left_eye5', + 'right_mouth_1', + 'right_mouth_2', + 'right_mouth_3', + 'mouth_top', + 'left_mouth_3', + 'left_mouth_2', + 'left_mouth_1', + 'left_mouth_5', # 59 in OpenPose output + 'left_mouth_4', # 58 in OpenPose output + 'mouth_bottom', + 'right_mouth_4', + 'right_mouth_5', + 'right_lip_1', + 'right_lip_2', + 'lip_top', + 'left_lip_2', + 'left_lip_1', + 'left_lip_3', + 'lip_bottom', + 'right_lip_3', + # Face contour + 'right_contour_1', + 'right_contour_2', + 'right_contour_3', + 'right_contour_4', + 'right_contour_5', + 'right_contour_6', + 'right_contour_7', + 'right_contour_8', + 'contour_middle', + 'left_contour_8', + 'left_contour_7', + 'left_contour_6', + 'left_contour_5', + 'left_contour_4', + 'left_contour_3', + 'left_contour_2', + 'left_contour_1', +] diff --git a/lib/smplx/lbs.py b/lib/smplx/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..6d32bcd0a538ba6c6943789511d18272cee1ee22 --- /dev/null +++ b/lib/smplx/lbs.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from typing import Tuple, List, Optional +import numpy as np + +import torch +import torch.nn.functional as F + +from .utils import rot_mat_to_euler, Tensor + + +def find_dynamic_lmk_idx_and_bcoords( + vertices: Tensor, + pose: Tensor, + dynamic_lmk_faces_idx: Tensor, + dynamic_lmk_b_coords: Tensor, + neck_kin_chain: List[int], + pose2rot: bool = True, +) -> Tuple[Tensor, Tensor]: + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + neck_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + dtype = vertices.dtype + batch_size = vertices.shape[0] + + if pose2rot: + aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + neck_kin_chain) + rot_mats = batch_rodrigues(aa_pose.view(-1, + 3)).view(batch_size, -1, 3, 3) + else: + rot_mats = torch.index_select(pose.view(batch_size, -1, 3, 3), 1, + neck_kin_chain) + + rel_rot_mat = torch.eye(3, device=vertices.device, + dtype=dtype).unsqueeze_(dim=0).repeat( + batch_size, 1, 1) + for idx in range(len(neck_kin_chain)): + rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, + y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks(vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, + lmk_bary_coords: Tensor) -> Tensor: + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( + batch_size, -1, 3) + + lmk_faces += torch.arange(batch_size, dtype=torch.long, + device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def lbs( + betas: Tensor, + pose: Tensor, + v_template: Tensor, + shapedirs: Tensor, + posedirs: Tensor, + J_regressor: Tensor, + parents: Tensor, + lbs_weights: Tensor, + pose2rot: bool = True, + return_transformation: bool = False, +) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device, dtype = betas.device, betas.dtype + + # Add shape contribution + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + + # 3. Add pose blend shapes + # N x J x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + if pose2rot: + rot_mats = batch_rodrigues(pose.view(-1, + 3)).view([batch_size, -1, 3, 3]) + + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + # (N x P) x (P, V * 3) -> N x V x 3 + pose_offsets = torch.matmul(pose_feature, + posedirs).view(batch_size, -1, 3) + else: + pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident + rot_mats = pose.view(batch_size, -1, 3, 3) + + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, + device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + if return_transformation: + return verts, J_transformed, A, T + + return verts, J_transformed + + +def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor: + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor: + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues( + rot_vecs: Tensor, + epsilon: float = 1e-8, +) -> Tensor: + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device, dtype = rot_vecs.device, rot_vecs.dtype + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def transform_mat(R: Tensor, t: Tensor) -> Tensor: + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], + dim=2) + + +def batch_rigid_transform(rot_mats: Tensor, + joints: Tensor, + parents: Tensor, + dtype=torch.float32) -> Tensor: + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape( + -1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, + i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms diff --git a/lib/smplx/utils.py b/lib/smplx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2d37e09f0fa06fdd5949e936ba501db0c1ca9b --- /dev/null +++ b/lib/smplx/utils.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import NewType, Union, Optional +from dataclasses import dataclass, asdict, fields +import numpy as np +import torch + +Tensor = NewType('Tensor', torch.Tensor) +Array = NewType('Array', np.ndarray) + + +@dataclass +class ModelOutput: + vertices: Optional[Tensor] = None + joints: Optional[Tensor] = None + full_pose: Optional[Tensor] = None + global_orient: Optional[Tensor] = None + transl: Optional[Tensor] = None + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + def __iter__(self): + return self.keys() + + def keys(self): + keys = [t.name for t in fields(self)] + return iter(keys) + + def values(self): + values = [getattr(self, t.name) for t in fields(self)] + return iter(values) + + def items(self): + data = [(t.name, getattr(self, t.name)) for t in fields(self)] + return iter(data) + + +@dataclass +class SMPLOutput(ModelOutput): + betas: Optional[Tensor] = None + body_pose: Optional[Tensor] = None + + +@dataclass +class SMPLHOutput(SMPLOutput): + left_hand_pose: Optional[Tensor] = None + right_hand_pose: Optional[Tensor] = None + transl: Optional[Tensor] = None + + +@dataclass +class SMPLXOutput(SMPLHOutput): + expression: Optional[Tensor] = None + jaw_pose: Optional[Tensor] = None + joint_transformation: Optional[Tensor] = None + vertex_transformation: Optional[Tensor] = None + + +@dataclass +class MANOOutput(ModelOutput): + betas: Optional[Tensor] = None + hand_pose: Optional[Tensor] = None + + +@dataclass +class FLAMEOutput(ModelOutput): + betas: Optional[Tensor] = None + expression: Optional[Tensor] = None + jaw_pose: Optional[Tensor] = None + neck_pose: Optional[Tensor] = None + + +def find_joint_kin_chain(joint_id, kinematic_tree): + kin_chain = [] + curr_idx = joint_id + while curr_idx != -1: + kin_chain.append(curr_idx) + curr_idx = kinematic_tree[curr_idx] + return kin_chain + + +def to_tensor(array: Union[Array, Tensor], dtype=torch.float32) -> Tensor: + if torch.is_tensor(array): + return array + else: + return torch.tensor(array, dtype=dtype) + + +class Struct(object): + + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) diff --git a/lib/smplx/vertex_ids.py b/lib/smplx/vertex_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc7d88b5eaf4bdec7d1b1eaf2049f567bc1b8d1 --- /dev/null +++ b/lib/smplx/vertex_ids.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +# Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to +# MSCOCO and OpenPose joints +vertex_ids = { + 'smplh': { + 'nose': 332, + 'reye': 6260, + 'leye': 2800, + 'rear': 4071, + 'lear': 583, + 'rthumb': 6191, + 'rindex': 5782, + 'rmiddle': 5905, + 'rring': 6016, + 'rpinky': 6133, + 'lthumb': 2746, + 'lindex': 2319, + 'lmiddle': 2445, + 'lring': 2556, + 'lpinky': 2673, + 'LBigToe': 3216, + 'LSmallToe': 3226, + 'LHeel': 3387, + 'RBigToe': 6617, + 'RSmallToe': 6624, + 'RHeel': 6787 + }, + 'smplx': { + 'nose': 9120, + 'reye': 9929, + 'leye': 9448, + 'rear': 616, + 'lear': 6, + 'rthumb': 8079, + 'rindex': 7669, + 'rmiddle': 7794, + 'rring': 7905, + 'rpinky': 8022, + 'lthumb': 5361, + 'lindex': 4933, + 'lmiddle': 5058, + 'lring': 5169, + 'lpinky': 5286, + 'LBigToe': 5770, + 'LSmallToe': 5780, + 'LHeel': 8846, + 'RBigToe': 8463, + 'RSmallToe': 8474, + 'RHeel': 8635 + }, + 'mano': { + 'thumb': 744, + 'index': 320, + 'middle': 443, + 'ring': 554, + 'pinky': 671, + } +} diff --git a/lib/smplx/vertex_joint_selector.py b/lib/smplx/vertex_joint_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..71a849a09350a0035db93830e8d09a29072ee28d --- /dev/null +++ b/lib/smplx/vertex_joint_selector.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + +import torch +import torch.nn as nn + +from .utils import to_tensor + + +class VertexJointSelector(nn.Module): + + def __init__(self, + vertex_ids=None, + use_hands=True, + use_feet_keypoints=True, + **kwargs): + super(VertexJointSelector, self).__init__() + + extra_joints_idxs = [] + + face_keyp_idxs = np.array([ + vertex_ids['nose'], vertex_ids['reye'], vertex_ids['leye'], + vertex_ids['rear'], vertex_ids['lear'] + ], + dtype=np.int64) + + extra_joints_idxs = np.concatenate([extra_joints_idxs, face_keyp_idxs]) + + if use_feet_keypoints: + feet_keyp_idxs = np.array([ + vertex_ids['LBigToe'], vertex_ids['LSmallToe'], + vertex_ids['LHeel'], vertex_ids['RBigToe'], + vertex_ids['RSmallToe'], vertex_ids['RHeel'] + ], + dtype=np.int32) + + extra_joints_idxs = np.concatenate( + [extra_joints_idxs, feet_keyp_idxs]) + + if use_hands: + self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] + + tips_idxs = [] + for hand_id in ['l', 'r']: + for tip_name in self.tip_names: + tips_idxs.append(vertex_ids[hand_id + tip_name]) + + extra_joints_idxs = np.concatenate([extra_joints_idxs, tips_idxs]) + + self.register_buffer('extra_joints_idxs', + to_tensor(extra_joints_idxs, dtype=torch.long)) + + def forward(self, vertices, joints): + extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) + joints = torch.cat([joints, extra_joints], dim=1) + + return joints diff --git a/mvdiffusion/data/dreamdata.py b/mvdiffusion/data/dreamdata.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc90435cf90c2403d2f06300fdcf51649f33b92 --- /dev/null +++ b/mvdiffusion/data/dreamdata.py @@ -0,0 +1,544 @@ +import numpy as np +import torch +from torch.utils.data import Dataset +import json +from typing import Tuple, Optional, Any +import cv2 +import random +import os +import math +from PIL import Image, ImageOps +from .normal_utils import worldNormal2camNormal, img2normal, norm_normalize +from icecream import ic +def shift_list(lst, n): + length = len(lst) + n = n % length # Ensure n is within the range of the list length + return lst[-n:] + lst[:-n] + + +class ObjaverseDataset(Dataset): + def __init__(self, + root_dir: str, + azi_interval: float, + random_views: int, + predict_relative_views: list, + bg_color: Any, + object_list: str, + prompt_embeds_path: str, + img_wh: Tuple[int, int], + validation: bool = False, + num_validation_samples: int = 64, + num_samples: Optional[int] = None, + invalid_list: Optional[str] = None, + trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view + # augment_data: bool = False, + side_views_rate: float = 0., + read_normal: bool = True, + read_color: bool = False, + read_depth: bool = False, + mix_color_normal: bool = False, + random_view_and_domain: bool = False, + load_cache: bool = False, + exten: str = '.png', + elevation_list: Optional[str] = None, + with_smpl: Optional[bool] = False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = root_dir + self.fixed_views = int(360 // azi_interval) + self.bg_color = bg_color + self.validation = validation + self.num_samples = num_samples + self.trans_norm_system = trans_norm_system + # self.augment_data = augment_data + self.img_wh = img_wh + self.read_normal = read_normal + self.read_color = read_color + self.read_depth = read_depth + self.mix_color_normal = mix_color_normal # mix load color and normal maps + self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view + self.random_views = random_views + self.load_cache = load_cache + self.total_views = int(self.fixed_views * (self.random_views + 1)) + self.predict_relative_views = predict_relative_views + self.pred_view_nums = len(self.predict_relative_views) + self.exten = exten + self.side_views_rate = side_views_rate + self.with_smpl = with_smpl + if self.with_smpl: + self.smpl_image_path = 'smpl_image' + self.smpl_normal_path = 'smpl_normal' + + + ic(self.total_views) + ic(self.fixed_views) + ic(self.predict_relative_views) + ic(self.with_smpl) + + self.objects = [] + if object_list is not None: + for dataset_list in object_list: + with open(dataset_list, 'r') as f: + objects = json.load(f) + self.objects.extend(objects) + else: + self.objects = os.listdir(self.root_dir) + + # load fixed camera poses + self.trans_cv2gl_mat = np.linalg.inv(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])) + self.fix_cam_poses = [] + camera_path = os.path.join(self.root_dir, self.objects[0], 'camera') + for vid in range(0, self.total_views, self.random_views+1): + cam_info = np.load(f'{camera_path}/{vid:03d}.npy', allow_pickle=True).item() + assert cam_info['camera'] == 'ortho', 'Only support predict ortho camera !!!' + self.fix_cam_poses.append(cam_info['extrinsic']) + random.shuffle(self.objects) + + + if elevation_list: + with open(elevation_list, 'r') as f: + ele_list = [o.strip() for o in f.readlines()] + self.objects = set(ele_list) & set(self.objects) + + self.all_objects = set(self.objects) + self.all_objects = list(self.all_objects) + + self.validation = validation + if not validation: + self.all_objects = self.all_objects[:-num_validation_samples] + # print('Warning: you are fitting in small-scale dataset') + # self.all_objects = self.all_objects + else: + self.all_objects = self.all_objects[-num_validation_samples:] + + if num_samples is not None: + self.all_objects = self.all_objects[:num_samples] + ic(len(self.all_objects)) + print(f"loaded {len(self.all_objects)} in the dataset") + + normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') + color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') + if len(self.predict_relative_views) == 6: + self.normal_prompt_embedding = normal_prompt_embedding + self.color_prompt_embedding = color_prompt_embedding + elif len(self.predict_relative_views) == 4: + self.normal_prompt_embedding = torch.stack([normal_prompt_embedding[0], normal_prompt_embedding[2], normal_prompt_embedding[3], normal_prompt_embedding[4], normal_prompt_embedding[6]] , 0) + self.color_prompt_embedding = torch.stack([color_prompt_embedding[0], color_prompt_embedding[2], color_prompt_embedding[3], color_prompt_embedding[4], color_prompt_embedding[6]] , 0) + + # flip back and left views + if len(self.predict_relative_views) == 6: + self.flip_views = [3, 4] + elif len(self.predict_relative_views) == 4: + self.flip_views = [2, 3] + + # self.backup_data = self.__getitem_norm__(0, 'Thuman2.0/0340') + self.backup_data = self.__getitem_norm__(0) + + def trans_cv2gl(self, rt): + r, t = rt[:3, :3], rt[:3, -1] + r = np.matmul(self.trans_cv2gl_mat, r) + t = np.matmul(self.trans_cv2gl_mat, t) + return np.concatenate([r, t[:, None]], axis=-1) + + def cartesian_to_spherical(self, xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + def get_T(self, target_RT, cond_RT): + R, T = target_RT[:3, :3], target_RT[:3, -1] + T_target = -R.T @ T # change to cam2world + + R, T = cond_RT[:3, :3], cond_RT[:3, -1] + T_cond = -R.T @ T + + theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_theta, d_azimuth + + def get_bg_color(self): + if self.bg_color == 'white': + bg_color = np.array([1., 1., 1.], dtype=np.float32) + elif self.bg_color == 'black': + bg_color = np.array([0., 0., 0.], dtype=np.float32) + elif self.bg_color == 'gray': + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif self.bg_color == 'random': + bg_color = np.random.rand(3) + elif self.bg_color == 'three_choices': + white = np.array([1., 1., 1.], dtype=np.float32) + black = np.array([0., 0., 0.], dtype=np.float32) + gray = np.array([0.5, 0.5, 0.5], dtype=np.float32) + bg_color = random.choice([white, black, gray]) + elif isinstance(self.bg_color, float): + bg_color = np.array([self.bg_color] * 3, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + + def crop_image(self, top_left, img): + size = max(self.img_wh) + tar_size = size - top_left * 2 + + alpha_np = np.asarray(img)[:, :, 3] + + + coords = np.argwhere(alpha_np > 0.5) + x_min, y_min = coords.min(axis=0) + x_max, y_max = coords.max(axis=0) + + img = img.crop((x_min, y_min, x_max, y_max)).resize((tar_size, tar_size)) + img = ImageOps.expand(img, border=(top_left, top_left, top_left, top_left), fill=0) + return img + + def load_cropped_img(self, img_path, bg_color, top_left, return_type='np'): + rgba = Image.open(img_path) + rgba = self.crop_image(top_left, rgba) + rgba = np.array(rgba) + rgba = rgba.astype(np.float32) / 255. # [0, 1] + img, alpha = rgba[..., :3], rgba[..., 3:4] + + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + + def load_image(self, img_path, bg_color, alpha=None, return_type='np'): + # not using cv2 as may load in uint16 format + # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255] + # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC) + # pil always returns uint8 + rgba = np.array(Image.open(img_path).resize(self.img_wh)) + rgba = rgba.astype(np.float32) / 255. # [0, 1] + + img = rgba[..., :3] + if alpha is None: + assert rgba.shape[-1] == 4 + alpha = rgba[..., 3:4] + assert alpha.sum() > 1e-8, 'w/o foreground' + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + + def load_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'): + normal_np = np.array(Image.open(img_path).resize(self.img_wh))[:, :, :3] + assert np.var(normal_np) > 1e-8, 'pure normal' + normal_cv = img2normal(normal_np) + + normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv) + normal_relative_cv = norm_normalize(normal_relative_cv) + + normal_relative_gl = normal_relative_cv + normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:] + + img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1] + + if alpha.shape[-1] != 1: + alpha = alpha[:, :, None] + + + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + else: + raise NotImplementedError + + return img + + def load_halfbody_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'): + normal_np = np.array(Image.open(img_path).resize(self.img_wh).crop((256, 0, 512, 256)).resize(self.img_wh))[:, :, :3] + assert np.var(normal_np) > 1e-8, 'pure normal' + normal_cv = img2normal(normal_np) + + normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv) + normal_relative_cv = norm_normalize(normal_relative_cv) + # normal_relative_gl = normal_relative_cv[..., [ 0, 2, 1]] + # normal_relative_gl[..., 2] = -normal_relative_gl[..., 2] + normal_relative_gl = normal_relative_cv + normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:] + + img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1] + + if alpha.shape[-1] != 1: + alpha = alpha[:, :, None] + + + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + else: + raise NotImplementedError + + return img + + def __len__(self): + return len(self.all_objects) + + def load_halfbody_image(self, img_path, bg_color, alpha=None, return_type='np'): + + + rgba = np.array(Image.open(img_path).resize(self.img_wh).crop((256, 0, 512, 256)).resize(self.img_wh)) + rgba = rgba.astype(np.float32) / 255. # [0, 1] + + img = rgba[..., :3] + if alpha is None: + assert rgba.shape[-1] == 4 + alpha = rgba[..., 3:4] + assert alpha.sum() > 1e-8, 'w/o foreground' + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + def __getitem_norm__(self, index, debug_object=None): + # get the bg color + bg_color = self.get_bg_color() + if debug_object is not None: + object_name = debug_object + else: + object_name = self.all_objects[index % len(self.all_objects)] + face_info = np.load(f'{self.root_dir}/{object_name}/face_info.npy', allow_pickle=True).item() + # front_fixed_idx = face_info['top3_vid'][0] // (self.random_views+1) + if self.side_views_rate > 0 and random.random() < self.side_views_rate: + front_fixed_idx = random.choice(face_info['top3_vid']) + else: + front_fixed_idx = face_info['top3_vid'][0] + with_face_idx = list(face_info.keys()) + with_face_idx.remove('top3_vid') + + assert front_fixed_idx in with_face_idx, 'not detected face' + + if self.validation: + cond_ele0_idx = front_fixed_idx + cond_random_idx = 0 + else: + if object_name[:9] == 'realistic': # This dataset set has random pose + cond_ele0_idx = random.choice(range(self.fixed_views)) + cond_random_idx = random.choice(range(self.random_views+1)) + else: + cond_vid = front_fixed_idx + cond_ele0_idx = cond_vid // (self.random_views + 1) + cond_ele0_vid = cond_ele0_idx * (self.random_views + 1) + cond_random_idx = 0 + + # condition info + cond_ele0_vid = cond_ele0_idx * (self.random_views + 1) + cond_vid = cond_ele0_vid + cond_random_idx + cond_ele0_w2c = self.fix_cam_poses[cond_ele0_idx] + + img_tensors_in = [ + self.load_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1) + ] * self.pred_view_nums + [ + self.load_halfbody_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1) + ] + + # output info + pred_vids = [(cond_ele0_vid + i * (self.random_views+1)) % self.total_views for i in self.predict_relative_views] + # pred_w2cs = [self.fix_cam_poses[(cond_ele0_idx + i) % self.fixed_views] for i in self.predict_relative_views] + img_tensors_out = [] + normal_tensors_out = [] + smpl_tensors_in = [] + for i, vid in enumerate(pred_vids): + # output image + img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt') + img_tensor = img_tensor.permute(2, 0, 1) # (3, H, W) + if i in self.flip_views: img_tensor = torch.flip(img_tensor, [2]) + img_tensors_out.append(img_tensor) + + # output normal + normal_tensor = self.load_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, alpha_.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1) + if i in self.flip_views: normal_tensor = torch.flip(normal_tensor, [2]) + normal_tensors_out.append(normal_tensor) + + # input smpl image + if self.with_smpl: + smpl_image_tensor, smpl_alpha_ = self.load_image(f"{self.root_dir}/{object_name}/{self.smpl_image_path}/{vid:03d}{self.exten}", bg_color, return_type='pt') + smpl_image_tensor = smpl_image_tensor.permute(2, 0, 1) # (3, H, W) + if i in self.flip_views: smpl_image_tensor = torch.flip(smpl_image_tensor, [2]) + smpl_tensors_in.append(smpl_image_tensor) + + # faces + if i == 0: + face_clr_out, face_alpha_out = self.load_halfbody_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt') + face_clr_out = face_clr_out.permute(2, 0, 1) + face_nrm_out = self.load_halfbody_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, face_alpha_out.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1) + if self.with_smpl: + face_smpl_in = self.load_halfbody_image(f"{self.root_dir}/{object_name}/{self.smpl_image_path}/{vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1) + + img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) + img_tensors_out.append(face_clr_out) + img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W) + normal_tensors_out.append(face_nrm_out) + normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W) + + if self.with_smpl: + smpl_tensors_in = smpl_tensors_in + [face_smpl_in] + smpl_tensors_in = torch.stack(smpl_tensors_in, dim=0).float() # (Nv, 3, H, W) + + item = { + 'id': object_name.replace('/', '_'), + 'vid':cond_vid, + 'imgs_in': img_tensors_in, + 'imgs_out': img_tensors_out, + 'normals_out': normal_tensors_out, + 'normal_prompt_embeddings': self.normal_prompt_embedding, + 'color_prompt_embeddings': self.color_prompt_embedding, + } + if self.with_smpl: + item.update({'smpl_imgs_in': smpl_tensors_in}) + return item + + def __getitem__(self, index): + try: + data = self.__getitem_norm__(index) + return data + except: + print("load error ", self.all_objects[index%len(self.all_objects)] ) + return self.backup_data + + +def draw_kps(image, kps): + nose_pos = kps[2].astype(np.int32) + top_left = nose_pos - 64 + bottom_right = nose_pos + 64 + image_cv = image.copy() + img = cv2.rectangle(image_cv, tuple(top_left), tuple(bottom_right), (0, 255, 0), 2) + return img + +if __name__ == "__main__": + # pass + from torch.utils.data import DataLoader + from torchvision.utils import make_grid + from PIL import ImageDraw, ImageFont + def draw_text(img, text, pos, color=(128, 128, 128)): + draw = ImageDraw.Draw(img) + # font = ImageFont.truetype(size= size) + font = ImageFont.load_default() + font = font.font_variant(size=10) + draw.text(pos, text, color, font=font) + return img + random.seed(11) + train_params = dict( + root_dir='/aifs4su/mmcode/lipeng/human_8view_with_smplx/', + azi_interval=45., + random_views=0, + predict_relative_views=[0,2,4,6], + bg_color='white', + object_list=['../../data_lists/human_only_scan_with_smplx.json'], + img_wh=(768, 768), + validation=False, + num_validation_samples=10, + read_normal=True, + read_color=True, + read_depth=False, + # mix_color_normal= True, + random_view_and_domain=False, + load_cache=False, + exten='.png', + prompt_embeds_path='fixed_prompt_embeds_7view', + side_views_rate=0.1, + with_smpl=True + ) + train_dataset = ObjaverseDataset(**train_params) + data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0) + if False: + case = 'CustomHumans/0593_00083_06_00101' + batch = train_dataset.__getitem_norm__(0, case) + imgs = [] + obj_name = batch['id'][:8] + imgs_in = batch['imgs_in'] + imgs_out = batch['imgs_out'] + normal_out = batch['normals_out'] + imgs_vis = torch.cat([imgs_in[0:1], imgs_in[-1:], imgs_out, normal_out], 0) + img_vis = make_grid(imgs_vis, nrow=16).permute(1, 2,0) + img_vis = (img_vis.numpy() * 255).astype(np.uint8) + img_vis = Image.fromarray(img_vis) + img_vis = draw_text(img_vis, obj_name, (5, 1)) + img_vis = torch.from_numpy(np.array(img_vis)).permute(2, 0, 1) / 255. + imgs.append(img_vis) + imgs = torch.stack(imgs, dim=0) + img_grid = make_grid(imgs, nrow=4, padding=0) + img_grid = img_grid.permute(1, 2, 0).numpy() + img_grid = (img_grid * 255).astype(np.uint8) + img_grid = Image.fromarray(img_grid) + img_grid.save(f'../../debug/{case.replace("/", "_")}.png') + else: + imgs = [] + i = 0 + for batch in data_loader: + # print(i) + if i < 4: + i += 1 + obj_name = batch['id'][0][:8] + imgs_in = batch['imgs_in'].squeeze(0) + smpl_in = batch['smpl_imgs_in'].squeeze(0) + imgs_out = batch['imgs_out'].squeeze(0) + normal_out = batch['normals_out'].squeeze(0) + imgs_vis = torch.cat([imgs_in[0:1], imgs_in[-1:], smpl_in, imgs_out, normal_out], 0) + img_vis = make_grid(imgs_vis, nrow=12).permute(1, 2,0) + img_vis = (img_vis.numpy() * 255).astype(np.uint8) + print(img_vis.shape) + # import pdb;pdb.set_trace() + # nose_kps = batch['face_kps'][0].numpy() + # print(nose_kps) + # img_vis = draw_kps(img_vis, nose_kps) + img_vis = Image.fromarray(img_vis) + img_vis = draw_text(img_vis, obj_name, (5, 1)) + img_vis = torch.from_numpy(np.array(img_vis)).permute(2, 0, 1) / 255. + imgs.append(img_vis) + else: + break + imgs = torch.stack(imgs, dim=0) + img_grid = make_grid(imgs, nrow=1, padding=0) + img_grid = img_grid.permute(1, 2, 0).numpy() + img_grid = (img_grid * 255).astype(np.uint8) + img_grid = Image.fromarray(img_grid) + img_grid.save('../../debug/noele_imgs_out_10.png') + + + + + diff --git a/mvdiffusion/data/fixed_prompt_embeds_7view/clr_embeds.pt b/mvdiffusion/data/fixed_prompt_embeds_7view/clr_embeds.pt new file mode 100644 index 0000000000000000000000000000000000000000..df8f14b1c4fbb8e7d976c3fa90135943f9ce4ae5 --- /dev/null +++ b/mvdiffusion/data/fixed_prompt_embeds_7view/clr_embeds.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64bf1cc171540aeed9ff9966d00aa2d0434582677b3744423b37242d381a6309 +size 1105067 diff --git a/mvdiffusion/data/fixed_prompt_embeds_7view/normal_embeds.pt b/mvdiffusion/data/fixed_prompt_embeds_7view/normal_embeds.pt new file mode 100644 index 0000000000000000000000000000000000000000..ee403d766058a903346f945900359c4c58d3da8e --- /dev/null +++ b/mvdiffusion/data/fixed_prompt_embeds_7view/normal_embeds.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53d56c2819935fdbd1a9e60c4eee10a8af5fa9be6fe4eccfe44aa921a4af7275 +size 1105082 diff --git a/mvdiffusion/data/generate_fixed_text_embeds.py b/mvdiffusion/data/generate_fixed_text_embeds.py new file mode 100644 index 0000000000000000000000000000000000000000..ce077ff3eb20ec7b310e847901ba0e0182dbd539 --- /dev/null +++ b/mvdiffusion/data/generate_fixed_text_embeds.py @@ -0,0 +1,52 @@ +from transformers import CLIPTokenizer, CLIPTextModel +import torch +import os + +root = '/mnt/data/lipeng/' +pretrained_model_name_or_path = 'stabilityai/stable-diffusion-2-1-unclip' + + +weight_dtype = torch.float16 +device = torch.device("cuda:0") +tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") +text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') +text_encoder = text_encoder.to(device, dtype=weight_dtype) + +def generate_mv_embeds(): + path = './fixed_prompt_embeds_7view' + os.makedirs(path, exist_ok=True) + views = ["front", "front_right", "right", "back", "left", "front_left", 'face'] + clr_prompt = [f"a rendering image of 3D human, {view} view, color map." for view in views] + normal_prompt = [f"a rendering image of 3D human, {view} view, normal map." for view in views] + + + for id, text_prompt in enumerate([clr_prompt, normal_prompt]): + print(text_prompt) + text_inputs = tokenizer(text_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(text_prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + if hasattr(text_encoder.config, "use_attention_mask") and text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask,) + prompt_embeds = prompt_embeds[0].detach().cpu() + print(prompt_embeds.shape) + + + # print(prompt_embeds.dtype) + if id == 0: + torch.save(prompt_embeds, f'./{path}/clr_embeds.pt') + else: + torch.save(prompt_embeds, f'./{path}/normal_embeds.pt') + print('done') + + + + +generate_mv_embeds() diff --git a/mvdiffusion/data/normal_utils.py b/mvdiffusion/data/normal_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a390082ca0d891f044eca3c6fc291d8f29036de --- /dev/null +++ b/mvdiffusion/data/normal_utils.py @@ -0,0 +1,78 @@ +import numpy as np +def deg2rad(deg): + return deg*np.pi/180 + +def inv_RT(RT): + # RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0) + RT_inv = np.linalg.inv(RT) + + return RT_inv[:3, :] +def camNormal2worldNormal(rot_c2w, camNormal): + H,W,_ = camNormal.shape + normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + return normal_img + +def worldNormal2camNormal(rot_w2c, normal_map_world): + H,W,_ = normal_map_world.shape + # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + # faster version + # Reshape the normal map into a 2D array where each row represents a normal vector + normal_map_flat = normal_map_world.reshape(-1, 3) + + # Transform the normal vectors using the transformation matrix + normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T) + + # Reshape the transformed normal map back to its original shape + normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape) + + return normal_map_camera + +def trans_normal(normal, RT_w2c, RT_w2c_target): + + # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) + # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) + + relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) + return worldNormal2camNormal(relative_RT[:3,:3], normal) + +def trans_normal_complex(normal, RT_w2c, RT_w2c_rela_to_cond): + # camview -> world -> condview + normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) + # debug_normal_world = normal2img(normal_world) + + # relative_RT = np.matmul(RT_w2c_rela_to_cond[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) + normal_target_cam = worldNormal2camNormal(RT_w2c_rela_to_cond[:3,:3], normal_world) + # normal_condview = normal2img(normal_target_cam) + return normal_target_cam +def img2normal(img): + return (img/255.)*2-1 + +def normal2img(normal): + return np.uint8((normal*0.5+0.5)*255) + +def norm_normalize(normal, dim=-1): + + normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) + + return normal + +def plot_grid_images(images, row, col, path=None): + import cv2 + """ + Args: + images: np.array [B, H, W, 3] + row: + col: + save_path: + + Returns: + + """ + images = images.detach().cpu().numpy() + assert row * col == images.shape[0] + images = np.vstack([np.hstack(images[r * col:(r + 1) * col]) for r in range(row)]) + if path: + cv2.imwrite(path, images[:,:,::-1] * 255) + return images \ No newline at end of file diff --git a/mvdiffusion/data/single_image_dataset.py b/mvdiffusion/data/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f40462e815dd80623828202d4c4803fc48d17823 --- /dev/null +++ b/mvdiffusion/data/single_image_dataset.py @@ -0,0 +1,354 @@ + +import numpy as np +import torch +from torch.utils.data import Dataset +from PIL import Image +import PIL +from typing import Tuple, Optional +import random +import os +from icecream import ic +import cv2 + + +def add_margin(pil_img, color=0, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + +def scale_and_place_object(image, scale_factor): + assert np.shape(image)[-1]==4 # RGBA + + # Extract the alpha channel (transparency) and the object (RGB channels) + alpha_channel = image[:, :, 3] + + # Find the bounding box coordinates of the object + coords = cv2.findNonZero(alpha_channel) + x, y, width, height = cv2.boundingRect(coords) + + # Calculate the scale factor for resizing + original_height, original_width = image.shape[:2] + + if width > height: + size = width + original_size = original_width + else: + size = height + original_size = original_height + + scale_factor = min(scale_factor, size / (original_size+0.0)) + + new_size = scale_factor * original_size + scale_factor = new_size / size + + # Calculate the new size based on the scale factor + new_width = int(width * scale_factor) + new_height = int(height * scale_factor) + + center_x = original_width // 2 + center_y = original_height // 2 + + paste_x = center_x - (new_width // 2) + paste_y = center_y - (new_height // 2) + + # Resize the object (RGB channels) to the new size + rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height)) + + # Create a new RGBA image with the resized image + new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8) + + new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object + + return new_image + +class SingleImageDataset(Dataset): + def __init__(self, + root_dir: str, + num_views: int, + img_wh: Tuple[int, int], + bg_color: str, + crop_size: int = 224, + single_image: Optional[PIL.Image.Image] = None, + num_validation_samples: Optional[int] = None, + filepaths: Optional[list] = None, + cond_type: Optional[str] = None, + prompt_embeds_path: Optional[str] = None, + gt_path: Optional[str] = None, + margin_size: Optional[int] = 0, + smpl_folder: Optional[str] = None, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = root_dir + self.num_views = num_views + self.img_wh = img_wh + self.crop_size = crop_size + self.bg_color = bg_color + self.cond_type = cond_type + self.gt_path = gt_path + + + if single_image is None: + file_list = sorted(os.listdir(self.root_dir)) + # Filter the files that end with .png or .jpg + self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg', '.webp'))] + else: + self.file_list = None + + # load all images + self.all_images = [] + self.all_alphas = [] + self.all_faces = [] + + self.all_face_embeddings = [] + bg_color = self.get_bg_color() + + if single_image is not None: + face_info = self.get_face_info(single_image) + image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image) + self.all_images.append(image) + self.all_alphas.append(alpha) + self.all_faces.append(self.process_face(f'{self.root_dir}/{single_image}', face_info['bbox'].astype(np.int32), bg_color)) + else: + for file in self.file_list: + print(os.path.join(self.root_dir, file)) + image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt') + self.all_images.append(image) + self.all_alphas.append(alpha) + + face, _ = self.load_face(os.path.join(self.root_dir, file), bg_color, return_type='pt') + self.all_faces.append(face) + + self.all_images = self.all_images[:num_validation_samples] + self.all_alphas = self.all_alphas[:num_validation_samples] + self.all_faces = self.all_faces[:num_validation_samples] + + ic(len(self.all_images)) + + try: + normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') + color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') + self.normal_text_embeds = normal_prompt_embedding + self.color_text_embeds = color_prompt_embedding + except: + self.color_text_embeds = torch.load(f'{prompt_embeds_path}/embeds.pt') + self.normal_text_embeds = None + + def __len__(self): + return len(self.all_images) + + def get_face_info(self, file): + file_name = file.split('.')[0] + face_info = np.load(f'{self.root_dir}/{file_name}_face_info.npy', allow_pickle=True).item() + return face_info + + + def get_bg_color(self): + if self.bg_color == 'white': + bg_color = np.array([1., 1., 1.], dtype=np.float32) + elif self.bg_color == 'black': + bg_color = np.array([0., 0., 0.], dtype=np.float32) + elif self.bg_color == 'gray': + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif self.bg_color == 'random': + bg_color = np.random.rand(3) + elif isinstance(self.bg_color, float): + bg_color = np.array([self.bg_color] * 3, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + + + def load_image(self, img_path, bg_color, return_type='np', Imagefile=None): + # pil always returns uint8 + if Imagefile is None: + image_input = Image.open(img_path) + else: + image_input = Imagefile + image_size = self.img_wh[0] + + if self.crop_size!=-1: + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = self.crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_)) + image_input = add_margin(ref_img_, size=image_size) + else: + image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) + image_input = image_input.resize((image_size, image_size)) + + # img = scale_and_place_object(img, self.scale_ratio) + img = np.array(image_input) + img = img.astype(np.float32) / 255. # [0, 1] + assert img.shape[-1] == 4 # RGBA + + alpha = img[...,3:4] + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + def load_face(self, img_path, bg_color, return_type='np', Imagefile=None): + # pil always returns uint8 + if Imagefile is None: + image_input = Image.open(img_path) + else: + image_input = Imagefile + image_size = self.img_wh[0] + + if self.crop_size!=-1: + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = self.crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_)) + image_input = add_margin(ref_img_, size=image_size) + else: + image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) + image_input = image_input.resize((image_size, image_size)) + + image_input = image_input.crop((256, 0, 512, 256)).resize((self.img_wh[0], self.img_wh[1])) + + # img = scale_and_place_object(img, self.scale_ratio) + img = np.array(image_input) + img = img.astype(np.float32) / 255. # [0, 1] + assert img.shape[-1] == 4 # RGBA + + alpha = img[...,3:4] + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + def __len__(self): + return len(self.all_images) + + def process_face(self, img_path, bbox, bg_color, normal_path=None, w2c=None, h=512, w=512): + image = Image.open(img_path) + bbox_w, bbox_h = bbox[2] - bbox[0], bbox[3] - bbox[1] + if bbox_w > bbox_h: + bbox[1] -= (bbox_w - bbox_h) // 2 + bbox[3] += (bbox_w - bbox_h) // 2 + else: + bbox[0] -= (bbox_h - bbox_w) // 2 + bbox[2] += (bbox_h - bbox_w) // 2 + bbox[0:2] -= 20 + bbox[2:4] += 20 + image = image.crop(bbox) + + image = image.resize((w, h)) + image = np.array(image) / 255. + img, alpha = image[:, :, :3], image[:, :, 3:4] + img = img * alpha + bg_color * (1 - alpha) + + padded_img = np.full((self.img_wh[0], self.img_wh[1], 3), bg_color, dtype=np.float32) + dx = (self.img_wh[0] - w) // 2 + dy = (self.img_wh[1] - h) // 2 + padded_img[dy:dy+h, dx:dx+w] = img + padded_img = torch.from_numpy(padded_img).permute(2,0,1) + + return padded_img + + def __getitem__(self, index): + image = self.all_images[index%len(self.all_images)] + # alpha = self.all_alphas[index%len(self.all_images)] + if self.file_list is not None: + filename = self.file_list[index%len(self.all_images)].replace(".png", "") + else: + filename = 'null' + img_tensors_in = [ + image.permute(2, 0, 1) + ] * (self.num_views-1) + [ + self.all_faces[index%len(self.all_images)].permute(2, 0, 1) + ] + + + img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) + + normal_prompt_embeddings = self.normal_text_embeds if hasattr(self, 'normal_text_embeds') else None + color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None + + if normal_prompt_embeddings is None: + out = { + 'imgs_in': img_tensors_in, + 'color_prompt_embeddings': color_prompt_embeddings, + 'filename': filename, + } + else: + out = { + 'imgs_in': img_tensors_in, + 'normal_prompt_embeddings': normal_prompt_embeddings, + 'color_prompt_embeddings': color_prompt_embeddings, + 'filename': filename, + } + return out + + + +if __name__ == "__main__": + # pass + from torch.utils.data import DataLoader + from torchvision.utils import make_grid + from PIL import ImageDraw, ImageFont + def draw_text(img, text, pos, color=(128, 128, 128)): + draw = ImageDraw.Draw(img) + # font = ImageFont.truetype(size= size) + font = ImageFont.load_default() + font = font.font_variant(size=10) + draw.text(pos, text, color, font=font) + return img + random.seed(11) + test_params = dict( + root_dir='../../evaluate', + bg_color='white', + img_wh=(768, 768), + prompt_embeds_path='fixed_prompt_embeds_7view', + num_views=5, + crop_size=740, + ) + train_dataset = SingleImageDataset(**test_params) + data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0) + + batch = train_dataset.__getitem__(0) + imgs = [] + obj_name = 'test_case' + imgs_in = batch['imgs_in'] + imgs_vis = torch.cat([imgs_in[0:1], imgs_in[-1:]], 0) + img_vis = make_grid(imgs_vis, nrow=2).permute(1, 2,0) + img_vis = (img_vis.numpy() * 255).astype(np.uint8) + img_vis = Image.fromarray(img_vis) + img_vis = draw_text(img_vis, obj_name, (5, 1)) + img_vis = torch.from_numpy(np.array(img_vis)).permute(2, 0, 1) / 255. + imgs.append(img_vis) + imgs = torch.stack(imgs, dim=0) + img_grid = make_grid(imgs, nrow=4, padding=0) + img_grid = img_grid.permute(1, 2, 0).numpy() + img_grid = (img_grid * 255).astype(np.uint8) + img_grid = Image.fromarray(img_grid) + img_grid.save(f'../../debug/{obj_name}.png') diff --git a/mvdiffusion/data/six_human_pose/000.npy b/mvdiffusion/data/six_human_pose/000.npy new file mode 100644 index 0000000000000000000000000000000000000000..e6015fc84f252c34afe1e64b75dad2ef981bb0f3 --- /dev/null +++ b/mvdiffusion/data/six_human_pose/000.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b801503c201690f70d2e258a45aed6ca913939c04714fe2f0d26468320e2ccec +size 718 diff --git a/mvdiffusion/data/six_human_pose/001.npy b/mvdiffusion/data/six_human_pose/001.npy new file mode 100644 index 0000000000000000000000000000000000000000..dc8f5dbdc7e7e4265f1e075cd742872c4af06b24 --- /dev/null +++ b/mvdiffusion/data/six_human_pose/001.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:314c4ef4d310e5cde33ad804f47f2577ce391ddf83bb628e9f449ba58a4274da +size 718 diff --git a/mvdiffusion/data/six_human_pose/002.npy b/mvdiffusion/data/six_human_pose/002.npy new file mode 100644 index 0000000000000000000000000000000000000000..74bd973c785cea85cb740452ed340a0d62e523dc --- /dev/null +++ b/mvdiffusion/data/six_human_pose/002.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f8cdc7fd863fcb60da9ce7c76188f7ff9d0352b519dec4ed946b32bd479c844 +size 718 diff --git a/mvdiffusion/data/six_human_pose/003.npy b/mvdiffusion/data/six_human_pose/003.npy new file mode 100644 index 0000000000000000000000000000000000000000..e3069676c58989c8adb1b46d7f231e0b85d088a2 --- /dev/null +++ b/mvdiffusion/data/six_human_pose/003.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:113e1a650ab4b9caf4752f733b31921f39ab598723c481b78c36856c7754f59a +size 718 diff --git a/mvdiffusion/data/six_human_pose/004.npy b/mvdiffusion/data/six_human_pose/004.npy new file mode 100644 index 0000000000000000000000000000000000000000..08029646871538dedf6041ef0dab6370a8dd3d2b --- /dev/null +++ b/mvdiffusion/data/six_human_pose/004.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b9d3f1fd5420b55eb7f0f29004084a0aa09a690947c72f50e44d80bfdc0fe06 +size 718 diff --git a/mvdiffusion/data/six_human_pose/005.npy b/mvdiffusion/data/six_human_pose/005.npy new file mode 100644 index 0000000000000000000000000000000000000000..6d16e911bd65675c25e78749f8bf6e3a4e99e6a9 --- /dev/null +++ b/mvdiffusion/data/six_human_pose/005.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:439e2eaa2141d8f06348b2bccdf1325707ce425bee7a6a5c05a2d68f671bd396 +size 718 diff --git a/mvdiffusion/data/six_human_pose/006.npy b/mvdiffusion/data/six_human_pose/006.npy new file mode 100644 index 0000000000000000000000000000000000000000..824b119521bd897a287dc1deff6930904e50d026 --- /dev/null +++ b/mvdiffusion/data/six_human_pose/006.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc462001b577b91d940996b350b53fa0f1105ca306f862124a054ba8b828cd3c +size 718 diff --git a/mvdiffusion/data/six_human_pose/007.npy b/mvdiffusion/data/six_human_pose/007.npy new file mode 100644 index 0000000000000000000000000000000000000000..adac639df2cda4a23560ba23717be231cd1d9222 --- /dev/null +++ b/mvdiffusion/data/six_human_pose/007.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0140436712df3901abb16bc492bcf2e4bc50ab1e0fce743b6dbe35bff724795 +size 718 diff --git a/mvdiffusion/data/testdata_with_smpl.py b/mvdiffusion/data/testdata_with_smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..5980c5bc777609d1cde8dbe006c2dac9215a3558 --- /dev/null +++ b/mvdiffusion/data/testdata_with_smpl.py @@ -0,0 +1,385 @@ + +import numpy as np +import torch +from torch.utils.data import Dataset +from PIL import Image +from typing import Tuple, Optional +import cv2 +import random +import os +import PIL +from icecream import ic + +def add_margin(pil_img, color=0, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + +def scale_and_place_object(image, scale_factor): + assert np.shape(image)[-1]==4 # RGBA + + # Extract the alpha channel (transparency) and the object (RGB channels) + alpha_channel = image[:, :, 3] + + # Find the bounding box coordinates of the object + coords = cv2.findNonZero(alpha_channel) + x, y, width, height = cv2.boundingRect(coords) + + # Calculate the scale factor for resizing + original_height, original_width = image.shape[:2] + + if width > height: + size = width + original_size = original_width + else: + size = height + original_size = original_height + + scale_factor = min(scale_factor, size / (original_size+0.0)) + + new_size = scale_factor * original_size + scale_factor = new_size / size + + # Calculate the new size based on the scale factor + new_width = int(width * scale_factor) + new_height = int(height * scale_factor) + + center_x = original_width // 2 + center_y = original_height // 2 + + paste_x = center_x - (new_width // 2) + paste_y = center_y - (new_height // 2) + + # Resize the object (RGB channels) to the new size + rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height)) + + # Create a new RGBA image with the resized image + new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8) + + new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object + + return new_image + +class SingleImageDataset(Dataset): + def __init__(self, + root_dir: str, + num_views: int, + img_wh: Tuple[int, int], + bg_color: str, + margin_size: int = 0, + single_image: Optional[PIL.Image.Image] = None, + num_validation_samples: Optional[int] = None, + filepaths: Optional[list] = None, + cond_type: Optional[str] = None, + prompt_embeds_path: Optional[str] = None, + gt_path: Optional[str] = None, + crop_size: Optional[int] = 720, + smpl_folder: Optional[str] = 'smpl_image_pymaf', + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = root_dir + self.num_views = num_views + self.img_wh = img_wh + self.margin_size = margin_size + self.bg_color = bg_color + self.cond_type = cond_type + self.gt_path = gt_path + self.crop_size = crop_size + self.smpl_folder = smpl_folder + if single_image is None: + if filepaths is None: + # Get a list of all files in the directory + file_list = os.listdir(self.root_dir) + else: + file_list = filepaths + + # Filter the files that end with .png or .jpg + self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg', '.webp'))] + else: + self.file_list = [single_image] + + ic(len(self.file_list)) + + try: + normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') + color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') + if self.num_views == 7: + self.normal_text_embeds = normal_prompt_embedding + self.color_text_embeds = color_prompt_embedding + elif self.num_views == 5: + self.normal_text_embeds = torch.stack([normal_prompt_embedding[0], normal_prompt_embedding[2], normal_prompt_embedding[3], normal_prompt_embedding[4], normal_prompt_embedding[6]] , 0) + self.color_text_embeds = torch.stack([color_prompt_embedding[0], color_prompt_embedding[2], color_prompt_embedding[3], color_prompt_embedding[4], color_prompt_embedding[6]] , 0) + except: + self.color_text_embeds = torch.load(f'{prompt_embeds_path}/embeds.pt') + self.normal_text_embeds = None + + def __len__(self): + return len(self.file_list) + + def get_face_info(self, file): + file_name = file.split('.')[0] + face_info = np.load(f'{self.root_dir}/{file_name}_face_info.npy', allow_pickle=True).item() + return face_info + + + def get_bg_color(self): + if self.bg_color == 'white': + bg_color = np.array([1., 1., 1.], dtype=np.float32) + elif self.bg_color == 'black': + bg_color = np.array([0., 0., 0.], dtype=np.float32) + elif self.bg_color == 'gray': + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif self.bg_color == 'random': + bg_color = np.random.rand(3) + elif isinstance(self.bg_color, float): + bg_color = np.array([self.bg_color] * 3, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + + def load_smpl_images(self, smpl_path, bg_color, return_type='np'): + if self.num_views - 1 == 4: + _views = [0, 2, 4, 6] + flip_views = [4, 6] + elif self.num_views - 1 == 6: + _views = [0, 1, 2, 4, 6, 7] + flip_views = [4, 6] + elif self.num_views - 1 == 8: + _views = [0, 1, 2, 3, 4, 5, 6, 7] + flip_views = [4, 5, 6, 7] + + imgs = [] + alphas = [] + for i in _views: + smpl_image = Image.open(os.path.join(smpl_path, f'{i:03d}.png')) + if i == 0: + assert smpl_image.size[0] == self.img_wh[0] + smpl_alpha = np.asarray(smpl_image)[...,3] + coords = np.stack(np.nonzero(smpl_alpha), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + crop_size = max(max_x - min_x, max_y - min_y) + self.margin_size + # print(crop_size) + smpl_image = np.asarray(smpl_image).astype(np.float32) / 255. # [0, 1] + alpha = smpl_image[...,3:4] + img = smpl_image[...,:3] * alpha + bg_color * (1 - alpha) + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + if i in flip_views: + img = torch.flip(img, [1]) + alpha = torch.flip(alpha, [1]) + imgs.append(img) + alphas.append(alpha) + return imgs, crop_size, alphas + + def load_image(self, img_path, bg_color, crop_size, return_type='np', Imagefile=None): + # pil always returns uint8 + if Imagefile is None: + image_input = Image.open(img_path) + else: + image_input = Imagefile + image_size = self.img_wh[0] + + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_)) + image_input = add_margin(ref_img_, size=image_size) + + + # img = scale_and_place_object(img, self.scale_ratio) + img = np.array(image_input) + img = img.astype(np.float32) / 255. # [0, 1] + assert img.shape[-1] == 4 # RGBA + + alpha = img[...,3:4] + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + def load_face(self, img_path, bg_color, return_type='np', crop_size=-1): + image_input = Image.open(img_path) + image_size = self.img_wh[0] + + if crop_size > 0: # color image + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_)) + image_input = add_margin(ref_img_, size=image_size) + + + image_input = image_input.crop((256, 0, 512, 256)).resize((self.img_wh[0], self.img_wh[1])) + + # img = scale_and_place_object(img, self.scale_ratio) + img = np.array(image_input) + img = img.astype(np.float32) / 255. # [0, 1] + assert img.shape[-1] == 4 # RGBA + + alpha = img[...,3:4] + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img + + + def process_face(self, img_path, bbox, bg_color, normal_path=None, w2c=None, h=512, w=512): + image = Image.open(img_path) + bbox_w, bbox_h = bbox[2] - bbox[0], bbox[3] - bbox[1] + if bbox_w > bbox_h: + bbox[1] -= (bbox_w - bbox_h) // 2 + bbox[3] += (bbox_w - bbox_h) // 2 + else: + bbox[0] -= (bbox_h - bbox_w) // 2 + bbox[2] += (bbox_h - bbox_w) // 2 + bbox[0:2] -= 20 + bbox[2:4] += 20 + image = image.crop(bbox) + + image = image.resize((w, h)) + image = np.array(image) / 255. + img, alpha = image[:, :, :3], image[:, :, 3:4] + img = img * alpha + bg_color * (1 - alpha) + + padded_img = np.full((self.img_wh[0], self.img_wh[1], 3), bg_color, dtype=np.float32) + dx = (self.img_wh[0] - w) // 2 + dy = (self.img_wh[1] - h) // 2 + padded_img[dy:dy+h, dx:dx+w] = img + padded_img = torch.from_numpy(padded_img).permute(2,0,1) + + return padded_img + + def __getitem__(self, index): + filename = self.file_list[index].split('.')[0] + bg_color = self.get_bg_color() + + smpl_images, crop_size, smpl_alphas = self.load_smpl_images(f'{self.root_dir}/{self.smpl_folder}/{filename}', bg_color, return_type='pt') + smpl_face = self.load_face(f'{self.root_dir}/{self.smpl_folder}/{filename}/000.png', bg_color, return_type='pt') + + image, _ = self.load_image(f'{self.root_dir}/{self.file_list[index]}', bg_color, crop_size, return_type='pt') # m + face = self.load_face(f'{self.root_dir}/{self.file_list[index]}', bg_color, return_type='pt', crop_size=crop_size) # m + + img_tensors_in = [ image.permute(2, 0, 1) ] * (self.num_views-1) + [ face.permute(2, 0, 1)] + smpl_tensors_in = [ tmp.permute(2, 0, 1) for tmp in smpl_images ] + [ smpl_face.permute(2, 0, 1) ] + smpl_alphas = [ tmp.permute(2, 0, 1) for tmp in smpl_alphas ] + + # import pdb; pdb.set_trace() + img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) + smpl_tensors_in = torch.stack(smpl_tensors_in, dim=0).float() # (Nv, 3, H, W) + smpl_alphas = torch.stack(smpl_alphas, dim=0).float() # (Nv, 1, H, W) + + normal_prompt_embeddings = self.normal_text_embeds if hasattr(self, 'normal_text_embeds') else None + color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None + + if normal_prompt_embeddings is None: + out = { + 'imgs_in': img_tensors_in, + 'smpl_imgs_in': smpl_tensors_in, + 'smpl_alphas': smpl_alphas, + 'color_prompt_embeddings': color_prompt_embeddings, + 'filename': filename, + } + else: + out = { + 'imgs_in': img_tensors_in, + 'smpl_imgs_in': smpl_tensors_in, + 'smpl_alphas': smpl_alphas, + 'normal_prompt_embeddings': normal_prompt_embeddings, + 'color_prompt_embeddings': color_prompt_embeddings, + 'filename': filename, + } + return out + + + +if __name__ == "__main__": + # pass + from torch.utils.data import DataLoader + from torchvision.utils import make_grid + from PIL import ImageDraw, ImageFont + def draw_text(img, text, pos, color=(128, 128, 128)): + draw = ImageDraw.Draw(img) + # font = ImageFont.truetype(size= size) + font = ImageFont.load_default() + font = font.font_variant(size=10) + draw.text(pos, text, color, font=font) + return img + random.seed(11) + test_params = dict( + root_dir='../../examples/CAPE', + bg_color='white', + img_wh=(768, 768), + prompt_embeds_path='fixed_prompt_embeds_7view', + num_views=7, + # crop_size=740, + margin_size=15, + smpl_folder='gt_smpl_image', + ) + train_dataset = SingleImageDataset(**test_params) + data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0) + + for batch in data_loader: + # batch = train_dataset.__getitem__(1) + imgs = [] + obj_name = batch['filename'][0] + imgs_in = batch['imgs_in'][0] + imgs_smpl_in = batch['smpl_imgs_in'][0] + alphas_smpl = batch['smpl_alphas'][0] + + img0 = (imgs_in[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8) + img1 = (imgs_smpl_in[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8) + print(img0.shape, img1.shape) + smpl_alpha = alphas_smpl[0].permute(1, 2, 0).repeat(1, 1, 3).numpy() + img0[smpl_alpha > 0.5] = img1[smpl_alpha > 0.5] + Image.fromarray(img0).save(f'../../debug/{obj_name}_rgb.png') + # Image.fromarray(img1).save(f'../../debug/{obj_name}_smpl.png') + + exit() + imgs_vis = torch.cat([imgs_in, imgs_smpl_in], 0) + img_vis = make_grid(imgs_vis, nrow=4).permute(1, 2,0) + img_vis = (img_vis.numpy() * 255).astype(np.uint8) + img_vis = Image.fromarray(img_vis) + img_vis = draw_text(img_vis, obj_name, (5, 1)) + img_vis = torch.from_numpy(np.array(img_vis)).permute(2, 0, 1) / 255. + imgs.append(img_vis) + imgs = torch.stack(imgs, dim=0) + img_grid = make_grid(imgs, nrow=4, padding=0) + img_grid = img_grid.permute(1, 2, 0).numpy() + img_grid = (img_grid * 255).astype(np.uint8) + img_grid = Image.fromarray(img_grid) + img_grid.save(f'../../debug/{obj_name}.png') + print('finished.') diff --git a/mvdiffusion/models_unclip/attn_processors.py b/mvdiffusion/models_unclip/attn_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..cff80f2b79c66c1d4419c2fa1abc02d2dde1e944 --- /dev/null +++ b/mvdiffusion/models_unclip/attn_processors.py @@ -0,0 +1,631 @@ +from typing import Any, Dict, Optional + +import torch +from torch import nn + + + +from diffusers.models.attention import Attention +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat +import math + +import torch.nn.functional as F +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +class RowwiseMVAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class IPCDAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersIPCDAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True, + cd_attention_mid=False + ): + # print(num_views) + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length)) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + print('Warning: using group norm, pay attention to use it in row-wise attention') + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + # pdb.set_trace() + def transpose(tensor): + tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c + tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c + tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) + return tensor + # print(mvcd_attention) + # import pdb;pdb.set_trace() + if cd_attention_mid: + key = transpose(key_raw) + value = transpose(value_raw) + query = transpose(query) + else: + key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) + + + query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if cd_attention_mid: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) + hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c + hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c + hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) + else: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersIPCDAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def process(self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2, + num_views=6): + ### TODO: num_views + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length)) + height_st = height // 3 + height_end = height - height_st + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + + # ip attn + # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) + # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] + # print(body_hidden_states.shape, face_hidden_states.shape) + # import pdb;pdb.set_trace() + # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) + # hidden_states = rearrange( + # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), + # 'b v l c -> (b v) l c') + + # face cross attention + # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) + # ip_key = attn.to_k_ip(ip_hidden_states) + # ip_value = attn.to_v_ip(ip_hidden_states) + # ip_key = attn.head_to_batch_dim(ip_key).contiguous() + # ip_value = attn.head_to_batch_dim(ip_value).contiguous() + # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() + # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) + # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) + # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) + # import pdb;pdb.set_trace() + + + def transpose(tensor): + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c + tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c + # tensor = rearrange(tensor, "(b v) l c -> b v l c", v=num_views+1) + # body, face = tensor[:, :-1, :], tensor[:, -1:, :] # b,v,l,c; b,1,l,c + # face = face.repeat(1, num_views, 1, 1) # b,v,l,c + # tensor = torch.cat([body, face], dim=2) # b, v, 4hw, c + # tensor = rearrange(tensor, "b v l c -> (b v) l c") + return tensor + key = transpose(key) + value = transpose(value) + query = transpose(query) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) # bv, hw, c + + hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) + face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach() + face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') + hidden_states_normal = hidden_states_normal.clone() # Create a copy of hidden_states_normal + hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal + # hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_normal + hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c") + + + hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) + face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach() + face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') + hidden_states_color = hidden_states_color.clone() # Create a copy of hidden_states_color + hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color + # hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_color + hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c") + + hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c + + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2, + ): + hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks) + # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c') + # body_hidden_states, head_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1:, :, :] + # import pdb;pdb.set_trace() + # hidden_states = body_hidden_states + attn.ip_scale * head_hidden_states.detach().repeat(1, views, 1, 1) + # hidden_states = rearrange( + # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states], dim=1), + # 'b v l c -> (b v) l c') + return hidden_states + +class IPCrossAttn(Attention): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, + query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention) + + self.ip_scale = ip_scale + # self.num_tokens = num_tokens + + # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) + # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) + + # self.to_out_ip = nn.ModuleList([]) + # self.to_out_ip.append(nn.Linear(self.inner_dim, self.inner_dim, bias=bias)) + # self.to_out_ip.append(nn.Dropout(dropout)) + # nn.init.zeros_(self.to_k_ip.weight.data) + # nn.init.zeros_(self.to_v_ip.weight.data) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersIPCrossAttnProcessor() + self.set_processor(processor) + +class XFormersIPCrossAttnProcessor: + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1 + ): + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # ip attn + # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) + # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] + # print(body_hidden_states.shape, face_hidden_states.shape) + # import pdb;pdb.set_trace() + # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) + # hidden_states = rearrange( + # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), + # 'b v l c -> (b v) l c') + + # face cross attention + # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) + # ip_key = attn.to_k_ip(ip_hidden_states) + # ip_value = attn.to_v_ip(ip_hidden_states) + # ip_key = attn.head_to_batch_dim(ip_key).contiguous() + # ip_value = attn.head_to_batch_dim(ip_value).contiguous() + # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() + # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) + # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) + # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) + # import pdb;pdb.set_trace() + + # body_hidden_states = body_hidden_states + attn.ip_scale * ip_hidden_states + # hidden_states = rearrange( + # torch.cat([rearrange(body_hidden_states, '(b v) l c -> b v l c', v=num_views-1), face_hidden_states.unsqueeze(1)], dim=1), + # 'b v l c -> (b v) l c') + # import pdb;pdb.set_trace() + # + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + + # TODO: region control + # region control + # if len(region_control.prompt_image_conditioning) == 1: + # region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) + # if region_mask is not None: + # h, w = region_mask.shape[:2] + # ratio = (h * w / query.shape[1]) ** 0.5 + # mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + # else: + # mask = torch.ones_like(ip_hidden_states) + # ip_hidden_states = ip_hidden_states * mask + + return hidden_states + + +class RowwiseMVProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + cd_attention_mid=False + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length)) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) + # pdb.set_trace() + # multi-view self-attention + def transpose(tensor): + tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c + tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c + tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) + return tensor + + if cd_attention_mid: + key = transpose(key) + value = transpose(value) + query = transpose(query) + else: + key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if cd_attention_mid: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) + hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c + hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c + hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) + else: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CDAttention(Attention): + # def __init__(self, ip_scale, + # query_dim, heads, dim_head, dropout, bias, cross_attention_dim, upcast_attention, processor): + # super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, processor=processor) + + # self.ip_scale = ip_scale + + # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) + # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) + # nn.init.zeros_(self.to_k_ip.weight.data) + # nn.init.zeros_(self.to_v_ip.weight.data) + + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersCDAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class XFormersCDAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + def transpose(tensor): + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c + tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c + return tensor + key = transpose(key) + value = transpose(value) + query = transpose(query) + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + \ No newline at end of file diff --git a/mvdiffusion/models_unclip/face_networks.py b/mvdiffusion/models_unclip/face_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..75c4fc7e2197adf445fb4dc313df7b3f6a2fde50 --- /dev/null +++ b/mvdiffusion/models_unclip/face_networks.py @@ -0,0 +1,142 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def prepare_face_proj_model(ckpt_path, image_emb_dim=512, num_tokens=16, cross_attention_dim=1024, pretrain=True, + ): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=cross_attention_dim, # self.unet.config.cross_attention_dim, + ff_mult=4, + ) + # image_proj_model.eval() + if pretrain: + state_dict = torch.load(ckpt_path, map_location="cpu") + if 'image_proj' in state_dict: + state_dict = state_dict["image_proj"] + image_proj_model.load_state_dict(state_dict) + return image_proj_model + + \ No newline at end of file diff --git a/mvdiffusion/models_unclip/transformer_mv2d_self_rowwise.py b/mvdiffusion/models_unclip/transformer_mv2d_self_rowwise.py new file mode 100644 index 0000000000000000000000000000000000000000..e40770d68e252d0fc064ef1063af3c43dd9ccdbf --- /dev/null +++ b/mvdiffusion/models_unclip/transformer_mv2d_self_rowwise.py @@ -0,0 +1,633 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import pdb +from .attn_processors import IPCDAttention, RowwiseMVAttention, IPCrossAttn + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + cd_attention_mid: bool=False, + cd_attention_last: bool=False, + multiview_attention: bool=True, + sparse_mv_attention: bool = True, # not used + mvcd_attention: bool=False, + use_dino: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + dino_feature: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + mvcd_attention: bool = False, + rowwise_attention: bool = True, + use_dino: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.mvcd_attention = mvcd_attention + self.cd_attention_mid = cd_attention_mid + self.rowwise_attention = multiview_attention and rowwise_attention + + if mvcd_attention and (not cd_attention_mid): + # add cross domain attn to self attn + self.attn1 = IPCDAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + # processor=JointCDAttnProcessor() + ) + else: + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention + ) + # 1.1 rowwise multiview attention + if self.rowwise_attention: + # print('INFO: using self+row_wise mv attention...') + self.norm_mv = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn_mv = RowwiseMVAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + # processor=MVAttnProcessor() + ) + nn.init.zeros_(self.attn_mv.to_out[0].weight.data) + else: + self.norm_mv = None + self.attn_mv = None + + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = IPCrossAttn( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + # nn.init.zeros_(self.attn2.to_out_ip[0].weight.data) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + # multiview_attention=self.multiview_attention, + # mvcd_attention=self.mvcd_attention, + **cross_attention_kwargs, + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 1.1 row wise multiview attention + hidden_states = rearrange(hidden_states, '(b v) n c -> b v n c', v=self.num_views) + body_hidden_states, face_hidden_states = \ + rearrange(hidden_states[:, :-1, :, :], 'b v n c -> (b v) n c'), hidden_states[:, -1:, :, :] + if self.rowwise_attention: + norm_hidden_states = ( + self.norm_mv(body_hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(body_hidden_states) + ) + attn_output = self.attn_mv( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + num_views=self.num_views-1, + multiview_attention=self.multiview_attention, + cd_attention_mid=self.cd_attention_mid, + **cross_attention_kwargs, + ) + body_hidden_states = attn_output + body_hidden_states + hidden_states = rearrange(torch.cat([ + rearrange(body_hidden_states, '(b v) n c -> b v n c', v=self.num_views-1), face_hidden_states], dim=1), 'b v n c -> (b v) n c') + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, dino_feature], dim=1) if dino_feature is not None else encoder_hidden_states + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + num_views=self.num_views, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + diff --git a/mvdiffusion/models_unclip/unet_mv2d_blocks.py b/mvdiffusion/models_unclip/unet_mv2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..c145f27261b1c0a8f28c4ca15984e44d5a5579e1 --- /dev/null +++ b/mvdiffusion/models_unclip/unet_mv2d_blocks.py @@ -0,0 +1,974 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D + +from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D +from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, + num_views=1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif down_block_type == "CrossAttnDownBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") + return CrossAttnDownBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, + num_views=1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif up_block_type == "CrossAttnUpBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") + return CrossAttnUpBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockMV2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + if selfattn_block == "custom": + from .transformer_mv2d import TransformerMV2DModel + elif selfattn_block == "rowwise": + from .transformer_mv2d_rowwise import TransformerMV2DModel + elif selfattn_block == "self_rowwise": + from .transformer_mv2d_self_rowwise import TransformerMV2DModel + elif selfattn_block == "inpaint": + from .transformer_inpaint import TransformerMV2DModel + else: + raise NotImplementedError + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + ) + else: + raise NotImplementedError + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + dino_feature=dino_feature, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if selfattn_block == "custom": + from .transformer_mv2d import TransformerMV2DModel + elif selfattn_block == "rowwise": + from .transformer_mv2d_rowwise import TransformerMV2DModel + elif selfattn_block == "self_rowwise": + from .transformer_mv2d_self_rowwise import TransformerMV2DModel + elif selfattn_block == "inpaint": + from .transformer_inpaint import TransformerMV2DModel + else: + raise NotImplementedError + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + dino_feature, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + dino_feature=dino_feature, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnDownBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if selfattn_block == "custom": + from .transformer_mv2d import TransformerMV2DModel + elif selfattn_block == "rowwise": + from .transformer_mv2d_rowwise import TransformerMV2DModel + elif selfattn_block == "self_rowwise": + from .transformer_mv2d_self_rowwise import TransformerMV2DModel + elif selfattn_block == "inpaint": + from .transformer_inpaint import TransformerMV2DModel + else: + raise NotImplementedError + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + dino_feature, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + diff --git a/mvdiffusion/models_unclip/unet_mv2d_condition.py b/mvdiffusion/models_unclip/unet_mv2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..547bc03361f2dddaa07c70ecaced080283053b63 --- /dev/null +++ b/mvdiffusion/models_unclip/unet_mv2d_condition.py @@ -0,0 +1,1723 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_torch_version, + logging, +) +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.utils.hub_utils import HF_HUB_OFFLINE +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE + +from diffusers import __version__ +from .unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) +from einops import rearrange, repeat + +from diffusers import __version__ +from mvdiffusion.models_unclip.unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.linear1 = nn.Linear(dim, dim) + self.activation = nn.SiLU() + self.linear2 = nn.Linear(dim, dim) + + def forward(self, x): + identity = x + out = self.linear1(x) + out = self.activation(out) + out = self.linear2(out) + out += identity + out = self.activation(out) + return out + +class ResidualLiner(nn.Module): + def __init__(self, in_features, out_features, dim, act=None, num_block=1): + super(ResidualLiner, self).__init__() + self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU()) + + blocks = nn.ModuleList() + for _ in range(num_block): + blocks.append(ResidualBlock(dim)) + self.blocks = blocks + + self.linear_out = nn.Linear(dim, out_features) + self.act = act + + def forward(self, x): + out = self.linear_in(x) + for block in self.blocks: + out = block(out) + out = self.linear_out(out) + if self.act is not None: + out = self.act(out) + return out + +class BasicConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(BasicConvBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) + self.act = nn.SiLU() + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) + self.downsample = nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) + ) + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.act(out) + out = self.conv2(out) + out = self.norm2(out) + out += self.downsample(identity) + out = self.act(out) + return out + +class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + projection_camera_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool = False, + regress_elevation: bool = False, + regress_focal_length: bool = False, + num_regress_blocks: int = 4, + use_dino: bool = False, + addition_downsample: bool = False, + addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280), + ): + super().__init__() + + self.sample_size = sample_size + self.num_views = num_views + self.mvcd_attention = mvcd_attention + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + self.addition_downsample = addition_downsample + if self.addition_downsample: + inc = block_out_channels[-1] + self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.conv_block = nn.ModuleList() + self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1)) + for dim_ in addition_channels[1:-1]: + self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1)) + self.conv_block.append(BasicConvBlock(dim_, inc)) + self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False) + nn.init.zeros_(self.addition_conv_out.weight.data) + self.addition_act_out = nn.SiLU() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.regress_elevation = regress_elevation + self.regress_focal_length = regress_focal_length + if regress_elevation or regress_focal_length: + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim) + + regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels + + if regress_elevation: + self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks) + if regress_focal_length: + self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks) + ''' + self.regress_elevation = regress_elevation + self.regress_focal_length = regress_focal_length + if regress_elevation and (not regress_focal_length): + print("Regressing elevation") + cam_dim = 1 + elif regress_focal_length and (not regress_elevation): + print("Regressing focal length") + cam_dim = 6 + elif regress_elevation and regress_focal_length: + print("Regressing both elevation and focal length") + cam_dim = 7 + else: + cam_dim = 0 + assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim" + if regress_elevation or regress_focal_length: + self.elevation_regressor = nn.ModuleList([ + nn.Linear(block_out_channels[-1], 1280), + nn.SiLU(), + nn.Linear(1280, 1280), + nn.SiLU(), + nn.Linear(1280, cam_dim) + ]) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.focal_act = nn.Softmax(dim=-1) + self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim) + ''' + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + dino_feature: Optional[torch.Tensor] = None, + return_dict: bool = True, + vis_max_min: bool = False, + ) -> Union[UNetMV2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + record_max_min = {} + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + if vis_max_min: record_max_min["sample"] = (sample.min().detach().float().cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + if vis_max_min: record_max_min["t_emb"] = (t_emb.min().detach().float().cpu().numpy().tolist(), t_emb.max().detach().float().cpu().numpy().tolist()) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + if vis_max_min: record_max_min["class_emb"] = (class_emb.min().detach().float().cpu().numpy().tolist(), class_emb.max().detach().float().cpu().numpy().tolist()) + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + if aug_emb is not None and vis_max_min: record_max_min["aug_emb"] = (aug_emb.min().detach().float().cpu().numpy().tolist(), aug_emb.max().detach().float().cpu().numpy().tolist()) + emb_pre_act = emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + if vis_max_min: record_max_min["conv_in"] = (sample.min().detach().float().cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for i, downsample_block in enumerate(self.down_blocks): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + if vis_max_min: record_max_min[f"down_block_{i}"] = (sample.min().detach().float().cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + if self.addition_downsample: + global_sample = sample + global_sample = self.downsample(global_sample) + for layer in self.conv_block: + global_sample = layer(global_sample) + global_sample = self.addition_act_out(self.addition_conv_out(global_sample)) + global_sample = self.upsample(global_sample) + if vis_max_min: record_max_min["global_sample"] = (global_sample.min().detach().float().cpu().numpy().tolist(), global_sample.max().detach().float().cpu().numpy().tolist()) + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + if vis_max_min: record_max_min["mid_block"] = (sample.min().detach().float().cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + + # 4.1 regress elevation and focal length + # # predict elevation -> embed -> projection -> add to time emb + if self.regress_elevation or self.regress_focal_length: + pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C) + if self.mvcd_attention: + pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0) + pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C) + pose_pred = [] + if self.regress_elevation: + ele_pred = self.elevation_regressor(pool_embeds) + ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views) + ele_pred = torch.mean(ele_pred, dim=1) + pose_pred.append(ele_pred) # b, c + if vis_max_min: record_max_min["ele_pred"] = (ele_pred.min().detach().float().cpu().numpy().tolist(), ele_pred.max().detach().float().cpu().numpy().tolist()) + + if self.regress_focal_length: + focal_pred = self.focal_regressor(pool_embeds) + focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views) + focal_pred = torch.mean(focal_pred, dim=1) + pose_pred.append(focal_pred) + if vis_max_min: record_max_min["focal_pred"] = (focal_pred.min().detach().float().cpu().numpy().tolist(), focal_pred.max().detach().float().cpu().numpy().tolist()) + pose_pred = torch.cat(pose_pred, dim=-1) + # 'e_de_da_sincos', (B, 2) + pose_embeds = torch.cat([ + torch.sin(pose_pred), + torch.cos(pose_pred) + ], dim=-1) + pose_embeds = self.camera_embedding(pose_embeds) + pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0) + if vis_max_min: record_max_min["pose_embeds"] = (pose_embeds.min().detach().float().cpu().numpy().tolist(), pose_embeds.max().detach().float().cpu().numpy().tolist()) + if self.mvcd_attention: + pose_embeds = torch.cat([pose_embeds,] * 2, dim=0) + + emb = pose_embeds + emb_pre_act + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + ''' + if self.regress_elevation or self.regress_focal_length: + pose_pred = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (B, C) + + for liner in self.elevation_regressor: + pose_pred = liner(pose_pred) + + pose_pred = torch.cat([ + pose_pred[:, 0:1], + self.focal_act(pose_pred[:, 1:]) + ], dim=-1) + # 'e_de_da_sincos', (B, 2) + pose_embeds = torch.cat([ + torch.sin(pose_pred), + torch.cos(pose_pred) + ], dim=-1) + pose_embeds = self.camera_embedding(pose_embeds) + emb = pose_embeds + emb_pre_act + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + ''' + if is_controlnet: + sample = sample + mid_block_additional_residual + + if self.addition_downsample: + sample = sample + global_sample + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + if vis_max_min: record_max_min[f"upsample_block_{i}"] = (torch.abs(sample.min().detach().float()).cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + up_s = sample + if torch.isnan(sample).any() or torch.isinf(sample).any(): + print("NAN in sample, stop training.") + exit() + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + if vis_max_min: record_max_min[f"conv_norm_out"] = (sample.min().detach().float().cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if vis_max_min: record_max_min[f"conv_out"] = (sample.min().detach().float().cpu().numpy().tolist(), sample.max().detach().float().cpu().numpy().tolist()) + if not return_dict: + return (sample,) + # return (sample, pose_pred) + if self.regress_elevation or self.regress_focal_length: + return UNetMV2DConditionOutput(sample=sample), pose_pred, record_max_min, up_s + else: + return UNetMV2DConditionOutput(sample=sample), up_s + # return UNetMV2DConditionOutput(sample=sample), up_s, record_max_min + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_camera_embeddings_input_dim: int=2, + cd_attention_last: bool = False, num_regress_blocks: int = 4, + cd_attention_mid: bool = False, multiview_attention: bool = True, + sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False, + init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False, use_face_adapter: bool=True, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors: + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['cd_attention_last'] = cd_attention_last + config['cd_attention_mid'] = cd_attention_mid + config['multiview_attention'] = multiview_attention + config['sparse_mv_attention'] = sparse_mv_attention + config['selfattn_block'] = selfattn_block + config['mvcd_attention'] = mvcd_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + + + config['regress_elevation'] = regress_elevation # true + config['regress_focal_length'] = regress_focal_length # true + config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length + config['use_dino'] = use_dino + config['num_regress_blocks'] = num_regress_blocks + config['addition_downsample'] = addition_downsample + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + import copy + state_dict_pretrain = load_state_dict(model_file, variant=variant) + state_dict = copy.deepcopy(state_dict_pretrain) + if init_mvattn_with_selfattn: + for key in state_dict_pretrain: + if 'attn1' in key: + key_mv = key.replace('attn1', 'attn_mv') + state_dict[key_mv] = state_dict_pretrain[key] + if 'to_out.0.weight' in key: + nn.init.zeros_(state_dict[key_mv].data) + if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block + key_mv = key.replace('norm1', 'norm_mv') + state_dict[key_mv] = state_dict_pretrain[key] + del state_dict_pretrain + + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if (regress_elevation or regress_focal_length) and zero_init_camera_projection: # true + params = [p for p in model.camera_embedding.parameters()] + torch.nn.init.zeros_(params[-1].data) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + diff --git a/mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py b/mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..e35c858201523301626635a8a270dcfe567ba1c4 --- /dev/null +++ b/mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py @@ -0,0 +1,651 @@ +import inspect +import warnings +from typing import Callable, List, Optional, Union, Dict, Any +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.embeddings import get_timestep_embedding +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +import os +import torchvision.transforms.functional as TF +from einops import rearrange +logger = logging.get_logger(__name__) + +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image to image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + feature_extractor ([`CLIPFeatureExtractor`]): + Feature extractor for image pre-processing before being encoded. + image_encoder ([`CLIPVisionModelWithProjection`]): + CLIP vision model for encoding images. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by `noise_level` in `StableUnCLIPPipeline.__call__`. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + # image encoding components + feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + vae: AutoencoderKL + + def __init__( + self, + # image encoding components + feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + num_views: int = 7, + ): + super().__init__() + + self.register_modules( + feature_extractor=feature_extractor, + image_encoder=image_encoder, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.num_views: int = num_views + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list + models = [ + self.image_encoder, + self.text_encoder, + self.unet, + self.vae, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + + if do_classifier_free_guidance: + normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0) + prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0) + + return prompt_embeds + + def _encode_image( + self, + image_pil, + smpl_pil, + device, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level: int=0, + generator: Optional[torch.Generator] = None + ): + dtype = next(self.image_encoder.parameters()).dtype + # ______________________________clip image embedding______________________________ + image = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + # duplicate image embeddings for each generation per prompt, using mps friendly method + # image_embeds = image_embeds.unsqueeze(1) + # note: the condition input is same + image_embeds = image_embeds.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + normal_image_embeds, color_image_embeds = torch.chunk(image_embeds, 2, dim=0) + negative_prompt_embeds = torch.zeros_like(normal_image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0) + + # _____________________________vae input latents__________________________________________________ + def vae_encode(tensor): + image_pt = torch.stack([TF.to_tensor(img) for img in tensor], dim=0).to(device) + image_pt = image_pt * 2.0 - 1.0 + image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor + # Note: repeat differently from official pipelines + image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1) + return image_latents + + image_latents = vae_encode(image_pil) + if smpl_pil is not None: + smpl_latents = vae_encode(smpl_pil) + image_latents = torch.cat([image_latents, smpl_latents], 1) + + if do_classifier_free_guidance: + normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0) + image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents, + torch.zeros_like(color_image_latents), color_image_latents], 0) + + return image_embeds, image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + noise_level, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + def process_dino_feature(self, feat, device, num_images_per_prompt, do_classifier_free_guidance): + feat = feat.to(dtype=self.text_encoder.dtype, device=device) + if do_classifier_free_guidance: + # # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + # seq_len = negative_prompt_embeds.shape[1] + + # negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + # negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + feat = torch.cat([feat, feat], 0) + return feat + @torch.no_grad() + # @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + prompt_embeds: torch.FloatTensor = None, + dino_feature: torch.FloatTensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + image_embeds: Optional[torch.FloatTensor] = None, + gt_img_in: Optional[torch.FloatTensor] = None, + smpl_in: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which + the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the + latents in the denoising process such as in the standard stable diffusion text guided image variation + process. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in + the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as + `latents`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level + ) + + # 2. Define call parameters + if isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + assert batch_size >= self.num_views and batch_size % self.num_views == 0 + elif isinstance(image, PIL.Image.Image): + image = [image]*self.num_views*2 + batch_size = self.num_views*2 + + if isinstance(prompt, str): + prompt = [prompt] * self.num_views * 2 + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale != 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + if dino_feature is not None: + dino_feature = self.process_dino_feature(dino_feature, device=device, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt) + + # 4. Encoder input image + if isinstance(image, list): + image_pil = image + smpl_pil = smpl_in + elif isinstance(image, torch.Tensor): + image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + smpl_pil = [TF.to_pil_image(smpl_in[i]) for i in range(smpl_in.shape[0])] if smpl_in is not None else None + noise_level = torch.tensor([noise_level], device=device) + image_embeds, image_latents = self._encode_image( + image_pil=image_pil, + smpl_pil=smpl_pil, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + noise_level=noise_level, + generator=generator, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.out_channels + if gt_img_in is not None: + latents = gt_img_in * self.scheduler.init_noise_sigma + else: + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + eles, focals = [], [] + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + if do_classifier_free_guidance: + normal_latents, color_latents = torch.chunk(latents, 2, dim=0) + latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0) + else: + latent_model_input = latents + latent_model_input = torch.cat([ + latent_model_input, image_latents + ], dim=1) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + unet_out = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + dino_feature=dino_feature, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False) + + noise_pred = unet_out[0] + + + # perform guidance + if do_classifier_free_guidance: + normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0) + + noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + if not output_type == "latent": + if num_channels_latents == 8: + latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0) + with torch.no_grad(): + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + # self.final_offload_hook.offload() + if not return_dict: + return (image, ) + return ImagePipelineOutput(images=image) diff --git a/node_config/gpu.yaml b/node_config/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd9a2c0baf6506b571f90b3a0a341470d605c926 --- /dev/null +++ b/node_config/gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +main_process_port: 8002 +mixed_precision: 'fp16' +num_machines: 1 +num_processes: 7 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/node_config/rank.yaml b/node_config/rank.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43e32331d446e0811384026b6f6cd3d4cef60b48 --- /dev/null +++ b/node_config/rank.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_process_ip: 10.22.4.87 +main_process_port: 8002 +main_training_function: main +mixed_precision: fp16 +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/reconstruct.py b/reconstruct.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcb49828bdcd248b652219ed8ecf87ed22cf76a --- /dev/null +++ b/reconstruct.py @@ -0,0 +1,449 @@ +from core.remesh import calc_vertex_normals +from core.opt import MeshOptimizer +from utils.func import make_sparse_camera, make_round_views +from utils.render import NormalsRenderer +import torch.optim as optim +from tqdm import tqdm +from utils.video_utils import write_video +from omegaconf import OmegaConf +import numpy as np +import os +from PIL import Image +import kornia +import torch +import torch.nn as nn +import trimesh +from icecream import ic +from utils.project_mesh import multiview_color_projection, get_cameras_list +from utils.mesh_utils import to_py3d_mesh, rot6d_to_rotmat, tensor2variable +from utils.project_mesh import project_color, get_cameras_list +from utils.smpl_util import SMPLX +from lib.dataset.mesh_util import apply_vertex_mask, part_removal, poisson, keep_largest +from scipy.spatial.transform import Rotation as R +from scipy.spatial import KDTree +import argparse +#### ------------------- config---------------------- +bg_color = np.array([1,1,1]) + +class colorModel(nn.Module): + def __init__(self, renderer, v, f, c): + super().__init__() + self.renderer = renderer + self.v = v + self.f = f + self.colors = nn.Parameter(c, requires_grad=True) + self.bg_color = torch.from_numpy(bg_color).float().to(self.colors.device) + def forward(self, return_mask=False): + rgba = self.renderer.render(self.v, self.f, colors=self.colors) + if return_mask: + return rgba + else: + mask = rgba[..., 3:] + return rgba[..., :3] * mask + self.bg_color * (1 - mask) + + +def scale_mesh(vert): + min_bbox, max_bbox = vert.min(0)[0], vert.max(0)[0] + center = (min_bbox + max_bbox) / 2 + offset = -center + vert = vert + offset + + max_dist = torch.max(torch.sqrt(torch.sum(vert**2, dim=1))) + scale = 1.0 / max_dist + return scale, offset + +def save_mesh(save_name, vertices, faces, color=None): + trimesh.Trimesh( + vertices.detach().cpu().numpy(), + faces.detach().cpu().numpy(), + vertex_colors=(color.detach().cpu().numpy() * 255).astype(np.uint8) if color is not None else None) \ + .export(save_name) + + + + +class ReMesh: + def __init__(self, opt, econ_dataset): + self.opt = opt + self.device = torch.device(f"cuda:{opt.gpu_id}" if torch.cuda.is_available() else "cpu") + self.num_view = opt.num_view + + self.out_path = opt.res_path + os.makedirs(self.out_path, exist_ok=True) + self.resolution = opt.resolution + self.views = ['front_face', 'front_right', 'right', 'back', 'left', 'front_left' ] + self.weights = torch.Tensor([1., 0.4, 0.8, 1.0, 0.8, 0.4]).view(6,1,1,1).to(self.device) + + self.renderer = self.prepare_render() + # pose prediction + self.econ_dataset = econ_dataset + self.smplx_face = torch.Tensor(econ_dataset.faces.astype(np.int64)).long().to(self.device) + + def prepare_render(self): + ### ------------------- prepare camera and renderer---------------------- + mv, proj = make_sparse_camera(self.opt.cam_path, self.opt.scale, views=[0,1,2,4,6,7], device=self.device) + renderer = NormalsRenderer(mv, proj, [self.resolution, self.resolution], device=self.device) + return renderer + + def proj_texture(self, fused_images, vertices, faces): + mesh = to_py3d_mesh(vertices, faces) + mesh = mesh.to(self.device) + camera_focal = 1/2 + cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) + mesh = multiview_color_projection(mesh, fused_images, camera_focal=camera_focal, resolution=self.resolution, weights=self.weights.squeeze().cpu().numpy(), + device=self.device, complete_unseen=True, confidence_threshold=0.2, cameras_list=cameras_list) + return mesh + + def get_invisible_idx(self, imgs, vertices, faces): + mesh = to_py3d_mesh(vertices, faces) + mesh = mesh.to(self.device) + camera_focal = 1/2 + if self.num_view == 6: + cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) + elif self.num_view == 4: + cameras_list = get_cameras_list([0, 90, 180, 270], device=self.device, focal=camera_focal) + valid_vert_id = [] + vertices_colors = torch.zeros((vertices.shape[0], 3)).float().to(self.device) + valid_cnt = torch.zeros((vertices.shape[0])).to(self.device) + for cam, img, weight in zip(cameras_list, imgs, self.weights.squeeze()): + ret = project_color(mesh, cam, img, eps=0.01, resolution=self.resolution, device=self.device) + # print(ret['valid_colors'].shape) + valid_cnt[ret['valid_verts']] += weight + vertices_colors[ret['valid_verts']] += ret['valid_colors']*weight + valid_mask = valid_cnt > 1 + invalid_mask = valid_cnt < 1 + vertices_colors[valid_mask] /= valid_cnt[valid_mask][:, None] + + # visibility + invisible_vert = valid_cnt < 1 + invisible_vert_indices = torch.nonzero(invisible_vert).squeeze() + # vertices_colors[invalid_vert] = torch.tensor([1.0, 0.0, 0.0]).float().to("cuda") + return vertices_colors, invisible_vert_indices + + def inpaint_missed_colors(self, all_vertices, all_colors, missing_indices): + all_vertices = all_vertices.detach().cpu().numpy() + all_colors = all_colors.detach().cpu().numpy() + missing_indices = missing_indices.detach().cpu().numpy() + + + non_missing_indices = np.setdiff1d(np.arange(len(all_vertices)), missing_indices) + + kdtree = KDTree(all_vertices[non_missing_indices]) + + + for missing_index in missing_indices: + missing_vertex = all_vertices[missing_index] + + _, nearest_index = kdtree.query(missing_vertex.reshape(1, -1)) + + interpolated_color = all_colors[non_missing_indices[nearest_index]] + + all_colors[missing_index] = interpolated_color + + return torch.from_numpy(all_colors).to(self.device) + + def load_training_data(self, case): + ###------------------ load target images ------------------------------- + kernal = torch.ones(3, 3) + erode_iters = 2 + normals = [] + masks = [] + colors = [] + for idx, view in enumerate(self.views): + # for idx in [0,2,3,4]: + normal = Image.open(f'{self.opt.mv_path}/{case}/normals_{view}_masked.png') + # normal = Image.open(f'{data_path}/{case}/normals/{idx:02d}_rgba.png') + normal = normal.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) + normal = np.array(normal).astype(np.float32) / 255. + mask = normal[..., 3:] # alpha + mask_troch = torch.from_numpy(mask).unsqueeze(0) + for _ in range(erode_iters): + mask_torch = kornia.morphology.erosion(mask_troch, kernal) + mask_erode = mask_torch.squeeze(0).numpy() + masks.append(mask_erode) + normal = normal[..., :3] * mask_erode + normals.append(normal) + + color = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') + color = color.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) + color = np.array(color).astype(np.float32) / 255. + color_mask = color[..., 3:] # alpha + # color_dilate = color[..., :3] * color_mask + bg_color * (1 - color_mask) + color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) + colors.append(color_dilate) + + masks = np.stack(masks, 0) + masks = torch.from_numpy(masks).to(self.device) + normals = np.stack(normals, 0) + target_normals = torch.from_numpy(normals).to(self.device) + colors = np.stack(colors, 0) + target_colors = torch.from_numpy(colors).to(self.device) + return masks, target_colors, target_normals + + def preprocess(self, color_pils, normal_pils): + ###------------------ load target images ------------------------------- + kernal = torch.ones(3, 3) + erode_iters = 2 + normals = [] + masks = [] + colors = [] + for normal, color in zip(normal_pils, color_pils): + normal = normal.resize((self.resolution, self.resolution), Image.BILINEAR) + normal = np.array(normal).astype(np.float32) / 255. + mask = normal[..., 3:] # alpha + mask_troch = torch.from_numpy(mask).unsqueeze(0) + for _ in range(erode_iters): + mask_torch = kornia.morphology.erosion(mask_troch, kernal) + mask_erode = mask_torch.squeeze(0).numpy() + masks.append(mask_erode) + normal = normal[..., :3] * mask_erode + normals.append(normal) + + color = color.resize((self.resolution, self.resolution), Image.BILINEAR) + color = np.array(color).astype(np.float32) / 255. + color_mask = color[..., 3:] # alpha + # color_dilate = color[..., :3] * color_mask + bg_color * (1 - color_mask) + color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) + colors.append(color_dilate) + + masks = np.stack(masks, 0) + masks = torch.from_numpy(masks).to(self.device) + normals = np.stack(normals, 0) + target_normals = torch.from_numpy(normals).to(self.device) + colors = np.stack(colors, 0) + target_colors = torch.from_numpy(colors).to(self.device) + return masks, target_colors, target_normals + + def optimize_case(self, case, pose, clr_img, nrm_img, opti_texture=True): + case_path = f'{self.out_path}/{case}' + os.makedirs(case_path, exist_ok=True) + + if clr_img is not None: + masks, target_colors, target_normals = self.preprocess(clr_img, nrm_img) + else: + masks, target_colors, target_normals = self.load_training_data(case) + + # rotation + rz = R.from_euler('z', 180, degrees=True).as_matrix() + ry = R.from_euler('y', 180, degrees=True).as_matrix() + rz = torch.from_numpy(rz).float().to(self.device) + ry = torch.from_numpy(ry).float().to(self.device) + + scale, offset = None, None + + global_orient = pose["global_orient"] # pymaf_res[idx]['smplx_params']['body_pose'][:, :1, :, :2].to(device).reshape(1, 1, -1) # data["global_orient"] + body_pose = pose["body_pose"] # pymaf_res[idx]['smplx_params']['body_pose'][:, 1:22, :, :2].to(device).reshape(1, 21, -1) # data["body_pose"] + left_hand_pose = pose["left_hand_pose"] # pymaf_res[idx]['smplx_params']['left_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1) + right_hand_pose = pose["right_hand_pose"] # pymaf_res[idx]['smplx_params']['right_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1) + beta = pose["betas"] + + # The optimizer and variables + optimed_pose = torch.tensor(body_pose, + device=self.device, + requires_grad=True) # [1,23,3,3] + optimed_trans = torch.tensor(pose["trans"], + device=self.device, + requires_grad=True) # [3] + optimed_betas = torch.tensor(beta, + device=self.device, + requires_grad=True) # [1,200] + optimed_orient = torch.tensor(global_orient, + device=self.device, + requires_grad=True) # [1,1,3,3] + optimed_rhand = torch.tensor(right_hand_pose, + device=self.device, + requires_grad=True) + optimed_lhand = torch.tensor(left_hand_pose, + device=self.device, + requires_grad=True) + + optimed_params = [ + {'params': [optimed_lhand, optimed_rhand], 'lr': 1e-3}, + {'params': [optimed_betas, optimed_trans, optimed_orient, optimed_pose], 'lr': 3e-3}, + ] + optimizer_smpl = torch.optim.Adam( + optimed_params, + amsgrad=True, + ) + scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_smpl, + mode="min", + factor=0.5, + verbose=0, + min_lr=1e-5, + patience=5, + ) + smpl_steps = 100 + + for i in tqdm(range(smpl_steps)): + optimizer_smpl.zero_grad() + # 6d_rot to rot_mat + optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view( + -1, 6)).unsqueeze(0) + optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view( + -1, 6)).unsqueeze(0) + + smpl_verts, smpl_landmarks, smpl_joints = self.econ_dataset.smpl_model( + shape_params=optimed_betas, + expression_params=tensor2variable(pose["exp"], self.device), + body_pose=optimed_pose_mat, + global_pose=optimed_orient_mat, + jaw_pose=tensor2variable(pose["jaw_pose"], self.device), + left_hand_pose=optimed_lhand, + right_hand_pose=optimed_rhand, + + ) + + smpl_verts = smpl_verts + optimed_trans + + v_smpl = torch.matmul(torch.matmul(smpl_verts.squeeze(0), rz.T), ry.T) + if scale is None: + scale, offset = scale_mesh(v_smpl.detach()) + v_smpl = (v_smpl + offset) * scale * 2 + # if i == 0: + # save_mesh(f'{case_path}/{case}_init_smpl.obj', v_smpl, self.smplx_face) + # exit() + normals = calc_vertex_normals(v_smpl, self.smplx_face) + nrm = self.renderer.render(v_smpl, self.smplx_face, normals=normals) + + masks_ = nrm[..., 3:] + smpl_mask_loss = ((masks_ - masks) * self.weights).abs().mean() + smpl_nrm_loss = ((nrm[..., :3] - target_normals) * self.weights).abs().mean() + + smpl_loss = smpl_mask_loss + smpl_nrm_loss + # smpl_loss = smpl_mask_loss + smpl_loss.backward() + optimizer_smpl.step() + scheduler_smpl.step(smpl_loss) + + mesh_smpl = trimesh.Trimesh(vertices=v_smpl.detach().cpu().numpy(), faces=self.smplx_face.detach().cpu().numpy()) + + + nrm_opt = MeshOptimizer(v_smpl.detach(), self.smplx_face.detach(), edge_len_lims=[0.01, 0.1]) + vertices, faces = nrm_opt.vertices, nrm_opt.faces + + # ###----------------------- optimization iterations------------------------------------- + for i in tqdm(range(self.opt.iters)): + nrm_opt.zero_grad() + + normals = calc_vertex_normals(vertices,faces) + nrm = self.renderer.render(vertices,faces, normals=normals) + normals = nrm[..., :3] + # if i < 800: + loss = ((normals-target_normals) * self.weights).abs().mean() + # else: + # loss = ((normals-target_images) * masks).abs().mean() + + alpha = nrm[..., 3:] + loss += ((alpha - masks) * self.weights).abs().mean() + + loss.backward() + + nrm_opt.step() + + vertices,faces = nrm_opt.remesh() + + if self.opt.debug and i % self.opt.snapshot_step == 0: + import imageio + os.makedirs(f'{case_path}/normals', exist_ok=True) + imageio.imwrite(f'{case_path}/normals/{i:02d}.png',(nrm.detach()[0,:,:,:3]*255).clamp(max=255).type(torch.uint8).cpu().numpy()) + # mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) + # mesh_remeshed.export(f'{case_path}/{case}_remeshed_step{i}.obj') + torch.cuda.empty_cache() + + mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) + mesh_remeshed.export(f'{case_path}/{case}_remeshed.obj') + # save_mesh(case, vertices, faces) + vertices = vertices.detach() + faces = faces.detach() + + #### replace hand + smpl_data = SMPLX() + if self.opt.replace_hand and True in pose['hands_visibility'][0]: + hand_mask = torch.zeros(smpl_data.smplx_verts.shape[0], ) + if pose['hands_visibility'][0][0]: + hand_mask.index_fill_( + 0, torch.tensor(smpl_data.smplx_mano_vid_dict["left_hand"]), 1.0 + ) + if pose['hands_visibility'][0][1]: + hand_mask.index_fill_( + 0, torch.tensor(smpl_data.smplx_mano_vid_dict["right_hand"]), 1.0 + ) + + hand_mesh = apply_vertex_mask(mesh_smpl.copy(), hand_mask) + body_mesh = part_removal( + mesh_remeshed.copy(), + hand_mesh, + 0.08, + self.device, + mesh_smpl.copy(), + region="hand" + ) + final = poisson(sum([hand_mesh, body_mesh]), f'{case_path}/{case}_final.obj', 10, False) + else: + final = poisson(mesh_remeshed, f'{case_path}/{case}_final.obj', 10, False) + vertices = torch.from_numpy(final.vertices).float().to(self.device) + faces = torch.from_numpy(final.faces).long().to(self.device) + # Differing from paper, we use the texturing method in Unique3D + masked_color = [] + for tmp in clr_img: + # tmp = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') + tmp = tmp.resize((self.resolution, self.resolution), Image.BILINEAR) + tmp = np.array(tmp).astype(np.float32) / 255. + masked_color.append(torch.from_numpy(tmp).permute(2, 0, 1).to(self.device)) + + meshes = self.proj_texture(masked_color, vertices, faces) + vertices = meshes.verts_packed().float() + faces = meshes.faces_packed().long() + colors = meshes.textures.verts_features_packed().float() + save_mesh(f'./{case_path}/result_clr_scale{self.opt.scale}_{case}.obj', vertices, faces, colors) + self.evaluate(vertices, colors, faces, save_path=f'{case_path}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) + + + def evaluate(self, target_vertices, target_colors, target_faces, save_path=None, save_nrm=False): + mv, proj = make_round_views(60, self.opt.scale, device=self.device) + renderer = NormalsRenderer(mv, proj, [512, 512], device=self.device) + + target_images = renderer.render(target_vertices,target_faces, colors=target_colors) + target_images = target_images.detach().cpu().numpy() + target_images = target_images[..., :3] * target_images[..., 3:4] + bg_color * (1 - target_images[..., 3:4]) + target_images = (target_images.clip(0, 1) * 255).astype(np.uint8) + + if save_nrm: + target_normals = calc_vertex_normals(target_vertices, target_faces) + # target_normals[:, 2] *= -1 + target_normals = renderer.render(target_vertices, target_faces, normals=target_normals) + target_normals = target_normals.detach().cpu().numpy() + target_normals = target_normals[..., :3] * target_normals[..., 3:4] + bg_color * (1 - target_normals[..., 3:4]) + target_normals = (target_normals.clip(0, 1) * 255).astype(np.uint8) + frames = [np.concatenate([img, nrm], 1) for img, nrm in zip(target_images, target_normals)] + else: + frames = [img for img in target_images] + if save_path is not None: + write_video(frames, fps=25, save_path=save_path) + return frames + + def run(self): + cases = sorted(os.listdir(self.opt.imgs_path)) + for idx in range(len(cases)): + case = cases[idx].split('.')[0] + print(f'Processing {case}') + pose = self.econ_dataset.__getitem__(idx) + v, f, c = self.optimize_case(case, pose, None, None, opti_texture=True) + self.evaluate(v, c, f, save_path=f'{self.opt.res_path}/{case}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--config", help="path to the yaml configs file", default='config.yaml') + args, extras = parser.parse_known_args() + + opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) + from econdataset import SMPLDataset + dataset_param = {'image_dir': opt.imgs_path, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} + econdata = SMPLDataset(dataset_param, device='cuda') + EHuman = ReMesh(opt, econdata) + EHuman.run() + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7814ff8fd4b176161e8b222a08d8b72ceb1b10ae --- /dev/null +++ b/requirements.txt @@ -0,0 +1,217 @@ +absl-py==2.1.0 +accelerate==1.1.1 +addict==2.4.0 +aiofiles==23.2.1 +aiohappyeyeballs==2.4.3 +aiohttp==3.11.2 +aiosignal==1.3.1 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.6.2.post1 +asttokens==2.4.1 +async-timeout==5.0.1 +attrs==24.2.0 +blinker==1.9.0 +certifi==2022.12.7 +cffi==1.17.1 +charset-normalizer==2.1.1 +click==8.1.7 +colorama==0.4.6 +coloredlogs==15.0.1 +comm==0.2.2 +ConfigArgParse==1.7 +contourpy==1.3.1 +cycler==0.12.1 +dash==2.18.2 +dash-core-components==2.0.0 +dash-html-components==2.0.0 +dash-table==5.0.0 +dataclasses-json==0.6.7 +datasets==3.1.0 +decorator==4.4.2 +Deprecated==1.2.15 +diffusers==0.27.2 +dill==0.3.8 +docker-pycreds==0.4.0 +einops==0.8.0 +entrypoints==0.4 +exceptiongroup==1.2.2 +executing==2.1.0 +fastapi==0.115.5 +fastjsonschema==2.21.0 +ffmpy==0.4.0 +filelock==3.13.1 +fire==0.7.0 +Flask==3.0.3 +flatbuffers==24.3.25 +fonttools==4.55.0 +frozenlist==1.5.0 +fsspec==2024.2.0 +gitdb==4.0.11 +GitPython==3.1.43 +gradio==5.6.0 +gradio_client==1.4.3 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.27.2 +huggingface-hub==0.24.5 +humanfriendly==10.0 +icecream==2.1.3 +idna==3.4 +imageio==2.36.0 +imageio-ffmpeg==0.5.1 +importlib_metadata==8.5.0 +iopath==0.1.10 +ipycanvas==0.13.3 +ipyevents==2.0.2 +ipython==8.30.0 +ipywidgets==8.1.5 +itsdangerous==2.2.0 +jax==0.4.35 +jaxlib==0.4.35 +jaxtyping==0.2.34 +jedi==0.19.2 +Jinja2==3.1.3 +joblib==1.4.2 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +jupyter_client==7.4.9 +jupyter_core==5.7.2 +jupyterlab_widgets==3.0.13 +kaolin==0.17.0 +kiwisolver==1.4.7 +kornia==0.7.4 +kornia_rs==0.1.7 +lazy_loader==0.4 +llvmlite==0.43.0 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +marshmallow==3.23.1 +matplotlib==3.9.2 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mediapipe==0.10.18 +ml_dtypes==0.5.0 +moviepy==1.0.3 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +mypy-extensions==1.0.0 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.2.1 +numba==0.60.0 +numpy==1.26.3 +nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@729261dc64c4241ea36efda84fbf532cc8b425b8 +nvidia-cuda-runtime-cu12==12.6.77 +nvidia-pyindex==1.0.9 +nvidia-tensorrt==99.0.0 +omegaconf==2.3.0 +onnxruntime==1.20.0 +onnxruntime-gpu==1.20.0 +open3d==0.18.0 +opencv-contrib-python==4.10.0.84 +opencv-python==4.10.0.84 +opencv-python-headless==4.10.0.84 +opt_einsum==3.4.0 +orjson==3.10.11 +ort-nightly-gpu==1.15.0.dev20230502003 +packaging==24.2 +pandas==2.2.3 +parso==0.8.4 +peft==0.13.2 +pexpect==4.9.0 +pillow==10.2.0 +platformdirs==4.3.6 +plotly==5.24.1 +pooch==1.8.2 +portalocker==2.10.1 +proglog==0.1.10 +prompt_toolkit==3.0.48 +propcache==0.2.0 +protobuf==4.25.5 +psutil==5.9.8 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyarrow==18.0.0 +pybind11==2.13.6 +pycparser==2.22 +pydantic==2.9.2 +pydantic_core==2.23.4 +pydub==0.25.1 +pygltflib==1.16.2 +Pygments==2.18.0 +PyMatting==1.1.13 +pymeshlab==2023.12.post2 +pyparsing==3.2.0 +pyquaternion==0.9.9 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-multipart==0.0.12 +pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47 +pytz==2024.2 +PyYAML==6.0.2 +pyzmq==26.2.0 +referencing==0.35.1 +regex==2024.11.6 +rembg==2.0.59 +requests==2.32.3 +retrying==1.3.4 +rich==13.9.4 +rpds-py==0.21.0 +ruff==0.7.4 +safehttpx==0.1.1 +safetensors==0.4.5 +scikit-image==0.24.0 +scikit-learn==1.5.2 +scipy==1.14.1 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==2.18.0 +setproctitle==1.3.4 +shellingham==1.5.4 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +sounddevice==0.5.1 +spaces==0.30.4 +stack-data==0.6.3 +starlette==0.41.2 +sympy==1.13.1 +tenacity==9.0.0 +tensorrt==10.6.0 +tensorrt-cu12==10.6.0 +tensorrt-cu12-bindings==10.6.0 +tensorrt-cu12-libs==10.6.0 +termcolor==2.5.0 +threadpoolctl==3.5.0 +tifffile==2024.9.20 +tokenizers==0.20.3 +tomlkit==0.12.0 +torch_scatter==2.1.2 +tornado==6.4.2 +tqdm==4.67.0 +traitlets==5.14.3 +transformers==4.46.2 +trimesh==4.5.2 +triton==2.1.0 +typeguard==2.13.3 +typer==0.13.0 +typing-inspect==0.9.0 +typing_extensions==4.9.0 +tzdata==2024.2 +urllib3==1.26.13 +usd-core==24.11 +uvicorn==0.32.0 +wandb==0.18.7 +warp-lang==1.4.2 +wcwidth==0.2.13 +websockets==12.0 +Werkzeug==3.0.6 +widgetsnbextension==4.0.13 +wrapt==1.16.0 +xformers==0.0.22.post7 +xxhash==3.5.0 +yacs==0.1.8 +yarl==1.17.1 +zipp==3.21.0 diff --git a/scripts/inference_768.sh b/scripts/inference_768.sh new file mode 100644 index 0000000000000000000000000000000000000000..480f5a8d861f9e4442f07240bfda047dcd70fe6f --- /dev/null +++ b/scripts/inference_768.sh @@ -0,0 +1,13 @@ +seed=600 +gpu=$1 +CUDA_VISIBLE_DEVICES=$gpu python inference.py --config configs/inference-768-6view.yaml \ + pretrained_model_name_or_path='pengHTYX/PSHuman_Unclip_768_6views' \ + validation_dataset.crop_size=740 \ + with_smpl=false \ + validation_dataset.root_dir='examples' \ + seed=$seed \ + num_views=7 \ + save_mode='rgb' # if save rgba images for each view, if not, noly save concatenated images for visualization + + + diff --git a/scripts/train_768.sh b/scripts/train_768.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f234ee5bdacfec3f311985038ebfaadeb50656c --- /dev/null +++ b/scripts/train_768.sh @@ -0,0 +1,8 @@ +#!/bin/bash +export WANDB_API_KEY=$KEY # replace $KEY with your wandb key +export CUDA_VISIBLE_DEVICES=0,1,2,3 +### CMD + +accelerate launch --config_file node_config/gpu.yaml --num_processes 4 \ + train_mvdiffusion_unit_unclip.py \ + --config configs/train-768-6view-onlyscan_face.yaml > log/log_$$.txt 2>&1 diff --git a/train_mvdiffusion_unit_unclip.py b/train_mvdiffusion_unit_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..8cea716143ff472a0f23f90bfb9ab7344697036b --- /dev/null +++ b/train_mvdiffusion_unit_unclip.py @@ -0,0 +1,1135 @@ +import logging +import warnings +from typing import Callable, List, Optional, Union, Dict, Any + +import PIL +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.embeddings import get_timestep_embedding +from diffusers.schedulers import KarrasDiffusionSchedulers, PNDMScheduler, DDIMScheduler, DDPMScheduler +from diffusers.utils import deprecate +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +import transformers +import diffusers +import accelerate +from accelerate import Accelerator +from torchvision.transforms import InterpolationMode +import argparse +from omegaconf import OmegaConf +from mvdiffusion.models_unclip.unet_mv2d_condition import UNetMV2DConditionModel +# from mvdiffusion.data.objaverse_dataset_unclip_xxdata import ObjaverseDataset as MVDiffusionDataset +from mvdiffusion.data.dreamdata import ObjaverseDataset as MVDiffusionDataset +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from accelerate.logging import get_logger +import os +import numpy as np +from PIL import Image +import math +from tqdm import tqdm +from einops import rearrange, repeat +from torchvision.transforms import InterpolationMode +from einops import rearrange, repeat +from diffusers.schedulers import PNDMScheduler +from collections import defaultdict +from torchvision.utils import make_grid, save_image +from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline +from dataclasses import dataclass +import json +import shutil +from mvdiffusion.models_unclip.face_networks import prepare_face_proj_model +logger = get_logger(__name__, log_level="INFO") +@dataclass +class TrainingConfig: + pretrained_model_name_or_path: str + pretrained_unet_path: Optional[str] + clip_path: str + revision: Optional[str] + data_common: Optional[dict] + train_dataset: Dict + validation_dataset: Dict + validation_train_dataset: Dict + output_dir: str + checkpoint_prefix: str + seed: Optional[int] + train_batch_size: int + validation_batch_size: int + validation_train_batch_size: int + max_train_steps: int + gradient_accumulation_steps: int + gradient_checkpointing: bool + learning_rate: float + scale_lr: bool + lr_scheduler: str + step_rules: Optional[str] + lr_warmup_steps: int + snr_gamma: Optional[float] + use_8bit_adam: bool + allow_tf32: bool + use_ema: bool + dataloader_num_workers: int + adam_beta1: float + adam_beta2: float + adam_weight_decay: float + adam_epsilon: float + max_grad_norm: Optional[float] + prediction_type: Optional[str] + logging_dir: str + vis_dir: str + mixed_precision: Optional[str] + report_to: Optional[str] + local_rank: int + checkpointing_steps: int + checkpoints_total_limit: Optional[int] + resume_from_checkpoint: Optional[str] + enable_xformers_memory_efficient_attention: bool + validation_steps: int + validation_sanity_check: bool + tracker_project_name: str + + trainable_modules: Optional[list] + use_classifier_free_guidance: bool + condition_drop_rate: float + scale_input_latents: bool + regress_elevation: bool + regress_focal_length: bool + elevation_loss_weight: float + focal_loss_weight: float + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_guidance_scales: List[float] + validation_grid_nrow: int + camera_embedding_lr_mult: float + plot_pose_acc: bool + num_views: int + data_view_num: Optional[int] + pred_type: str + drop_type: str + with_smpl: Optional[bool] + +@torch.no_grad() +def convert_image( + tensor, + fp, + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(fp, format=format) + +def log_validation_joint(dataloader, vae, feature_extractor, image_encoder, image_normlizer, image_noising_scheduler, tokenizer, text_encoder, + unet, face_proj_model, cfg:TrainingConfig, accelerator, weight_dtype, global_step, name, save_dir): + + pipeline = StableUnCLIPImg2ImgPipeline( + image_encoder=image_encoder, feature_extractor=feature_extractor, image_normalizer=image_normlizer, + image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, + vae=vae, unet=accelerator.unwrap_model(unet), + scheduler=DDIMScheduler.from_pretrained_linear(cfg.pretrained_model_name_or_path, subfolder="scheduler"), + **cfg.pipe_kwargs + ) + + pipeline.set_progress_bar_config(disable=True) + + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=unet.device).manual_seed(cfg.seed) + + images_cond, pred_cat = [], defaultdict(list) + for i, batch in tqdm(enumerate(dataloader)): + images_cond.append(batch['imgs_in'][:, 0]) + if face_proj_model is not None: + face_embeds = batch['face_embed'] + face_embeds = torch.cat([face_embeds]*2, dim=0) + face_embeds = rearrange(face_embeds, "B Nv L C -> (B Nv) L C") + face_embeds = face_embeds.to(device=accelerator.device, dtype=weight_dtype) + face_embeds = face_proj_model(face_embeds) + else: + face_embeds = None + # if dino_encoder: + # dino_input = TF.resize(batch['imgs_in'][:, 0], (224, 224)).float().to(accelerator.device) + # dino_feature = dino_encoder(dino_input) + # dino_feature = repeat(dino_feature, "B N C -> (B V) N C", V=cfg.num_views*2) + # else: + # dino_feature = None + imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) + num_views = imgs_in.shape[1] + imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) + + if cfg.with_smpl: + smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0) + smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W") + else: + smpl_in = None + + normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings'] + prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) + prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") + with torch.autocast("cuda"): + # B*Nv images + for guidance_scale in cfg.validation_guidance_scales: + out = pipeline( + imgs_in, None, prompt_embeds=prompt_embeddings, + dino_feature=face_embeds, smpl_in=smpl_in, + generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs + ).images + + bsz = out.shape[0] // 2 + normals_pred = out[:bsz] + images_pred = out[bsz:] + # print(normals_pred.shape, images_pred.shape) + pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) # b, 3, h, w + + # from icecream import ic + images_cond_all = torch.cat(images_cond, dim=0) + images_pred_all = {} + for k, v in pred_cat.items(): + images_pred_all[k] = torch.cat(v, dim=0).cpu() + # print(images_pred_all[k].shape) + # import pdb;pdb.set_trace() + nrow = cfg.validation_grid_nrow + # ncol = images_cond_all.shape[0] // nrow + images_cond_grid = make_grid(images_cond_all, nrow=1, padding=0, value_range=(0, 1)) + edge_pad = torch.zeros(list(images_cond_grid.shape[:2]) + [3], dtype=torch.float32) + images_vis = torch.cat([images_cond_grid, edge_pad], -1) + for k, v in images_pred_all.items(): + images_vis = torch.cat([images_vis, make_grid(v, nrow=nrow, padding=0, value_range=(0, 1)), edge_pad], -1) + save_image(images_vis, os.path.join(save_dir, f"{name}-{global_step}.jpg")) + torch.cuda.empty_cache() + +def log_validation(dataloader, vae, feature_extractor, image_encoder, image_normlizer, image_noising_scheduler, tokenizer, text_encoder, + unet, face_proj_model, cfg:TrainingConfig, accelerator, weight_dtype, global_step, name, save_dir): + logger.info(f"Running {name} ... ") + + pipeline = StableUnCLIPImg2ImgPipeline( + image_encoder=image_encoder, feature_extractor=feature_extractor, image_normalizer=image_normlizer, + image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, + vae=vae, unet=accelerator.unwrap_model(unet), + scheduler=DDIMScheduler.from_pretrained_linear(cfg.pretrained_model_name_or_path, subfolder="scheduler"), + **cfg.pipe_kwargs + ) + + pipeline.set_progress_bar_config(disable=True) + + if cfg.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed) + + images_cond, images_gt, images_pred = [], [], defaultdict(list) + for i, batch in enumerate(dataloader): + # (B, Nv, 3, H, W) + imgs_in, colors_out, normals_out = batch['imgs_in'], batch['imgs_out'], batch['normals_out'] + images_cond.append(imgs_in[:, 0, :, :, :]) + + # repeat (2B, Nv, 3, H, W) + imgs_in = torch.cat([imgs_in]*2, dim=0) + imgs_out = torch.cat([normals_out, colors_out], dim=0) + imgs_in, imgs_out = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W"), rearrange(imgs_out, "B Nv C H W -> (B Nv) C H W") + images_gt.append(imgs_out) + + if cfg.with_smpl: + smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0) + smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W") + else: + smpl_in = None + + prompt_embeddings = torch.cat([batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']], dim=0) + # (B*Nv, N, C) + prompt_embeds = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") + prompt_embeds = prompt_embeds.to(weight_dtype) + + if face_proj_model is not None: + face_embeds = batch['face_embed'] + face_embeds = torch.cat([face_embeds]*2, dim=0) + face_embeds = rearrange(face_embeds, "B Nv L C -> (B Nv) L C") + face_embeds = face_embeds.to(device=accelerator.device, dtype=weight_dtype) + face_embeds = face_proj_model(face_embeds) + else: + face_embeds = None + with torch.autocast("cuda"): + # B*Nv images + for guidance_scale in cfg.validation_guidance_scales: + out = pipeline( + imgs_in, None, prompt_embeds=prompt_embeds, smpl_in=smpl_in, dino_feature=face_embeds, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs + ).images + shape = out.shape + out0, out1 = out[:shape[0]//2], out[shape[0]//2:] + out = [] + for ii in range(shape[0]//2): + out.append(out0[ii]) + out.append(out1[ii]) + out = torch.stack(out, dim=0) + images_pred[f"{name}-sample_cfg{guidance_scale:.1f}"].append(out) + + images_cond_all = torch.cat(images_cond, dim=0) + images_gt_all = torch.cat(images_gt, dim=0) + images_pred_all = {} + for k, v in images_pred.items(): + images_pred_all[k] = torch.cat(v, dim=0).cpu() + + nrow = cfg.validation_grid_nrow * 2 + images_cond_grid = make_grid(images_cond_all, nrow=1, padding=0, value_range=(0, 1)) + images_gt_grid = make_grid(images_gt_all, nrow=nrow, padding=0, value_range=(0, 1)) + edge_pad = torch.zeros(list(images_cond_grid.shape[:2]) + [3], dtype=torch.float32) + images_vis = torch.cat([images_cond_grid.cpu(), edge_pad], -1) + for k, v in images_pred_all.items(): + images_vis = torch.cat([images_vis, make_grid(v, nrow=nrow, padding=0, value_range=(0, 1)), edge_pad], -1) + + # images_pred_grid = {} + # for k, v in images_pred_all.items(): + # images_pred_grid[k] = make_grid(v, nrow=nrow, padding=0, value_range=(0, 1)) + save_image(images_vis, os.path.join(save_dir, f"{global_step}-{name}-cond.jpg")) + save_image(images_gt_grid, os.path.join(save_dir, f"{global_step}-{name}-gt.jpg")) + torch.cuda.empty_cache() + + +def noise_image_embeddings( + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + image_normalizer: Optional[StableUnCLIPImageNormalizer] = None, + image_noising_scheduler: Optional[DDPMScheduler] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + image_embeds = image_normalizer.scale(image_embeds) + + image_embeds = image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + image_embeds = torch.cat((image_embeds, noise_level), 1) + return image_embeds + + +def main(cfg: TrainingConfig): + # -------------------------------------------prepare custom log and accelaeator -------------------------------- + # override local_rank with envvar + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank not in [-1, cfg.local_rank]: + cfg.local_rank = env_local_rank + + logging_dir = os.path.join(cfg.output_dir, cfg.logging_dir) + model_dir = os.path.join(cfg.checkpoint_prefix, cfg.output_dir) + vis_dir = os.path.join(cfg.output_dir, cfg.vis_dir) + accelerator_project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=logging_dir) + # print(os.getenv("SLURM_PROCID"), os.getenv("SLURM_LOCALID"), os.getenv("SLURM_NODEID"), os.getenv('GLOBAL_RANK'), os.getenv('LOCAL_RANK'), os.getenv('RNAK'), os.getenv('MASTER_ADDR')) + # exit() + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + mixed_precision=cfg.mixed_precision, + log_with=cfg.report_to, + project_config=accelerator_project_config, + ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + set_seed(cfg.seed) + + # Handle the repository creation + if accelerator.is_main_process: + os.makedirs(model_dir, exist_ok=True) + os.makedirs(cfg.output_dir, exist_ok=True) + os.makedirs(vis_dir, exist_ok=True) + OmegaConf.save(cfg, os.path.join(cfg.output_dir, 'config.yaml')) + ## -------------------------------------- load models -------------------------------- + image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", revision=cfg.revision) + feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision) + image_noising_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_noising_scheduler") + image_normlizer = StableUnCLIPImageNormalizer.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_normalizer") + + tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="tokenizer", revision=cfg.revision) + text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder='text_encoder', revision=cfg.revision) + # note: official code use PNDMScheduler + noise_scheduler = DDPMScheduler.from_pretrained_linear(cfg.pretrained_model_name_or_path, subfolder="scheduler") + vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision) + if cfg.pretrained_unet_path is None: + + unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) + else: + logger.info(f'laod pretrained model from {cfg.pretrained_unet_path}') + unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) + # unet = UNet2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision) + if cfg.unet_from_pretrained_kwargs.use_dino: + from models.dinov2_wrapper import Dinov2Wrapper + dino_encoder = Dinov2Wrapper(model_name='dinov2_vitb14', freeze=True) + else: + dino_encoder = None + + # TODO: extract face projection model weights + if cfg.unet_from_pretrained_kwargs.use_face_adapter: + face_proj_model = prepare_face_proj_model('models/image_proj_model.pth', cross_attention_dim=1024, pretrain=False) + else: + face_proj_model = None + + if cfg.use_ema: + ema_unet = EMAModel(unet.parameters(), model_cls=UNetMV2DConditionModel, model_config=unet.config) + # ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # Freeze vae, image_encoder, text_encoder + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + image_normlizer.requires_grad_(False) + text_encoder.requires_grad_(False) + if face_proj_model is not None: face_proj_model.requires_grad_(True) + + if cfg.trainable_modules is None: + unet.requires_grad_(True) + else: + unet.requires_grad_(False) + for name, module in unet.named_modules(): + if name.endswith(tuple(cfg.trainable_modules)): + for params in module.parameters(): + params.requires_grad = True + + if cfg.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + print("use xformers to speed up") + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if cfg.use_ema: + ema_unet.save_pretrained(os.path.join(cfg.checkpoint_prefix, output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(cfg.checkpoint_prefix, output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if cfg.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(cfg.checkpoint_prefix, input_dir, "unet_ema"), UNetMV2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNetMV2DConditionModel.from_pretrained(os.path.join(cfg.checkpoint_prefix, input_dir), subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if cfg.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if cfg.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + # -------------------------------------- optimizer and lr -------------------------------- + if cfg.scale_lr: + cfg.learning_rate = ( + cfg.learning_rate * cfg.gradient_accumulation_steps * cfg.train_batch_size * accelerator.num_processes + ) + # Initialize the optimizer + if cfg.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + params, params_class_embedding, params_rowwise_layers = [], [], [] + for name, param in unet.named_parameters(): + if ('class_embedding' in name) or ('camera_embedding' in name): + params_class_embedding.append(param) + elif ('attn_mv' in name) or ('norm_mv' in name): + # print('Find mv attn block') + params_rowwise_layers.append(param) + else: + params.append(param) + opti_params = [{"params": params, "lr": cfg.learning_rate}] + if len(params_class_embedding) > 0: + opti_params.append({"params": params_class_embedding, "lr": cfg.learning_rate * cfg.camera_embedding_lr_mult}) + if len(params_rowwise_layers) > 0: + opti_params.append({"params": params_rowwise_layers, "lr": cfg.learning_rate * cfg.camera_embedding_lr_mult}) + optimizer = optimizer_cls( + opti_params, + betas=(cfg.adam_beta1, cfg.adam_beta2), + weight_decay=cfg.adam_weight_decay, + eps=cfg.adam_epsilon, + ) + lr_scheduler = get_scheduler( + cfg.lr_scheduler, + step_rules=cfg.step_rules, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps * accelerator.num_processes, + num_training_steps=cfg.max_train_steps * accelerator.num_processes, + ) + # -------------------------------------- load dataset -------------------------------- + # Get the training dataset + train_dataset = MVDiffusionDataset( + **cfg.train_dataset + ) + if cfg.with_smpl: + from mvdiffusion.data.testdata_with_smpl import SingleImageDataset + else: + from mvdiffusion.data.single_image_dataset import SingleImageDataset + validation_dataset = SingleImageDataset( + **cfg.validation_dataset + ) + validation_train_dataset = MVDiffusionDataset( + **cfg.validation_train_dataset + ) + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.dataloader_num_workers, + ) + validation_dataloader = torch.utils.data.DataLoader( + validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers + ) + validation_train_dataloader = torch.utils.data.DataLoader( + validation_train_dataset, batch_size=cfg.validation_train_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers + ) + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + if cfg.use_ema: + ema_unet.to(accelerator.device) + # -------------------------------------- cast dtype and device -------------------------------- + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + cfg.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + cfg.mixed_precision = accelerator.mixed_precision + + # Move text_encode and vae to gpu and cast to weight_dtype + image_encoder.to(accelerator.device, dtype=weight_dtype) + image_normlizer.to(accelerator.device, weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + if face_proj_model: face_proj_model.to(accelerator.device, dtype=weight_dtype) + if dino_encoder: dino_encoder.to(accelerator.device) + + clip_image_mean = torch.as_tensor(feature_extractor.image_mean)[:,None,None].to(accelerator.device, dtype=torch.float32) + clip_image_std = torch.as_tensor(feature_extractor.image_std)[:,None,None].to(accelerator.device, dtype=torch.float32) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + # tracker_config = dict(vars(cfg)) + tracker_config = {} + accelerator.init_trackers( + project_name= cfg.tracker_project_name, + config= tracker_config, + init_kwargs={"wandb": + {"entity": "lpstarry", + "notes": cfg.output_dir.split('/')[-1], + # "tags": [cfg.output_dir.split('/')[-1]], + }},) + + # -------------------------------------- load pipeline -------------------------------- + # pipe = StableUnCLIPImg2ImgPipeline(feature_extractor=feature_extractor, + # image_encoder=image_encoder, + # image_normalizer=image_normlizer, + # image_noising_scheduler= image_noising_scheduler, + # tokenizer=tokenizer, + # text_encoder=text_encoder, + # unet=unet, + # scheduler=noise_scheduler, + # vae=vae).to('cuda') + + # -------------------------------------- train -------------------------------- + total_batch_size = cfg.train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps + generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed) + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {cfg.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + if cfg.resume_from_checkpoint != "latest": + path = os.path.basename(cfg.resume_from_checkpoint) + else: + # Get the most recent checkpoint + if os.path.exists(os.path.join(model_dir, "checkpoint")): + path = "checkpoint" + else: + dirs = os.listdir(model_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{cfg.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + cfg.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(model_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + if False: + # log_validation_joint( + # validation_dataloader, + # vae, + # feature_extractor, + # image_encoder, + # image_normlizer, + # image_noising_scheduler, + # tokenizer, + # text_encoder, + # unet, + # dino_encoder, + # cfg, + # accelerator, + # weight_dtype, + # global_step, + # 'validation', + # vis_dir + # ) + log_validation( + validation_train_dataloader, + vae, + feature_extractor, + image_encoder, + image_normlizer, + image_noising_scheduler, + tokenizer, + text_encoder, + unet, + cfg, + accelerator, + weight_dtype, + global_step, + 'validation-train', + vis_dir + ) + exit() + + progress_bar = tqdm( + range(0, cfg.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + new_layer_norm = {} + # Main training loop + + for epoch in range(first_epoch, num_train_epochs): + unet.train() + train_mse_loss, train_ele_loss, train_focal_loss = 0.0, 0.0, 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + # if cfg.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + # if step % cfg.gradient_accumulation_steps == 0: + # progress_bar.update(1) + # continue + + with accelerator.accumulate(unet): + # (B, Nv, 3, H, W) + imgs_in, colors_out, normals_out = batch['imgs_in'], batch['imgs_out'], batch['normals_out'] + ids = batch['id'] + bnm, Nv = imgs_in.shape[:2] + # repeat (2B, Nv, 3, H, W) + imgs_in = torch.cat([imgs_in]*2, dim=0) + imgs_out = torch.cat([normals_out, colors_out], dim=0) + # (B*Nv, 3, H, W) + imgs_in, imgs_out = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W"), rearrange(imgs_out, "B Nv C H W -> (B Nv) C H W") + imgs_in, imgs_out = imgs_in.to(weight_dtype), imgs_out.to(weight_dtype) + + if cfg.with_smpl: + smpl_in = batch['smpl_imgs_in'] + smpl_in = torch.cat([smpl_in]*2, dim=0) + smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W").to(weight_dtype) + else: + smpl_in = None + + prompt_embeddings = torch.cat([batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']], dim=0) + # (B*Nv, N, C) + prompt_embeds = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") + prompt_embeds = prompt_embeds.to(weight_dtype) # BV, L, C + # ------------------------------------project face embed -------------------------------- + if face_proj_model is not None: + face_embeds = batch['face_embed'] + face_embeds = torch.cat([face_embeds]*2, dim=0) + face_embeds = rearrange(face_embeds, "B Nv L C -> (B Nv) L C") + face_embeds = face_embeds.to(weight_dtype) + face_embeds = face_proj_model(face_embeds) + else: + face_embeds = None + # ------------------------------------Encoder input image -------------------------------- + imgs_in_proc = TF.resize(imgs_in, (feature_extractor.crop_size['height'], feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC) + # do the normalization in float32 to preserve precision + imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(weight_dtype) + # (B*Nv, 1024) + image_embeddings = image_encoder(imgs_in_proc).image_embeds + + noise_level = torch.tensor([0], device=accelerator.device) + # (B*Nv, 2048) + image_embeddings = noise_image_embeddings(image_embeddings, noise_level, generator=generator, image_normalizer=image_normlizer, + image_noising_scheduler= image_noising_scheduler).to(weight_dtype) + #--------------------------------------vae input and output latents --------------------------------------- + cond_vae_embeddings = vae.encode(imgs_in * 2.0 - 1.0).latent_dist.mode() # + if cfg.scale_input_latents: + cond_vae_embeddings *= vae.config.scaling_factor + if cfg.with_smpl: + cond_smpl_embeddings = vae.encode(smpl_in * 2.0 - 1.0).latent_dist.mode() + if cfg.scale_input_latents: + cond_smpl_embeddings *= vae.config.scaling_factor + cond_vae_embeddings = torch.cat([cond_vae_embeddings, cond_smpl_embeddings], dim=1) + # sample outputs latent + latents = vae.encode(imgs_out * 2.0 - 1.0).latent_dist.sample() * vae.config.scaling_factor + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # same noise for different views of the same object + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz // cfg.num_views,), device=latents.device) + timesteps = repeat(timesteps, "b -> (b v)", v=cfg.num_views) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if cfg.use_classifier_free_guidance and cfg.condition_drop_rate > 0.: + if cfg.drop_type == 'drop_as_a_whole': + # drop a group of normals and colors as a whole + random_p = torch.rand(bnm, device=latents.device, generator=generator) + + # Sample masks for the conditioning images. + image_mask_dtype = cond_vae_embeddings.dtype + image_mask = 1 - ( + (random_p >= cfg.condition_drop_rate).to(image_mask_dtype) + * (random_p < 3 * cfg.condition_drop_rate).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bnm, 1, 1, 1, 1).repeat(1, Nv, 1, 1, 1) + image_mask = rearrange(image_mask, "B Nv C H W -> (B Nv) C H W") + image_mask = torch.cat([image_mask]*2, dim=0) + # Final image conditioning. + cond_vae_embeddings = image_mask * cond_vae_embeddings + + # Sample masks for the conditioning images. + clip_mask_dtype = image_embeddings.dtype + clip_mask = 1 - ( + (random_p < 2 * cfg.condition_drop_rate).to(clip_mask_dtype) + ) + clip_mask = clip_mask.reshape(bnm, 1, 1).repeat(1, Nv, 1) + clip_mask = rearrange(clip_mask, "B Nv C -> (B Nv) C") + clip_mask = torch.cat([clip_mask]*2, dim=0) + # Final image conditioning. + image_embeddings = clip_mask * image_embeddings + elif cfg.drop_type == 'drop_independent': + random_p = torch.rand(bsz, device=latents.device, generator=generator) + + # Sample masks for the conditioning images. + image_mask_dtype = cond_vae_embeddings.dtype + image_mask = 1 - ( + (random_p >= cfg.condition_drop_rate).to(image_mask_dtype) + * (random_p < 3 * cfg.condition_drop_rate).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + cond_vae_embeddings = image_mask * cond_vae_embeddings + + # Sample masks for the conditioning images. + clip_mask_dtype = image_embeddings.dtype + clip_mask = 1 - ( + (random_p < 2 * cfg.condition_drop_rate).to(clip_mask_dtype) + ) + clip_mask = clip_mask.reshape(bsz, 1, 1) + # Final image conditioning. + image_embeddings = clip_mask * image_embeddings + + # (B*Nv, 8, Hl, Wl) + latent_model_input = torch.cat([noisy_latents, cond_vae_embeddings], dim=1) + model_out = unet( + latent_model_input, + timesteps, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeddings, + dino_feature=face_embeds, + vis_max_min=False + ) + + if cfg.regress_elevation or cfg.regress_focal_length: + model_pred = model_out[0].sample + pose_pred = model_out[1] + else: + model_pred = model_out[0].sample + pose_pred = None + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + # target = noise_scheduler._get_prev_sample(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if cfg.snr_gamma is None: + loss_mse = F.mse_loss(model_pred.float(), target.float(), reduction="mean").to(weight_dtype) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss_mse = loss.mean().to(weight_dtype) + # Gather the losses across all processes for logging (if we use distributed training). + avg_mse_loss = accelerator.gather(loss_mse.repeat(cfg.train_batch_size)).mean() + train_mse_loss += avg_mse_loss.item() / cfg.gradient_accumulation_steps + + if cfg.regress_elevation: + loss_ele = F.mse_loss(pose_pred[:, 0:1], batch['elevations_cond'].to(accelerator.device).float(), reduction="mean").to(weight_dtype) + avg_ele_loss = accelerator.gather(loss_ele.repeat(cfg.train_batch_size)).mean() + train_ele_loss += avg_ele_loss.item() / cfg.gradient_accumulation_steps + if cfg.plot_pose_acc: + ele_acc = torch.sum(torch.abs(pose_pred[:, 0:1] - torch.cat([batch['elevations_cond']]*2)) < 0.01) / pose_pred.shape[0] + else: + loss_ele = torch.tensor(0.0, device=accelerator.device, dtype=weight_dtype) + train_ele_loss += torch.tensor(0.0, device=accelerator.device, dtype=weight_dtype) + if cfg.plot_pose_acc: + ele_acc = torch.tensor(0.0, device=accelerator.device, dtype=weight_dtype) + + if cfg.regress_focal_length: + loss_focal = F.mse_loss(pose_pred[:, 1:], batch['focal_cond'].to(accelerator.device).float(), reduction="mean").to(weight_dtype) + avg_focal_loss = accelerator.gather(loss_focal.repeat(cfg.train_batch_size)).mean() + train_focal_loss += avg_focal_loss.item() / cfg.gradient_accumulation_steps + if cfg.plot_pose_acc: + focal_acc = torch.sum(torch.abs(pose_pred[:, 1:] - torch.cat([batch['focal_cond']]*2)) < 0.01) / pose_pred.shape[0] + else: + loss_focal = torch.tensor(0.0, device=accelerator.device, dtype=weight_dtype) + train_focal_loss += torch.tensor(0.0, device=accelerator.device, dtype=weight_dtype) + if cfg.plot_pose_acc: + focal_acc = torch.tensor(0.0, device=accelerator.device, dtype=weight_dtype) + + # Backpropagate + loss = loss_mse + cfg.elevation_loss_weight * loss_ele + cfg.focal_loss_weight * loss_focal + accelerator.backward(loss) + + if accelerator.sync_gradients and cfg.max_grad_norm is not None: + accelerator.clip_grad_norm_(unet.parameters(), cfg.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if cfg.use_ema: + ema_unet.step(unet) + progress_bar.update(1) + global_step += 1 + # accelerator.log({"train_loss": train_loss}, step=global_step) + accelerator.log({"train_mse_loss": train_mse_loss}, step=global_step) + accelerator.log({"train_ele_loss": train_ele_loss}, step=global_step) + if cfg.plot_pose_acc: + accelerator.log({"ele_acc": ele_acc}, step=global_step) + accelerator.log({"focal_acc": focal_acc}, step=global_step) + accelerator.log({"train_focal_loss": train_focal_loss}, step=global_step) + + train_ele_loss, train_mse_loss, train_focal_loss = 0.0, 0.0, 0.0 + + if global_step % cfg.checkpointing_steps == 0: + if accelerator.is_main_process: + if cfg.checkpoints_total_limit is not None: + checkpoints = os.listdir(model_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= cfg.checkpoints_total_limit: + num_to_remove = len(checkpoints) - cfg.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(model_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(model_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % cfg.validation_steps == 0 or (cfg.validation_sanity_check and global_step == 1): + if accelerator.is_main_process: + if cfg.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + torch.cuda.empty_cache() + log_validation_joint( + validation_dataloader, + vae, + feature_extractor, + image_encoder, + image_normlizer, + image_noising_scheduler, + tokenizer, + text_encoder, + unet, + face_proj_model, + cfg, + accelerator, + weight_dtype, + global_step, + 'validation', + vis_dir + ) + log_validation( + validation_train_dataloader, + vae, + feature_extractor, + image_encoder, + image_normlizer, + image_noising_scheduler, + tokenizer, + text_encoder, + unet, + face_proj_model, + cfg, + accelerator, + weight_dtype, + global_step, + 'validation_train', + vis_dir + ) + + if cfg.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= cfg.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if cfg.use_ema: + ema_unet.copy_to(unet.parameters()) + pipeline = StableUnCLIPImg2ImgPipeline( + image_encoder=image_encoder, feature_extractor=feature_extractor, image_normalizer=image_normlizer, + image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, + vae=vae, unet=unet, + scheduler=DDIMScheduler.from_pretrained_linear(cfg.pretrained_model_name_or_path, subfolder="scheduler"), + **cfg.pipe_kwargs + ) + os.makedirs(os.path.join(model_dir, "ckpts"), exist_ok=True) + pipeline.save_pretrained(os.path.join(model_dir, "ckpts")) + + accelerator.end_training() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + args = parser.parse_args() + schema = OmegaConf.structured(TrainingConfig) + cfg = OmegaConf.load(args.config) + cfg = OmegaConf.merge(schema, cfg) + main(cfg) + + # device = 'cuda' + # ## -------------------------------------- load models -------------------------------- + # image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", revision=cfg.revision) + # feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision) + # image_noising_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_noising_scheduler") + # image_normlizer = StableUnCLIPImageNormalizer.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_normalizer") + + # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="tokenizer", revision=cfg.revision) + # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder='text_encoder', revision=cfg.revision) + + # noise_scheduler = PNDMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler") + # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision) + # unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) + # # unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, + # # **cfg.unet_from_pretrained_kwargs + # # ) + + + # if cfg.enable_xformers_memory_efficient_attention: + # if is_xformers_available(): + # import xformers + + # xformers_version = version.parse(xformers.__version__) + # if xformers_version == version.parse("0.0.16"): + # print( + # "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + # ) + # unet.enable_xformers_memory_efficient_attention() + # print("use xformers.") + + # # from diffusers import StableUnCLIPImg2ImgPipeline + # # -------------------------------------- load pipeline -------------------------------- + # pipe = StableUnCLIPImg2ImgPipeline(feature_extractor=feature_extractor, + # image_encoder=image_encoder, + # image_normalizer=image_normlizer, + # image_noising_scheduler= image_noising_scheduler, + # tokenizer=tokenizer, + # text_encoder=text_encoder, + # unet=unet, + # scheduler=noise_scheduler, + # vae=vae).to('cuda') + + # # -------------------------------------- input -------------------------------- + # # image = Image.open('test/woman.jpg') + # # w, h = image.size + # # image = np.asarray(image)[:w, :w, :] + # # image_in = Image.fromarray(image).resize((768, 768)) + + # im_path = '/mnt/pfs/users/longxiaoxiao/data/test_images/syncdreamer_testset/box.png' + # rgba = np.array(Image.open(im_path)) / 255.0 + # rgb = rgba[:,:,:3] + # alpha = rgba[:,:,3:4] + # bg_color = np.array([1., 1., 1.]) + # image_in = rgb * alpha + (1 - alpha) * bg_color[None,None,:] + # image_in = Image.fromarray((image_in * 255).astype(np.uint8)).resize((768, 768)) + # res = pipe(image_in, 'a rendering image of 3D models, left view, normal map.').images[0] + # res.save("unclip.png") \ No newline at end of file diff --git a/utils/func.py b/utils/func.py new file mode 100644 index 0000000000000000000000000000000000000000..b74484b56b75ffcf2c45ab85cb83f32a1ce74b8e --- /dev/null +++ b/utils/func.py @@ -0,0 +1,369 @@ +import torch +import io +import numpy as np +from pathlib import Path +import re +import trimesh +import imageio +import os +from scipy.spatial.transform import Rotation as R +def to_numpy(*args): + def convert(a): + if isinstance(a,torch.Tensor): + return a.detach().cpu().numpy() + assert a is None or isinstance(a,np.ndarray) + return a + + return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args) + +def save_obj( + vertices, + faces, + filename:Path, + colors=None, + ): + filename = Path(filename) + + bytes_io = io.BytesIO() + if colors is not None: + vertices = torch.cat((vertices, colors),dim=-1) + np.savetxt(bytes_io, vertices.detach().cpu().numpy(), 'v %.4f %.4f %.4f %.4f %.4f %.4f') + else: + np.savetxt(bytes_io, vertices.detach().cpu().numpy(), 'v %.4f %.4f %.4f') + np.savetxt(bytes_io, faces.cpu().numpy() + 1, 'f %d %d %d') #1-based indexing + + obj_path = filename.with_suffix('.obj') + with open(obj_path, 'w') as file: + file.write(bytes_io.getvalue().decode('UTF-8')) + +def save_glb( + filename, + v_pos, + t_pos_idx, + v_nrm=None, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + ) -> str: + + mesh = trimesh.Trimesh( + vertices=v_pos, faces=t_pos_idx, vertex_normals=v_nrm, vertex_colors=v_rgb + ) + # not tested + if v_tex is not None: + mesh.visual = trimesh.visual.TextureVisuals(uv=v_tex) + mesh.export(filename) + + +def load_obj( + filename:Path, + device='cuda', + load_color=False + ) -> tuple[torch.Tensor,torch.Tensor]: + filename = Path(filename) + obj_path = filename.with_suffix('.obj') + with open(obj_path) as file: + obj_text = file.read() + num = r"([0-9\.\-eE]+)" + if load_color: + v = re.findall(f"(v {num} {num} {num} {num} {num} {num})",obj_text) + else: + v = re.findall(f"(v {num} {num} {num})",obj_text) + vertices = np.array(v)[:,1:].astype(np.float32) + all_faces = [] + f = re.findall(f"(f {num} {num} {num})",obj_text) + if f: + all_faces.append(np.array(f)[:,1:].astype(np.int32).reshape(-1,3,1)[...,:1]) + f = re.findall(f"(f {num}/{num} {num}/{num} {num}/{num})",obj_text) + if f: + all_faces.append(np.array(f)[:,1:].astype(np.int32).reshape(-1,3,2)[...,:2]) + f = re.findall(f"(f {num}/{num}/{num} {num}/{num}/{num} {num}/{num}/{num})",obj_text) + if f: + all_faces.append(np.array(f)[:,1:].astype(np.int32).reshape(-1,3,3)[...,:2]) + f = re.findall(f"(f {num}//{num} {num}//{num} {num}//{num})",obj_text) + if f: + all_faces.append(np.array(f)[:,1:].astype(np.int32).reshape(-1,3,2)[...,:1]) + all_faces = np.concatenate(all_faces,axis=0) + all_faces -= 1 #1-based indexing + faces = all_faces[:,:,0] + + vertices = torch.tensor(vertices,dtype=torch.float32,device=device) + faces = torch.tensor(faces,dtype=torch.long,device=device) + + return vertices,faces + +def save_ply( + filename:Path, + vertices:torch.Tensor, #V,3 + faces:torch.Tensor, #F,3 + vertex_colors:torch.Tensor=None, #V,3 + vertex_normals:torch.Tensor=None, #V,3 + ): + + filename = Path(filename).with_suffix('.ply') + vertices,faces,vertex_colors = to_numpy(vertices,faces,vertex_colors) + assert np.all(np.isfinite(vertices)) and faces.min()==0 and faces.max()==vertices.shape[0]-1 + + header = 'ply\nformat ascii 1.0\n' + + header += 'element vertex ' + str(vertices.shape[0]) + '\n' + header += 'property double x\n' + header += 'property double y\n' + header += 'property double z\n' + + if vertex_normals is not None: + header += 'property double nx\n' + header += 'property double ny\n' + header += 'property double nz\n' + + if vertex_colors is not None: + assert vertex_colors.shape[0] == vertices.shape[0] + color = (vertex_colors*255).astype(np.uint8) + header += 'property uchar red\n' + header += 'property uchar green\n' + header += 'property uchar blue\n' + + header += 'element face ' + str(faces.shape[0]) + '\n' + header += 'property list int int vertex_indices\n' + header += 'end_header\n' + + with open(filename, 'w') as file: + file.write(header) + + for i in range(vertices.shape[0]): + s = f"{vertices[i,0]} {vertices[i,1]} {vertices[i,2]}" + if vertex_normals is not None: + s += f" {vertex_normals[i,0]} {vertex_normals[i,1]} {vertex_normals[i,2]}" + if vertex_colors is not None: + s += f" {color[i,0]:03d} {color[i,1]:03d} {color[i,2]:03d}" + file.write(s+'\n') + + for i in range(faces.shape[0]): + file.write(f"3 {faces[i,0]} {faces[i,1]} {faces[i,2]}\n") + full_verts = vertices[faces] #F,3,3 + +def save_images( + images:torch.Tensor, #B,H,W,CH + dir:Path, + ): + dir = Path(dir) + dir.mkdir(parents=True,exist_ok=True) + if images.shape[-1]==1: + images = images.repeat(1,1,1,3) + for i in range(images.shape[0]): + imageio.imwrite(dir/f'{i:02d}.png',(images.detach()[i,:,:,:3]*255).clamp(max=255).type(torch.uint8).cpu().numpy()) +def normalize_scene(vertices): + bbox_min, bbox_max = vertices.min(axis=0)[0], vertices.max(axis=0)[0] + offset = -(bbox_min + bbox_max) / 2 + vertices = vertices + offset + + # print(offset) + dxyz = bbox_max - bbox_min + dist = torch.sqrt(dxyz[0]**2+ dxyz[1]**2+dxyz[2]**2) + scale = 1. / dist + # print(scale) + vertices *= scale + return vertices +def normalize_vertices( + vertices:torch.Tensor, #V,3 + ): + """shift and resize mesh to fit into a unit sphere""" + vertices -= (vertices.min(dim=0)[0] + vertices.max(dim=0)[0]) / 2 + vertices /= torch.norm(vertices, dim=-1).max() + return vertices + +def laplacian( + num_verts:int, + edges: torch.Tensor #E,2 + ) -> torch.Tensor: #sparse V,V + """create sparse Laplacian matrix""" + V = num_verts + E = edges.shape[0] + + #adjacency matrix, + idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E) + ones = torch.ones(2*E, dtype=torch.float32, device=edges.device) + A = torch.sparse.FloatTensor(idx, ones, (V, V)) + + #degree matrix + deg = torch.sparse.sum(A, dim=1).to_dense() + idx = torch.arange(V, device=edges.device) + idx = torch.stack([idx, idx], dim=0) + D = torch.sparse.FloatTensor(idx, deg, (V, V)) + + return D - A + +def _translation(x, y, z, device): + return torch.tensor([[1., 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]],device=device) #4,4 + + +def make_round_views(view_nums, scale=2., device='cuda'): + w2c = [] + ortho_scale = scale/2 + projection = get_ortho_projection_matrix(-ortho_scale, ortho_scale, -ortho_scale, ortho_scale, 0.1, 100) + for i in reversed(range(view_nums)): + tmp = np.eye(4) + rot = R.from_euler('xyz', [0, 360/view_nums*i-180, 0], degrees=True).as_matrix() + rot[:, 2] *= -1 + tmp[:3, :3] = rot + tmp[2, 3] = -1.8 + w2c.append(tmp) + w2c = torch.from_numpy(np.stack(w2c, 0)).float().to(device=device) + projection = torch.from_numpy(projection).float().to(device=device) + return w2c, projection + +def make_star_views(az_degs, pol_degs, scale=2., device='cuda'): + w2c = [] + ortho_scale = scale/2 + projection = get_ortho_projection_matrix(-ortho_scale, ortho_scale, -ortho_scale, ortho_scale, 0.1, 100) + for pol in pol_degs: + for az in az_degs: + tmp = np.eye(4) + rot = R.from_euler('xyz', [0, az-180, 0], degrees=True).as_matrix() + rot[:, 2] *= -1 + rot_z = R.from_euler('xyz', [pol, 0, 0], degrees=True).as_matrix() + rot = rot_z @ rot + tmp[:3, :3] = rot + tmp[2, 3] = -1.8 + w2c.append(tmp) + w2c = torch.from_numpy(np.stack(w2c, 0)).float().to(device=device) + projection = torch.from_numpy(projection).float().to(device=device) + return w2c, projection + +# def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): +# if r is None: +# r = 1/distance +# A = az_count +# P = pol_count +# C = A * P + +# phi = torch.arange(0,A) * (2*torch.pi/A) +# phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone() +# phi_rot[:,0,2,2] = phi.cos() +# phi_rot[:,0,2,0] = -phi.sin() +# phi_rot[:,0,0,2] = phi.sin() +# phi_rot[:,0,0,0] = phi.cos() + +# theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2 +# theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone() +# theta_rot[0,:,1,1] = theta.cos() +# theta_rot[0,:,1,2] = -theta.sin() +# theta_rot[0,:,2,1] = theta.sin() +# theta_rot[0,:,2,2] = theta.cos() + +# mv = torch.empty((C,4,4), device=device) +# mv[:] = torch.eye(4, device=device) +# mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3) +# mv = _translation(0, 0, -distance, device) @ mv +# print(mv[:, :3, 3]) +# return mv, _projection(r, device) + +def get_ortho_projection_matrix(left, right, bottom, top, near, far): + projection_matrix = np.zeros((4, 4), dtype=np.float32) + + projection_matrix[0, 0] = 2.0 / (right - left) + projection_matrix[1, 1] = -2.0 / (top - bottom) # add a negative sign here as the y axis is flipped in nvdiffrast output + projection_matrix[2, 2] = -2.0 / (far - near) + + projection_matrix[0, 3] = -(right + left) / (right - left) + projection_matrix[1, 3] = -(top + bottom) / (top - bottom) + projection_matrix[2, 3] = -(far + near) / (far - near) + projection_matrix[3, 3] = 1.0 + + return projection_matrix + +def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + p = torch.zeros([4,4],device=device) + p[0,0] = 2*n/(r-l) + p[0,2] = (r+l)/(r-l) + p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1) + p[1,2] = (t+b)/(t-b) + p[2,2] = -(f+n)/(f-n) + p[2,3] = -(2*f*n)/(f-n) + p[3,2] = -1 + return p #4,4 +def get_perspective_projection_matrix(fov, aspect=1.0, near=0.1, far=100.0): + tan_half_fovy = torch.tan(torch.deg2rad(fov/2)) + projection_matrix = torch.zeros(4, 4) + projection_matrix[0, 0] = 1 / (aspect * tan_half_fovy) + projection_matrix[1, 1] = -1 / tan_half_fovy + projection_matrix[2, 2] = -(far + near) / (far - near) + projection_matrix[2, 3] = -2 * far * near / (far - near) + projection_matrix[3, 2] = -1 + +def make_sparse_camera(cam_path, scale=4., views=None, device='cuda', mode='ortho'): + + if mode == 'ortho': + ortho_scale = scale/2 + projection = get_ortho_projection_matrix(-ortho_scale, ortho_scale, -ortho_scale, ortho_scale, 0.1, 100) + else: + npy_data = np.load(os.path.join(cam_path, f'{i:03d}.npy'), allow_pickle=True).item() + fov = npy_data['fov'] + projection = get_perspective_projection_matrix(fov, aspect=1.0, near=0.1, far=100.0) + # projection = _projection(r=1/1.5, device=device, n=0.1, f=100) + # for view in ['front', 'right', 'back', 'left']: + # tmp = np.loadtxt(os.path.join(cam_path, f'{view}_RT.txt')) + # rot = tmp[:, [0, 2, 1]] + # rot[:, 2] *= -1 + # tmp[:3, :3] = rot + # tmp = np.concatenate([tmp, np.array([[0, 0, 0, 1]])], axis=0) + # c2w = np.linalg.inv(tmp) + # w2c.append(np.concatenate([tmp, np.array([[0, 0, 0, 1]])], axis=0)) + + ''' + world : + z + | + |____y + / + / + x + camera:(opencv) + z + / + /____x + | + | + y + ''' + if views is None: + views = [0, 1, 2, 4, 6, 7] + w2c = [] + for i in views: + npy_data = np.load(os.path.join(cam_path, f'{i:03d}.npy'), allow_pickle=True).item() + w2c_cv = npy_data['extrinsic'] + w2c_cv = np.concatenate([w2c_cv, np.array([[0, 0, 0, 1]])], axis=0) + c2w_cv = np.linalg.inv(w2c_cv) + + c2w_gl = c2w_cv[[1, 2, 0, 3], :] # invert world coordinate, y->x, z->y, x->z + c2w_gl[:3, 1:3] *= -1 # opencv->opengl, flip y and z + w2c_gl = np.linalg.inv(c2w_gl) + w2c.append(w2c_gl) + + # special pose for test + # w2c = np.eye(4) + # rot = R.from_euler('xyz', [0, 0, 0], degrees=True).as_matrix() + # w2c[:3, :3] = rot + # w2c[2, 3] = -1.5 + w2c = torch.from_numpy(np.stack(w2c, 0)).float().to(device=device) + projection = torch.from_numpy(projection).float().to(device=device) + return w2c, projection + +def make_sphere(level:int=2,radius=1.,device='cuda') -> tuple[torch.Tensor,torch.Tensor]: + sphere = trimesh.creation.icosphere(subdivisions=level, radius=radius, color=np.array([0.5, 0.5, 0.5])) + vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius + + # print(vertices.shape) + # exit() + faces = torch.tensor(sphere.faces, device=device, dtype=torch.long) + colors = torch.tensor(sphere.visual.vertex_colors[..., :3], device=device, dtype=torch.float32) + return vertices, faces, colors \ No newline at end of file diff --git a/utils/igl.py b/utils/igl.py new file mode 100644 index 0000000000000000000000000000000000000000..ef26cc81bc1146ddbd7f1f36fac3bcd64d45068d --- /dev/null +++ b/utils/igl.py @@ -0,0 +1,45 @@ +import torch +import igl +import numpy as np + +@torch.no_grad() +def igl_flips( + vertices:np.array, #V,3 + faces:np.array, #F,3 + target_vertices:np.array, #VT,3 + target_faces:np.array, #FT,3 + )->tuple[np.array,np.array]: + + full_vertices = vertices[faces] #F,C=3,3 + face_centers = full_vertices.mean(axis=1) #F,3 + _,ind,points = igl.point_mesh_squared_distance(face_centers,target_vertices,target_faces) + target_faces = target_faces[ind] #F,3 + corners = target_vertices[target_faces] #F,3,3 + bary = igl.barycentric_coordinates_tri(points,corners[:,0].copy(),corners[:,1].copy(),corners[:,2].copy()) #P,3 + target_normals = igl.per_vertex_normals(target_vertices,target_faces,igl.PER_VERTEX_NORMALS_WEIGHTING_TYPE_AREA) + corner_normals = target_normals[target_faces] #P,3,3 + ref_normals = (bary[:,:,None] * corner_normals).sum(axis=1) #F,3 + face_normals = igl.per_face_normals(vertices,faces,np.array([0,0,0],dtype=np.float32)) #F,3 not normalized + flip = np.sum(ref_normals * face_normals, axis=-1)<0 #F + flipped_area = np.sum(flip * np.linalg.norm(face_normals,axis=-1)) + total_area = np.sum(np.linalg.norm(face_normals,axis=-1)) + ratio = flipped_area / total_area + return flip, ratio + + +@torch.no_grad() +def igl_distance( + vertices:np.array, #V,3 + faces:np.array, #F,3 + target_vertices:np.array, #VT,3 + target_faces:np.array, #FT,3 + ): + + dist1_sq,_,_ = igl.point_mesh_squared_distance(vertices,target_vertices,target_faces) + dist2_sq,_,_ = igl.point_mesh_squared_distance(target_vertices,vertices,faces) + vertex_distance = np.sqrt(dist1_sq) + + rms_distance = ((dist1_sq.mean()+dist2_sq.mean())/2)**.5 + max_distance = max(dist1_sq.max(),dist2_sq.max())**.5 + + return vertex_distance,rms_distance,max_distance \ No newline at end of file diff --git a/utils/img_util.py b/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fe05f3b9c0f1ab7ff0df631e8d972e9b915340e3 --- /dev/null +++ b/utils/img_util.py @@ -0,0 +1,7 @@ +from PIL import Image + +def add_margin(pil_img, color=0, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result \ No newline at end of file diff --git a/utils/mesh_utils.py b/utils/mesh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4892deed0d8ebde205f1e600cb02c2f25ceeede7 --- /dev/null +++ b/utils/mesh_utils.py @@ -0,0 +1,179 @@ +import torch +import numpy as np +import pymeshlab as ml +from pytorch3d.renderer import TexturesVertex +from pytorch3d.structures import Meshes +import torch +import torch.nn.functional as F +import trimesh +from pymeshlab import PercentageValue +import open3d as o3d + + +def tensor2variable(tensor, device): + # [1,23,3,3] + return torch.tensor(tensor, device=device, requires_grad=True) + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def fix_vert_color_glb(mesh_path): + from pygltflib import GLTF2, Material, PbrMetallicRoughness + obj1 = GLTF2().load(mesh_path) + obj1.meshes[0].primitives[0].material = 0 + obj1.materials.append(Material( + pbrMetallicRoughness = PbrMetallicRoughness( + baseColorFactor = [1.0, 1.0, 1.0, 1.0], + metallicFactor = 0., + roughnessFactor = 1.0, + ), + emissiveFactor = [0.0, 0.0, 0.0], + doubleSided = True, + )) + obj1.save(mesh_path) + +def srgb_to_linear(c_srgb): + c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4) + return c_linear.clip(0, 1.) + +def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True): + # convert from pytorch3d meshes to trimesh mesh + vertices = meshes.verts_packed().cpu().float().numpy() + triangles = meshes.faces_packed().cpu().long().numpy() + np_color = meshes.textures.verts_features_packed().cpu().float().numpy() + if save_glb_path.endswith(".glb"): + # rotate 180 along +Y + vertices[:, [0, 2]] = -vertices[:, [0, 2]] + + if apply_sRGB_to_LinearRGB: + np_color = srgb_to_linear(np_color) + assert vertices.shape[0] == np_color.shape[0] + assert np_color.shape[1] == 3 + assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}" + mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color) + mesh.remove_unreferenced_vertices() + # save mesh + mesh.export(save_glb_path) + if save_glb_path.endswith(".glb"): + fix_vert_color_glb(save_glb_path) + print(f"saving to {save_glb_path}") + +def load_mesh_with_trimesh(file_name, file_type=None): + mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type) + if isinstance(mesh, trimesh.Scene): + assert len(mesh.geometry) > 0 + # save to obj first and load again to avoid offset issue + from io import BytesIO + with BytesIO() as f: + mesh.export(f, file_type="obj") + f.seek(0) + mesh = trimesh.load(f, file_type="obj") + if isinstance(mesh, trimesh.Scene): + # we lose texture information here + mesh = trimesh.util.concatenate( + tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) + for g in mesh.geometry.values())) + assert isinstance(mesh, trimesh.Trimesh) + + vertices = torch.from_numpy(mesh.vertices).T + faces = torch.from_numpy(mesh.faces).T + colors = None + if mesh.visual is not None: + if hasattr(mesh.visual, 'vertex_colors'): + colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255. + if colors is None: + # print("Warning: no vertex color found in mesh! Filling it with gray.") + colors = torch.ones_like(vertices) * 0.5 + return vertices, faces, colors + +def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes: + verts = torch.from_numpy(mesh.vertex_matrix()).float() + faces = torch.from_numpy(mesh.face_matrix()).long() + colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float() + textures = TexturesVertex(verts_features=[colors]) + return Meshes(verts=[verts], faces=[faces], textures=textures) + + +def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> ml.Mesh: + colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64) + m1 = ml.Mesh( + vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64), + face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32), + v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64), + v_color_matrix=colors_in) + return m1 + + +def to_pyml_mesh(vertices,faces): + m1 = ml.Mesh( + vertex_matrix=vertices.cpu().float().numpy().astype(np.float64), + face_matrix=faces.cpu().long().numpy().astype(np.int32), + ) + return m1 + + +def to_py3d_mesh(vertices, faces, normals=None): + from pytorch3d.structures import Meshes + from pytorch3d.renderer.mesh.textures import TexturesVertex + mesh = Meshes(verts=[vertices], faces=[faces], textures=None) + if normals is None: + normals = mesh.verts_normals_packed() + # set normals as vertext colors + mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5]) + return mesh + + +def from_py3d_mesh(mesh): + return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed() + + +def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25): + ms = ml.MeshSet() + ms.add_mesh(pyml_mesh, "cube_mesh") + + if apply_smooth: + ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False) + if apply_sub_divide: # 5s, slow + ms.apply_filter("meshing_repair_non_manifold_vertices") + ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces') + ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold)) + return meshlab_mesh_to_py3dmesh(ms.current_mesh()) + + + +def post_process_mesh(mesh, cluster_to_keep=1000): + """ + Post-process a mesh to filter out floaters and disconnected parts + """ + import copy + print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep)) + mesh_0 = copy.deepcopy(mesh) + with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm: + triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles()) + + triangle_clusters = np.asarray(triangle_clusters) + cluster_n_triangles = np.asarray(cluster_n_triangles) + cluster_area = np.asarray(cluster_area) + n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep] + n_cluster = max(n_cluster, 50) # filter meshes smaller than 50 + triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster + mesh_0.remove_triangles_by_mask(triangles_to_remove) + mesh_0.remove_unreferenced_vertices() + mesh_0.remove_degenerate_triangles() + print("num vertices raw {}".format(len(mesh.vertices))) + print("num vertices post {}".format(len(mesh_0.vertices))) + return mesh_0 \ No newline at end of file diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c16fafa2ab8e7b934be711c41aed6e12001444fd --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,54 @@ +import os +from omegaconf import OmegaConf +from packaging import version + + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) +OmegaConf.register_new_resolver('add', lambda a, b: a + b) +OmegaConf.register_new_resolver('sub', lambda a, b: a - b) +OmegaConf.register_new_resolver('mul', lambda a, b: a * b) +OmegaConf.register_new_resolver('div', lambda a, b: a / b) +OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) +OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) +# ======================================================= # + + +def prompt(question): + inp = input(f"{question} (y/n)").lower().strip() + if inp and inp == 'y': + return True + if inp and inp == 'n': + return False + return prompt(question) + + +def load_config(*yaml_files, cli_args=[]): + yaml_confs = [OmegaConf.load(f) for f in yaml_files] + cli_conf = OmegaConf.from_cli(cli_args) + conf = OmegaConf.merge(*yaml_confs, cli_conf) + OmegaConf.resolve(conf) + return conf + + +def config_to_primitive(config, resolve=True): + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path, config): + with open(path, 'w') as fp: + OmegaConf.save(config=config, f=fp) + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def parse_version(ver): + return version.parse(ver) diff --git a/utils/project_mesh.py b/utils/project_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..c09a259393f4e790a83dc8ea9aaa3bbdc9542117 --- /dev/null +++ b/utils/project_mesh.py @@ -0,0 +1,376 @@ +from typing import List +import torch +import numpy as np +from PIL import Image +from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + RasterizationSettings, + TexturesVertex, + FoVPerspectiveCameras, + FoVOrthographicCameras, +) +from pytorch3d.renderer import MeshRasterizer + +def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1): + """ + Renders pix2face of visible faces. + + :param mesh: Pytorch3d.structures.Meshes + :param cameras: pytorch3d.renderer.Cameras + :param H: target image height + :param W: target image width + :param blur_radius: Float distance in the range [0, 2] used to expand the face + bounding boxes for rasterization. Setting blur radius + results in blurred edges around the shape instead of a + hard boundary. Set to 0 for no blur. + :param faces_per_pixel: (int) Number of faces to keep track of per pixel. + We return the nearest faces_per_pixel faces along the z-axis. + """ + # Define the settings for rasterization and shading + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel + ) + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=raster_settings + ) + fragments: Fragments = rasterizer(meshes, cameras=cameras) + return { + "pix_to_face": fragments.pix_to_face[..., 0], + } + +import nvdiffrast.torch as dr + +def _warmup(glctx, device=None): + device = 'cuda' if device is None else device + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + def tensor(*args, **kwargs): + return torch.tensor(*args, device=device, **kwargs) + pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) + tri = tensor([[0, 1, 2]], dtype=torch.int32) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +class Pix2FacesRenderer: + def __init__(self, device="cuda"): + self._glctx = dr.RasterizeCudaContext(device=device) + self.device = device + _warmup(self._glctx, device) + + def transform_vertices(self, meshes: Meshes, cameras: CamerasBase): + vertices = cameras.transform_points_ndc(meshes.verts_padded()) + + perspective_correct = cameras.is_perspective() + znear = cameras.get_znear() + if isinstance(znear, torch.Tensor): + znear = znear.min().item() + z_clip = None if not perspective_correct or znear is None else znear / 2 + + if z_clip: + vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip + vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices) + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32) + return vertices + + def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512): + meshes = meshes.to(self.device) + cameras = cameras.to(self.device) + vertices = self.transform_vertices(meshes, cameras) + faces = meshes.faces_packed().to(torch.int32) + rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4 + pix_to_face = rast_out[..., -1].to(torch.int32) - 1 + return pix_to_face + +pix2faces_renderer = None + +def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024): + # global pix2faces_renderer + # if pix2faces_renderer is None: + # pix2faces_renderer = Pix2FacesRenderer() + pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face'] + # pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution) + + unique_faces = torch.unique(pix_to_face.flatten()) + unique_faces = unique_faces[unique_faces != -1] + return unique_faces + +def project_color(meshes: Meshes, cameras: CamerasBase, image: torch.Tensor, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object. + cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object. + pil_image (PIL.Image.Image): The input image. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + debug (bool, optional): Whether to save debug images. Defaults to False. + + Returns: + dict: A dictionary containing the following keys: + - "new_texture" (TexturesVertex): The updated texture with interpolated colors. + - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected. + - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices. + """ + meshes = meshes.to(device) + cameras = cameras.to(device) + + unique_faces = get_visible_faces(meshes, cameras, resolution=resolution) + + # visible faces + faces_normals = meshes.faces_normals_packed()[unique_faces] + faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True) + world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0] + view_direction = world_points[1] - world_points[0] + view_direction = view_direction / view_direction.norm(dim=0, keepdim=True) + + # find invalid faces + cos_angles = (faces_normals * view_direction).sum(dim=1) + assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}" + selected_faces = unique_faces[cos_angles < -eps] + + # find verts + faces = meshes.faces_packed()[selected_faces] # [N, 3] + verts = torch.unique(faces.flatten()) # [N, 1] + verts_coordinates = meshes.verts_packed()[verts] # [N, 3] + + # compute color + pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points + valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1 dict: + """ + meshes: the mesh with vertex color to be completed. + valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1] + """ + valid_index = valid_index.to(meshes.device) + colors = meshes.textures.verts_features_packed() # [V, 3] + V = colors.shape[0] + + invalid_index = torch.ones_like(colors[:, 0]).bool() # [V] + invalid_index[valid_index] = False + invalid_index = torch.arange(V).to(meshes.device)[invalid_index] + + L = meshes.laplacian_packed() + E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device) + L = L + E + # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device) + # L = L + E + colored_count = torch.ones_like(colors[:, 0]) # [V] + colored_count[invalid_index] = 0 + L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V] + + total_colored = colored_count.sum() + coloring_round = 0 + stage = "uncolored" + from tqdm import tqdm + pbar = tqdm(miniters=100) + while stage == "uncolored" or coloring_round > 0: + new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3] + new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1] + colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index]) + colored_count[invalid_index] = (new_count[:, 0] > 0).float() + + new_total_colored = colored_count.sum() + if new_total_colored > total_colored: + total_colored = new_total_colored + coloring_round += 1 + else: + stage = "colored" + coloring_round -= 1 + pbar.update(1) + if coloring_round > 10000: + print("coloring_round > 10000, break") + break + assert not torch.isnan(colors).any() + meshes.textures = TexturesVertex(verts_features=[colors]) + return meshes + +def multiview_color_projection(meshes: Meshes, image_list: torch.Tensor, cameras_list: List[CamerasBase]=None, camera_focal: float = 2 / 1.35, weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth") -> Meshes: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh. + image_list (PIL.Image.Image): List of images. + cameras_list (list): List of cameras. + camera_focal (float, optional): The focal length of the camera, if cameras_list is not passed. Defaults to 2 / 1.35. + weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1. + complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False. + + Returns: + Meshes: the colored mesh + """ + # 1. preprocess inputs + if image_list is None: + raise ValueError("image_list is None") + if cameras_list is None: + if len(image_list) == 8: + cameras_list = get_8view_cameras(device, focal=camera_focal) + elif len(image_list) == 6: + cameras_list = get_6view_cameras(device, focal=camera_focal) + elif len(image_list) == 4: + cameras_list = get_4view_cameras(device, focal=camera_focal) + elif len(image_list) == 2: + cameras_list = get_2view_cameras(device, focal=camera_focal) + else: + raise ValueError("cameras_list is None, and can not be guessed from image_list") + if weights is None: + if len(image_list) == 8: + weights = [2.0, 0.05, 0.2, 0.02, 1.0, 0.02, 0.2, 0.05] + elif len(image_list) == 6: + weights = [2.0, 0.05, 0.2, 1.0, 0.2, 0.05] + # weights = [2.0, 0.2, 1.5, 1.0, 1.5, 0.2] + elif len(image_list) == 4: + weights = [2.0, 0.2, 1.0, 0.2] + elif len(image_list) == 2: + weights = [1.0, 1.0] + else: + raise ValueError("weights is None, and can not be guessed from image_list") + + # 2. run projection + meshes = meshes.clone().to(device) + assert len(cameras_list) == len(image_list) == len(weights) + original_color = meshes.textures.verts_features_packed() + assert not torch.isnan(original_color).any() + texture_counts = torch.zeros_like(original_color[..., :1]) + texture_values = torch.zeros_like(original_color) + max_texture_counts = torch.zeros_like(original_color[..., :1]) + max_texture_values = torch.zeros_like(original_color) + for camera, image, weight in zip(cameras_list, image_list, weights): + ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha) + if reweight_with_cosangle == "linear": + weight = (ret['cos_angles'].abs() * weight)[:, None] + elif reweight_with_cosangle == "square": + weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None] + if use_alpha: + weight = weight * ret['valid_alpha'] + assert weight.min() > -0.0001 + texture_counts[ret['valid_verts']] += weight + texture_values[ret['valid_verts']] += ret['valid_colors'] * weight + max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']]) + max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight) + + # Method2 + texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values) + if below_confidence_strategy == "smooth": + texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values) + elif below_confidence_strategy == "original": + texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values) + else: + raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported") + assert not torch.isnan(texture_values).any() + meshes.textures = TexturesVertex(verts_features=[texture_values]) + + if complete_unseen: + meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold]) + ret_mesh = meshes.detach() + del meshes + return ret_mesh + +def get_camera(R, T, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'): + if cam_type == 'fov': + camera = FoVPerspectiveCameras(device=R.device, R=R, T=T, fov=fov_in_degrees, degrees=True) + else: + focal_length = 1 / focal_length + camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) + return camera + +def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1): + ret = [] + for azim in azim_list: + R, T = look_at_view_transform(dist, 0, azim) + cameras: OrthographicCameras = get_camera(R, T, focal_length=focal, cam_type='orthogonal').to(device) + ret.append(cameras) + return ret + +def get_8view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 225, 270, 315, 0, 45, 90, 135], device=device, focal=focal) + +def get_6view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], device=device, focal=focal) + +def get_4view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 270, 0, 90], device=device, focal=focal) + +def get_2view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 0], device=device, focal=focal) + +def get_multiple_view_cameras(device, focal=2/1.35, offset=180, num_views=8, dist=1.1): + return get_cameras_list(azim_list = (np.linspace(0, 360, num_views+1)[:-1] + offset) % 360, device=device, focal=focal, dist=dist) + +def align_with_alpha_bbox(source_img, target_img, final_size=1024): + # align source_img with target_img using alpha channel + # source_img and target_img are PIL.Image.Image + source_img = source_img.convert("RGBA") + target_img = target_img.convert("RGBA").resize((final_size, final_size)) + source_np = np.array(source_img) + target_np = np.array(target_img) + source_alpha = source_np[:, :, 3] + target_alpha = target_np[:, :, 3] + bbox_source_min, bbox_source_max = np.argwhere(source_alpha > 0).min(axis=0), np.argwhere(source_alpha > 0).max(axis=0) + bbox_target_min, bbox_target_max = np.argwhere(target_alpha > 0).min(axis=0), np.argwhere(target_alpha > 0).max(axis=0) + source_content = source_np[bbox_source_min[0]:bbox_source_max[0]+1, bbox_source_min[1]:bbox_source_max[1]+1, :] + # resize source_content to fit in the position of target_content + source_content = Image.fromarray(source_content).resize((bbox_target_max[1]-bbox_target_min[1]+1, bbox_target_max[0]-bbox_target_min[0]+1), resample=Image.BICUBIC) + target_np[bbox_target_min[0]:bbox_target_max[0]+1, bbox_target_min[1]:bbox_target_max[1]+1, :] = np.array(source_content) + return Image.fromarray(target_np) + +def load_image_list_from_mvdiffusion(mvdiffusion_path, front_from_pil_or_path=None): + import os + image_list = [] + for dir in ['front', 'front_right', 'right', 'back', 'left', 'front_left']: + image_path = os.path.join(mvdiffusion_path, f"rgb_000_{dir}.png") + pil = Image.open(image_path) + if dir == 'front': + if front_from_pil_or_path is not None: + if isinstance(front_from_pil_or_path, str): + replace_pil = Image.open(front_from_pil_or_path) + else: + replace_pil = front_from_pil_or_path + # align replace_pil with pil using bounding box in alpha channel + pil = align_with_alpha_bbox(replace_pil, pil, final_size=1024) + image_list.append(pil) + return image_list + +def load_image_list_from_img_grid(img_grid_path, resolution = 1024): + img_list = [] + grid = Image.open(img_grid_path) + w, h = grid.size + for row in range(0, h, resolution): + for col in range(0, w, resolution): + img_list.append(grid.crop((col, row, col + resolution, row + resolution))) + return img_list \ No newline at end of file diff --git a/utils/remove_bg.py b/utils/remove_bg.py new file mode 100644 index 0000000000000000000000000000000000000000..fb8f437e57d683a425f3384a27f871bd630951fc --- /dev/null +++ b/utils/remove_bg.py @@ -0,0 +1,18 @@ +import os +from glob import glob +from rembg import remove +from argparse import ArgumentParser +from PIL import Image +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--path', type=str, required=True, help='Path to input images') + args = parser.parse_args() + + imgs = glob(os.path.join(args.path, '*.png')) + glob(os.path.join(args.path, '*.jpg')) + for img in imgs: + path = os.path.dirname(img) + name = os.path.basename(img).split('.')[0] + '_rmbg.png' + + img_np = Image.open(img) + img = remove(img_np) + img.save(os.path.join(args.path, name)) \ No newline at end of file diff --git a/utils/render.py b/utils/render.py new file mode 100644 index 0000000000000000000000000000000000000000..f551e487935ac4bb4437aa57caf59d6a0c15017e --- /dev/null +++ b/utils/render.py @@ -0,0 +1,50 @@ +from matplotlib import image +import nvdiffrast.torch as dr +import torch + +def _warmup(glctx, device): + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + + pos = torch.tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32, device=device) + tri = torch.tensor([[0, 1, 2]], dtype=torch.int32, device=device) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +class NormalsRenderer: + + _glctx:dr.RasterizeGLContext = None + + def __init__( + self, + mv: torch.Tensor, #C,4,4 + proj: torch.Tensor, #C,4,4 + image_size: tuple[int,int], + device: str + ): + self._mvp = proj @ mv #C,4,4 + self._image_size = image_size + # self._glctx = dr.RasterizeGLContext() + self._glctx = dr.RasterizeCudaContext(device=device) + _warmup(self._glctx, device) + + def render(self, + vertices: torch.Tensor, #V,3 float + faces: torch.Tensor, #F,3 long + colors: torch.Tensor = None, #V,3 float + normals: torch.Tensor = None, #V,3 float + return_triangles: bool = False + ) -> torch.Tensor: #C,H,W,4 + + V = vertices.shape[0] + faces = faces.type(torch.int32) + vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4 + vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 + rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 + vert_nrm = (normals+1)/2 if normals is not None else colors + nrm, _ = dr.interpolate(vert_nrm, rast_out, faces) #C,H,W,3 + alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1 + nrm = torch.concat((nrm,alpha),dim=-1) #C,H,W,4 + nrm = dr.antialias(nrm, rast_out, vertices_clip, faces) #C,H,W,4 + if return_triangles: + return nrm, rast_out[..., -1] + return nrm #C,H,W,4 + diff --git a/utils/saving.py b/utils/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..93fdd424ef7de49a7637bd9a4d19f15e647c76ea --- /dev/null +++ b/utils/saving.py @@ -0,0 +1,725 @@ +import json +import os +import re +import shutil + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw +from pytorch_lightning.loggers import WandbLogger + +import lrm +from ..models.mesh import Mesh +from ..utils.typing import * + + +class SaverMixin: + _save_dir: Optional[str] = None + _wandb_logger: Optional[WandbLogger] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + if data.dtype in [torch.float16, torch.bfloat16]: + data = data.float() + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + self._wandb_logger.log_image( + key=name, images=[self.get_save_path(filename)], step=step + ) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range, name, step) + return save_path + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(save_path, img) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + self._wandb_logger.log_image( + key=name, images=[self.get_save_path(filename)], step=step + ) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap, name, step) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + elif isinstance(align, int): + h = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, or int" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h: + w = int(cols[i].shape[1] * h / cols[i].shape[0]) + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + name: Optional[str] = None, + step: Optional[int] = None, + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + if name and self._wandb_logger: + self._wandb_logger.log_image(key=name, images=[save_path], step=step) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(save_path, imgs_full) + return save_path + + def save_data(self, filename, data) -> str: + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + save_path = self.get_save_path(filename) + np.savez(save_path, **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + save_path = self.get_save_path(filename) + np.save(save_path, data) + return save_path + + def save_state_dict(self, filename, data) -> str: + save_path = self.get_save_path(filename) + torch.save(data, save_path) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + if name and self._wandb_logger: + lrm.warn("Wandb logger does not support video logging yet!") + return save_path + + def save_img_sequences( + self, + seq_dir, + matcher, + save_format="mp4", + fps=30, + delete=True, + name: Optional[str] = None, + step: Optional[int] = None, + ): + seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) + for f in os.listdir(seq_dir_): + img_dir_ = os.path.join(seq_dir_, f) + if not os.path.isdir(img_dir_): + continue + try: + self.save_img_sequence( + os.path.join(seq_dir, f), + os.path.join(seq_dir, f), + matcher, + save_format=save_format, + fps=fps, + name=f"{name}_{f}", + step=step, + ) + if delete: + shutil.rmtree(img_dir_) + except: + lrm.warn(f"Video saving for directory {seq_dir_} failed!") + + def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str: + import trimesh + + save_path = self.get_save_path(filename) + v_pos = self.convert_data(v_pos) + t_pos_idx = self.convert_data(t_pos_idx) + mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx) + mesh.export(save_path) + return save_path + + def save_obj( + self, + filename: str, + mesh: Mesh, + save_mat: bool = False, + save_normal: bool = False, + save_uv: bool = False, + save_vertex_color: bool = False, + map_Kd: Optional[Float[Tensor, "H W 3"]] = None, + map_Ks: Optional[Float[Tensor, "H W 3"]] = None, + map_Bump: Optional[Float[Tensor, "H W 3"]] = None, + map_Pm: Optional[Float[Tensor, "H W 1"]] = None, + map_Pr: Optional[Float[Tensor, "H W 1"]] = None, + map_format: str = "jpg", + ) -> List[str]: + save_paths: List[str] = [] + if not filename.endswith(".obj"): + filename += ".obj" + v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( + mesh.t_pos_idx + ) + v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None + if save_normal: + v_nrm = self.convert_data(mesh.v_nrm) + if save_uv: + v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( + mesh.t_tex_idx + ) + if save_vertex_color: + v_rgb = self.convert_data(mesh.v_rgb) + matname, mtllib = None, None + if save_mat: + matname = "default" + mtl_filename = filename.replace(".obj", ".mtl") + mtllib = os.path.basename(mtl_filename) + mtl_save_paths = self._save_mtl( + mtl_filename, + matname, + map_Kd=self.convert_data(map_Kd), + map_Ks=self.convert_data(map_Ks), + map_Bump=self.convert_data(map_Bump), + map_Pm=self.convert_data(map_Pm), + map_Pr=self.convert_data(map_Pr), + map_format=map_format, + ) + save_paths += mtl_save_paths + obj_save_path = self._save_obj( + filename, + v_pos, + t_pos_idx, + v_nrm=v_nrm, + v_tex=v_tex, + t_tex_idx=t_tex_idx, + v_rgb=v_rgb, + matname=matname, + mtllib=mtllib, + ) + save_paths.append(obj_save_path) + return save_paths + + def _save_obj( + self, + filename, + v_pos, + t_pos_idx, + v_nrm=None, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + matname=None, + mtllib=None, + ) -> str: + obj_str = "" + if matname is not None: + obj_str += f"mtllib {mtllib}\n" + obj_str += f"g object\n" + obj_str += f"usemtl {matname}\n" + for i in range(len(v_pos)): + obj_str += f"v {v_pos[i][0]} {v_pos[i][1]} {v_pos[i][2]}" + if v_rgb is not None: + obj_str += f" {v_rgb[i][0]} {v_rgb[i][1]} {v_rgb[i][2]}" + obj_str += "\n" + if v_nrm is not None: + for v in v_nrm: + obj_str += f"vn {v[0]} {v[1]} {v[2]}\n" + if v_tex is not None: + for v in v_tex: + obj_str += f"vt {v[0]} {1.0 - v[1]}\n" + + for i in range(len(t_pos_idx)): + obj_str += "f" + for j in range(3): + obj_str += f" {t_pos_idx[i][j] + 1}/" + if v_tex is not None: + obj_str += f"{t_tex_idx[i][j] + 1}" + obj_str += "/" + if v_nrm is not None: + obj_str += f"{t_pos_idx[i][j] + 1}" + obj_str += "\n" + + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(obj_str) + return save_path + + def _save_mtl( + self, + filename, + matname, + Ka=(0.0, 0.0, 0.0), + Kd=(1.0, 1.0, 1.0), + Ks=(0.0, 0.0, 0.0), + map_Kd=None, + map_Ks=None, + map_Bump=None, + map_Pm=None, + map_Pr=None, + map_format="jpg", + step: Optional[int] = None, + ) -> List[str]: + mtl_save_path = self.get_save_path(filename) + save_paths = [mtl_save_path] + mtl_str = f"newmtl {matname}\n" + mtl_str += f"Ka {Ka[0]} {Ka[1]} {Ka[2]}\n" + if map_Kd is not None: + map_Kd_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_kd.{map_format}" + ) + mtl_str += f"map_Kd texture_kd.{map_format}\n" + self._save_rgb_image( + map_Kd_save_path, + map_Kd, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Kd", + step=step, + ) + save_paths.append(map_Kd_save_path) + else: + mtl_str += f"Kd {Kd[0]} {Kd[1]} {Kd[2]}\n" + if map_Ks is not None: + map_Ks_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_ks.{map_format}" + ) + mtl_str += f"map_Ks texture_ks.{map_format}\n" + self._save_rgb_image( + map_Ks_save_path, + map_Ks, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Ks", + step=step, + ) + save_paths.append(map_Ks_save_path) + else: + mtl_str += f"Ks {Ks[0]} {Ks[1]} {Ks[2]}\n" + if map_Bump is not None: + map_Bump_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_nrm.{map_format}" + ) + mtl_str += f"map_Bump texture_nrm.{map_format}\n" + self._save_rgb_image( + map_Bump_save_path, + map_Bump, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Bump", + step=step, + ) + save_paths.append(map_Bump_save_path) + if map_Pm is not None: + map_Pm_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_metallic.{map_format}" + ) + mtl_str += f"map_Pm texture_metallic.{map_format}\n" + self._save_grayscale_image( + map_Pm_save_path, + map_Pm, + data_range=(0, 1), + cmap=None, + name=f"{matname}_refl", + step=step, + ) + save_paths.append(map_Pm_save_path) + if map_Pr is not None: + map_Pr_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_roughness.{map_format}" + ) + mtl_str += f"map_Pr texture_roughness.{map_format}\n" + self._save_grayscale_image( + map_Pr_save_path, + map_Pr, + data_range=(0, 1), + cmap=None, + name=f"{matname}_Ns", + step=step, + ) + save_paths.append(map_Pr_save_path) + with open(self.get_save_path(filename), "w") as f: + f.write(mtl_str) + return save_paths + + def save_glb( + self, + filename: str, + mesh: Mesh, + save_mat: bool = False, + save_normal: bool = False, + save_uv: bool = False, + save_vertex_color: bool = False, + map_Kd: Optional[Float[Tensor, "H W 3"]] = None, + map_Ks: Optional[Float[Tensor, "H W 3"]] = None, + map_Bump: Optional[Float[Tensor, "H W 3"]] = None, + map_Pm: Optional[Float[Tensor, "H W 1"]] = None, + map_Pr: Optional[Float[Tensor, "H W 1"]] = None, + map_format: str = "jpg", + ) -> List[str]: + save_paths: List[str] = [] + if not filename.endswith(".glb"): + filename += ".glb" + v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( + mesh.t_pos_idx + ) + v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None + if save_normal: + v_nrm = self.convert_data(mesh.v_nrm) + if save_uv: + v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( + mesh.t_tex_idx + ) + if save_vertex_color: + v_rgb = self.convert_data(mesh.v_rgb) + + obj_save_path = self._save_glb( + filename, + v_pos, + t_pos_idx, + v_nrm=v_nrm, + v_tex=v_tex, + t_tex_idx=t_tex_idx, + v_rgb=v_rgb, + ) + save_paths.append(obj_save_path) + return save_paths + + def _save_glb( + self, + filename, + v_pos, + t_pos_idx, + v_nrm=None, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + matname=None, + mtllib=None, + ) -> str: + import trimesh + + mesh = trimesh.Trimesh( + vertices=v_pos, faces=t_pos_idx, vertex_normals=v_nrm, vertex_colors=v_rgb + ) + # not tested + if v_tex is not None: + mesh.visual = trimesh.visual.TextureVisuals(uv=v_tex) + + save_path = self.get_save_path(filename) + mesh.export(save_path) + return save_path + + def save_file(self, filename, src_path, delete=False) -> str: + save_path = self.get_save_path(filename) + shutil.copyfile(src_path, save_path) + if delete: + os.remove(src_path) + return save_path + + def save_json(self, filename, payload) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(json.dumps(payload)) + return save_path \ No newline at end of file diff --git a/utils/smpl_util.py b/utils/smpl_util.py new file mode 100644 index 0000000000000000000000000000000000000000..54dc43e5cc6a30f9c1f2a3c9558db8ecb12d3cf8 --- /dev/null +++ b/utils/smpl_util.py @@ -0,0 +1,1263 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np +import cv2 +import pymeshlab +import torch +import torchvision +import trimesh +import json +from pytorch3d.io import load_obj +import os +from termcolor import colored +import os.path as osp +from scipy.spatial import cKDTree +import _pickle as cPickle +import open3d as o3d + +from pytorch3d.structures import Meshes +import torch.nn.functional as F +# from lib.pymaf.utils.imutils import uncrop +# from lib.common.render_utils import Pytorch3dRasterizer, face_vertices + +from pytorch3d.renderer.mesh import rasterize_meshes +from PIL import Image, ImageFont, ImageDraw +from kaolin.ops.mesh import check_sign +from kaolin.metrics.trianglemesh import point_to_mesh_distance + +from pytorch3d.loss import (mesh_laplacian_smoothing, mesh_normal_consistency) + +# import tinyobjloader + + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def obj_loader(path): + # Create reader. + reader = tinyobjloader.ObjReader() + + # Load .obj(and .mtl) using default configuration + ret = reader.ParseFromFile(path) + + if ret == False: + print("Failed to load : ", path) + return None + + # note here for wavefront obj, #v might not equal to #vt, same as #vn. + attrib = reader.GetAttrib() + verts = np.array(attrib.vertices).reshape(-1, 3) + + shapes = reader.GetShapes() + tri = shapes[0].mesh.numpy_indices().reshape(-1, 9) + faces = tri[:, [0, 3, 6]] + + return verts, faces + + +class HoppeMesh: + + def __init__(self, verts, faces): + ''' + The HoppeSDF calculates signed distance towards a predefined oriented point cloud + http://hhoppe.com/recon.pdf + For clean and high-resolution pcl data, this is the fastest and accurate approximation of sdf + :param points: pts + :param normals: normals + ''' + self.trimesh = trimesh.Trimesh(verts, faces, process=True) + self.verts = np.array(self.trimesh.vertices) + self.faces = np.array(self.trimesh.faces) + self.vert_normals, self.faces_normals = compute_normal( + self.verts, self.faces) + + def contains(self, points): + + labels = check_sign( + torch.as_tensor(self.verts).unsqueeze(0), + torch.as_tensor(self.faces), + torch.as_tensor(points).unsqueeze(0)) + return labels.squeeze(0).numpy() + + def triangles(self): + return self.verts[self.faces] # [n, 3, 3] + + +def tensor2variable(tensor, device): + # [1,23,3,3] + return torch.tensor(tensor, device=device, requires_grad=True) + + +class GMoF(torch.nn.Module): + + def __init__(self, rho=1): + super(GMoF, self).__init__() + self.rho = rho + + def extra_repr(self): + return 'rho = {}'.format(self.rho) + + def forward(self, residual): + dist = torch.div(residual, residual + self.rho**2) + return self.rho**2 * dist + + +def mesh_edge_loss(meshes, target_length: float = 0.0): + """ + Computes mesh edge length regularization loss averaged across all meshes + in a batch. Each mesh contributes equally to the final loss, regardless of + the number of edges per mesh in the batch by weighting each mesh with the + inverse number of edges. For example, if mesh 3 (out of N) has only E=4 + edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to + contribute to the final loss. + + Args: + meshes: Meshes object with a batch of meshes. + target_length: Resting value for the edge length. + + Returns: + loss: Average loss across the batch. Returns 0 if meshes contains + no meshes or all empty meshes. + """ + if meshes.isempty(): + return torch.tensor([0.0], + dtype=torch.float32, + device=meshes.device, + requires_grad=True) + + N = len(meshes) + edges_packed = meshes.edges_packed() # (sum(E_n), 3) + verts_packed = meshes.verts_packed() # (sum(V_n), 3) + edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), ) + num_edges_per_mesh = meshes.num_edges_per_mesh() # N + + # Determine the weight for each edge based on the number of edges in the + # mesh it corresponds to. + # TODO (nikhilar) Find a faster way of computing the weights for each edge + # as this is currently a bottleneck for meshes with a large number of faces. + weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx) + weights = 1.0 / weights.float() + + verts_edges = verts_packed[edges_packed] + v0, v1 = verts_edges.unbind(1) + loss = ((v0 - v1).norm(dim=1, p=2) - target_length)**2.0 + loss_vertex = loss * weights + # loss_outlier = torch.topk(loss, 100)[0].mean() + # loss_all = (loss_vertex.sum() + loss_outlier.mean()) / N + loss_all = loss_vertex.sum() / N + + return loss_all + + +def remesh(obj_path, perc, device): + + ms = pymeshlab.MeshSet() + ms.load_new_mesh(obj_path) + ms.apply_coord_laplacian_smoothing() + ms.meshing_isotropic_explicit_remeshing(targetlen=pymeshlab.PercentageValue(perc), adaptive=True) + # ms.remeshing_isotropic_explicit_remeshing( + # targetlen=pymeshlab.Percentage(perc), adaptive=True) + ms.save_current_mesh(obj_path.replace("recon", "remesh")) + polished_mesh = trimesh.load_mesh(obj_path.replace("recon", "remesh")) + verts_pr = torch.tensor( + polished_mesh.vertices).float().unsqueeze(0).to(device) + faces_pr = torch.tensor(polished_mesh.faces).long().unsqueeze(0).to(device) + + return verts_pr, faces_pr + + +def possion(mesh, obj_path): + + mesh.export(obj_path) + ms = pymeshlab.MeshSet() + ms.load_new_mesh(obj_path) + ms.surface_reconstruction_screened_poisson(depth=10) + ms.set_current_mesh(1) + ms.save_current_mesh(obj_path) + + return trimesh.load(obj_path) + + +def get_mask(tensor, dim): + + mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0 + mask = mask.type_as(tensor) + + return mask + + +def blend_rgb_norm(rgb, norm, mask): + + # [0,0,0] or [127,127,127] should be marked as mask + final = rgb * (1 - mask) + norm * (mask) + + return final.astype(np.uint8) + + +def unwrap(image, data): + + img_uncrop = uncrop( + np.array( + Image.fromarray(image).resize( + data['uncrop_param']['box_shape'][:2])), + data['uncrop_param']['center'], data['uncrop_param']['scale'], + data['uncrop_param']['crop_shape']) + + img_orig = cv2.warpAffine(img_uncrop, + np.linalg.inv(data['uncrop_param']['M'])[:2, :], + data['uncrop_param']['ori_shape'][::-1][1:], + flags=cv2.INTER_CUBIC) + + return img_orig + + +# Losses to smooth / regularize the mesh shape +def update_mesh_shape_prior_losses(mesh, losses): + + # and (b) the edge length of the predicted mesh + losses["edge"]['value'] = mesh_edge_loss(mesh) + # mesh normal consistency + losses["nc"]['value'] = mesh_normal_consistency(mesh) + # mesh laplacian smoothing + losses["laplacian"]['value'] = mesh_laplacian_smoothing(mesh, + method="uniform") + + +def rename(old_dict, old_name, new_name): + new_dict = {} + for key, value in zip(old_dict.keys(), old_dict.values()): + new_key = key if key != old_name else new_name + new_dict[new_key] = old_dict[key] + return new_dict + + +def load_checkpoint(model, cfg): + + model_dict = model.state_dict() + main_dict = {} + normal_dict = {} + + device = torch.device(f"cuda:{cfg['test_gpus'][0]}") + + if os.path.exists(cfg.resume_path) and cfg.resume_path.endswith("ckpt"): + main_dict = torch.load(cfg.resume_path, + map_location=device)['state_dict'] + + main_dict = { + k: v + for k, v in main_dict.items() + if k in model_dict and v.shape == model_dict[k].shape and ( + 'reconEngine' not in k) and ("normal_filter" not in k) and ( + 'voxelization' not in k) + } + print(colored(f"Resume MLP weights from {cfg.resume_path}", 'green')) + + if os.path.exists(cfg.normal_path) and cfg.normal_path.endswith("ckpt"): + normal_dict = torch.load(cfg.normal_path, + map_location=device)['state_dict'] + + for key in normal_dict.keys(): + normal_dict = rename(normal_dict, key, + key.replace("netG", "netG.normal_filter")) + + normal_dict = { + k: v + for k, v in normal_dict.items() + if k in model_dict and v.shape == model_dict[k].shape + } + print(colored(f"Resume normal model from {cfg.normal_path}", 'green')) + + model_dict.update(main_dict) + model_dict.update(normal_dict) + model.load_state_dict(model_dict) + + model.netG = model.netG.to(device) + model.reconEngine = model.reconEngine.to(device) + + model.netG.training = False + model.netG.eval() + + del main_dict + del normal_dict + del model_dict + + return model + + +def read_smpl_constants(folder): + """Load smpl vertex code""" + smpl_vtx_std = np.loadtxt(os.path.join(folder, 'vertices.txt')) + min_x = np.min(smpl_vtx_std[:, 0]) + max_x = np.max(smpl_vtx_std[:, 0]) + min_y = np.min(smpl_vtx_std[:, 1]) + max_y = np.max(smpl_vtx_std[:, 1]) + min_z = np.min(smpl_vtx_std[:, 2]) + max_z = np.max(smpl_vtx_std[:, 2]) + + smpl_vtx_std[:, 0] = (smpl_vtx_std[:, 0] - min_x) / (max_x - min_x) + smpl_vtx_std[:, 1] = (smpl_vtx_std[:, 1] - min_y) / (max_y - min_y) + smpl_vtx_std[:, 2] = (smpl_vtx_std[:, 2] - min_z) / (max_z - min_z) + smpl_vertex_code = np.float32(np.copy(smpl_vtx_std)) + """Load smpl faces & tetrahedrons""" + smpl_faces = np.loadtxt(os.path.join(folder, 'faces.txt'), + dtype=np.int32) - 1 + smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] + + smpl_vertex_code[smpl_faces[:, 1]] + + smpl_vertex_code[smpl_faces[:, 2]]) / 3.0 + smpl_tetras = np.loadtxt(os.path.join(folder, 'tetrahedrons.txt'), + dtype=np.int32) - 1 + + return smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras + + +def surface_field_deformation(xyz, de_nn_verts, de_nn_normals, ori_nn_verts, ori_nn_normals): + ''' + xyz: [B, N, 3] + de_nn_verts: [B, N, 3] + de_nn_normals: [B, N, 3] + ori_nn_verts: [B, N, 3] + ori_nn_normals: [B, N, 3] + ''' + vector=xyz-de_nn_verts # [B, N, 3] + delta=torch.sum(vector*de_nn_normals, dim=-1, keepdim=True)*ori_nn_normals + ori_xyz=ori_nn_verts+delta + + return ori_xyz # the deformed xyz + + +def feat_select(feat, select): + + # feat [B, featx2, N] + # select [B, 1, N] + # return [B, feat, N] + + dim = feat.shape[1] // 2 + idx = torch.tile((1-select), (1, dim, 1))*dim + \ + torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select) + feat_select = torch.gather(feat, 1, idx.long()) + + return feat_select + +def get_visibility_color(xy, z, faces): + """get the visibility of vertices + + Args: + xy (torch.tensor): [N,2] + z (torch.tensor): [N,1] + faces (torch.tensor): [N,3] + size (int): resolution of rendered image + """ + + xyz = torch.cat((xy, -z), dim=1) + xyz = (xyz + 1.0) / 2.0 + faces = faces.long() + + rasterizer = Pytorch3dRasterizer(image_size=2**12) + meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...]) + raster_settings = rasterizer.raster_settings + + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, + ) + + vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :]) + vis_mask = torch.zeros(size=(z.shape[0], 1)) + vis_mask[vis_vertices_id] = 1.0 + + # 新增的部分: 检测边缘像素 + edge_mask = torch.zeros_like(pix_to_face) + offset=1 + for i in range(-1-offset, 2+offset): + for j in range(-1-offset, 2+offset): + if i == 0 and j == 0: + continue + shifted = torch.roll(pix_to_face, shifts=(i,j), dims=(0,1)) + edge_mask = torch.logical_or(edge_mask, shifted == -1) + + # 更新可见性掩码 + edge_faces = torch.unique(pix_to_face[edge_mask]) + edge_vertices = torch.unique(faces[edge_faces]) + vis_mask[edge_vertices] = 0.0 + + return vis_mask + + +def get_visibility(xy, z, faces): + """get the visibility of vertices + + Args: + xy (torch.tensor): [N,2] + z (torch.tensor): [N,1] + faces (torch.tensor): [N,3] + size (int): resolution of rendered image + """ + + xyz = torch.cat((xy, -z), dim=1) + xyz = (xyz + 1.0) / 2.0 + faces = faces.long() + + rasterizer = Pytorch3dRasterizer(image_size=2**12) + meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...]) + raster_settings = rasterizer.raster_settings + + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, + ) + + vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :]) + vis_mask = torch.zeros(size=(z.shape[0], 1)) + vis_mask[vis_vertices_id] = 1.0 + + # print("------------------------\n") + # print(f"keep points : {vis_mask.sum()/len(vis_mask)}") + + return vis_mask + + +def barycentric_coordinates_of_projection(points, vertices): + ''' https://github.com/MPI-IS/mesh/blob/master/mesh/geometry/barycentric_coordinates_of_projection.py + ''' + """Given a point, gives projected coords of that point to a triangle + in barycentric coordinates. + See + **Heidrich**, Computing the Barycentric Coordinates of a Projected Point, JGT 05 + at http://www.cs.ubc.ca/~heidrich/Papers/JGT.05.pdf + + :param p: point to project. [B, 3] + :param v0: first vertex of triangles. [B, 3] + :returns: barycentric coordinates of ``p``'s projection in triangle defined by ``q``, ``u``, ``v`` + vectorized so ``p``, ``q``, ``u``, ``v`` can all be ``3xN`` + """ + #(p, q, u, v) + v0, v1, v2 = vertices[:, 0], vertices[:, 1], vertices[:, 2] + p = points + + q = v0 + u = v1 - v0 + v = v2 - v0 + n = torch.cross(u, v) + s = torch.sum(n * n, dim=1) + # If the triangle edges are collinear, cross-product is zero, + # which makes "s" 0, which gives us divide by zero. So we + # make the arbitrary choice to set s to epsv (=numpy.spacing(1)), + # the closest thing to zero + s[s == 0] = 1e-6 + oneOver4ASquared = 1.0 / s + w = p - q + b2 = torch.sum(torch.cross(u, w) * n, dim=1) * oneOver4ASquared + b1 = torch.sum(torch.cross(w, v) * n, dim=1) * oneOver4ASquared + weights = torch.stack((1 - b1 - b2, b1, b2), dim=-1) + # check barycenric weights + # p_n = v0*weights[:,0:1] + v1*weights[:,1:2] + v2*weights[:,2:3] + return weights + + +def cal_sdf_batch(verts, faces, cmaps, vis, points): + + # verts [B, N_vert, 3] + # faces [B, N_face, 3] + # triangles [B, N_face, 3, 3] + # points [B, N_point, 3] + # cmaps [B, N_vert, 3] + + Bsize = points.shape[0] + + normals = Meshes(verts, faces).verts_normals_padded() + + # SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth + # 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 + # 2. fill mouth holes with 30 more faces + + if verts.shape[1] == 10475: + faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] + mouth_faces = torch.as_tensor( + SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(Bsize, 1, + 1).to(faces.device) + faces = torch.cat([faces, mouth_faces], dim=1) + + triangles = face_vertices(verts, faces) + normals = face_vertices(normals, faces) + cmaps = face_vertices(cmaps, faces) + vis = face_vertices(vis, faces) + + residues, pts_ind, _ = point_to_mesh_distance(points, triangles) + closest_triangles = torch.gather( + triangles, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, + 3)).view(-1, 3, 3) + closest_normals = torch.gather( + normals, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, + 3)).view(-1, 3, 3) + closest_cmaps = torch.gather( + cmaps, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, + 3)).view(-1, 3, 3) + closest_vis = torch.gather(vis, 1, pts_ind[:, :, None, + None].expand(-1, -1, 3, + 1)).view(-1, 3, 1) + bary_weights = barycentric_coordinates_of_projection( + points.view(-1, 3), closest_triangles) + + pts_cmap = (closest_cmaps * bary_weights[:, :, None]).sum(1).unsqueeze(0) + pts_vis = (closest_vis * + bary_weights[:, :, None]).sum(1).unsqueeze(0).ge(1e-1) + pts_norm = (closest_normals * + bary_weights[:, :, None]).sum(1).unsqueeze(0) * torch.tensor( + [-1.0, 1.0, -1.0]).type_as(normals) + pts_norm = F.normalize(pts_norm, dim=2) + pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3)) + + pts_signs = 2.0 * (check_sign(verts, faces[0], points).float() - 0.5) + pts_sdf = (pts_dist * pts_signs).unsqueeze(-1) + + return pts_sdf.view(Bsize, -1, + 1), pts_norm.view(Bsize, -1, 3), pts_cmap.view( + Bsize, -1, 3), pts_vis.view(Bsize, -1, 1) + + +def orthogonal(points, calibrations, transforms=None): + ''' + Compute the orthogonal projections of 3D points into the image plane by given projection matrix + :param points: [B, 3, N] Tensor of 3D points + :param calibrations: [B, 3, 4] Tensor of projection matrix + :param transforms: [B, 2, 3] Tensor of image transform matrix + :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane + ''' + rot = calibrations[:, :3, :3] + trans = calibrations[:, :3, 3:4] + pts = torch.baddbmm(trans, rot, points) # [B, 3, N] + if transforms is not None: + scale = transforms[:2, :2] + shift = transforms[:2, 2:3] + pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) + return pts + + +def projection(points, calib): + if torch.is_tensor(points): + calib = torch.as_tensor(calib) if not torch.is_tensor(calib) else calib + return torch.mm(calib[:3, :3], points.T).T + calib[:3, 3] + else: + return np.matmul(calib[:3, :3], points.T).T + calib[:3, 3] + + +def load_calib(calib_path): + calib_data = np.loadtxt(calib_path, dtype=float) + extrinsic = calib_data[:4, :4] + intrinsic = calib_data[4:8, :4] + calib_mat = np.matmul(intrinsic, extrinsic) + calib_mat = torch.from_numpy(calib_mat).float() + return calib_mat + + +def load_obj_mesh_for_Hoppe(mesh_file): + vertex_data = [] + face_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + vertices = np.array(vertex_data) + faces = np.array(face_data) + faces[faces > 0] -= 1 + + normals, _ = compute_normal(vertices, faces) + + return vertices, normals, faces + + +def load_obj_mesh_with_color(mesh_file): + vertex_data = [] + color_data = [] + face_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + c = list(map(float, values[4:7])) + color_data.append(c) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + vertices = np.array(vertex_data) + colors = np.array(color_data) + faces = np.array(face_data) + faces[faces > 0] -= 1 + + return vertices, colors, faces + + +def load_obj_mesh(mesh_file, with_normal=False, with_texture=False): + vertex_data = [] + norm_data = [] + uv_data = [] + + face_data = [] + face_norm_data = [] + face_uv_data = [] + + if isinstance(mesh_file, str): + f = open(mesh_file, "r") + else: + f = mesh_file + for line in f: + if isinstance(line, bytes): + line = line.decode("utf-8") + if line.startswith('#'): + continue + values = line.split() + if not values: + continue + + if values[0] == 'v': + v = list(map(float, values[1:4])) + vertex_data.append(v) + elif values[0] == 'vn': + vn = list(map(float, values[1:4])) + norm_data.append(vn) + elif values[0] == 'vt': + vt = list(map(float, values[1:3])) + uv_data.append(vt) + + elif values[0] == 'f': + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + f = list( + map(lambda x: int(x.split('/')[0]), + [values[3], values[4], values[1]])) + face_data.append(f) + # tri mesh + else: + f = list(map(lambda x: int(x.split('/')[0]), values[1:4])) + face_data.append(f) + + # deal with texture + if len(values[1].split('/')) >= 2: + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[1]), values[1:4])) + face_uv_data.append(f) + f = list( + map(lambda x: int(x.split('/')[1]), + [values[3], values[4], values[1]])) + face_uv_data.append(f) + # tri mesh + elif len(values[1].split('/')[1]) != 0: + f = list(map(lambda x: int(x.split('/')[1]), values[1:4])) + face_uv_data.append(f) + # deal with normal + if len(values[1].split('/')) == 3: + # quad mesh + if len(values) > 4: + f = list(map(lambda x: int(x.split('/')[2]), values[1:4])) + face_norm_data.append(f) + f = list( + map(lambda x: int(x.split('/')[2]), + [values[3], values[4], values[1]])) + face_norm_data.append(f) + # tri mesh + elif len(values[1].split('/')[2]) != 0: + f = list(map(lambda x: int(x.split('/')[2]), values[1:4])) + face_norm_data.append(f) + + vertices = np.array(vertex_data) + faces = np.array(face_data) + faces[faces > 0] -= 1 + + if with_texture and with_normal: + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) + face_uvs[face_uvs > 0] -= 1 + norms = np.array(norm_data) + if norms.shape[0] == 0: + norms, _ = compute_normal(vertices, faces) + face_normals = faces + else: + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) + face_normals[face_normals > 0] -= 1 + return vertices, faces, norms, face_normals, uvs, face_uvs + + if with_texture: + uvs = np.array(uv_data) + face_uvs = np.array(face_uv_data) - 1 + return vertices, faces, uvs, face_uvs + + if with_normal: + norms = np.array(norm_data) + norms = normalize_v3(norms) + face_normals = np.array(face_norm_data) - 1 + return vertices, faces, norms, face_normals + + return vertices, faces + + +def normalize_v3(arr): + ''' Normalize a numpy array of 3 component vectors shape=(n,3) ''' + lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2) + eps = 0.00000001 + lens[lens < eps] = eps + arr[:, 0] /= lens + arr[:, 1] /= lens + arr[:, 2] /= lens + return arr + + +def compute_normal(vertices, faces): + # Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal + vert_norms = np.zeros(vertices.shape, dtype=vertices.dtype) + # Create an indexed view into the vertex array using the array of three indices for triangles + tris = vertices[faces] + # Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle + face_norms = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0]) + # n is now an array of normals per triangle. The length of each normal is dependent the vertices, + # we need to normalize these, so that our next step weights each normal equally. + normalize_v3(face_norms) + # now we have a normalized array of normals, one per triangle, i.e., per triangle normals. + # But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle, + # the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards. + # The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array + vert_norms[faces[:, 0]] += face_norms + vert_norms[faces[:, 1]] += face_norms + vert_norms[faces[:, 2]] += face_norms + normalize_v3(vert_norms) + + return vert_norms, face_norms + + +def save_obj_mesh(mesh_path, verts, faces): + file = open(mesh_path, 'w') + for v in verts: + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def save_obj_mesh_with_color(mesh_path, verts, faces, colors): + file = open(mesh_path, 'w') + + for idx, v in enumerate(verts): + c = colors[idx] + file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % + (v[0], v[1], v[2], c[0], c[1], c[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def calculate_mIoU(outputs, labels): + + SMOOTH = 1e-6 + + outputs = outputs.int() + labels = labels.int() + + intersection = ( + outputs + & labels).float().sum() # Will be zero if Truth=0 or Prediction=0 + union = (outputs | labels).float().sum() # Will be zzero if both are 0 + + iou = (intersection + SMOOTH) / (union + SMOOTH + ) # We smooth our devision to avoid 0/0 + + thresholded = torch.clamp( + 20 * (iou - 0.5), 0, + 10).ceil() / 10 # This is equal to comparing with thresolds + + return thresholded.mean().detach().cpu().numpy( + ) # Or thresholded.mean() if you are interested in average across the batch + + +def mask_filter(mask, number=1000): + """only keep {number} True items within a mask + + Args: + mask (bool array): [N, ] + number (int, optional): total True item. Defaults to 1000. + """ + true_ids = np.where(mask)[0] + keep_ids = np.random.choice(true_ids, size=number) + filter_mask = np.isin(np.arange(len(mask)), keep_ids) + + return filter_mask + + +def query_mesh(path): + + verts, faces_idx, _ = load_obj(path) + + return verts, faces_idx.verts_idx + + +def add_alpha(colors, alpha=0.7): + + colors_pad = np.pad(colors, ((0, 0), (0, 1)), + mode='constant', + constant_values=alpha) + + return colors_pad + + +def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type='smpl'): + + font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf") + font = ImageFont.truetype(font_path, 30) + grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), + nrow=nrow) + grid_img = Image.fromarray( + ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * + 255.0).astype(np.uint8)) + + # add text + draw = ImageDraw.Draw(grid_img) + grid_size = 512 + if loss is not None: + draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font) + + if type == 'smpl': + for col_id, col_txt in enumerate([ + 'image', 'smpl-norm(render)', 'cloth-norm(pred)', 'diff-norm', + 'diff-mask' + ]): + draw.text((10 + (col_id * grid_size), 5), + col_txt, (255, 0, 0), + font=font) + elif type == 'cloth': + for col_id, col_txt in enumerate( + ['cloth-norm(recon)']): + draw.text((10 + (col_id * grid_size), 5), + col_txt, (255, 0, 0), + font=font) + for col_id, col_txt in enumerate(['0', '90', '180', '270']): + draw.text((10 + (col_id * grid_size), grid_size * 2 + 5), + col_txt, (255, 0, 0), + font=font) + else: + print(f"{type} should be 'smpl' or 'cloth'") + + grid_img = grid_img.resize((grid_img.size[0], grid_img.size[1]), + Image.LANCZOS) + + return grid_img + + +def clean_mesh(verts, faces): + + device = verts.device + + mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), + faces.detach().cpu().numpy()) + mesh_lst = mesh_lst.split(only_watertight=False) + comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst] + mesh_clean = mesh_lst[comp_num.index(max(comp_num))] + + final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device) + final_faces = torch.as_tensor(mesh_clean.faces).int().to(device) + + return final_verts, final_faces + + +def merge_mesh(verts_A, faces_A, verts_B, faces_B, color=False): + + sep_mesh = trimesh.Trimesh(np.concatenate([verts_A, verts_B], axis=0), + np.concatenate( + [faces_A, faces_B + faces_A.max() + 1], + axis=0), + maintain_order=True, + process=False) + if color: + colors = np.ones_like(sep_mesh.vertices) + colors[:verts_A.shape[0]] *= np.array([255.0, 0.0, 0.0]) + colors[verts_A.shape[0]:] *= np.array([0.0, 255.0, 0.0]) + sep_mesh.visual.vertex_colors = colors + + # union_mesh = trimesh.boolean.union([trimesh.Trimesh(verts_A, faces_A), + # trimesh.Trimesh(verts_B, faces_B)], engine='blender') + + return sep_mesh + + +def mesh_move(mesh_lst, step, scale=1.0): + + trans = np.array([1.0, 0.0, 0.0]) * step + + resize_matrix = trimesh.transformations.scale_and_translate( + scale=(scale), translate=trans) + + results = [] + + for mesh in mesh_lst: + mesh.apply_transform(resize_matrix) + results.append(mesh) + + return results + + +def rescale_smpl(fitted_path, scale=100, translate=(0, 0, 0)): + + fitted_body = trimesh.load(fitted_path, + process=False, + maintain_order=True, + skip_materials=True) + resize_matrix = trimesh.transformations.scale_and_translate( + scale=(scale), translate=translate) + + fitted_body.apply_transform(resize_matrix) + + return np.array(fitted_body.vertices) + + +class SMPLX(): + + def __init__(self): + + self.current_dir = "smpl_related" # new smplx file in ECON folder + + self.smpl_verts_path = osp.join(self.current_dir, + "smpl_data/smpl_verts.npy") + self.smpl_faces_path = osp.join(self.current_dir, + "smpl_data/smpl_faces.npy") + self.smplx_verts_path = osp.join(self.current_dir, + "smpl_data/smplx_verts.npy") + self.smplx_faces_path = osp.join(self.current_dir, + "smpl_data/smplx_faces.npy") + self.cmap_vert_path = osp.join(self.current_dir, + "smpl_data/smplx_cmap.npy") + + self.smplx_to_smplx_path = osp.join(self.current_dir, + "smpl_data/smplx_to_smpl.pkl") + + self.smplx_eyeball_fid = osp.join(self.current_dir, + "smpl_data/eyeball_fid.npy") + self.smplx_fill_mouth_fid = osp.join(self.current_dir, + "smpl_data/fill_mouth_fid.npy") + + self.smplx_faces = np.load(self.smplx_faces_path) + self.smplx_verts = np.load(self.smplx_verts_path) + self.smpl_verts = np.load(self.smpl_verts_path) + self.smpl_faces = np.load(self.smpl_faces_path) + + self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid) + self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid) + + self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, 'rb')) + + self.model_dir = osp.join(self.current_dir, "models") + # self.tedra_dir = osp.join(self.current_dir, "../tedra_data") + + + + # copy from econ + self.smplx_flame_vid_path = osp.join( + self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy" + ) + self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl") + # self.smpl_vert_seg_path = osp.join( + # osp.dirname(__file__), "../../lib/common/smpl_vert_segmentation.json" + # ) + self.smpl_vert_seg_path = osp.join(self.current_dir, "smpl_vert_segmentation.json") + self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy") + self.smplx_vertex_lmkid_path = osp.join( + self.current_dir, "smpl_data/smplx_vertex_lmkid.npy" + ) + + self.smplx_vertex_lmkid = np.load(self.smplx_vertex_lmkid_path) + self.smpl_vert_seg = json.load(open(self.smpl_vert_seg_path)) + self.smpl_mano_vid = np.concatenate( + [ + self.smpl_vert_seg["rightHand"], self.smpl_vert_seg["rightHandIndex1"], + self.smpl_vert_seg["leftHand"], self.smpl_vert_seg["leftHandIndex1"] + ] + ) + + self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True) + self.smplx_mano_vid = np.concatenate( + [self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]] + ) + self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True) + self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)] + + + # hands + self.smplx_mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_mano_vid), 1.0 + ) + self.smpl_mano_vertex_mask = torch.zeros(self.smpl_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smpl_mano_vid), 1.0 + ) + + # face + self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_front_flame_vid), 1.0 + ) + self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0 + ) + + + self.ghum_smpl_pairs = torch.tensor( + [ + (0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19), + (15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40), + (23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29), + (32, 32) + ] + ).long() + + # smpl-smplx correspondence + self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73] + self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68] + self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist() + + self.extra_joint_ids = np.array( + [ + 61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72, + 74, 73 + ] + ) + + self.extra_joint_ids += 68 + + self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist()) + + + def cmap_smpl_vids(self, type): + + # keys: + # closest_faces - [6890, 3] with smplx vert_idx + # bc - [6890, 3] with barycentric weights + + cmap_smplx = torch.as_tensor(np.load(self.cmap_vert_path)).float() + if type == 'smplx': + return cmap_smplx + elif type == 'smpl': + bc = torch.as_tensor(self.smplx_to_smpl['bc'].astype(np.float32)) + closest_faces = self.smplx_to_smpl['closest_faces'].astype( + np.int32) + + cmap_smpl = torch.einsum('bij, bi->bj', cmap_smplx[closest_faces], + bc) + + return cmap_smpl + + + +# copy from ECON + +def apply_face_mask(mesh, face_mask): + + mesh.update_faces(face_mask) + mesh.remove_unreferenced_vertices() + + return mesh + + +def apply_vertex_mask(mesh, vertex_mask): + + faces_mask = vertex_mask[mesh.faces].any(dim=1) + mesh = apply_face_mask(mesh, faces_mask) + + return mesh + + +def apply_vertex_face_mask(mesh, vertex_mask, face_mask): + + faces_mask = vertex_mask[mesh.faces].any(dim=1) * torch.tensor(face_mask) + mesh.update_faces(faces_mask) + mesh.remove_unreferenced_vertices() + + return mesh + + +def clean_floats(mesh): + thres = mesh.vertices.shape[0] * 1e-2 + mesh_lst = mesh.split(only_watertight=False) + clean_mesh_lst = [mesh for mesh in mesh_lst if mesh.vertices.shape[0] > thres] + return sum(clean_mesh_lst) + +def isin(input, test_elements): + # 扩展输入和测试元素的维度以进行广播 + input = input.unsqueeze(-1) + test_elements = test_elements.unsqueeze(0) + + # 比较两个张量的元素 + comparison_result = torch.eq(input, test_elements) + + # 沿着新添加的维度进行求和,以检查每个输入元素是否在测试元素中 + isin_result = comparison_result.sum(-1).bool() + + return isin_result + + + +def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=True): + + smpl_tree = cKDTree(smpl_obj.vertices) + SMPL_container = SMPLX() + + from lib.dataset.PointFeat import ECON_PointFeat + + part_extractor = ECON_PointFeat( + torch.tensor(part_mesh.vertices).unsqueeze(0).to(device), + torch.tensor(part_mesh.faces).unsqueeze(0).to(device) + ) + + (part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device)) + + remove_mask = part_dist < thres + + if region == "hand": + _, idx = smpl_tree.query(full_mesh.vertices, k=1) + full_lmkid = SMPL_container.smplx_vertex_lmkid[idx] + remove_mask = torch.logical_and( + remove_mask, + torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0) + ) + + elif region == "face": + _, idx = smpl_tree.query(full_mesh.vertices, k=5) + face_space_mask = isin( + torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid) + ) + remove_mask = torch.logical_and( + remove_mask, + face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0) + ) + + BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1) + full_mesh.update_faces(BNI_part_mask.detach().cpu()) + full_mesh.remove_unreferenced_vertices() + + if clean: + full_mesh = clean_floats(full_mesh) + + return full_mesh + +def keep_largest(mesh): + mesh_lst = mesh.split(only_watertight=False) + keep_mesh = mesh_lst[0] + for mesh in mesh_lst: + if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]: + keep_mesh = mesh + return keep_mesh + + +def poisson(mesh, obj_path, depth=10, decimation=True): + + pcd_path = obj_path[:-4] + "_soups.ply" + assert (mesh.vertex_normals.shape[1] == 3) + mesh.export(pcd_path) + pcl = o3d.io.read_point_cloud(pcd_path) + with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm: + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( + pcl, depth=depth, n_threads=6 + ) + + # only keep the largest component + largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles))) + largest_mesh.export(obj_path) + + if decimation: + # mesh decimation for faster rendering + low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000) + return low_res_mesh + else: + return largest_mesh \ No newline at end of file diff --git a/utils/snapshot.py b/utils/snapshot.py new file mode 100644 index 0000000000000000000000000000000000000000..6e25871ef6e236fbf55b7d5f75b70a50a6ce311b --- /dev/null +++ b/utils/snapshot.py @@ -0,0 +1,28 @@ +from copy import deepcopy +from time import time +from typing import Any +import torch +from dataclasses import dataclass + +from core.opt import MeshOptimizer + + +@dataclass +class Snapshot: + step:int + time:float + vertices:torch.Tensor #V,3 + faces:torch.Tensor #F,3 + optimizer:Any=None + +def snapshot(opt:MeshOptimizer): + opt = deepcopy(opt) + opt._vertices.requires_grad_(False) + + return Snapshot( + step=opt._step, + time=time()-opt._start, + vertices=opt.vertices, + faces=opt.faces, + optimizer=opt, + ) \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..abd31564ea2d820453e9feb1728cc179adf64a05 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,82 @@ +from torchvision.utils import make_grid +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import torch +import math +import cv2 +def make_grid_(imgs, save_file, nrow=10, pad_value=1): + if isinstance(imgs, list): + if isinstance(imgs[0], Image.Image): + imgs = [torch.from_numpy(np.array(img)/255.) for img in imgs] + elif isinstance(imgs[0], np.ndarray): + imgs = [torch.from_numpy(img/255.) for img in imgs] + imgs = torch.stack(imgs, 0).permute(0, 3, 1, 2) + if isinstance(imgs, np.ndarray): + imgs = torch.from_numpy(imgs) + + img_grid = make_grid(imgs, nrow=nrow, padding=2, pad_value=pad_value) + img_grid = img_grid.permute(1, 2, 0).numpy() + img_grid = (img_grid * 255).astype(np.uint8) + img_grid = Image.fromarray(img_grid) + img_grid.save(save_file) + +def draw_caption(img, text, pos, size=100, color=(128, 128, 128)): + draw = ImageDraw.Draw(img) + # font = ImageFont.truetype(size= size) + font = ImageFont.load_default() + font = font.font_variant(size=size) + draw.text(pos, text, color, font=font) + return img + + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + +def resize_img(input_image, max_side=1280, min_side=1024, size=None, + pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image \ No newline at end of file diff --git a/utils/video_utils.py b/utils/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23c9ca5fa7d7dfab74ae2ff4f0dce5914ba09b67 --- /dev/null +++ b/utils/video_utils.py @@ -0,0 +1,44 @@ +from moviepy.editor import VideoFileClip, VideoClip, ImageClip, CompositeVideoClip, ImageSequenceClip +from moviepy.editor import VideoFileClip, clips_array +import os +import argparse + +def write_video(vclip, fps=25, save_path='res.mp4'): + if isinstance(vclip, list): + vclip = ImageSequenceClip(vclip, fps=fps) + vclip.write_videofile(save_path, codec="libx264") + +def load_video(vpath): + return VideoFileClip(vpath) + +def concat_video_clips(clips, videos_per_row=3): + if isinstance(clips[0], str): + clips = [VideoFileClip(v) for v in clips] + elif not isinstance(clips[0], VideoFileClip): + print(f'Find {len(clips)} clips') + + min_duration = min(clip.duration for clip in clips) + clips = [clip.subclip(0, min_duration) for clip in clips] + rows = [clips[i:i + videos_per_row] for i in range(0, len(clips), videos_per_row)] + final_clip = clips_array(rows) + return final_clip + +def concat_video_and_frames(vpath, frames): + vclip1 = VideoClip(vpath) + vclip2 = ImageSequenceClip(frames, fps=25) + clips = concat_video_clips([vclip1, vclip2]) + return clips + +def concat_img_video(vclip, img): + if isinstance(vclip, str): + vclip = VideoFileClip(vclip) + + image_clip = ImageClip(img).set_duration(vclip.duration) + image_clip = image_clip.resize(height=vclip.size[1]) + total_width = vclip.size[0] + image_clip.size[0] + total_height = max(vclip.size[1], image_clip.size[1]) + composite_clip = CompositeVideoClip([ + image_clip.set_position('left'), vclip.set_position((image_clip.size[0], 0)), ], + size=(total_width, total_height)) + return composite_clip + diff --git a/utils/view.py b/utils/view.py new file mode 100644 index 0000000000000000000000000000000000000000..b249cf62872256e7f997831c142dd4b1011edf97 --- /dev/null +++ b/utils/view.py @@ -0,0 +1,252 @@ +from copy import deepcopy +from re import S +from typing import Union +import torch +import open3d as o3d +import open3d.visualization.gui as gui +import open3d.visualization.rendering as rendering +from core.opt import MeshOptimizer +import numpy as np +from util.func import to_numpy + +from util.snapshot import Snapshot + + +class Viewer: + def __init__(self, + target_vertices:torch.Tensor, #V,3 + target_faces:torch.Tensor, #F,3 + snapshots:list[Snapshot], + vertex_colors:dict[str,list[np.array]] + ): + self._target_vertices = target_vertices + self._target_faces = target_faces + self._snapshots = snapshots + self._vertex_colors = vertex_colors + + self._window_o3 = gui.Application.instance.create_window("Continuous Remeshing",1000,800) + self._window_o3.set_on_layout(self._layout) + self._scene_widget = gui.SceneWidget() + self._scene_widget.scene = rendering.Open3DScene(self._window_o3.renderer) + self._scene_widget.scene.set_background([.5, .5, .5, 1]) + bbox = o3d.geometry.AxisAlignedBoundingBox([-1, -1, -1], [1, 1, 1]) + self._scene_widget.setup_camera(60, bbox, [0, 0, 0]) + self._window_o3.add_child(self._scene_widget) + self._scene_widget.set_on_mouse(self._on_mouse) + + #lights + self._scene_widget.scene.scene.enable_sun_light(False) + self._scene_widget.scene.scene.set_indirect_light(gui.Application.instance.resource_path + '/park2') + self._scene_widget.scene.scene.enable_indirect_light(True) + self._scene_widget.scene.scene.set_indirect_light_intensity(45000) + self._scene_widget.scene.show_skybox(False) + + #right panel + margins = gui.Margins(*[self._window_o3.theme.default_margin]*4) + spacing = self._window_o3.theme.default_layout_spacing + self._right_panel = gui.Vert(spacing,margins) + + def make_checkbox(name,checked): + checkbox = gui.Checkbox(name) + checkbox.checked = checked + checkbox.set_on_checked(lambda *args:self._update()) + self._right_panel.add_child(checkbox) + return checkbox + + self._mesh_checkbox = make_checkbox("Show Mesh",True) + self._colorbox = gui.Combobox() + for item in ['Gray','Relative Velocity nu','Reference Edge Length l_ref',*self._vertex_colors.keys()]: + self._colorbox.add_item(item) + self._colorbox.set_on_selection_changed(lambda *args:self._update()) + self._right_panel.add_child(self._colorbox) + + self._clim_slider = gui.Slider(gui.Slider.DOUBLE) + self._clim_slider.double_value = .2 + self._clim_slider.set_limits(1e-3, 1) + self._clim_slider.set_on_value_changed(lambda *args:self._update()) + self._right_panel.add_child(self._clim_slider) + + self._edges_checkbox = make_checkbox("Show Edges",True) + self._target_mesh_checkbox = make_checkbox("Show Target Mesh",False) + self._right_panel.add_child(gui.Label('Ctrl-Click Mesh For Plot!')) + self._target_edges_checkbox = make_checkbox("Show Target Edges",False) + self._positions_checkbox = make_checkbox("Plot Positions",False) + self._gradients_checkbox = make_checkbox("Plot Gradients",False) + self._m1_checkbox = make_checkbox("Plot m1",False) + self._m2_checkbox = make_checkbox("Plot m2",False) + self._nu_checkbox = make_checkbox("Plot nu",True) + self._lref_checkbox = make_checkbox("Plot l_ref",True) + self._window_o3.add_child(self._right_panel) + + #bottom panel + self._bottom_panel = gui.VGrid(cols=2,spacing=spacing) + self._snapshot_slider = gui.Slider(gui.Slider.INT) + self._snapshot_slider.int_value = len(self._snapshots)-1 + self._snapshot_slider.set_limits(0, len(self._snapshots)-1) + self._snapshot_slider.set_on_value_changed(lambda *args:self._update()) + self._bottom_panel.add_child(self._snapshot_slider) + self._window_o3.add_child(self._bottom_panel) + + self._update() + + def _update(self): + snapshot = self._snapshots[self._snapshot_slider.int_value] + + self._scene_widget.scene.clear_geometry() + + self._scene_widget.scene.show_axes(True) + + MaterialType = rendering.MaterialRecord if hasattr(rendering,'MaterialRecord') else rendering.Material + + def add_mesh(name,color,vertices,faces,show_mesh,show_edges,vertex_colors=None): + vertices_np = vertices.detach().cpu().numpy() + vertices_o3 = o3d.utility.Vector3dVector(vertices_np) + faces_o3 = o3d.utility.Vector3iVector(faces.type(torch.int32).cpu().numpy()) + triangleMesh = o3d.geometry.TriangleMesh(vertices_o3,faces_o3) + triangleMesh.compute_vertex_normals() + if vertex_colors is not None: + vertex_colors_np = to_numpy(vertex_colors) + triangleMesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors_np) + + if show_mesh: + material = MaterialType() + if vertex_colors is None: + material.shader = "defaultLit" + material.base_color = color + self._scene_widget.scene.add_geometry(f"{name}_triangleMesh", triangleMesh, material) + + if show_edges: + edges_material = MaterialType() + edges_material.base_color = [0,0,0,1] + edges_material.shader = "unlitLine" + edges_lineset = o3d.geometry.LineSet.create_from_triangle_mesh(triangleMesh) + edges_lineset.points = o3d.utility.Vector3dVector(vertices_np + 1e-4 * np.asarray(triangleMesh.vertex_normals)) + self._scene_widget.scene.add_geometry(f"{name}_edges_lineset", edges_lineset, edges_material) + + clim = self._clim_slider.double_value + if self._colorbox.selected_text=='Relative Velocity nu' and isinstance(snapshot.optimizer, MeshOptimizer): + vertex_colors = snapshot.optimizer._nu + elif self._colorbox.selected_text=='Reference Edge Length l_ref' and isinstance(snapshot.optimizer, MeshOptimizer): + vertex_colors = snapshot.optimizer._ref_len + elif self._colorbox.selected_text in self._vertex_colors.keys(): + vertex_colors = self._vertex_colors[self._colorbox.selected_text][self._snapshot_slider.int_value] + else: + vertex_colors = None + + if vertex_colors is not None: + c = (to_numpy(vertex_colors) / clim).clip(0,1) + vertex_colors = np.stack((c,1-c,np.zeros_like(c)),axis=-1) + + add_mesh("mesh",[.5,.5,.5,1],snapshot.vertices,snapshot.faces,self._mesh_checkbox.checked,self._edges_checkbox.checked, vertex_colors) + add_mesh("target",[.5,.5,1,1],self._target_vertices,self._target_faces,self._target_mesh_checkbox.checked,self._target_edges_checkbox.checked) + + def _layout(self,layout_context): + r = self._window_o3.content_rect + + h = self._bottom_panel.calc_preferred_size(layout_context, gui.Widget.Constraints()).height + self._bottom_panel.frame = gui.Rect(0, r.height - h, r.width, h) + r.height -= h + + w = 250 + self._right_panel.frame = gui.Rect(r.width - w, 0, w, r.height) + r.width -= w + + self._scene_widget.frame = r + + def _on_mouse(self, event): + if event.type == gui.MouseEvent.Type.BUTTON_DOWN and event.is_modifier_down(gui.KeyModifier.CTRL): + self._hit_test(event) + return gui.Widget.EventCallbackResult.HANDLED + return gui.Widget.EventCallbackResult.IGNORED + + def _hit_test(self,event): + def depth_callback(depth_image): + f = self._scene_widget.frame + depth = np.asarray(depth_image)[event.y - f.y, event.x - f.x] + if depth == 1.0: # clicked on nothing (i.e. the far plane) + return + #need to flip y https://github.com/isl-org/Open3D/issues/4244 + pos = self._scene_widget.scene.camera.unproject(event.x - f.x, f.height - event.y, depth, f.width, f.height) + + opt = self._snapshots[self._snapshot_slider.int_value].optimizer + vertex = (opt.vertices.cpu() - torch.tensor(pos)).norm(dim=-1).argmin().item() + + gui.Application.instance.post_to_main_thread(self._window_o3, lambda:self._on_click(pos,vertex)) + + self._scene_widget.scene.scene.render_to_depth_image(depth_callback) + + def _on_click(self,pos,vertex): + self._show_plot(vertex) + + def _show_plot(self,vertex): + ind = self._snapshot_slider.int_value + device = self._snapshots[0].vertices.device + vert_ind = torch.zeros(len(self._snapshots),dtype=torch.long,device=device) + vert_ind[ind] = vertex + + def trace(i): + nonlocal cur_pos + vert_ind[i] = (self._snapshots[i].vertices - cur_pos).norm(dim=-1).argmin(dim=0) + cur_pos = self._snapshots[i].vertices[vert_ind[i]] + + cur_pos = self._snapshots[ind].vertices[vertex] + for i in range(ind-1,-1,-1): + trace(i) + + cur_pos = self._snapshots[ind].vertices[vertex] + for i in range(ind+1,len(self._snapshots)): + trace(i) + + dims = slice(None,None) + + grad_scale = 100 + + from cycler import cycler + import matplotlib.pyplot as plt + plt.gca().set_prop_cycle(cycler(linestyle=['-', '--', ':'][dims])) + + def extract(prop): + values = [prop(self._snapshots[i].optimizer,vert_ind[i]) for i in range(0,len(vert_ind))] + if isinstance(values[0],torch.Tensor): + values = torch.stack(values).cpu() + return values + + s = [s.optimizer._step for s in self._snapshots] + + if self._positions_checkbox.checked: + plt.plot(s,extract(lambda opt,v:opt.vertices[v,dims]),'b',label='pos') + if self._gradients_checkbox.checked: + plt.plot(s,grad_scale*extract(lambda opt,v:opt.vertices.grad[v,dims]),'k',label='grad') + + m1 = extract(lambda opt,v:opt._m1[v]) + m2 = extract(lambda opt,v:opt._m2[v]) + velocity = m1 / m2[:,None].sqrt().add_(1e-8) #V,3 + speed = velocity.norm(dim=-1) + if self._m1_checkbox.checked: + plt.plot(s,grad_scale*extract(lambda opt,v:opt._m1[v,dims]),'r',label='m1') + if self._m2_checkbox.checked: + plt.plot(s,grad_scale*extract(lambda opt,v:opt._m2[v].sqrt()),'-m',label='m2.sqrt()') + if self._nu_checkbox.checked: + plt.plot(s,speed,color='orange',label='speed') + plt.plot(s,extract(lambda opt,v:opt._nu[v]),'-c',label='nu') + if self._lref_checkbox.checked: + plt.plot(s,extract(lambda opt,v:opt._ref_len[v]),color='gray',label='l_ref') + + plt.axvline(x=ind, color='k') + plt.legend() + plt.grid() + plt.show() + + +def show( + target_vertices:torch.Tensor, #V,3 + target_faces:torch.Tensor, #F,3 + snapshots:list[Snapshot], + vertex_colors:dict[str,list[np.array]]={} + ): + for vc in vertex_colors.values(): + assert [c.shape[0] for c in vc] == [s.vertices.shape[0] for s in snapshots] + + gui.Application.instance.initialize() + viewer = Viewer(target_vertices,target_faces,snapshots,vertex_colors) + gui.Application.instance.run() \ No newline at end of file