diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c18dd8d83ceed1806b50b0aaa46beb7e335fff13 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3854349ba00cabf62532fbba9515e00696b12d81 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black diff --git a/README.md b/README.md index b0cefcedb8c4abedd9c625e8748ceeb2c027f226..edaea2acb3df2de920bcba0df26525be58558230 100644 --- a/README.md +++ b/README.md @@ -13,4 +13,4 @@ short_description: Scalable and Versatile 3D Generation from images Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference -Paper: https://huggingface.co/papers/2412.01506 \ No newline at end of file +Paper: https://huggingface.co/papers/2412.01506 diff --git a/app.py b/app.py index dac80298545e5b042dbbf46224fa25b7ad8bcf18..0ce920c56551aa919d79be9fc93543198e5d3e52 100644 --- a/app.py +++ b/app.py @@ -4,7 +4,8 @@ from gradio_litmodel3d import LitModel3D import os import shutil -os.environ['SPCONV_ALGO'] = 'native' + +os.environ["SPCONV_ALGO"] = "native" from typing import * import torch import numpy as np @@ -17,15 +18,24 @@ from trellis.utils import render_utils, postprocessing_utils MAX_SEED = np.iinfo(np.int32).max -TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") os.makedirs(TMP_DIR, exist_ok=True) +pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +pipeline.cuda() +try: + pipeline.preprocess_image( + Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) + ) # Preload rembg +except: + pass + def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) - - + + def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(user_dir) @@ -48,10 +58,10 @@ def preprocess_image(image: Image.Image) -> Image.Image: def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: """ Preprocess a list of input images. - + Args: images (List[Tuple[Image.Image, str]]): The input images. - + Returns: List[Image.Image]: The preprocessed images. """ @@ -62,41 +72,41 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: return { - 'gaussian': { + "gaussian": { **gs.init_params, - '_xyz': gs._xyz.cpu().numpy(), - '_features_dc': gs._features_dc.cpu().numpy(), - '_scaling': gs._scaling.cpu().numpy(), - '_rotation': gs._rotation.cpu().numpy(), - '_opacity': gs._opacity.cpu().numpy(), + "_xyz": gs._xyz.cpu().numpy(), + "_features_dc": gs._features_dc.cpu().numpy(), + "_scaling": gs._scaling.cpu().numpy(), + "_rotation": gs._rotation.cpu().numpy(), + "_opacity": gs._opacity.cpu().numpy(), }, - 'mesh': { - 'vertices': mesh.vertices.cpu().numpy(), - 'faces': mesh.faces.cpu().numpy(), + "mesh": { + "vertices": mesh.vertices.cpu().numpy(), + "faces": mesh.faces.cpu().numpy(), }, } - - + + def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: gs = Gaussian( - aabb=state['gaussian']['aabb'], - sh_degree=state['gaussian']['sh_degree'], - mininum_kernel_size=state['gaussian']['mininum_kernel_size'], - scaling_bias=state['gaussian']['scaling_bias'], - opacity_bias=state['gaussian']['opacity_bias'], - scaling_activation=state['gaussian']['scaling_activation'], + aabb=state["gaussian"]["aabb"], + sh_degree=state["gaussian"]["sh_degree"], + mininum_kernel_size=state["gaussian"]["mininum_kernel_size"], + scaling_bias=state["gaussian"]["scaling_bias"], + opacity_bias=state["gaussian"]["opacity_bias"], + scaling_activation=state["gaussian"]["scaling_activation"], ) - gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') - gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') - gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') - gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') - gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') - + gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda") + gs._features_dc = torch.tensor(state["gaussian"]["_features_dc"], device="cuda") + gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda") + gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda") + gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda") + mesh = edict( - vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), - faces=torch.tensor(state['mesh']['faces'], device='cuda'), + vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"), + faces=torch.tensor(state["mesh"]["faces"], device="cuda"), ) - + return gs, mesh @@ -170,12 +180,14 @@ def image_to_3d( }, mode=multiimage_algo, ) - video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] - video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] - video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] - video_path = os.path.join(user_dir, 'sample.mp4') + video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"] + video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"] + video = [ + np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video)) + ] + video_path = os.path.join(user_dir, "sample.mp4") imageio.mimsave(video_path, video, fps=15) - state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + state = pack_state(outputs["gaussian"][0], outputs["mesh"][0]) torch.cuda.empty_cache() return state, video_path @@ -200,8 +212,10 @@ def extract_glb( """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, mesh = unpack_state(state) - glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) - glb_path = os.path.join(user_dir, 'sample.glb') + glb = postprocessing_utils.to_glb( + gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False + ) + glb_path = os.path.join(user_dir, "sample.glb") glb.export(glb_path) torch.cuda.empty_cache() return glb_path, glb_path @@ -220,19 +234,21 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, _ = unpack_state(state) - gaussian_path = os.path.join(user_dir, 'sample.ply') + gaussian_path = os.path.join(user_dir, "sample.ply") gs.save_ply(gaussian_path) torch.cuda.empty_cache() return gaussian_path, gaussian_path def prepare_multi_example() -> List[Image.Image]: - multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) + multi_case = list( + set([i.split("_")[0] for i in os.listdir("assets/example_multi_image")]) + ) images = [] for case in multi_case: _images = [] for i in range(1, 4): - img = Image.open(f'assets/example_multi_image/{case}_{i}.png') + img = Image.open(f"assets/example_multi_image/{case}_{i}.png") W, H = img.size img = img.resize((int(W / H * 512), 512)) _images.append(np.array(img)) @@ -246,71 +262,113 @@ def split_image(image: Image.Image) -> List[Image.Image]: """ image = np.array(image) alpha = image[..., 3] - alpha = np.any(alpha>0, axis=0) + alpha = np.any(alpha > 0, axis=0) start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() images = [] for s, e in zip(start_pos, end_pos): - images.append(Image.fromarray(image[:, s:e+1])) + images.append(Image.fromarray(image[:, s : e + 1])) return [preprocess_image(image) for image in images] with gr.Blocks(delete_cache=(600, 600)) as demo: - gr.Markdown(""" + gr.Markdown( + """ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/) * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. - + ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction. - """) - + """ + ) + with gr.Row(): with gr.Column(): with gr.Tabs() as input_tabs: with gr.Tab(label="Single Image", id=0) as single_image_input_tab: - image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) + image_prompt = gr.Image( + label="Image Prompt", + format="png", + image_mode="RGBA", + type="pil", + height=300, + ) with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: - multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) - gr.Markdown(""" - Input different views of the object in separate images. - + multiimage_prompt = gr.Gallery( + label="Image Prompt", + format="png", + type="pil", + height=300, + columns=3, + ) + gr.Markdown( + """ + Input different views of the object in separate images. + *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* - """) - + """ + ) + with gr.Accordion(label="Generation Settings", open=False): seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): - ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) - ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + ss_guidance_strength = gr.Slider( + 0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1 + ) + ss_sampling_steps = gr.Slider( + 1, 50, label="Sampling Steps", value=12, step=1 + ) gr.Markdown("Stage 2: Structured Latent Generation") with gr.Row(): - slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) - slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) - multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") + slat_guidance_strength = gr.Slider( + 0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1 + ) + slat_sampling_steps = gr.Slider( + 1, 50, label="Sampling Steps", value=12, step=1 + ) + multiimage_algo = gr.Radio( + ["stochastic", "multidiffusion"], + label="Multi-image Algorithm", + value="stochastic", + ) generate_btn = gr.Button("Generate") - + with gr.Accordion(label="GLB Extraction Settings", open=False): - mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.0, step=0.01) - texture_size = gr.Slider(512, 2048, label="Texture Size", value=2048, step=512) - + mesh_simplify = gr.Slider( + 0.0, 0.98, label="Simplify", value=0.0, step=0.01 + ) + texture_size = gr.Slider( + 512, 2048, label="Texture Size", value=2048, step=512 + ) + with gr.Row(): extract_glb_btn = gr.Button("Extract GLB", interactive=False) extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) - gr.Markdown(""" + gr.Markdown( + """ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* - """) + """ + ) with gr.Column(): - video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) - model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) - + video_output = gr.Video( + label="Generated 3D Asset", autoplay=True, loop=True, height=300 + ) + model_output = LitModel3D( + label="Extracted GLB/Gaussian", exposure=10.0, height=300 + ) + with gr.Row(): - download_glb = gr.DownloadButton(label="Download GLB", interactive=False) - download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) - + download_glb = gr.DownloadButton( + label="Download GLB", interactive=False + ) + download_gs = gr.DownloadButton( + label="Download Gaussian", interactive=False + ) + is_multiimage = gr.State(False) output_buf = gr.State() @@ -318,7 +376,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo: with gr.Row() as single_image_example: examples = gr.Examples( examples=[ - f'assets/example_image/{image}' + f"assets/example_image/{image}" for image in os.listdir("assets/example_image") ], inputs=[image_prompt], @@ -340,16 +398,20 @@ with gr.Blocks(delete_cache=(600, 600)) as demo: # Handlers demo.load(start_session) demo.unload(end_session) - + single_image_input_tab.select( - lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]), - outputs=[is_multiimage, single_image_example, multiimage_example] + lambda: tuple( + [False, gr.Row.update(visible=True), gr.Row.update(visible=False)] + ), + outputs=[is_multiimage, single_image_example, multiimage_example], ) multiimage_input_tab.select( - lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]), - outputs=[is_multiimage, single_image_example, multiimage_example] + lambda: tuple( + [True, gr.Row.update(visible=False), gr.Row.update(visible=True)] + ), + outputs=[is_multiimage, single_image_example, multiimage_example], ) - + image_prompt.upload( preprocess_image, inputs=[image_prompt], @@ -361,13 +423,19 @@ with gr.Blocks(delete_cache=(600, 600)) as demo: outputs=[multiimage_prompt], ) - generate_btn.click( - get_seed, - inputs=[randomize_seed, seed], - outputs=[seed], - ).then( + generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed],).then( image_to_3d, - inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], + inputs=[ + image_prompt, + multiimage_prompt, + is_multiimage, + seed, + ss_guidance_strength, + ss_sampling_steps, + slat_guidance_strength, + slat_sampling_steps, + multiimage_algo, + ], outputs=[output_buf, video_output], ).then( lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), @@ -387,7 +455,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo: lambda: gr.Button(interactive=True), outputs=[download_glb], ) - + extract_gs_btn.click( extract_gaussian, inputs=[output_buf], @@ -401,14 +469,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo: lambda: gr.Button(interactive=False), outputs=[download_glb], ) - + # Launch the Gradio app if __name__ == "__main__": - pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") - pipeline.cuda() - try: - pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg - except: - pass demo.launch() diff --git a/extensions/nvdiffrast/README.md b/extensions/nvdiffrast/README.md index 3eeb4115c839a7703c5cac22fe6e89828ad29f2c..d75b1d1de85698235fb8e7b742195fe3be669e45 100644 --- a/extensions/nvdiffrast/README.md +++ b/extensions/nvdiffrast/README.md @@ -21,7 +21,7 @@ We do not currently accept outside code contributions in the form of pull reques Environment map stored as part of `samples/data/envphong.npz` is derived from a Wave Engine [sample material](https://github.com/WaveEngine/Samples-2.5/tree/master/Materials/EnvironmentMap/Content/Assets/CubeMap.cubemap) -originally shared under +originally shared under [MIT License](https://github.com/WaveEngine/Samples-2.5/blob/master/LICENSE.md). Mesh and texture stored as part of `samples/data/earth.npz` are derived from [3D Earth Photorealistic 2K](https://www.turbosquid.com/3d-models/3d-realistic-earth-photorealistic-2k-1279125) diff --git a/extensions/nvdiffrast/nvdiffrast/__init__.py b/extensions/nvdiffrast/nvdiffrast/__init__.py index fd28a0879ef844ef791dca19abdc8416c2468e58..5cd55e3ca1541fc09c861cc8f03f94a0346f0400 100644 --- a/extensions/nvdiffrast/nvdiffrast/__init__.py +++ b/extensions/nvdiffrast/nvdiffrast/__init__.py @@ -6,4 +6,4 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -__version__ = '0.3.3' +__version__ = "0.3.3" diff --git a/extensions/nvdiffrast/nvdiffrast/common/antialias.cu b/extensions/nvdiffrast/nvdiffrast/common/antialias.cu index 95cc3bab582661a7deb6064daa616adf7121ea36..5d306f7ad5f96a59cb4c8200ebb08ba71cabe860 100644 --- a/extensions/nvdiffrast/nvdiffrast/common/antialias.cu +++ b/extensions/nvdiffrast/nvdiffrast/common/antialias.cu @@ -112,7 +112,7 @@ static __device__ __forceinline__ void evhash_insert_vertex(const AntialiasKerne { if (va == vb) return; - + uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order uint64_t v1 = (uint32_t)max(va, vb) + 1; uint64_t vk = v0 | (v1 << 32); // hash key diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp index 3c1c3a7fd137618d6d20217b5ee4d9b964d3f9b8..28cd6bea939d4ffca45ce90361c01a02367c5863 100644 --- a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp @@ -60,4 +60,3 @@ private: //------------------------------------------------------------------------ } // namespace CR - diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp index d594acdfeb2a83133726a6dfd594b3ccad0d74cc..e0ea1931b4680de32882af48d0080596a4102717 100644 --- a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp @@ -99,4 +99,3 @@ private: //------------------------------------------------------------------------ } // namespace CR - diff --git a/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu b/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu index 3bd2a7a7ab3111ae12f6cdce73906eeb9bbf6935..94993edecddc5d7215e82f34966217b5fcdbae35 100644 --- a/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu +++ b/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu @@ -94,9 +94,9 @@ static __forceinline__ __device__ void InterpolateFwdKernelTemplate(const Interp float dvdx = db.z; float dvdy = db.w; - // Calculate the pixel differentials of chosen attributes. + // Calculate the pixel differentials of chosen attributes. for (int i=0; i < p.numDiffAttr; i++) - { + { // Input attribute index. int j = p.diff_attrs_all ? i : p.diffAttrs[i]; if (j < 0) diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture.cpp b/extensions/nvdiffrast/nvdiffrast/common/texture.cpp index 51633e10120b4dc465e5283241a38c95db31f8dc..039b98a4a92ac42a472112f0887e5cb63bd4b934 100644 --- a/extensions/nvdiffrast/nvdiffrast/common/texture.cpp +++ b/extensions/nvdiffrast/nvdiffrast/common/texture.cpp @@ -47,7 +47,7 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p) // Append level size to error message. snprintf(buf, bufsz, "mip %-2d ", level); - msg += buf; + msg += buf; if (ew) snprintf(buf, bufsz, " err "); else snprintf(buf, bufsz, "%5d ", w); msg += buf; diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py index be51deef13e0ecfbd5bfe8bc376af24a18db7224..10a828427d56eae5eb54ef6b47301a645a9b9178 100644 --- a/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py @@ -11,22 +11,26 @@ import numpy as np import os from . import plugin_loader -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # Helpers. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # OpenGL-related linker options depending on platform. def _get_gl_opts(): libs = { - 'posix': ['GL', 'EGL'], - 'nt': ['gdi32', 'opengl32', 'user32', 'setgpu'], + "posix": ["GL", "EGL"], + "nt": ["gdi32", "opengl32", "user32", "setgpu"], } - return ['-l' + x for x in libs[os.name]] + return ["-l" + x for x in libs[os.name]] + # Load the cpp plugin. def _get_plugin(): - fn = os.path.join(os.path.dirname(__file__), 'tf_all.cu') - return plugin_loader.get_plugin(fn, extra_nvcc_options=_get_gl_opts() + ['-DNVDR_TENSORFLOW']) + fn = os.path.join(os.path.dirname(__file__), "tf_all.cu") + return plugin_loader.get_plugin( + fn, extra_nvcc_options=_get_gl_opts() + ["-DNVDR_TENSORFLOW"] + ) + # Convert parameter to a numpy array if possible. def _get_constant(x, dtype): @@ -35,19 +39,24 @@ def _get_constant(x, dtype): except (TypeError, ValueError): return None + # Tests for a construction-time constantness instead of tf.constant node because # the latter can be overridden in Session.run() feed_dict at evaluation time. def _is_constant(x, dtype): if isinstance(x, np.ndarray): - return np.can_cast(x.dtype, dtype, 'unsafe') + return np.can_cast(x.dtype, dtype, "unsafe") else: return _get_constant(x, dtype) is not None -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Rasterize. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- -def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True, grad_db=True): + +def rasterize( + pos, tri, resolution, ranges=None, tri_const=False, output_db=True, grad_db=True +): assert tri_const is True or tri_const is False assert output_db is True or output_db is False @@ -63,15 +72,19 @@ def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True pos = tf.convert_to_tensor(pos, dtype=tf.float32) resolution = tf.convert_to_tensor(resolution, dtype=tf.int32) if ranges is None: - ranges = tf.convert_to_tensor(np.zeros(shape=[0, 2], dtype=np.int32)) # Empty tensor. + ranges = tf.convert_to_tensor( + np.zeros(shape=[0, 2], dtype=np.int32) + ) # Empty tensor. else: - ranges = tf.convert_to_tensor(ranges, dtype=tf.int32) # Convert input to tensor. + ranges = tf.convert_to_tensor( + ranges, dtype=tf.int32 + ) # Convert input to tensor. # Infer as much about the output shape as possible. out_shape = [None, None, None, 4] - if pos.shape.rank == 3: # Instanced mode. + if pos.shape.rank == 3: # Instanced mode. out_shape[0] = pos.shape[0].value - elif pos.shape.rank == 2: # Range mode. + elif pos.shape.rank == 2: # Range mode. if ranges.shape.rank not in [None, 0]: out_shape[0] = ranges.shape[0].value if resolution_c is not None: @@ -81,24 +94,32 @@ def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True # Output pixel differentials. @tf.custom_gradient def func_db(pos): - out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 1, tri_const) + out, out_db = _get_plugin().rasterize_fwd( + pos, tri, resolution, ranges, 1, tri_const + ) out.set_shape(out_shape) out_db.set_shape(out_shape) + def grad(dy, ddb): if grad_db: return _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb) else: return _get_plugin().rasterize_grad(pos, tri, out, dy) + return (out, out_db), grad # Do not output pixel differentials. @tf.custom_gradient def func(pos): - out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 0, tri_const) + out, out_db = _get_plugin().rasterize_fwd( + pos, tri, resolution, ranges, 0, tri_const + ) out.set_shape(out_shape) - out_db.set_shape(out_shape[:-1] + [0]) # Zero channels in out_db. + out_db.set_shape(out_shape[:-1] + [0]) # Zero channels in out_db. + def grad(dy, _): return _get_plugin().rasterize_grad(pos, tri, out, dy) + return (out, out_db), grad # Choose stub. @@ -107,15 +128,17 @@ def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True else: return func(pos) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Interpolate. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): # Sanitize the list of pixel differential attributes. if diff_attrs is None: diff_attrs = [] - elif diff_attrs != 'all': + elif diff_attrs != "all": diff_attrs = _get_constant(diff_attrs, np.int32) assert (diff_attrs is not None) and len(diff_attrs.shape) == 1 diff_attrs = diff_attrs.tolist() @@ -130,16 +153,23 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): # Infer output shape. out_shape = [None, None, None, None] if rast.shape.rank is not None: - out_shape = [rast.shape[0].value, rast.shape[1].value, rast.shape[2].value, None] + out_shape = [ + rast.shape[0].value, + rast.shape[1].value, + rast.shape[2].value, + None, + ] if attr.shape.rank in [2, 3]: out_shape[3] = attr.shape[-1].value # Output pixel differentials for at least some attributes. @tf.custom_gradient def func_da(attr, rast, rast_db): - diff_attrs_all = int(diff_attrs == 'all') + diff_attrs_all = int(diff_attrs == "all") diff_attrs_list = [] if diff_attrs_all else diff_attrs - out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + out, out_da = _get_plugin().interpolate_fwd_da( + attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list + ) # Infer number of channels in out_da. if not diff_attrs_all: @@ -154,7 +184,10 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): out_da.set_shape([out_shape[0], out_shape[1], out_shape[2], da_channels]) def grad(dy, dda): - return _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + return _get_plugin().interpolate_grad_da( + attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list + ) + return (out, out_da), grad # No pixel differentials for any attribute. @@ -162,9 +195,11 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): def func(attr, rast): out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri) out.set_shape(out_shape) - out_da.set_shape(out_shape[:-1] + [0]) # Zero channels in out_da. + out_da.set_shape(out_shape[:-1] + [0]) # Zero channels in out_da. + def grad(dy, _): return _get_plugin().interpolate_grad(attr, rast, tri, dy) + return (out, out_da), grad # Choose stub. @@ -173,16 +208,26 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): else: return func(attr, rast) -#---------------------------------------------------------------------------- -# Texture. -#---------------------------------------------------------------------------- -def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_const=False, max_mip_level=None): +# ---------------------------------------------------------------------------- +# Texture. +# ---------------------------------------------------------------------------- + + +def texture( + tex, + uv, + uv_da=None, + filter_mode="auto", + boundary_mode="wrap", + tex_const=False, + max_mip_level=None, +): assert tex_const is True or tex_const is False # Default filter mode. - if filter_mode == 'auto': - filter_mode = 'linear-mipmap-linear' if (uv_da is not None) else 'linear' + if filter_mode == "auto": + filter_mode = "linear-mipmap-linear" if (uv_da is not None) else "linear" # Known constant texture? tex_const = tex_const or _is_constant(tex, np.float32) @@ -198,7 +243,7 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c # Convert inputs to tensors. tex = tf.convert_to_tensor(tex, dtype=tf.float32) uv = tf.convert_to_tensor(uv, dtype=tf.float32) - if 'mipmap' in filter_mode: + if "mipmap" in filter_mode: uv_da = tf.convert_to_tensor(uv_da, dtype=tf.float32) # Infer output shape. @@ -207,37 +252,83 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c assert uv.shape.rank == 4 out_shape = [uv.shape[0].value, uv.shape[1].value, uv.shape[2].value, None] if tex.shape.rank is not None: - assert tex.shape.rank == (5 if boundary_mode == 'cube' else 4) + assert tex.shape.rank == (5 if boundary_mode == "cube" else 4) out_shape[-1] = tex.shape[-1].value # If mipping disabled via max level=0, we may as well use simpler filtering internally. - if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']: - filter_mode = 'linear' + if max_mip_level == 0 and filter_mode in [ + "linear-mipmap-nearest", + "linear-mipmap-linear", + ]: + filter_mode = "linear" # Convert filter mode to internal enumeration. - filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3} + filter_mode_dict = { + "nearest": 0, + "linear": 1, + "linear-mipmap-nearest": 2, + "linear-mipmap-linear": 3, + } filter_mode_enum = filter_mode_dict[filter_mode] # Convert boundary mode to internal enumeration. - boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3} + boundary_mode_dict = {"cube": 0, "wrap": 1, "clamp": 2, "zero": 3} boundary_mode_enum = boundary_mode_dict[boundary_mode] # Linear-mipmap-linear: Mipmaps enabled, all gradients active. @tf.custom_gradient def func_linear_mipmap_linear(tex, uv, uv_da): - out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level) + out, mip = _get_plugin().texture_fwd_mip( + tex, + uv, + uv_da, + filter_mode_enum, + boundary_mode_enum, + tex_const, + max_mip_level, + ) out.set_shape(out_shape) + def grad(dy): - return _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level) + return _get_plugin().texture_grad_linear_mipmap_linear( + tex, + uv, + dy, + uv_da, + mip, + filter_mode_enum, + boundary_mode_enum, + max_mip_level, + ) + return out, grad # Linear-mipmap-nearest: Mipmaps enabled, no gradients to uv_da. @tf.custom_gradient def func_linear_mipmap_nearest(tex, uv): - out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level) + out, mip = _get_plugin().texture_fwd_mip( + tex, + uv, + uv_da, + filter_mode_enum, + boundary_mode_enum, + tex_const, + max_mip_level, + ) out.set_shape(out_shape) + def grad(dy): - return _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level) + return _get_plugin().texture_grad_linear_mipmap_nearest( + tex, + uv, + dy, + uv_da, + mip, + filter_mode_enum, + boundary_mode_enum, + max_mip_level, + ) + return out, grad # Linear: Mipmaps disabled, no uv_da, no gradients to uv_da. @@ -245,8 +336,12 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c def func_linear(tex, uv): out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) out.set_shape(out_shape) + def grad(dy): - return _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return _get_plugin().texture_grad_linear( + tex, uv, dy, filter_mode_enum, boundary_mode_enum + ) + return out, grad # Nearest: Mipmaps disabled, no uv_da, no gradients to uv_da or uv. @@ -254,23 +349,29 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c def func_nearest(tex): out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) out.set_shape(out_shape) + def grad(dy): - return _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return _get_plugin().texture_grad_nearest( + tex, uv, dy, filter_mode_enum, boundary_mode_enum + ) + return out, grad # Choose stub. - if filter_mode == 'linear-mipmap-linear': + if filter_mode == "linear-mipmap-linear": return func_linear_mipmap_linear(tex, uv, uv_da) - elif filter_mode == 'linear-mipmap-nearest': + elif filter_mode == "linear-mipmap-nearest": return func_linear_mipmap_nearest(tex, uv) - elif filter_mode == 'linear': + elif filter_mode == "linear": return func_linear(tex, uv) - elif filter_mode == 'nearest': + elif filter_mode == "nearest": return func_nearest(tex) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Antialias. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0): assert tri_const is True or tri_const is False @@ -289,15 +390,22 @@ def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0): @tf.custom_gradient def func(color, pos): - color_out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, tri_const) + color_out, work_buffer = _get_plugin().antialias_fwd( + color, rast, pos, tri, tri_const + ) color_out.set_shape(color.shape) + def grad(dy): - grad_color, grad_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer) + grad_color, grad_pos = _get_plugin().antialias_grad( + color, rast, pos, tri, dy, work_buffer + ) if pos_gradient_boost != 1.0: grad_pos = grad_pos * pos_gradient_boost return grad_color, grad_pos + return color_out, grad return func(color, pos) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py index 3918aecdab6bb4192e8810bd872abf9a1fc30971..a2efe228cc0e3d6c6abe08704afe2a00cf32daea 100644 --- a/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py @@ -14,15 +14,16 @@ import hashlib import tempfile import shutil import tensorflow as tf -from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module +from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # Global options. _nvdiffrast_cache_dir = None + def set_cache_dir(path: str) -> None: - '''Set CUDA kernel compilation temp dir. + """Set CUDA kernel compilation temp dir. If `set_cache_dir` is not called, the cache directory will default to one of the below: @@ -33,103 +34,164 @@ def set_cache_dir(path: str) -> None: Args: path: Where to save CUDA kernel build temporaries - ''' + """ global _nvdiffrast_cache_dir _nvdiffrast_cache_dir = path + def make_cache_dir_path(*paths: str) -> str: if _nvdiffrast_cache_dir is not None: return os.path.join(_nvdiffrast_cache_dir, *paths) - if 'NVDIFFRAST_CACHE_DIR' in os.environ: - return os.path.join(os.environ['NVDIFFRAST_CACHE_DIR'], *paths) - if 'HOME' in os.environ: - return os.path.join(os.environ['HOME'], '.cache', 'nvdiffrast', *paths) - if 'USERPROFILE' in os.environ: - return os.path.join(os.environ['USERPROFILE'], '.cache', 'nvdiffrast', *paths) - return os.path.join(tempfile.gettempdir(), '.cache', 'nvdiffrast', *paths) - -cuda_cache_version_tag = 'v1' -do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! -verbose = True # Print status messages to stdout. - -#---------------------------------------------------------------------------- + if "NVDIFFRAST_CACHE_DIR" in os.environ: + return os.path.join(os.environ["NVDIFFRAST_CACHE_DIR"], *paths) + if "HOME" in os.environ: + return os.path.join(os.environ["HOME"], ".cache", "nvdiffrast", *paths) + if "USERPROFILE" in os.environ: + return os.path.join(os.environ["USERPROFILE"], ".cache", "nvdiffrast", *paths) + return os.path.join(tempfile.gettempdir(), ".cache", "nvdiffrast", *paths) + + +cuda_cache_version_tag = "v1" +do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! +verbose = True # Print status messages to stdout. + +# ---------------------------------------------------------------------------- # Internal helper funcs. + def _find_compiler_bindir(): - hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files (x86)/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + hostx64_paths = sorted( + glob.glob( + "C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64" + ), + reverse=True, + ) if hostx64_paths != []: return hostx64_paths[0] - vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' + vc_bin_dir = "C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin" if os.path.isdir(vc_bin_dir): return vc_bin_dir return None + def _get_compute_cap(device): caps_str = device.physical_device_desc - m = re.search('compute capability: (\\d+).(\\d+)', caps_str) + m = re.search("compute capability: (\\d+).(\\d+)", caps_str) major = m.group(1) minor = m.group(2) return (major, minor) + def _get_cuda_gpu_arch_string(): - gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] + gpus = [x for x in device_lib.list_local_devices() if x.device_type == "GPU"] if len(gpus) == 0: - raise RuntimeError('No GPU devices found') + raise RuntimeError("No GPU devices found") (major, minor) = _get_compute_cap(gpus[0]) - return 'sm_%s%s' % (major, minor) + return "sm_%s%s" % (major, minor) + def _run_cmd(cmd): with os.popen(cmd) as pipe: output = pipe.read() status = pipe.close() if status is not None: - raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) + raise RuntimeError( + "NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s" + % (cmd, output) + ) + def _prepare_nvcc_cli(opts): - cmd = 'nvcc ' + opts.strip() - cmd += ' --disable-warnings' + cmd = "nvcc " + opts.strip() + cmd += " --disable-warnings" cmd += ' --include-path "%s"' % tf.sysconfig.get_include() - cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') - cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') - cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') + cmd += ' --include-path "%s"' % os.path.join( + tf.sysconfig.get_include(), "external", "protobuf_archive", "src" + ) + cmd += ' --include-path "%s"' % os.path.join( + tf.sysconfig.get_include(), "external", "com_google_absl" + ) + cmd += ' --include-path "%s"' % os.path.join( + tf.sysconfig.get_include(), "external", "eigen_archive" + ) compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: # Require that _find_compiler_bindir succeeds on Windows. Allow # nvcc to use whatever is the default on Linux. - if os.name == 'nt': - raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) + if os.name == "nt": + raise RuntimeError( + 'Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' + % __file__ + ) else: cmd += ' --compiler-bindir "%s"' % compiler_bindir - cmd += ' 2>&1' + cmd += " 2>&1" return cmd -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Main entry point. _plugin_cache = dict() + def get_plugin(cuda_file, extra_nvcc_options=[]): cuda_file_base = os.path.basename(cuda_file) cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) @@ -140,80 +202,112 @@ def get_plugin(cuda_file, extra_nvcc_options=[]): # Setup plugin. if verbose: - print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) + print( + 'Setting up TensorFlow plugin "%s": ' % cuda_file_base, end="", flush=True + ) try: # Hash CUDA source. md5 = hashlib.md5() - with open(cuda_file, 'rb') as f: + with open(cuda_file, "rb") as f: md5.update(f.read()) - md5.update(b'\n') + md5.update(b"\n") # Hash headers included by the CUDA code by running it through the preprocessor. if not do_not_hash_included_headers: if verbose: - print('Preprocessing... ', end='', flush=True) + print("Preprocessing... ", end="", flush=True) with tempfile.TemporaryDirectory() as tmp_dir: - tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) - _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) - with open(tmp_file, 'rb') as f: - bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros - good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') + tmp_file = os.path.join( + tmp_dir, cuda_file_name + "_tmp" + cuda_file_ext + ) + _run_cmd( + _prepare_nvcc_cli( + '"%s" --preprocess -o "%s" --keep --keep-dir "%s"' + % (cuda_file, tmp_file, tmp_dir) + ) + ) + with open(tmp_file, "rb") as f: + bad_file_str = ('"' + cuda_file.replace("\\", "/") + '"').encode( + "utf-8" + ) # __FILE__ in error check macros + good_file_str = ('"' + cuda_file_base + '"').encode("utf-8") for ln in f: - if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas + if not ln.startswith(b"# ") and not ln.startswith( + b"#line " + ): # ignore line number pragmas ln = ln.replace(bad_file_str, good_file_str) md5.update(ln) - md5.update(b'\n') + md5.update(b"\n") # Select compiler options. - compile_opts = '' - if os.name == 'nt': - compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') - compile_opts += ' --library-path="%s"' % (os.path.dirname(__file__) + r"\..\lib") # Find libraries during compilation. - elif os.name == 'posix': - compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so') - compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' + compile_opts = "" + if os.name == "nt": + compile_opts += '"%s"' % os.path.join( + tf.sysconfig.get_lib(), "python", "_pywrap_tensorflow_internal.lib" + ) + compile_opts += ' --library-path="%s"' % ( + os.path.dirname(__file__) + r"\..\lib" + ) # Find libraries during compilation. + elif os.name == "posix": + compile_opts += '"%s"' % os.path.join( + tf.sysconfig.get_lib(), "python", "_pywrap_tensorflow_internal.so" + ) + compile_opts += " --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0'" else: - assert False # not Windows or Linux, w00t? - compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() - compile_opts += ' --use_fast_math' + assert False # not Windows or Linux, w00t? + compile_opts += " --gpu-architecture=%s" % _get_cuda_gpu_arch_string() + compile_opts += " --use_fast_math" for opt in extra_nvcc_options: - compile_opts += ' ' + opt + compile_opts += " " + opt nvcc_cmd = _prepare_nvcc_cli(compile_opts) # Hash build configuration. - md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') - md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') - md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') + md5.update(("nvcc_cmd: " + nvcc_cmd).encode("utf-8") + b"\n") + md5.update(("tf.VERSION: " + tf.VERSION).encode("utf-8") + b"\n") + md5.update( + ("cuda_cache_version_tag: " + cuda_cache_version_tag).encode("utf-8") + + b"\n" + ) # Compile if not already compiled. - bin_file_ext = '.dll' if os.name == 'nt' else '.so' + bin_file_ext = ".dll" if os.name == "nt" else ".so" cuda_cache_path = make_cache_dir_path() - bin_file = os.path.join(make_cache_dir_path(), cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) + bin_file = os.path.join( + make_cache_dir_path(), cuda_file_name + "_" + md5.hexdigest() + bin_file_ext + ) if not os.path.isfile(bin_file): if verbose: - print('Compiling... ', end='', flush=True) + print("Compiling... ", end="", flush=True) with tempfile.TemporaryDirectory() as tmp_dir: - tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) - _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) + tmp_file = os.path.join(tmp_dir, cuda_file_name + "_tmp" + bin_file_ext) + _run_cmd( + nvcc_cmd + + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' + % (cuda_file, tmp_file, tmp_dir) + ) os.makedirs(cuda_cache_path, exist_ok=True) - intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) + intermediate_file = os.path.join( + cuda_cache_path, + cuda_file_name + "_" + uuid.uuid4().hex + "_tmp" + bin_file_ext, + ) shutil.copyfile(tmp_file, intermediate_file) - os.rename(intermediate_file, bin_file) # atomic + os.rename(intermediate_file, bin_file) # atomic # Load. if verbose: - print('Loading... ', end='', flush=True) + print("Loading... ", end="", flush=True) plugin = tf.load_op_library(bin_file) # Add to cache. _plugin_cache[cuda_file] = plugin if verbose: - print('Done.', flush=True) + print("Done.", flush=True) return plugin except: if verbose: - print('Failed!', flush=True) + print("Failed!", flush=True) raise -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu index 9b14962a8b40e12bfab1ca3a7107d5f5e943a125..13cd6171d75d3cc86034f8e290a3f8a1f15e33ed 100644 --- a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu @@ -100,13 +100,13 @@ struct AntialiasFwdOp : public OpKernel // (Re-)calculate opposite vertex hash. if (!p.evHash || !p.tri_const) - { + { if (p.allocTriangles < p.numTriangles) { p.allocTriangles = max(p.allocTriangles, 64); while (p.allocTriangles < p.numTriangles) p.allocTriangles <<= 1; // Must be power of two. - + // (Re-)allocate memory for the hash. OP_CHECK_CUDA_ERROR(ctx, cudaFree(p.evHash)); OP_CHECK_CUDA_ERROR(ctx, cudaMalloc(&p.evHash, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4))); diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu index 612ce1afc5ce41a25496523b193725c1edac64de..cfe3037b86abcca1b9d776d88689bed16796a5df 100644 --- a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu @@ -112,7 +112,7 @@ struct InterpolateFwdOp : public OpKernel // Verify that buffers are aligned to allow float2/float4 operations. OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4")); - OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); if (ENABLE_DA) OP_REQUIRES(ctx, !((uintptr_t)p.outDA & 7), errors::Internal("out_da output tensor not aligned to float2")); @@ -158,7 +158,7 @@ struct InterpolateGradOp : public OpKernel InterpolateGradOp(OpKernelConstruction* ctx): OpKernel(ctx) { memset(&m_attribs, 0, sizeof(m_attribs)); - interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA); + interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA); } void Compute(OpKernelContext* ctx) @@ -247,7 +247,7 @@ struct InterpolateGradOp : public OpKernel OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_rast_shape, &grad_rast_db_tensor)); p.gradRasterDB = grad_rast_db_tensor->flat().data(); } - + // Clear attribute gradients. cudaMemsetAsync(p.gradAttr, 0, attr_depth * p.numVertices * p.numAttr * sizeof(float), stream); @@ -257,10 +257,10 @@ struct InterpolateGradOp : public OpKernel if (ENABLE_DA) { OP_REQUIRES(ctx, !((uintptr_t)p.dda & 7), errors::Internal("dda input tensor not aligned to float2")); - OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.gradRasterDB & 15), errors::Internal("grad_rast_db output tensor not aligned to float4")); } - + // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu index c5382fed28236da09d20a04c0524a937383daf5a..f9e87a79577dc87a8f3e5a580d60040b219f69f1 100644 --- a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu @@ -503,7 +503,7 @@ REGISTER_OP("TextureGradLinearMipmapNearest") .Attr ("filter_mode: int") .Attr ("boundary_mode: int") .Attr ("max_mip_level: int"); - + REGISTER_OP("TextureGradLinearMipmapLinear") .Input ("tex: float") .Input ("uv: float") @@ -516,10 +516,10 @@ REGISTER_OP("TextureGradLinearMipmapLinear") .Attr ("filter_mode: int") .Attr ("boundary_mode: int") .Attr ("max_mip_level: int"); - + REGISTER_KERNEL_BUILDER(Name("TextureGradNearest") .Device(DEVICE_GPU), TextureGradOp); REGISTER_KERNEL_BUILDER(Name("TextureGradLinear") .Device(DEVICE_GPU), TextureGradOp); REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapNearest").Device(DEVICE_GPU), TextureGradOp); REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapLinear") .Device(DEVICE_GPU), TextureGradOp); - + //------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/__init__.py b/extensions/nvdiffrast/nvdiffrast/torch/__init__.py index d28f95e7a9e423b5efb322c39e343a069caf0fe8..8b1c5e2314a13ab5c765cdbaf9cca0032d757f65 100644 --- a/extensions/nvdiffrast/nvdiffrast/torch/__init__.py +++ b/extensions/nvdiffrast/nvdiffrast/torch/__init__.py @@ -6,5 +6,30 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -from .ops import RasterizeCudaContext, RasterizeGLContext, get_log_level, set_log_level, rasterize, DepthPeeler, interpolate, texture, texture_construct_mip, antialias, antialias_construct_topology_hash -__all__ = ["RasterizeCudaContext", "RasterizeGLContext", "get_log_level", "set_log_level", "rasterize", "DepthPeeler", "interpolate", "texture", "texture_construct_mip", "antialias", "antialias_construct_topology_hash"] +from .ops import ( + RasterizeCudaContext, + RasterizeGLContext, + get_log_level, + set_log_level, + rasterize, + DepthPeeler, + interpolate, + texture, + texture_construct_mip, + antialias, + antialias_construct_topology_hash, +) + +__all__ = [ + "RasterizeCudaContext", + "RasterizeGLContext", + "get_log_level", + "set_log_level", + "rasterize", + "DepthPeeler", + "interpolate", + "texture", + "texture_construct_mip", + "antialias", + "antialias_construct_topology_hash", +] diff --git a/extensions/nvdiffrast/nvdiffrast/torch/ops.py b/extensions/nvdiffrast/nvdiffrast/torch/ops.py index edf8540fda5aed6736a72b44b993031157a9cf4b..a670d8b463e057cf19585bf58c73b7610db446c9 100644 --- a/extensions/nvdiffrast/nvdiffrast/torch/ops.py +++ b/extensions/nvdiffrast/nvdiffrast/torch/ops.py @@ -14,13 +14,15 @@ import torch import torch.utils.cpp_extension from . import _C -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # C++/Cuda plugin compiler/loader. _cached_plugin = {} + + def _get_plugin(gl=False): assert isinstance(gl, bool) - + # Modified with precompiled torch CUDA extension if not gl: return _C @@ -30,16 +32,27 @@ def _get_plugin(gl=False): return _cached_plugin[gl] # Make sure we can find the necessary compiler and libary binaries. - if os.name == 'nt': + if os.name == "nt": lib_dir = os.path.dirname(__file__) + r"\..\lib" + def find_cl_path(): import glob + def get_sort_key(x): # Primary criterion is VS version, secondary is edition, third is internal MSVC version. - x = x.split('\\')[3:] - x[1] = {'BuildTools': '~0', 'Community': '~1', 'Pro': '~2', 'Professional': '~3', 'Enterprise': '~4'}.get(x[1], x[1]) + x = x.split("\\")[3:] + x[1] = { + "BuildTools": "~0", + "Community": "~1", + "Pro": "~2", + "Professional": "~3", + "Enterprise": "~4", + }.get(x[1], x[1]) return x - vs_relative_path = r"\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64" + + vs_relative_path = ( + r"\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64" + ) paths = glob.glob(r"C:\Program Files" + vs_relative_path) paths += glob.glob(r"C:\Program Files (x86)" + vs_relative_path) if paths: @@ -49,104 +62,126 @@ def _get_plugin(gl=False): if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: - raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") - os.environ['PATH'] += ';' + cl_path + raise RuntimeError( + "Could not locate a supported Microsoft Visual C++ installation" + ) + os.environ["PATH"] += ";" + cl_path # Compiler options. - common_opts = ['-DNVDR_TORCH'] + common_opts = ["-DNVDR_TORCH"] cc_opts = [] - if os.name == 'nt': - cc_opts += ['/wd4067', '/wd4624'] # Disable warnings in torch headers. + if os.name == "nt": + cc_opts += ["/wd4067", "/wd4624"] # Disable warnings in torch headers. # Linker options for the GL-interfacing plugin. ldflags = [] if gl: - if os.name == 'posix': - ldflags = ['-lGL', '-lEGL'] - elif os.name == 'nt': - libs = ['gdi32', 'opengl32', 'user32', 'setgpu'] - ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs] + if os.name == "posix": + ldflags = ["-lGL", "-lEGL"] + elif os.name == "nt": + libs = ["gdi32", "opengl32", "user32", "setgpu"] + ldflags = ["/LIBPATH:" + lib_dir] + ["/DEFAULTLIB:" + x for x in libs] # List of source files. if gl: source_files = [ - '../common/common.cpp', - '../common/glutil.cpp', - '../common/rasterize_gl.cpp', - 'torch_bindings_gl.cpp', - 'torch_rasterize_gl.cpp', + "../common/common.cpp", + "../common/glutil.cpp", + "../common/rasterize_gl.cpp", + "torch_bindings_gl.cpp", + "torch_rasterize_gl.cpp", ] else: source_files = [ - '../common/cudaraster/impl/Buffer.cpp', - '../common/cudaraster/impl/CudaRaster.cpp', - '../common/cudaraster/impl/RasterImpl.cu', - '../common/cudaraster/impl/RasterImpl.cpp', - '../common/common.cpp', - '../common/rasterize.cu', - '../common/interpolate.cu', - '../common/texture.cu', - '../common/texture.cpp', - '../common/antialias.cu', - 'torch_bindings.cpp', - 'torch_rasterize.cpp', - 'torch_interpolate.cpp', - 'torch_texture.cpp', - 'torch_antialias.cpp', + "../common/cudaraster/impl/Buffer.cpp", + "../common/cudaraster/impl/CudaRaster.cpp", + "../common/cudaraster/impl/RasterImpl.cu", + "../common/cudaraster/impl/RasterImpl.cpp", + "../common/common.cpp", + "../common/rasterize.cu", + "../common/interpolate.cu", + "../common/texture.cu", + "../common/texture.cpp", + "../common/antialias.cu", + "torch_bindings.cpp", + "torch_rasterize.cpp", + "torch_interpolate.cpp", + "torch_texture.cpp", + "torch_antialias.cpp", ] # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. - os.environ['TORCH_CUDA_ARCH_LIST'] = '' + os.environ["TORCH_CUDA_ARCH_LIST"] = "" # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin. - if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')): - logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin") + if gl and (os.name == "posix") and ("libGLEW" in os.environ.get("LD_PRELOAD", "")): + logging.getLogger("nvdiffrast").warning( + "Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin" + ) # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. - plugin_name = 'nvdiffrast_plugin' + ('_gl' if gl else '') + plugin_name = "nvdiffrast_plugin" + ("_gl" if gl else "") try: - lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock') + lock_fn = os.path.join( + torch.utils.cpp_extension._get_build_directory(plugin_name, False), "lock" + ) if os.path.exists(lock_fn): - logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn) + logging.getLogger("nvdiffrast").warning( + "Lock file exists in build directory: '%s'" % lock_fn + ) except: pass # Speed up compilation on Windows. - if os.name == 'nt': + if os.name == "nt": # Skip telemetry sending step in vcvarsall.bat - os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1' + os.environ["VSCMD_SKIP_SENDTELEMETRY"] = "1" # Opportunistically patch distutils to cache MSVC environments. try: import distutils._msvccompiler import functools - if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'): - distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env) + + if not hasattr(distutils._msvccompiler._get_vc_env, "__wrapped__"): + distutils._msvccompiler._get_vc_env = functools.lru_cache()( + distutils._msvccompiler._get_vc_env + ) except: pass # Compile and load. source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] - torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=common_opts+cc_opts, extra_cuda_cflags=common_opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False) + torch.utils.cpp_extension.load( + name=plugin_name, + sources=source_paths, + extra_cflags=common_opts + cc_opts, + extra_cuda_cflags=common_opts + ["-lineinfo"], + extra_ldflags=ldflags, + with_cuda=True, + verbose=False, + ) # Import, cache, and return the compiled module. _cached_plugin[gl] = importlib.import_module(plugin_name) return _cached_plugin[gl] -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Log level. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def get_log_level(): - '''Get current log level. + """Get current log level. Returns: Current log level in nvdiffrast. See `set_log_level()` for possible values. - ''' + """ return _get_plugin().get_log_level() + def set_log_level(level): - '''Set log level. + """Set log level. Log levels follow the convention on the C++ side of Torch: 0 = Info, @@ -156,19 +191,21 @@ def set_log_level(level): The default log level is 1. Args: - level: New log level as integer. Internal nvdiffrast messages of this + level: New log level as integer. Internal nvdiffrast messages of this severity or higher will be printed, while messages of lower severity will be silent. - ''' + """ _get_plugin().set_log_level(level) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # CudaRaster state wrapper. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class RasterizeCudaContext: def __init__(self, device=None): - '''Create a new Cuda rasterizer context. + """Create a new Cuda rasterizer context. The context is deleted and internal storage is released when the object is destroyed. @@ -180,7 +217,7 @@ class RasterizeCudaContext: device. Returns: The newly created Cuda rasterizer context. - ''' + """ if device is None: cuda_device_idx = torch.cuda.current_device() else: @@ -190,13 +227,15 @@ class RasterizeCudaContext: self.output_db = True self.active_depth_peeler = None -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # GL state wrapper. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class RasterizeGLContext: - def __init__(self, output_db=True, mode='automatic', device=None): - '''Create a new OpenGL rasterizer context. + def __init__(self, output_db=True, mode="automatic", device=None): + """Create a new OpenGL rasterizer context. Creating an OpenGL context is a slow operation so you should usually reuse the same context in all calls to `rasterize()` on the same CPU thread. The OpenGL context @@ -220,9 +259,9 @@ class RasterizeGLContext: device. Returns: The newly created OpenGL rasterizer context. - ''' + """ assert output_db is True or output_db is False - assert mode in ['automatic', 'manual'] + assert mode in ["automatic", "manual"] self.output_db = output_db self.mode = mode if device is None: @@ -230,34 +269,42 @@ class RasterizeGLContext: else: with torch.cuda.device(device): cuda_device_idx = torch.cuda.current_device() - self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx) - self.active_depth_peeler = None # For error checking only. + self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper( + output_db, mode == "automatic", cuda_device_idx + ) + self.active_depth_peeler = None # For error checking only. def set_context(self): - '''Set (activate) OpenGL context in the current CPU thread. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Set (activate) OpenGL context in the current CPU thread. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.set_context() def release_context(self): - '''Release (deactivate) currently active OpenGL context. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Release (deactivate) currently active OpenGL context. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.release_context() -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Rasterize. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class _rasterize_func(torch.autograd.Function): @staticmethod def forward(ctx, raster_ctx, pos, tri, resolution, ranges, grad_db, peeling_idx): if isinstance(raster_ctx, RasterizeGLContext): - out, out_db = _get_plugin(gl=True).rasterize_fwd_gl(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx) + out, out_db = _get_plugin(gl=True).rasterize_fwd_gl( + raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx + ) else: - out, out_db = _get_plugin().rasterize_fwd_cuda(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx) + out, out_db = _get_plugin().rasterize_fwd_cuda( + raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx + ) ctx.save_for_backward(pos, tri, out) ctx.saved_grad_db = grad_db return out, out_db @@ -271,9 +318,10 @@ class _rasterize_func(torch.autograd.Function): g_pos = _get_plugin().rasterize_grad(pos, tri, out, dy) return None, g_pos, None, None, None, None, None + # Op wrapper. def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True): - '''Rasterize triangles. + """Rasterize triangles. All input tensors must be contiguous and reside in GPU memory except for the `ranges` tensor that, if specified, has to reside in CPU memory. The @@ -301,7 +349,7 @@ def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True): [minibatch_size, height, width, 4] and contain said derivatives in order (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape [minibatch_size, height, width, 0]. - ''' + """ assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext)) assert grad_db is True or grad_db is False grad_db = grad_db and glctx.output_db @@ -310,30 +358,34 @@ def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True): assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor) resolution = tuple(resolution) if ranges is None: - ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu') + ranges = torch.empty(size=(0, 2), dtype=torch.int32, device="cpu") else: assert isinstance(ranges, torch.Tensor) # Check that context is not currently reserved for depth peeling. if glctx.active_depth_peeler is not None: - return RuntimeError("Cannot call rasterize() during depth peeling operation, use rasterize_next_layer() instead") + return RuntimeError( + "Cannot call rasterize() during depth peeling operation, use rasterize_next_layer() instead" + ) # Instantiate the function. return _rasterize_func.apply(glctx, pos, tri, resolution, ranges, grad_db, -1) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Depth peeler context manager for rasterizing multiple depth layers. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class DepthPeeler: def __init__(self, glctx, pos, tri, resolution, ranges=None, grad_db=True): - '''Create a depth peeler object for rasterizing multiple depth layers. + """Create a depth peeler object for rasterizing multiple depth layers. Arguments are the same as in `rasterize()`. Returns: The newly created depth peeler. - ''' + """ assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext)) assert grad_db is True or grad_db is False grad_db = grad_db and glctx.output_db @@ -342,7 +394,7 @@ class DepthPeeler: assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor) resolution = tuple(resolution) if ranges is None: - ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu') + ranges = torch.empty(size=(0, 2), dtype=torch.int32, device="cpu") else: assert isinstance(ranges, torch.Tensor) @@ -359,7 +411,9 @@ class DepthPeeler: if self.raster_ctx is None: raise RuntimeError("Cannot re-enter a terminated depth peeling operation") if self.raster_ctx.active_depth_peeler is not None: - raise RuntimeError("Cannot have multiple depth peelers active simultaneously in a rasterization context") + raise RuntimeError( + "Cannot have multiple depth peelers active simultaneously in a rasterization context" + ) self.raster_ctx.active_depth_peeler = self self.peeling_idx = 0 return self @@ -367,7 +421,9 @@ class DepthPeeler: def __exit__(self, *args): assert self.raster_ctx.active_depth_peeler is self self.raster_ctx.active_depth_peeler = None - self.raster_ctx = None # Remove all references to input tensor so they're not left dangling. + self.raster_ctx = ( + None # Remove all references to input tensor so they're not left dangling. + ) self.pos = None self.tri = None self.resolution = None @@ -377,29 +433,40 @@ class DepthPeeler: return None def rasterize_next_layer(self): - '''Rasterize next depth layer. + """Rasterize next depth layer. Operation is equivalent to `rasterize()` except that previously reported surface points are culled away. Returns: A tuple of two tensors as in `rasterize()`. - ''' + """ assert self.raster_ctx.active_depth_peeler is self assert self.peeling_idx >= 0 - result = _rasterize_func.apply(self.raster_ctx, self.pos, self.tri, self.resolution, self.ranges, self.grad_db, self.peeling_idx) + result = _rasterize_func.apply( + self.raster_ctx, + self.pos, + self.tri, + self.resolution, + self.ranges, + self.grad_db, + self.peeling_idx, + ) self.peeling_idx += 1 return result -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Interpolate. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # Output pixel differentials for at least some attributes. class _interpolate_func_da(torch.autograd.Function): @staticmethod def forward(ctx, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list): - out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + out, out_da = _get_plugin().interpolate_fwd_da( + attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list + ) ctx.save_for_backward(attr, rast, tri, rast_db) ctx.saved_misc = diff_attrs_all, diff_attrs_list return out, out_da @@ -408,9 +475,12 @@ class _interpolate_func_da(torch.autograd.Function): def backward(ctx, dy, dda): attr, rast, tri, rast_db = ctx.saved_tensors diff_attrs_all, diff_attrs_list = ctx.saved_misc - g_attr, g_rast, g_rast_db = _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + g_attr, g_rast, g_rast_db = _get_plugin().interpolate_grad_da( + attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list + ) return g_attr, g_rast, None, g_rast_db, None, None + # No pixel differential for any attribute. class _interpolate_func(torch.autograd.Function): @staticmethod @@ -425,6 +495,7 @@ class _interpolate_func(torch.autograd.Function): g_attr, g_rast = _get_plugin().interpolate_grad(attr, rast, tri, dy) return g_attr, g_rast, None + # Op wrapper. def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): """Interpolate vertex attributes. @@ -433,13 +504,13 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): will be contiguous and reside in GPU memory. Args: - attr: Attribute tensor with dtype `torch.float32`. - Shape is [num_vertices, num_attributes] in range mode, or + attr: Attribute tensor with dtype `torch.float32`. + Shape is [num_vertices, num_attributes] in range mode, or [minibatch_size, num_vertices, num_attributes] in instanced mode. Broadcasting is supported along the minibatch axis. rast: Main output tensor from `rasterize()`. tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`. - rast_db: (Optional) Tensor containing image-space derivatives of barycentrics, + rast_db: (Optional) Tensor containing image-space derivatives of barycentrics, i.e., the second output tensor from `rasterize()`. Enables computing image-space derivatives of attributes. diff_attrs: (Optional) List of attribute indices for which image-space @@ -459,12 +530,12 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): # Sanitize the list of pixel differential attributes. if diff_attrs is None: diff_attrs = [] - elif diff_attrs != 'all': + elif diff_attrs != "all": diff_attrs = np.asarray(diff_attrs, np.int32) assert len(diff_attrs.shape) == 1 diff_attrs = diff_attrs.tolist() - diff_attrs_all = int(diff_attrs == 'all') + diff_attrs_all = int(diff_attrs == "all") diff_attrs_list = [] if diff_attrs_all else diff_attrs # Check inputs. @@ -474,18 +545,32 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): # Choose stub. if diff_attrs: - return _interpolate_func_da.apply(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + return _interpolate_func_da.apply( + attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list + ) else: return _interpolate_func.apply(attr, rast, tri) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Texture -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled. class _texture_func_mip(torch.autograd.Function): @staticmethod - def forward(ctx, filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack): + def forward( + ctx, + filter_mode, + tex, + uv, + uv_da, + mip_level_bias, + mip_wrapper, + filter_mode_enum, + boundary_mode_enum, + *mip_stack + ): empty = torch.tensor([]) if uv_da is None: uv_da = empty @@ -493,7 +578,16 @@ class _texture_func_mip(torch.autograd.Function): mip_level_bias = empty if mip_wrapper is None: mip_wrapper = _get_plugin().TextureMipWrapper() - out = _get_plugin().texture_fwd_mip(tex, uv, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + out = _get_plugin().texture_fwd_mip( + tex, + uv, + uv_da, + mip_level_bias, + mip_wrapper, + mip_stack, + filter_mode_enum, + boundary_mode_enum, + ) ctx.save_for_backward(tex, uv, uv_da, mip_level_bias, *mip_stack) ctx.saved_misc = filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum return out @@ -502,12 +596,50 @@ class _texture_func_mip(torch.autograd.Function): def backward(ctx, dy): tex, uv, uv_da, mip_level_bias, *mip_stack = ctx.saved_tensors filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum = ctx.saved_misc - if filter_mode == 'linear-mipmap-linear': - g_tex, g_uv, g_uv_da, g_mip_level_bias, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) - return (None, g_tex, g_uv, g_uv_da, g_mip_level_bias, None, None, None) + tuple(g_mip_stack) - else: # linear-mipmap-nearest - g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) - return (None, g_tex, g_uv, None, None, None, None, None) + tuple(g_mip_stack) + if filter_mode == "linear-mipmap-linear": + ( + g_tex, + g_uv, + g_uv_da, + g_mip_level_bias, + g_mip_stack, + ) = _get_plugin().texture_grad_linear_mipmap_linear( + tex, + uv, + dy, + uv_da, + mip_level_bias, + mip_wrapper, + mip_stack, + filter_mode_enum, + boundary_mode_enum, + ) + return ( + None, + g_tex, + g_uv, + g_uv_da, + g_mip_level_bias, + None, + None, + None, + ) + tuple(g_mip_stack) + else: # linear-mipmap-nearest + g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest( + tex, + uv, + dy, + uv_da, + mip_level_bias, + mip_wrapper, + mip_stack, + filter_mode_enum, + boundary_mode_enum, + ) + return (None, g_tex, g_uv, None, None, None, None, None) + tuple( + g_mip_stack + ) + # Linear and nearest: Mipmaps disabled. class _texture_func(torch.autograd.Function): @@ -522,15 +654,29 @@ class _texture_func(torch.autograd.Function): def backward(ctx, dy): tex, uv = ctx.saved_tensors filter_mode, filter_mode_enum, boundary_mode_enum = ctx.saved_misc - if filter_mode == 'linear': - g_tex, g_uv = _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + if filter_mode == "linear": + g_tex, g_uv = _get_plugin().texture_grad_linear( + tex, uv, dy, filter_mode_enum, boundary_mode_enum + ) return None, g_tex, g_uv, None, None - else: # nearest - g_tex = _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + else: # nearest + g_tex = _get_plugin().texture_grad_nearest( + tex, uv, dy, filter_mode_enum, boundary_mode_enum + ) return None, g_tex, None, None, None + # Op wrapper. -def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto', boundary_mode='wrap', max_mip_level=None): +def texture( + tex, + uv, + uv_da=None, + mip_level_bias=None, + mip=None, + filter_mode="auto", + boundary_mode="wrap", + max_mip_level=None, +): """Perform texture sampling. All input tensors must be contiguous and reside in GPU memory. The output tensor @@ -580,8 +726,12 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut """ # Default filter mode. - if filter_mode == 'auto': - filter_mode = 'linear-mipmap-linear' if (uv_da is not None or mip_level_bias is not None) else 'linear' + if filter_mode == "auto": + filter_mode = ( + "linear-mipmap-linear" + if (uv_da is not None or mip_level_bias is not None) + else "linear" + ) # Sanitize inputs. if max_mip_level is None: @@ -592,23 +742,33 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut # Check inputs. assert isinstance(tex, torch.Tensor) and isinstance(uv, torch.Tensor) - if 'mipmap' in filter_mode: - assert isinstance(uv_da, torch.Tensor) or isinstance(mip_level_bias, torch.Tensor) + if "mipmap" in filter_mode: + assert isinstance(uv_da, torch.Tensor) or isinstance( + mip_level_bias, torch.Tensor + ) # If mipping disabled via max level=0, we may as well use simpler filtering internally. - if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']: - filter_mode = 'linear' + if max_mip_level == 0 and filter_mode in [ + "linear-mipmap-nearest", + "linear-mipmap-linear", + ]: + filter_mode = "linear" # Convert filter mode to internal enumeration. - filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3} + filter_mode_dict = { + "nearest": 0, + "linear": 1, + "linear-mipmap-nearest": 2, + "linear-mipmap-linear": 3, + } filter_mode_enum = filter_mode_dict[filter_mode] # Convert boundary mode to internal enumeration. - boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3} + boundary_mode_dict = {"cube": 0, "wrap": 1, "clamp": 2, "zero": 3} boundary_mode_enum = boundary_mode_dict[boundary_mode] # Construct a mipmap if necessary. - if 'mipmap' in filter_mode: + if "mipmap" in filter_mode: mip_wrapper, mip_stack = None, [] if mip is not None: assert isinstance(mip, (_get_plugin().TextureMipWrapper, list)) @@ -618,13 +778,28 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut else: mip_wrapper = mip else: - mip_wrapper = _get_plugin().texture_construct_mip(tex, max_mip_level, boundary_mode == 'cube') + mip_wrapper = _get_plugin().texture_construct_mip( + tex, max_mip_level, boundary_mode == "cube" + ) # Choose stub. - if filter_mode == 'linear-mipmap-linear' or filter_mode == 'linear-mipmap-nearest': - return _texture_func_mip.apply(filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack) + if filter_mode == "linear-mipmap-linear" or filter_mode == "linear-mipmap-nearest": + return _texture_func_mip.apply( + filter_mode, + tex, + uv, + uv_da, + mip_level_bias, + mip_wrapper, + filter_mode_enum, + boundary_mode_enum, + *mip_stack + ) else: - return _texture_func.apply(filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum) + return _texture_func.apply( + filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum + ) + # Mipmap precalculation for cases where the texture stays constant. def texture_construct_mip(tex, max_mip_level=None, cube_mode=False): @@ -639,7 +814,7 @@ def texture_construct_mip(tex, max_mip_level=None, cube_mode=False): cube_mode: Must be set to True if `tex` specifies a cube map texture. Returns: - An opaque object containing the mipmap stack. This can be supplied in a call to `texture()` + An opaque object containing the mipmap stack. This can be supplied in a call to `texture()` in the `mip` argument. """ @@ -652,14 +827,18 @@ def texture_construct_mip(tex, max_mip_level=None, cube_mode=False): assert max_mip_level >= 0 return _get_plugin().texture_construct_mip(tex, max_mip_level, cube_mode) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Antialias. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class _antialias_func(torch.autograd.Function): @staticmethod def forward(ctx, color, rast, pos, tri, topology_hash, pos_gradient_boost): - out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, topology_hash) + out, work_buffer = _get_plugin().antialias_fwd( + color, rast, pos, tri, topology_hash + ) ctx.save_for_backward(color, rast, pos, tri) ctx.saved_misc = pos_gradient_boost, work_buffer return out @@ -668,11 +847,14 @@ class _antialias_func(torch.autograd.Function): def backward(ctx, dy): color, rast, pos, tri = ctx.saved_tensors pos_gradient_boost, work_buffer = ctx.saved_misc - g_color, g_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer) + g_color, g_pos = _get_plugin().antialias_grad( + color, rast, pos, tri, dy, work_buffer + ) if pos_gradient_boost != 1.0: g_pos = g_pos * pos_gradient_boost return g_color, None, g_pos, None, None, None + # Op wrapper. def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0): """Perform antialiasing. @@ -711,13 +893,16 @@ def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0) topology_hash = _get_plugin().antialias_construct_topology_hash(tri) # Instantiate the function. - return _antialias_func.apply(color, rast, pos, tri, topology_hash, pos_gradient_boost) + return _antialias_func.apply( + color, rast, pos, tri, topology_hash, pos_gradient_boost + ) + # Topology hash precalculation for cases where the triangle array stays constant. def antialias_construct_topology_hash(tri): """Construct a topology hash for a triangle tensor. - This function can be used for constructing a topology hash for a triangle tensor that is + This function can be used for constructing a topology hash for a triangle tensor that is known to remain constant. This avoids reconstructing it every time `antialias()` is called. Args: @@ -725,10 +910,11 @@ def antialias_construct_topology_hash(tri): GPU memory. Returns: - An opaque object containing the topology hash. This can be supplied in a call to + An opaque object containing the topology hash. This can be supplied in a call to `antialias()` in the `topology_hash` argument. """ assert isinstance(tri, torch.Tensor) return _get_plugin().antialias_construct_topology_hash(tri) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/setup copy.py b/extensions/nvdiffrast/setup copy.py index f7f9dede9649583be8fdd2ba6aa6c3aab184ed54..6c721647d2525bc2a0b455d555ea6a064d6e8570 100644 --- a/extensions/nvdiffrast/setup copy.py +++ b/extensions/nvdiffrast/setup copy.py @@ -24,28 +24,31 @@ setuptools.setup( url="https://github.com/NVlabs/nvdiffrast", packages=setuptools.find_packages(), package_data={ - 'nvdiffrast': [ - 'common/*.h', - 'common/*.inl', - 'common/*.cu', - 'common/*.cpp', - 'common/cudaraster/*.hpp', - 'common/cudaraster/impl/*.cpp', - 'common/cudaraster/impl/*.hpp', - 'common/cudaraster/impl/*.inl', - 'common/cudaraster/impl/*.cu', - 'lib/*.h', - 'torch/*.h', - 'torch/*.inl', - 'torch/*.cpp', - 'tensorflow/*.cu', - ] + (['lib/*.lib'] if os.name == 'nt' else []) + "nvdiffrast": [ + "common/*.h", + "common/*.inl", + "common/*.cu", + "common/*.cpp", + "common/cudaraster/*.hpp", + "common/cudaraster/impl/*.cpp", + "common/cudaraster/impl/*.hpp", + "common/cudaraster/impl/*.inl", + "common/cudaraster/impl/*.cu", + "lib/*.h", + "torch/*.h", + "torch/*.inl", + "torch/*.cpp", + "tensorflow/*.cu", + ] + + (["lib/*.lib"] if os.name == "nt" else []) }, include_package_data=True, - install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + install_requires=[ + "numpy" + ], # note: can't require torch here as it will install torch even for a TensorFlow container classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires=">=3.6", ) diff --git a/extensions/nvdiffrast/setup.py b/extensions/nvdiffrast/setup.py index 507cb06f18fbc948e81fd7791f87489d8c35347b..66d35e5d00bcb6156c686e95359043324258c362 100644 --- a/extensions/nvdiffrast/setup.py +++ b/extensions/nvdiffrast/setup.py @@ -48,35 +48,35 @@ setuptools.setup( CUDAExtension( name="nvdiffrast.torch._C", sources=[ - 'nvdiffrast/common/cudaraster/impl/Buffer.cpp', - 'nvdiffrast/common/cudaraster/impl/CudaRaster.cpp', - 'nvdiffrast/common/cudaraster/impl/RasterImpl_.cu', - 'nvdiffrast/common/cudaraster/impl/RasterImpl.cpp', - 'nvdiffrast/common/common.cpp', - 'nvdiffrast/common/rasterize.cu', - 'nvdiffrast/common/interpolate.cu', - 'nvdiffrast/common/texture_.cu', - 'nvdiffrast/common/texture.cpp', - 'nvdiffrast/common/antialias.cu', - 'nvdiffrast/torch/torch_bindings.cpp', - 'nvdiffrast/torch/torch_rasterize.cpp', - 'nvdiffrast/torch/torch_interpolate.cpp', - 'nvdiffrast/torch/torch_texture.cpp', - 'nvdiffrast/torch/torch_antialias.cpp', + "nvdiffrast/common/cudaraster/impl/Buffer.cpp", + "nvdiffrast/common/cudaraster/impl/CudaRaster.cpp", + "nvdiffrast/common/cudaraster/impl/RasterImpl_.cu", + "nvdiffrast/common/cudaraster/impl/RasterImpl.cpp", + "nvdiffrast/common/common.cpp", + "nvdiffrast/common/rasterize.cu", + "nvdiffrast/common/interpolate.cu", + "nvdiffrast/common/texture_.cu", + "nvdiffrast/common/texture.cpp", + "nvdiffrast/common/antialias.cu", + "nvdiffrast/torch/torch_bindings.cpp", + "nvdiffrast/torch/torch_rasterize.cpp", + "nvdiffrast/torch/torch_interpolate.cpp", + "nvdiffrast/torch/torch_texture.cpp", + "nvdiffrast/torch/torch_antialias.cpp", ], extra_compile_args={ - 'cxx': ['-DNVDR_TORCH'], - 'nvcc': ['-DNVDR_TORCH', '-lineinfo'], + "cxx": ["-DNVDR_TORCH"], + "nvcc": ["-DNVDR_TORCH", "-lineinfo"], }, ) ], - cmdclass={ - 'build_ext': BuildExtension - }, - install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + cmdclass={"build_ext": BuildExtension}, + install_requires=[ + "numpy" + ], # note: can't require torch here as it will install torch even for a TensorFlow container classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires=">=3.6", ) diff --git a/requirements.txt b/requirements.txt index 210a29d55bba54fa6606d7675e66c789ad62738d..ce7779d1b975de2469463a095be34624bbea272d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,4 @@ https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaus https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true spaces plyfile==1.1 -utils3d \ No newline at end of file +utils3d diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py index d90e9f9ab48e7028a370a0df663182f4b8ccadc5..5f6e3c8aa898a8549f295f635523ea8bf1b7f5cd 100644 --- a/trellis/models/__init__.py +++ b/trellis/models/__init__.py @@ -1,20 +1,21 @@ import importlib __attributes = { - 'SparseStructureEncoder': 'sparse_structure_vae', - 'SparseStructureDecoder': 'sparse_structure_vae', - 'SparseStructureFlowModel': 'sparse_structure_flow', - 'SLatEncoder': 'structured_latent_vae', - 'SLatGaussianDecoder': 'structured_latent_vae', - 'SLatRadianceFieldDecoder': 'structured_latent_vae', - 'SLatMeshDecoder': 'structured_latent_vae', - 'SLatFlowModel': 'structured_latent_flow', + "SparseStructureEncoder": "sparse_structure_vae", + "SparseStructureDecoder": "sparse_structure_vae", + "SparseStructureFlowModel": "sparse_structure_flow", + "SLatEncoder": "structured_latent_vae", + "SLatGaussianDecoder": "structured_latent_vae", + "SLatRadianceFieldDecoder": "structured_latent_vae", + "SLatMeshDecoder": "structured_latent_vae", + "SLatFlowModel": "structured_latent_flow", } __submodules = [] __all__ = list(__attributes.keys()) + __submodules + def __getattr__(name): if name not in globals(): if name in __attributes: @@ -41,6 +42,7 @@ def from_pretrained(path: str, **kwargs): import os import json from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") if is_local: @@ -48,23 +50,29 @@ def from_pretrained(path: str, **kwargs): model_file = f"{path}.safetensors" else: from huggingface_hub import hf_hub_download - path_parts = path.split('/') - repo_id = f'{path_parts[0]}/{path_parts[1]}' - model_name = '/'.join(path_parts[2:]) + + path_parts = path.split("/") + repo_id = f"{path_parts[0]}/{path_parts[1]}" + model_name = "/".join(path_parts[2:]) config_file = hf_hub_download(repo_id, f"{model_name}.json") model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = json.load(f) - model = __getattr__(config['name'])(**config['args'], **kwargs) + model = __getattr__(config["name"])(**config["args"], **kwargs) model.load_state_dict(load_file(model_file)) return model # For Pylance -if __name__ == '__main__': +if __name__ == "__main__": from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder from .sparse_structure_flow import SparseStructureFlowModel - from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder + from .structured_latent_vae import ( + SLatEncoder, + SLatGaussianDecoder, + SLatRadianceFieldDecoder, + SLatMeshDecoder, + ) from .structured_latent_flow import SLatFlowModel diff --git a/trellis/models/sparse_structure_flow.py b/trellis/models/sparse_structure_flow.py index aee71a9686fd3795960cf1df970e9b8db0ebd57a..ebcf213d3131a96d0eedcdd97714f96aec3e0042 100644 --- a/trellis/models/sparse_structure_flow.py +++ b/trellis/models/sparse_structure_flow.py @@ -4,7 +4,10 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from ..modules.utils import convert_module_to_f16, convert_module_to_f32 -from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.transformer import ( + AbsolutePositionEmbedder, + ModulatedTransformerCrossBlock, +) from ..modules.spatial import patchify, unpatchify @@ -12,6 +15,7 @@ class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ + def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( @@ -38,12 +42,16 @@ class TimestepEmbedder(nn.Module): # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( - -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + -np.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) return embedding def forward(self, t): @@ -93,34 +101,41 @@ class SparseStructureFlowModel(nn.Module): self.t_embedder = TimestepEmbedder(model_channels) if share_mod: self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) ) if pe_mode == "ape": pos_embedder = AbsolutePositionEmbedder(model_channels, 3) - coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.meshgrid( + *[ + torch.arange(res, device=self.device) + for res in [resolution // patch_size] * 3 + ], + indexing="ij", + ) coords = torch.stack(coords, dim=-1).reshape(-1, 3) pos_emb = pos_embedder(coords) self.register_buffer("pos_emb", pos_emb) self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) - - self.blocks = nn.ModuleList([ - ModulatedTransformerCrossBlock( - model_channels, - cond_channels, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - attn_mode='full', - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - share_mod=share_mod, - qk_rms_norm=self.qk_rms_norm, - qk_rms_norm_cross=self.qk_rms_norm_cross, - ) - for _ in range(num_blocks) - ]) + + self.blocks = nn.ModuleList( + [ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode="full", + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ] + ) self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) @@ -154,6 +169,7 @@ class SparseStructureFlowModel(nn.Module): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + self.apply(_basic_init) # Initialize timestep embedding MLP: @@ -173,9 +189,14 @@ class SparseStructureFlowModel(nn.Module): nn.init.constant_(self.out_layer.weight, 0) nn.init.constant_(self.out_layer.bias, 0) - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: - assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ - f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + def forward( + self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor + ) -> torch.Tensor: + assert [*x.shape] == [ + x.shape[0], + self.in_channels, + *[self.resolution] * 3, + ], f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = patchify(x, self.patch_size) h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() @@ -194,7 +215,9 @@ class SparseStructureFlowModel(nn.Module): h = F.layer_norm(h, h.shape[-1:]) h = self.out_layer(h) - h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = h.permute(0, 2, 1).view( + h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3 + ) h = unpatchify(h, self.patch_size).contiguous() return h diff --git a/trellis/models/sparse_structure_vae.py b/trellis/models/sparse_structure_vae.py index c3e09136cf294c4c1b47b0f09fa6ee57bad2166d..778ece61e1669b2c0e68d3bc6006cddff778c2ab 100644 --- a/trellis/models/sparse_structure_vae.py +++ b/trellis/models/sparse_structure_vae.py @@ -33,9 +33,15 @@ class ResBlock3d(nn.Module): self.norm1 = norm_layer(norm_type, channels) self.norm2 = norm_layer(norm_type, self.out_channels) self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) - self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) - self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() - + self.conv2 = zero_module( + nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + ) + self.skip_connection = ( + nn.Conv3d(channels, self.out_channels, 1) + if channels != self.out_channels + else nn.Identity() + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm1(x) h = F.silu(h) @@ -63,7 +69,9 @@ class DownsampleBlock3d(nn.Module): if mode == "conv": self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) elif mode == "avgpool": - assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + assert ( + in_channels == out_channels + ), "Pooling mode requires in_channels to be equal to out_channels" def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "conv"): @@ -86,9 +94,11 @@ class UpsampleBlock3d(nn.Module): self.out_channels = out_channels if mode == "conv": - self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + self.conv = nn.Conv3d(in_channels, out_channels * 8, 3, padding=1) elif mode == "nearest": - assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + assert ( + in_channels == out_channels + ), "Nearest mode requires in_channels to be equal to out_channels" def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "conv"): @@ -96,12 +106,12 @@ class UpsampleBlock3d(nn.Module): return pixel_shuffle_3d(x, 2) else: return F.interpolate(x, scale_factor=2, mode="nearest") - + class SparseStructureEncoder(nn.Module): """ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). - + Args: in_channels (int): Channels of the input. latent_channels (int): Channels of the latent representation. @@ -111,6 +121,7 @@ class SparseStructureEncoder(nn.Module): norm_type (Literal["group", "layer"]): Type of normalization layer. use_fp16 (bool): Whether to use FP16. """ + def __init__( self, in_channels: int, @@ -135,24 +146,21 @@ class SparseStructureEncoder(nn.Module): self.blocks = nn.ModuleList([]) for i, ch in enumerate(channels): - self.blocks.extend([ - ResBlock3d(ch, ch) - for _ in range(num_res_blocks) - ]) + self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)]) if i < len(channels) - 1: - self.blocks.append( - DownsampleBlock3d(ch, channels[i+1]) - ) - - self.middle_block = nn.Sequential(*[ - ResBlock3d(channels[-1], channels[-1]) - for _ in range(num_res_blocks_middle) - ]) + self.blocks.append(DownsampleBlock3d(ch, channels[i + 1])) + + self.middle_block = nn.Sequential( + *[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ] + ) self.out_layer = nn.Sequential( norm_layer(norm_type, channels[-1]), nn.SiLU(), - nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + nn.Conv3d(channels[-1], latent_channels * 2, 3, padding=1), ) if use_fp16: @@ -183,7 +191,9 @@ class SparseStructureEncoder(nn.Module): self.blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) - def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + def forward( + self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False + ) -> torch.Tensor: h = self.input_layer(x) h = h.type(self.dtype) @@ -201,16 +211,16 @@ class SparseStructureEncoder(nn.Module): z = mean + std * torch.randn_like(std) else: z = mean - + if return_raw: return z, mean, logvar return z - + class SparseStructureDecoder(nn.Module): """ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). - + Args: out_channels (int): Channels of the output. latent_channels (int): Channels of the latent representation. @@ -219,7 +229,8 @@ class SparseStructureDecoder(nn.Module): num_res_blocks_middle (int): Number of residual blocks in the middle. norm_type (Literal["group", "layer"]): Type of normalization layer. use_fp16 (bool): Whether to use FP16. - """ + """ + def __init__( self, out_channels: int, @@ -242,26 +253,23 @@ class SparseStructureDecoder(nn.Module): self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) - self.middle_block = nn.Sequential(*[ - ResBlock3d(channels[0], channels[0]) - for _ in range(num_res_blocks_middle) - ]) + self.middle_block = nn.Sequential( + *[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ] + ) self.blocks = nn.ModuleList([]) for i, ch in enumerate(channels): - self.blocks.extend([ - ResBlock3d(ch, ch) - for _ in range(num_res_blocks) - ]) + self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)]) if i < len(channels) - 1: - self.blocks.append( - UpsampleBlock3d(ch, channels[i+1]) - ) + self.blocks.append(UpsampleBlock3d(ch, channels[i + 1])) self.out_layer = nn.Sequential( norm_layer(norm_type, channels[-1]), nn.SiLU(), - nn.Conv3d(channels[-1], out_channels, 3, padding=1) + nn.Conv3d(channels[-1], out_channels, 3, padding=1), ) if use_fp16: @@ -273,7 +281,7 @@ class SparseStructureDecoder(nn.Module): Return the device of the model. """ return next(self.parameters()).device - + def convert_to_fp16(self) -> None: """ Convert the torso of the model to float16. @@ -291,12 +299,12 @@ class SparseStructureDecoder(nn.Module): self.dtype = torch.float32 self.blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) - + def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.input_layer(x) - + h = h.type(self.dtype) - + h = self.middle_block(h) for block in self.blocks: h = block(h) diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py index f1463d79bc472ce3ef6859a42e10a06de1f9ebf7..bcb19b99ba865aea023b25f9c3bd4b0febad843c 100644 --- a/trellis/models/structured_latent_flow.py +++ b/trellis/models/structured_latent_flow.py @@ -26,18 +26,26 @@ class SparseResBlock3d(nn.Module): self.out_channels = out_channels or channels self.downsample = downsample self.upsample = upsample - - assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + assert not ( + downsample and upsample + ), "Cannot downsample and upsample at the same time" self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) - self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.conv2 = zero_module( + sp.SparseConv3d(self.out_channels, self.out_channels, 3) + ) self.emb_layers = nn.Sequential( nn.SiLU(), nn.Linear(emb_channels, 2 * self.out_channels, bias=True), ) - self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.skip_connection = ( + sp.SparseLinear(channels, self.out_channels) + if channels != self.out_channels + else nn.Identity() + ) self.updown = None if self.downsample: self.updown = sp.SparseDownsample(2) @@ -63,7 +71,7 @@ class SparseResBlock3d(nn.Module): h = h + self.skip_connection(x) return h - + class SLatFlowModel(nn.Module): def __init__( @@ -109,14 +117,17 @@ class SLatFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = torch.float16 if use_fp16 else torch.float32 - assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" - assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" + assert int(np.log2(patch_size)) == np.log2( + patch_size + ), "Patch size must be a power of 2" + assert np.log2(patch_size) == len( + io_block_channels + ), "Number of IO ResBlocks must match the number of stages" self.t_embedder = TimestepEmbedder(model_channels) if share_mod: self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) ) if pe_mode == "ape": @@ -124,15 +135,19 @@ class SLatFlowModel(nn.Module): self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) self.input_blocks = nn.ModuleList([]) - for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): - self.input_blocks.extend([ - SparseResBlock3d( - chs, - model_channels, - out_channels=chs, - ) - for _ in range(num_io_res_blocks-1) - ]) + for chs, next_chs in zip( + io_block_channels, io_block_channels[1:] + [model_channels] + ): + self.input_blocks.extend( + [ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks - 1) + ] + ) self.input_blocks.append( SparseResBlock3d( chs, @@ -141,25 +156,30 @@ class SLatFlowModel(nn.Module): downsample=True, ) ) - - self.blocks = nn.ModuleList([ - ModulatedSparseTransformerCrossBlock( - model_channels, - cond_channels, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - attn_mode='full', - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - share_mod=self.share_mod, - qk_rms_norm=self.qk_rms_norm, - qk_rms_norm_cross=self.qk_rms_norm_cross, - ) - for _ in range(num_blocks) - ]) + + self.blocks = nn.ModuleList( + [ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode="full", + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ] + ) self.out_blocks = nn.ModuleList([]) - for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): + for chs, prev_chs in zip( + reversed(io_block_channels), + [model_channels] + list(reversed(io_block_channels[1:])), + ): self.out_blocks.append( SparseResBlock3d( prev_chs * 2 if self.use_skip_connection else prev_chs, @@ -168,14 +188,16 @@ class SLatFlowModel(nn.Module): upsample=True, ) ) - self.out_blocks.extend([ - SparseResBlock3d( - chs * 2 if self.use_skip_connection else chs, - model_channels, - out_channels=chs, - ) - for _ in range(num_io_res_blocks-1) - ]) + self.out_blocks.extend( + [ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks - 1) + ] + ) self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) self.initialize_weights() @@ -212,6 +234,7 @@ class SLatFlowModel(nn.Module): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + self.apply(_basic_init) # Initialize timestep embedding MLP: @@ -231,7 +254,9 @@ class SLatFlowModel(nn.Module): nn.init.constant_(self.out_layer.weight, 0) nn.init.constant_(self.out_layer.bias, 0) - def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: + def forward( + self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor + ) -> sp.SparseTensor: h = self.input_layer(x).type(self.dtype) t_emb = self.t_embedder(t) if self.share_mod: @@ -244,7 +269,7 @@ class SLatFlowModel(nn.Module): for block in self.input_blocks: h = block(h, t_emb) skips.append(h.feats) - + if self.pe_mode == "ape": h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) for block in self.blocks: diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py index ab0bf6a850b1c146e081c32ad92c7c44ead5ef6e..8fa8b8f1138464ca68add0b496b08be5e489d2ed 100644 --- a/trellis/models/structured_latent_vae/base.py +++ b/trellis/models/structured_latent_vae/base.py @@ -13,15 +13,23 @@ def block_attn_config(self): """ for i in range(self.num_blocks): if self.attn_mode == "shift_window": - yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + yield "serialized", self.window_size, 0, ( + 16 * (i % 2), + ) * 3, sp.SerializeMode.Z_ORDER elif self.attn_mode == "shift_sequence": - yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), ( + 0, + 0, + 0, + ), sp.SerializeMode.Z_ORDER elif self.attn_mode == "shift_order": yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] elif self.attn_mode == "full": yield "full", None, None, None, None elif self.attn_mode == "swin": - yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + yield "windowed", self.window_size, None, self.window_size // 2 * ( + i % 2 + ), None class SparseTransformerBase(nn.Module): @@ -29,6 +37,7 @@ class SparseTransformerBase(nn.Module): Sparse Transformer without output layers. Serve as the base class for encoder and decoder. """ + def __init__( self, in_channels: int, @@ -37,7 +46,9 @@ class SparseTransformerBase(nn.Module): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", window_size: Optional[int] = None, pe_mode: Literal["ape", "rope"] = "ape", use_fp16: bool = False, @@ -62,22 +73,26 @@ class SparseTransformerBase(nn.Module): self.pos_embedder = AbsolutePositionEmbedder(model_channels) self.input_layer = sp.SparseLinear(in_channels, model_channels) - self.blocks = nn.ModuleList([ - SparseTransformerBlock( - model_channels, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - attn_mode=attn_mode, - window_size=window_size, - shift_sequence=shift_sequence, - shift_window=shift_window, - serialize_mode=serialize_mode, - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - qk_rms_norm=self.qk_rms_norm, - ) - for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) - ]) + self.blocks = nn.ModuleList( + [ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config( + self + ) + ] + ) @property def device(self) -> torch.device: @@ -105,6 +120,7 @@ class SparseTransformerBase(nn.Module): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + self.apply(_basic_init) def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py index b893cfcfb2a166c7d57f96086a79317bd91884b9..5c6271774dac131791f0d24811bb217c86136609 100644 --- a/trellis/models/structured_latent_vae/decoder_gs.py +++ b/trellis/models/structured_latent_vae/decoder_gs.py @@ -18,7 +18,9 @@ class SLatGaussianDecoder(SparseTransformerBase): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", window_size: int = 8, pe_mode: Literal["ape", "rope"] = "ape", use_fp16: bool = False, @@ -57,26 +59,44 @@ class SLatGaussianDecoder(SparseTransformerBase): nn.init.constant_(self.out_layer.bias, 0) def _build_perturbation(self) -> None: - perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] + perturbation = [ + hammersley_sequence(3, i, self.rep_config["num_gaussians"]) + for i in range(self.rep_config["num_gaussians"]) + ] perturbation = torch.tensor(perturbation).float() * 2 - 1 - perturbation = perturbation / self.rep_config['voxel_size'] + perturbation = perturbation / self.rep_config["voxel_size"] perturbation = torch.atanh(perturbation).to(self.device) - self.register_buffer('offset_perturbation', perturbation) + self.register_buffer("offset_perturbation", perturbation) def _calc_layout(self) -> None: self.layout = { - '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, - '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, - '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, - '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, - '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, + "_xyz": { + "shape": (self.rep_config["num_gaussians"], 3), + "size": self.rep_config["num_gaussians"] * 3, + }, + "_features_dc": { + "shape": (self.rep_config["num_gaussians"], 1, 3), + "size": self.rep_config["num_gaussians"] * 3, + }, + "_scaling": { + "shape": (self.rep_config["num_gaussians"], 3), + "size": self.rep_config["num_gaussians"] * 3, + }, + "_rotation": { + "shape": (self.rep_config["num_gaussians"], 4), + "size": self.rep_config["num_gaussians"] * 4, + }, + "_opacity": { + "shape": (self.rep_config["num_gaussians"], 1), + "size": self.rep_config["num_gaussians"], + }, } start = 0 for k, v in self.layout.items(): - v['range'] = (start, start + v['size']) - start += v['size'] + v["range"] = (start, start + v["size"]) + start += v["size"] self.out_channels = start - + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: """ Convert a batch of network outputs to 3D representations. @@ -92,24 +112,35 @@ class SLatGaussianDecoder(SparseTransformerBase): representation = Gaussian( sh_degree=0, aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], - mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], - scaling_bias = self.rep_config['scaling_bias'], - opacity_bias = self.rep_config['opacity_bias'], - scaling_activation = self.rep_config['scaling_activation'] + mininum_kernel_size=self.rep_config["3d_filter_kernel_size"], + scaling_bias=self.rep_config["scaling_bias"], + opacity_bias=self.rep_config["opacity_bias"], + scaling_activation=self.rep_config["scaling_activation"], ) xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution for k, v in self.layout.items(): - if k == '_xyz': - offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) - offset = offset * self.rep_config['lr'][k] - if self.rep_config['perturb_offset']: + if k == "_xyz": + offset = x.feats[x.layout[i]][ + :, v["range"][0] : v["range"][1] + ].reshape(-1, *v["shape"]) + offset = offset * self.rep_config["lr"][k] + if self.rep_config["perturb_offset"]: offset = offset + self.offset_perturbation - offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] + offset = ( + torch.tanh(offset) + / self.resolution + * 0.5 + * self.rep_config["voxel_size"] + ) _xyz = xyz.unsqueeze(1) + offset setattr(representation, k, _xyz.flatten(0, 1)) else: - feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) - feats = feats * self.rep_config['lr'][k] + feats = ( + x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]] + .reshape(-1, *v["shape"]) + .flatten(0, 1) + ) + feats = feats * self.rep_config["lr"][k] setattr(representation, k, feats) ret.append(representation) return ret diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py index 75c1b1ec7b6fdc28e787be283e55589b36461e50..dee196559d5c56f8d720dae2c0c5cd7b6935a344 100644 --- a/trellis/models/structured_latent_vae/decoder_mesh.py +++ b/trellis/models/structured_latent_vae/decoder_mesh.py @@ -19,12 +19,13 @@ class SparseSubdivideBlock3d(nn.Module): out_channels: if specified, the number of output channels. num_groups: the number of groups for the group norm. """ + def __init__( self, channels: int, resolution: int, out_channels: Optional[int] = None, - num_groups: int = 32 + num_groups: int = 32, ): super().__init__() self.channels = channels @@ -33,24 +34,34 @@ class SparseSubdivideBlock3d(nn.Module): self.out_channels = out_channels or channels self.act_layers = nn.Sequential( - sp.SparseGroupNorm32(num_groups, channels), - sp.SparseSiLU() + sp.SparseGroupNorm32(num_groups, channels), sp.SparseSiLU() ) - + self.sub = sp.SparseSubdivide() - + self.out_layers = nn.Sequential( - sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), + sp.SparseConv3d( + channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}" + ), sp.SparseGroupNorm32(num_groups, self.out_channels), sp.SparseSiLU(), - zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), + zero_module( + sp.SparseConv3d( + self.out_channels, + self.out_channels, + 3, + indice_key=f"res_{self.out_resolution}", + ) + ), ) - + if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") - + self.skip_connection = sp.SparseConv3d( + channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}" + ) + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: """ Apply the block to a Tensor, conditioned on a timestep embedding. @@ -78,7 +89,9 @@ class SLatMeshDecoder(SparseTransformerBase): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", window_size: int = 8, pe_mode: Literal["ape", "rope"] = "ape", use_fp16: bool = False, @@ -102,20 +115,24 @@ class SLatMeshDecoder(SparseTransformerBase): ) self.resolution = resolution self.rep_config = representation_config - self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + self.mesh_extractor = SparseFeatures2Mesh( + res=self.resolution * 4, use_color=self.rep_config.get("use_color", False) + ) self.out_channels = self.mesh_extractor.feats_channels - self.upsample = nn.ModuleList([ - SparseSubdivideBlock3d( - channels=model_channels, - resolution=resolution, - out_channels=model_channels // 4 - ), - SparseSubdivideBlock3d( - channels=model_channels // 4, - resolution=resolution * 2, - out_channels=model_channels // 8 - ) - ]) + self.upsample = nn.ModuleList( + [ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4, + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8, + ), + ] + ) self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) self.initialize_weights() @@ -140,8 +157,8 @@ class SLatMeshDecoder(SparseTransformerBase): Convert the torso of the model to float32. """ super().convert_to_fp32() - self.upsample.apply(convert_module_to_f32) - + self.upsample.apply(convert_module_to_f32) + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: """ Convert a batch of network outputs to 3D representations. diff --git a/trellis/models/structured_latent_vae/decoder_rf.py b/trellis/models/structured_latent_vae/decoder_rf.py index 968bb30596647224292da0392dfdefeed49d214d..f9e01e5e202f5d4be567a6e06895604976c05ae7 100644 --- a/trellis/models/structured_latent_vae/decoder_rf.py +++ b/trellis/models/structured_latent_vae/decoder_rf.py @@ -18,7 +18,9 @@ class SLatRadianceFieldDecoder(SparseTransformerBase): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", window_size: int = 8, pe_mode: Literal["ape", "rope"] = "ape", use_fp16: bool = False, @@ -57,16 +59,25 @@ class SLatRadianceFieldDecoder(SparseTransformerBase): def _calc_layout(self) -> None: self.layout = { - 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, - 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, - 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, + "trivec": { + "shape": (self.rep_config["rank"], 3, self.rep_config["dim"]), + "size": self.rep_config["rank"] * 3 * self.rep_config["dim"], + }, + "density": { + "shape": (self.rep_config["rank"],), + "size": self.rep_config["rank"], + }, + "features_dc": { + "shape": (self.rep_config["rank"], 1, 3), + "size": self.rep_config["rank"] * 3, + }, } start = 0 for k, v in self.layout.items(): - v['range'] = (start, start + v['size']) - start += v['size'] - self.out_channels = start - + v["range"] = (start, start + v["size"]) + start += v["size"] + self.out_channels = start + def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: """ Convert a batch of network outputs to 3D representations. @@ -83,15 +94,28 @@ class SLatRadianceFieldDecoder(SparseTransformerBase): sh_degree=0, resolution=self.resolution, aabb=[-0.5, -0.5, -0.5, 1, 1, 1], - rank=self.rep_config['rank'], - dim=self.rep_config['dim'], - device='cuda', + rank=self.rep_config["rank"], + dim=self.rep_config["dim"], + device="cuda", ) representation.density_shift = 0.0 - representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution - representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') + representation.position = ( + x.coords[x.layout[i]][:, 1:].float() + 0.5 + ) / self.resolution + representation.depth = torch.full( + (representation.position.shape[0], 1), + int(np.log2(self.resolution)), + dtype=torch.uint8, + device="cuda", + ) for k, v in self.layout.items(): - setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) + setattr( + representation, + k, + x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]].reshape( + -1, *v["shape"] + ), + ) representation.trivec = representation.trivec + 1 ret.append(representation) return ret diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py index 8370921d8d61954b43dcf3e251b8d9b315f4f536..cafbe0cba797362ab2e04dc79feb8b2385441ffa 100644 --- a/trellis/models/structured_latent_vae/encoder.py +++ b/trellis/models/structured_latent_vae/encoder.py @@ -17,7 +17,9 @@ class SLatEncoder(SparseTransformerBase): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", window_size: int = 8, pe_mode: Literal["ape", "rope"] = "ape", use_fp16: bool = False, @@ -56,7 +58,7 @@ class SLatEncoder(SparseTransformerBase): h = h.type(x.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.out_layer(h) - + # Sample from the posterior distribution mean, logvar = h.feats.chunk(2, dim=-1) if sample_posterior: @@ -65,7 +67,7 @@ class SLatEncoder(SparseTransformerBase): else: z = mean z = h.replace(z) - + if return_raw: return z, mean, logvar else: diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py index f452320d5dbc4c0aa1664e33f76c56ff4bbe2039..53c98f18eeca9ce966a1fa5bd6f95f28b5a9cf68 100755 --- a/trellis/modules/attention/__init__.py +++ b/trellis/modules/attention/__init__.py @@ -1,32 +1,39 @@ from typing import * -BACKEND = 'flash_attn' +BACKEND = "flash_attn" DEBUG = False + def __from_env(): import os - + global BACKEND global DEBUG - - env_attn_backend = os.environ.get('ATTN_BACKEND') - env_sttn_debug = os.environ.get('ATTN_DEBUG') - - if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + + env_attn_backend = os.environ.get("ATTN_BACKEND") + env_sttn_debug = os.environ.get("ATTN_DEBUG") + + if env_attn_backend is not None and env_attn_backend in [ + "xformers", + "flash_attn", + "sdpa", + "naive", + ]: BACKEND = env_attn_backend if env_sttn_debug is not None: - DEBUG = env_sttn_debug == '1' + DEBUG = env_sttn_debug == "1" print(f"[ATTENTION] Using backend: {BACKEND}") - + __from_env() - -def set_backend(backend: Literal['xformers', 'flash_attn']): + +def set_backend(backend: Literal["xformers", "flash_attn"]): global BACKEND BACKEND = backend + def set_debug(debug: bool): global DEBUG DEBUG = debug diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py index d9ebf6380a78906d4c6e969c63223fb7b398e5a7..56179c747c967efbd1c3843598f9caae4baf8bb4 100755 --- a/trellis/modules/attention/full_attn.py +++ b/trellis/modules/attention/full_attn.py @@ -3,20 +3,20 @@ import torch import math from . import DEBUG, BACKEND -if BACKEND == 'xformers': +if BACKEND == "xformers": import xformers.ops as xops -elif BACKEND == 'flash_attn': +elif BACKEND == "flash_attn": import flash_attn -elif BACKEND == 'sdpa': +elif BACKEND == "sdpa": from torch.nn.functional import scaled_dot_product_attention as sdpa -elif BACKEND == 'naive': +elif BACKEND == "naive": pass else: raise ValueError(f"Unknown attention backend: {BACKEND}") __all__ = [ - 'scaled_dot_product_attention', + "scaled_dot_product_attention", ] @@ -24,14 +24,14 @@ def _naive_sdpa(q, k, v): """ Naive implementation of scaled dot product attention. """ - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] scale_factor = 1 / math.sqrt(q.size(-1)) attn_weight = q @ k.transpose(-2, -1) * scale_factor attn_weight = torch.softmax(attn_weight, dim=-1) out = attn_weight @ v - out = out.permute(0, 2, 1, 3) # [N, L, H, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] return out @@ -45,6 +45,7 @@ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: """ ... + @overload def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: """ @@ -56,8 +57,11 @@ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Ten """ ... + @overload -def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: +def scaled_dot_product_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: """ Apply scaled dot product attention. @@ -71,64 +75,79 @@ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tens """ ... + def scaled_dot_product_attention(*args, **kwargs): - arg_names_dict = { - 1: ['qkv'], - 2: ['q', 'kv'], - 3: ['q', 'k', 'v'] - } + arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]} num_all_args = len(args) + len(kwargs) - assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" - for key in arg_names_dict[num_all_args][len(args):]: + assert ( + num_all_args in arg_names_dict + ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args) :]: assert key in kwargs, f"Missing argument {key}" if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + qkv = args[0] if len(args) > 0 else kwargs["qkv"] + assert ( + len(qkv.shape) == 5 and qkv.shape[2] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" device = qkv.device elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" - assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + q = args[0] if len(args) > 0 else kwargs["q"] + kv = args[1] if len(args) > 1 else kwargs["kv"] + assert ( + q.shape[0] == kv.shape[0] + ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert ( + len(kv.shape) == 5 + ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" device = q.device elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] - assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" - assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" - assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" - device = q.device - - if BACKEND == 'xformers': + q = args[0] if len(args) > 0 else kwargs["q"] + k = args[1] if len(args) > 1 else kwargs["k"] + v = args[2] if len(args) > 2 else kwargs["v"] + assert ( + q.shape[0] == k.shape[0] == v.shape[0] + ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert ( + len(k.shape) == 4 + ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert ( + len(v.shape) == 4 + ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == "xformers": if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) out = xops.memory_efficient_attention(q, k, v) - elif BACKEND == 'flash_attn': + elif BACKEND == "flash_attn": if num_all_args == 1: out = flash_attn.flash_attn_qkvpacked_func(qkv) elif num_all_args == 2: out = flash_attn.flash_attn_kvpacked_func(q, kv) elif num_all_args == 3: out = flash_attn.flash_attn_func(q, k, v) - elif BACKEND == 'sdpa': + elif BACKEND == "sdpa": if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = sdpa(q, k, v) # [N, H, L, C] - out = out.permute(0, 2, 1, 3) # [N, L, H, C] - elif BACKEND == 'naive': + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == "naive": if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: @@ -136,5 +155,5 @@ def scaled_dot_product_attention(*args, **kwargs): out = _naive_sdpa(q, k, v) else: raise ValueError(f"Unknown attention module: {BACKEND}") - + return out diff --git a/trellis/modules/attention/modules.py b/trellis/modules/attention/modules.py index dbe6235c27134f0477e48d3e12de3068c6a500ef..78f379e997fa68db86214bf2921041f073d37467 100755 --- a/trellis/modules/attention/modules.py +++ b/trellis/modules/attention/modules.py @@ -8,11 +8,11 @@ from .full_attn import scaled_dot_product_attention class MultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(heads, dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: - return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + return (F.normalize(x.float(), dim=-1) * self.gamma * self.scale).to(x.dtype) class RotaryPositionEmbedder(nn.Module): @@ -23,21 +23,25 @@ class RotaryPositionEmbedder(nn.Module): self.in_channels = in_channels self.freq_dim = hidden_size // in_channels // 2 self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim - self.freqs = 1.0 / (10000 ** self.freqs) - + self.freqs = 1.0 / (10000**self.freqs) + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: self.freqs = self.freqs.to(indices.device) phases = torch.outer(indices, self.freqs) phases = torch.polar(torch.ones_like(phases), phases) return phases - + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = x_complex * phases - x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + x_embed = ( + torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + ) return x_embed - - def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + + def forward( + self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: q (sp.SparseTensor): [..., N, D] tensor of queries @@ -48,24 +52,38 @@ class RotaryPositionEmbedder(nn.Module): indices = torch.arange(q.shape[-2], device=q.device) if len(q.shape) > 2: indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) - + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) if phases.shape[1] < self.hidden_size // 2: - phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), - torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) - )], dim=-1) + phases = torch.cat( + [ + phases, + torch.polar( + torch.ones( + *phases.shape[:-1], + self.hidden_size // 2 - phases.shape[1], + device=phases.device, + ), + torch.zeros( + *phases.shape[:-1], + self.hidden_size // 2 - phases.shape[1], + device=phases.device, + ), + ), + ], + dim=-1, + ) q_embed = self._rotary_embedding(q, phases) k_embed = self._rotary_embedding(k, phases) return q_embed, k_embed - + class MultiHeadAttention(nn.Module): def __init__( self, channels: int, num_heads: int, - ctx_channels: Optional[int]=None, + ctx_channels: Optional[int] = None, type: Literal["self", "cross"] = "self", attn_mode: Literal["full", "windowed"] = "full", window_size: Optional[int] = None, @@ -78,11 +96,13 @@ class MultiHeadAttention(nn.Module): assert channels % num_heads == 0 assert type in ["self", "cross"], f"Invalid attention type: {type}" assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" - assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" - + assert ( + type == "self" or attn_mode == "full" + ), "Cross-attention only supports full attention" + if attn_mode == "windowed": raise NotImplementedError("Windowed attention is not yet implemented") - + self.channels = channels self.head_dim = channels // num_heads self.ctx_channels = ctx_channels if ctx_channels is not None else channels @@ -99,17 +119,22 @@ class MultiHeadAttention(nn.Module): else: self.to_q = nn.Linear(channels, channels, bias=qkv_bias) self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) - + if self.qk_rms_norm: self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - + self.to_out = nn.Linear(channels, channels) if use_rope: self.rope = RotaryPositionEmbedder(channels) - - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: B, L, C = x.shape if self._type == "self": qkv = self.to_qkv(x) diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py index 09035726081fb7afda2c62504d5474cfa483c58f..6cd152e22a506d1175b98d36da44deb3c7a252c4 100644 --- a/trellis/modules/norm.py +++ b/trellis/modules/norm.py @@ -5,21 +5,21 @@ import torch.nn as nn class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x.float()).type(x.dtype) - + class GroupNorm32(nn.GroupNorm): """ A GroupNorm layer that converts to float32 before the forward pass. """ + def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x.float()).type(x.dtype) - - + + class ChannelLayerNorm32(LayerNorm32): def forward(self, x: torch.Tensor) -> torch.Tensor: DIM = x.dim() x = x.permute(0, *range(2, DIM), 1).contiguous() x = super().forward(x) - x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + x = x.permute(0, DIM - 1, *range(1, DIM - 1)).contiguous() return x - \ No newline at end of file diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py index 726756c16dcfe0f04de0d2ea5bdce499fa220160..9682919642683847d06347a38e2245539264b00e 100755 --- a/trellis/modules/sparse/__init__.py +++ b/trellis/modules/sparse/__init__.py @@ -1,81 +1,88 @@ from typing import * -BACKEND = 'spconv' +BACKEND = "spconv" DEBUG = False -ATTN = 'flash_attn' +ATTN = "flash_attn" + def __from_env(): import os - + global BACKEND global DEBUG global ATTN - - env_sparse_backend = os.environ.get('SPARSE_BACKEND') - env_sparse_debug = os.environ.get('SPARSE_DEBUG') - env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + + env_sparse_backend = os.environ.get("SPARSE_BACKEND") + env_sparse_debug = os.environ.get("SPARSE_DEBUG") + env_sparse_attn = os.environ.get("SPARSE_ATTN_BACKEND") if env_sparse_attn is None: - env_sparse_attn = os.environ.get('ATTN_BACKEND') + env_sparse_attn = os.environ.get("ATTN_BACKEND") - if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + if env_sparse_backend is not None and env_sparse_backend in [ + "spconv", + "torchsparse", + ]: BACKEND = env_sparse_backend if env_sparse_debug is not None: - DEBUG = env_sparse_debug == '1' - if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + DEBUG = env_sparse_debug == "1" + if env_sparse_attn is not None and env_sparse_attn in ["xformers", "flash_attn"]: ATTN = env_sparse_attn - + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") - + __from_env() - -def set_backend(backend: Literal['spconv', 'torchsparse']): + +def set_backend(backend: Literal["spconv", "torchsparse"]): global BACKEND BACKEND = backend + def set_debug(debug: bool): global DEBUG DEBUG = debug -def set_attn(attn: Literal['xformers', 'flash_attn']): + +def set_attn(attn: Literal["xformers", "flash_attn"]): global ATTN ATTN = attn - - + + import importlib __attributes = { - 'SparseTensor': 'basic', - 'sparse_batch_broadcast': 'basic', - 'sparse_batch_op': 'basic', - 'sparse_cat': 'basic', - 'sparse_unbind': 'basic', - 'SparseGroupNorm': 'norm', - 'SparseLayerNorm': 'norm', - 'SparseGroupNorm32': 'norm', - 'SparseLayerNorm32': 'norm', - 'SparseReLU': 'nonlinearity', - 'SparseSiLU': 'nonlinearity', - 'SparseGELU': 'nonlinearity', - 'SparseActivation': 'nonlinearity', - 'SparseLinear': 'linear', - 'sparse_scaled_dot_product_attention': 'attention', - 'SerializeMode': 'attention', - 'sparse_serialized_scaled_dot_product_self_attention': 'attention', - 'sparse_windowed_scaled_dot_product_self_attention': 'attention', - 'SparseMultiHeadAttention': 'attention', - 'SparseConv3d': 'conv', - 'SparseInverseConv3d': 'conv', - 'SparseDownsample': 'spatial', - 'SparseUpsample': 'spatial', - 'SparseSubdivide' : 'spatial' + "SparseTensor": "basic", + "sparse_batch_broadcast": "basic", + "sparse_batch_op": "basic", + "sparse_cat": "basic", + "sparse_unbind": "basic", + "SparseGroupNorm": "norm", + "SparseLayerNorm": "norm", + "SparseGroupNorm32": "norm", + "SparseLayerNorm32": "norm", + "SparseReLU": "nonlinearity", + "SparseSiLU": "nonlinearity", + "SparseGELU": "nonlinearity", + "SparseActivation": "nonlinearity", + "SparseLinear": "linear", + "sparse_scaled_dot_product_attention": "attention", + "SerializeMode": "attention", + "sparse_serialized_scaled_dot_product_self_attention": "attention", + "sparse_windowed_scaled_dot_product_self_attention": "attention", + "SparseMultiHeadAttention": "attention", + "SparseConv3d": "conv", + "SparseInverseConv3d": "conv", + "SparseDownsample": "spatial", + "SparseUpsample": "spatial", + "SparseSubdivide": "spatial", } -__submodules = ['transformer'] +__submodules = ["transformer"] __all__ = list(__attributes.keys()) + __submodules + def __getattr__(name): if name not in globals(): if name in __attributes: @@ -91,7 +98,7 @@ def __getattr__(name): # For Pylance -if __name__ == '__main__': +if __name__ == "__main__": from .basic import * from .norm import * from .nonlinearity import * diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py index e9e27aeb98419621f3f9999fd3b11eebf2b90a40..96f695c943f14ec0c7924eb9d56f84798e30fc53 100755 --- a/trellis/modules/sparse/attention/full_attn.py +++ b/trellis/modules/sparse/attention/full_attn.py @@ -3,16 +3,16 @@ import torch from .. import SparseTensor from .. import DEBUG, ATTN -if ATTN == 'xformers': +if ATTN == "xformers": import xformers.ops as xops -elif ATTN == 'flash_attn': +elif ATTN == "flash_attn": import flash_attn else: raise ValueError(f"Unknown attention module: {ATTN}") __all__ = [ - 'sparse_scaled_dot_product_attention', + "sparse_scaled_dot_product_attention", ] @@ -26,8 +26,11 @@ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: """ ... + @overload -def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: +def sparse_scaled_dot_product_attention( + q: SparseTensor, kv: Union[SparseTensor, torch.Tensor] +) -> SparseTensor: """ Apply scaled dot product attention to a sparse tensor. @@ -37,8 +40,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, """ ... + @overload -def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: +def sparse_scaled_dot_product_attention( + q: torch.Tensor, kv: SparseTensor +) -> torch.Tensor: """ Apply scaled dot product attention to a sparse tensor. @@ -48,8 +54,11 @@ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> to """ ... + @overload -def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: +def sparse_scaled_dot_product_attention( + q: SparseTensor, k: SparseTensor, v: SparseTensor +) -> SparseTensor: """ Apply scaled dot product attention to a sparse tensor. @@ -63,8 +72,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: Spa """ ... + @overload -def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: +def sparse_scaled_dot_product_attention( + q: SparseTensor, k: torch.Tensor, v: torch.Tensor +) -> SparseTensor: """ Apply scaled dot product attention to a sparse tensor. @@ -75,8 +87,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: tor """ ... + @overload -def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: +def sparse_scaled_dot_product_attention( + q: torch.Tensor, k: SparseTensor, v: SparseTensor +) -> torch.Tensor: """ Apply scaled dot product attention to a sparse tensor. @@ -87,106 +102,158 @@ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: Spa """ ... + def sparse_scaled_dot_product_attention(*args, **kwargs): - arg_names_dict = { - 1: ['qkv'], - 2: ['q', 'kv'], - 3: ['q', 'k', 'v'] - } + arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]} num_all_args = len(args) + len(kwargs) - assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" - for key in arg_names_dict[num_all_args][len(args):]: + assert ( + num_all_args in arg_names_dict + ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args) :]: assert key in kwargs, f"Missing argument {key}" if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" - assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + qkv = args[0] if len(args) > 0 else kwargs["qkv"] + assert isinstance( + qkv, SparseTensor + ), f"qkv must be a SparseTensor, got {type(qkv)}" + assert ( + len(qkv.shape) == 4 and qkv.shape[1] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" device = qkv.device s = qkv - q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + q_seqlen = [ + qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0]) + ] kv_seqlen = q_seqlen - qkv = qkv.feats # [T, 3, H, C] + qkv = qkv.feats # [T, 3, H, C] elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ - isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ - f"Invalid types, got {type(q)} and {type(kv)}" - assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + q = args[0] if len(args) > 0 else kwargs["q"] + kv = args[1] if len(args) > 1 else kwargs["kv"] + assert ( + isinstance(q, SparseTensor) + and isinstance(kv, (SparseTensor, torch.Tensor)) + or isinstance(q, torch.Tensor) + and isinstance(kv, SparseTensor) + ), f"Invalid types, got {type(q)} and {type(kv)}" + assert ( + q.shape[0] == kv.shape[0] + ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" device = q.device if isinstance(q, SparseTensor): - assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + assert ( + len(q.shape) == 3 + ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" s = q q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] - q = q.feats # [T_Q, H, C] + q = q.feats # [T_Q, H, C] else: - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" s = None N, L, H, C = q.shape q_seqlen = [L] * N - q = q.reshape(N * L, H, C) # [T_Q, H, C] + q = q.reshape(N * L, H, C) # [T_Q, H, C] if isinstance(kv, SparseTensor): - assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" - kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] - kv = kv.feats # [T_KV, 2, H, C] + assert ( + len(kv.shape) == 4 and kv.shape[1] == 2 + ), f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [ + kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0]) + ] + kv = kv.feats # [T_KV, 2, H, C] else: - assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + assert ( + len(kv.shape) == 5 + ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" N, L, _, H, C = kv.shape kv_seqlen = [L] * N - kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] - assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ - isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ - f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" - assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + q = args[0] if len(args) > 0 else kwargs["q"] + k = args[1] if len(args) > 1 else kwargs["k"] + v = args[2] if len(args) > 2 else kwargs["v"] + assert ( + isinstance(q, SparseTensor) + and isinstance(k, (SparseTensor, torch.Tensor)) + and type(k) == type(v) + or isinstance(q, torch.Tensor) + and isinstance(k, SparseTensor) + and isinstance(v, SparseTensor) + ), f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert ( + q.shape[0] == k.shape[0] == v.shape[0] + ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" device = q.device if isinstance(q, SparseTensor): - assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + assert ( + len(q.shape) == 3 + ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" s = q q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] - q = q.feats # [T_Q, H, Ci] + q = q.feats # [T_Q, H, Ci] else: - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" s = None N, L, H, CI = q.shape q_seqlen = [L] * N q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] if isinstance(k, SparseTensor): - assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" - assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" - kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] - k = k.feats # [T_KV, H, Ci] - v = v.feats # [T_KV, H, Co] + assert ( + len(k.shape) == 3 + ), f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert ( + len(v.shape) == 3 + ), f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [ + k.layout[i].stop - k.layout[i].start for i in range(k.shape[0]) + ] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] else: - assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" - assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + assert ( + len(k.shape) == 4 + ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert ( + len(v.shape) == 4 + ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" N, L, H, CI, CO = *k.shape, v.shape[-1] kv_seqlen = [L] * N - k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] - v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] if DEBUG: if s is not None: for i in range(s.shape[0]): - assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + assert ( + s.coords[s.layout[i]] == i + ).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" if num_all_args in [2, 3]: - assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + assert q.shape[:2] == [ + 1, + sum(q_seqlen), + ], f"SparseScaledDotProductSelfAttention: q shape mismatch" if num_all_args == 3: - assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" - assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + assert k.shape[:2] == [ + 1, + sum(kv_seqlen), + ], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [ + 1, + sum(kv_seqlen), + ], f"SparseScaledDotProductSelfAttention: v shape mismatch" - if ATTN == 'xformers': + if ATTN == "xformers": if num_all_args == 1: q, k, v = qkv.unbind(dim=1) elif num_all_args == 2: @@ -196,19 +263,35 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): v = v.unsqueeze(0) mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) out = xops.memory_efficient_attention(q, k, v, mask)[0] - elif ATTN == 'flash_attn': - cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + elif ATTN == "flash_attn": + cu_seqlens_q = ( + torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]) + .int() + .to(device) + ) if num_all_args in [2, 3]: - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + cu_seqlens_kv = ( + torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)] + ) + .int() + .to(device) + ) if num_all_args == 1: - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + out = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens_q, max(q_seqlen) + ) elif num_all_args == 2: - out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + out = flash_attn.flash_attn_varlen_kvpacked_func( + q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen) + ) elif num_all_args == 3: - out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + out = flash_attn.flash_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen) + ) else: raise ValueError(f"Unknown attention module: {ATTN}") - + if s is not None: return s.replace(out) else: diff --git a/trellis/modules/sparse/attention/modules.py b/trellis/modules/sparse/attention/modules.py index 5d2fe782b0947700e308e9ec0325e7e91c84e3c2..e04d077f96bdfb7070f2a5b98dc166e20da7bca1 100755 --- a/trellis/modules/sparse/attention/modules.py +++ b/trellis/modules/sparse/attention/modules.py @@ -4,7 +4,10 @@ import torch.nn as nn import torch.nn.functional as F from .. import SparseTensor from .full_attn import sparse_scaled_dot_product_attention -from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .serialized_attn import ( + SerializeMode, + sparse_serialized_scaled_dot_product_self_attention, +) from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention from ...attention import RotaryPositionEmbedder @@ -12,16 +15,18 @@ from ...attention import RotaryPositionEmbedder class SparseMultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(heads, dim)) - def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + def forward( + self, x: Union[SparseTensor, torch.Tensor] + ) -> Union[SparseTensor, torch.Tensor]: x_type = x.dtype x = x.float() if isinstance(x, SparseTensor): x = x.replace(F.normalize(x.feats, dim=-1)) else: - x = F.normalize(x, dim=-1) + x = F.normalize(x, dim=-1) return (x * self.gamma * self.scale).to(x_type) @@ -44,9 +49,17 @@ class SparseMultiHeadAttention(nn.Module): super().__init__() assert channels % num_heads == 0 assert type in ["self", "cross"], f"Invalid attention type: {type}" - assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" - assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" - assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + assert attn_mode in [ + "full", + "serialized", + "windowed", + ], f"Invalid attention mode: {attn_mode}" + assert ( + type == "self" or attn_mode == "full" + ), "Cross-attention only supports full attention" + assert ( + type == "self" or use_rope is False + ), "Rotary position embeddings only supported for self-attention" self.channels = channels self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.num_heads = num_heads @@ -64,31 +77,37 @@ class SparseMultiHeadAttention(nn.Module): else: self.to_q = nn.Linear(channels, channels, bias=qkv_bias) self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) - + if self.qk_rms_norm: self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) - + self.to_out = nn.Linear(channels, channels) if use_rope: self.rope = RotaryPositionEmbedder(channels) @staticmethod - def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + def _linear( + module: nn.Linear, x: Union[SparseTensor, torch.Tensor] + ) -> Union[SparseTensor, torch.Tensor]: if isinstance(x, SparseTensor): return x.replace(module(x.feats)) else: return module(x) @staticmethod - def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + def _reshape_chs( + x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...] + ) -> Union[SparseTensor, torch.Tensor]: if isinstance(x, SparseTensor): return x.reshape(*shape) else: return x.reshape(*x.shape[:2], *shape) - def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + def _fused_pre( + self, x: Union[SparseTensor, torch.Tensor], num_fused: int + ) -> Union[SparseTensor, torch.Tensor]: if isinstance(x, SparseTensor): x_feats = x.feats.unsqueeze(0) else: @@ -97,12 +116,16 @@ class SparseMultiHeadAttention(nn.Module): return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats def _rope(self, qkv: SparseTensor) -> SparseTensor: - q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] q, k = self.rope(q, k, qkv.coords[:, 1:]) - qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) return qkv - - def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + + def forward( + self, + x: Union[SparseTensor, torch.Tensor], + context: Optional[Union[SparseTensor, torch.Tensor]] = None, + ) -> Union[SparseTensor, torch.Tensor]: if self._type == "self": qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) @@ -117,7 +140,11 @@ class SparseMultiHeadAttention(nn.Module): h = sparse_scaled_dot_product_attention(qkv) elif self.attn_mode == "serialized": h = sparse_serialized_scaled_dot_product_self_attention( - qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + qkv, + self.window_size, + serialize_mode=self.serialize_mode, + shift_sequence=self.shift_sequence, + shift_window=self.shift_window, ) elif self.attn_mode == "windowed": h = sparse_windowed_scaled_dot_product_self_attention( diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py index 5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26..c3f341d5d7f3a8c3ec70c9f401da9aa738586b8a 100755 --- a/trellis/modules/sparse/attention/serialized_attn.py +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -5,16 +5,16 @@ import math from .. import SparseTensor from .. import DEBUG, ATTN -if ATTN == 'xformers': +if ATTN == "xformers": import xformers.ops as xops -elif ATTN == 'flash_attn': +elif ATTN == "flash_attn": import flash_attn else: raise ValueError(f"Unknown attention module: {ATTN}") __all__ = [ - 'sparse_serialized_scaled_dot_product_self_attention', + "sparse_serialized_scaled_dot_product_self_attention", ] @@ -29,7 +29,7 @@ SerializeModes = [ SerializeMode.Z_ORDER, SerializeMode.Z_ORDER_TRANSPOSED, SerializeMode.HILBERT, - SerializeMode.HILBERT_TRANSPOSED + SerializeMode.HILBERT_TRANSPOSED, ] @@ -38,7 +38,7 @@ def calc_serialization( window_size: int, serialize_mode: SerializeMode = SerializeMode.Z_ORDER, shift_sequence: int = 0, - shift_window: Tuple[int, int, int] = (0, 0, 0) + shift_window: Tuple[int, int, int] = (0, 0, 0), ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: """ Calculate serialization and partitioning for a set of coordinates. @@ -58,32 +58,38 @@ def calc_serialization( seq_lens = [] seq_batch_indices = [] offsets = [0] - - if 'vox2seq' not in globals(): + + if "vox2seq" not in globals(): import vox2seq # Serialize the input serialize_coords = tensor.coords[:, 1:].clone() - serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + serialize_coords += torch.tensor( + shift_window, dtype=torch.int32, device=tensor.device + ).reshape(1, 3) if serialize_mode == SerializeMode.Z_ORDER: - code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + code = vox2seq.encode(serialize_coords, mode="z_order", permute=[0, 1, 2]) elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: - code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + code = vox2seq.encode(serialize_coords, mode="z_order", permute=[1, 0, 2]) elif serialize_mode == SerializeMode.HILBERT: - code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[0, 1, 2]) elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: - code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[1, 0, 2]) else: raise ValueError(f"Unknown serialize mode: {serialize_mode}") - + for bi, s in enumerate(tensor.layout): num_points = s.stop - s.start num_windows = (num_points + window_size - 1) // window_size valid_window_size = num_points / num_windows - to_ordered = torch.argsort(code[s.start:s.stop]) + to_ordered = torch.argsort(code[s.start : s.stop]) if num_windows == 1: fwd_indices.append(to_ordered) - bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + bwd_indices.append( + torch.zeros_like(to_ordered).scatter_( + 0, to_ordered, torch.arange(num_points, device=tensor.device) + ) + ) fwd_indices[-1] += s.start bwd_indices[-1] += offsets[-1] seq_lens.append(num_points) @@ -92,18 +98,39 @@ def calc_serialization( else: # Partition the input offset = 0 - mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] - split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] - bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + mids = [ + (i + 0.5) * valid_window_size + shift_sequence + for i in range(num_windows) + ] + split = [ + math.floor(i * valid_window_size + shift_sequence) + for i in range(num_windows + 1) + ] + bwd_index = torch.zeros( + (num_points,), dtype=torch.int64, device=tensor.device + ) for i in range(num_windows): mid = mids[i] valid_start = split[i] valid_end = split[i + 1] padded_start = math.floor(mid - 0.5 * window_size) padded_end = padded_start + window_size - fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + fwd_indices.append( + to_ordered[ + torch.arange(padded_start, padded_end, device=tensor.device) + % num_points + ] + ) offset += valid_start - padded_start - bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + bwd_index.scatter_( + 0, + fwd_indices[-1][ + valid_start - padded_start : valid_end - padded_start + ], + torch.arange( + offset, offset + valid_end - valid_start, device=tensor.device + ), + ) offset += padded_end - valid_start fwd_indices[-1] += s.start seq_lens.extend([window_size] * num_windows) @@ -115,14 +142,14 @@ def calc_serialization( bwd_indices = torch.cat(bwd_indices) return fwd_indices, bwd_indices, seq_lens, seq_batch_indices - + def sparse_serialized_scaled_dot_product_self_attention( qkv: SparseTensor, window_size: int, serialize_mode: SerializeMode = SerializeMode.Z_ORDER, shift_sequence: int = 0, - shift_window: Tuple[int, int, int] = (0, 0, 0) + shift_window: Tuple[int, int, int] = (0, 0, 0), ) -> SparseTensor: """ Apply serialized scaled dot product self attention to a sparse tensor. @@ -135,59 +162,89 @@ def sparse_serialized_scaled_dot_product_self_attention( shift_window (Tuple[int, int, int]): The shift of serialized coordinates. shift (int): The shift to use. """ - assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" - - serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' - serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + assert ( + len(qkv.shape) == 4 and qkv.shape[1] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = ( + f"serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}" + ) + serialization_spatial_cache = qkv.get_spatial_cache( + serialization_spatial_cache_name + ) if serialization_spatial_cache is None: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) - qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization( + qkv, window_size, serialize_mode, shift_sequence, shift_window + ) + qkv.register_spatial_cache( + serialization_spatial_cache_name, + (fwd_indices, bwd_indices, seq_lens, seq_batch_indices), + ) else: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + ( + fwd_indices, + bwd_indices, + seq_lens, + seq_batch_indices, + ) = serialization_spatial_cache M = fwd_indices.shape[0] T = qkv.feats.shape[0] H = qkv.feats.shape[2] C = qkv.feats.shape[3] - - qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] if DEBUG: start = 0 qkv_coords = qkv.coords[fwd_indices] for i in range(len(seq_lens)): - assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert ( + qkv_coords[start : start + seq_lens[i], 0] == seq_batch_indices[i] + ).all(), ( + f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + ) start += seq_lens[i] if all([seq_len == window_size for seq_len in seq_lens]): B = len(seq_lens) N = window_size qkv_feats = qkv_feats.reshape(B, N, 3, H, C) - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] - out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] - elif ATTN == 'flash_attn': - out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == "flash_attn": + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] else: raise ValueError(f"Unknown attention module: {ATTN}") - out = out.reshape(B * N, H, C) # [M, H, C] + out = out.reshape(B * N, H, C) # [M, H, C] else: - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] - q = q.unsqueeze(0) # [1, M, H, C] - k = k.unsqueeze(0) # [1, M, H, C] - v = v.unsqueeze(0) # [1, M, H, C] + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) - out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] - elif ATTN == 'flash_attn': - cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ - .to(qkv.device).int() - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] - - out = out[bwd_indices] # [T, H, C] + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == "flash_attn": + cu_seqlens = ( + torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], + dim=0, + ) + .to(qkv.device) + .int() + ) + out = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_feats, cu_seqlens, max(seq_lens) + ) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] if DEBUG: qkv_coords = qkv_coords[bwd_indices] - assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + assert torch.equal( + qkv_coords, qkv.coords + ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" return qkv.replace(out) diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py index cd642c5252e29a3a5e59fad7ed3880b7b00bcf9a..c1fd9b830dc2ab8314c1094d4486fb4cdd174712 100755 --- a/trellis/modules/sparse/attention/windowed_attn.py +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -4,23 +4,23 @@ import math from .. import SparseTensor from .. import DEBUG, ATTN -if ATTN == 'xformers': +if ATTN == "xformers": import xformers.ops as xops -elif ATTN == 'flash_attn': +elif ATTN == "flash_attn": import flash_attn else: raise ValueError(f"Unknown attention module: {ATTN}") __all__ = [ - 'sparse_windowed_scaled_dot_product_self_attention', + "sparse_windowed_scaled_dot_product_self_attention", ] def calc_window_partition( tensor: SparseTensor, window_size: Union[int, Tuple[int, ...]], - shift_window: Union[int, Tuple[int, ...]] = 0 + shift_window: Union[int, Tuple[int, ...]] = 0, ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: """ Calculate serialization and partitioning for a set of coordinates. @@ -37,33 +37,43 @@ def calc_window_partition( (List[int]): Sequence batch indices. """ DIM = tensor.coords.shape[1] - 1 - shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + shift_window = ( + (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + ) window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size shifted_coords = tensor.coords.clone().detach() - shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_coords[:, 1:] += torch.tensor( + shift_window, device=tensor.device, dtype=torch.int32 + ).unsqueeze(0) MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] - shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) - shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + shifted_coords[:, 1:] //= torch.tensor( + window_size, device=tensor.device, dtype=torch.int32 + ).unsqueeze(0) + shifted_indices = ( + shifted_coords + * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0) + ).sum(dim=1) fwd_indices = torch.argsort(shifted_indices) bwd_indices = torch.empty_like(fwd_indices) bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) seq_lens = torch.bincount(shifted_indices) - seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + seq_batch_indices = ( + torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) + // OFFSET[0] + ) mask = seq_lens != 0 seq_lens = seq_lens[mask].tolist() seq_batch_indices = seq_batch_indices[mask].tolist() return fwd_indices, bwd_indices, seq_lens, seq_batch_indices - + def sparse_windowed_scaled_dot_product_self_attention( - qkv: SparseTensor, - window_size: int, - shift_window: Tuple[int, int, int] = (0, 0, 0) + qkv: SparseTensor, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0) ) -> SparseTensor: """ Apply windowed scaled dot product self attention to a sparse tensor. @@ -74,62 +84,95 @@ def sparse_windowed_scaled_dot_product_self_attention( shift_window (Tuple[int, int, int]): The shift of serialized coordinates. shift (int): The shift to use. """ - assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" - - serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' - serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + assert ( + len(qkv.shape) == 4 and qkv.shape[1] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f"window_partition_{window_size}_{shift_window}" + serialization_spatial_cache = qkv.get_spatial_cache( + serialization_spatial_cache_name + ) if serialization_spatial_cache is None: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) - qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition( + qkv, window_size, shift_window + ) + qkv.register_spatial_cache( + serialization_spatial_cache_name, + (fwd_indices, bwd_indices, seq_lens, seq_batch_indices), + ) else: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + ( + fwd_indices, + bwd_indices, + seq_lens, + seq_batch_indices, + ) = serialization_spatial_cache M = fwd_indices.shape[0] T = qkv.feats.shape[0] H = qkv.feats.shape[2] C = qkv.feats.shape[3] - - qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] if DEBUG: start = 0 qkv_coords = qkv.coords[fwd_indices] for i in range(len(seq_lens)): - seq_coords = qkv_coords[start:start+seq_lens[i]] - assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" - assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ - f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + seq_coords = qkv_coords[start : start + seq_lens[i]] + assert ( + seq_coords[:, 0] == seq_batch_indices[i] + ).all(), ( + f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + ) + assert ( + seq_coords[:, 1:].max(dim=0).values + - seq_coords[:, 1:].min(dim=0).values + < window_size + ).all(), ( + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + ) start += seq_lens[i] if all([seq_len == window_size for seq_len in seq_lens]): B = len(seq_lens) N = window_size qkv_feats = qkv_feats.reshape(B, N, 3, H, C) - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] - out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] - elif ATTN == 'flash_attn': - out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == "flash_attn": + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] else: raise ValueError(f"Unknown attention module: {ATTN}") - out = out.reshape(B * N, H, C) # [M, H, C] + out = out.reshape(B * N, H, C) # [M, H, C] else: - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] - q = q.unsqueeze(0) # [1, M, H, C] - k = k.unsqueeze(0) # [1, M, H, C] - v = v.unsqueeze(0) # [1, M, H, C] + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) - out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] - elif ATTN == 'flash_attn': - cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ - .to(qkv.device).int() - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] - - out = out[bwd_indices] # [T, H, C] + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == "flash_attn": + cu_seqlens = ( + torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], + dim=0, + ) + .to(qkv.device) + .int() + ) + out = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_feats, cu_seqlens, max(seq_lens) + ) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] if DEBUG: qkv_coords = qkv_coords[bwd_indices] - assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + assert torch.equal( + qkv_coords, qkv.coords + ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" return qkv.replace(out) diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py index 8837f44052f6d573d09e3bfb897e659e10516bb5..561fa7953e12455fb4312074186849e17b20fe58 100755 --- a/trellis/modules/sparse/basic.py +++ b/trellis/modules/sparse/basic.py @@ -2,22 +2,23 @@ from typing import * import torch import torch.nn as nn from . import BACKEND, DEBUG -SparseTensorData = None # Lazy import + +SparseTensorData = None # Lazy import __all__ = [ - 'SparseTensor', - 'sparse_batch_broadcast', - 'sparse_batch_op', - 'sparse_cat', - 'sparse_unbind', + "SparseTensor", + "sparse_batch_broadcast", + "sparse_batch_op", + "sparse_cat", + "sparse_unbind", ] class SparseTensor: """ Sparse tensor with support for both torchsparse and spconv backends. - + Parameters: - feats (torch.Tensor): Features of the sparse tensor. - coords (torch.Tensor): Coordinates of the sparse tensor. @@ -29,64 +30,89 @@ class SparseTensor: - Data corresponding to a same batch should be contiguous. - Coords should be in [0, 1023] """ + @overload - def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + def __init__( + self, + feats: torch.Tensor, + coords: torch.Tensor, + shape: Optional[torch.Size] = None, + layout: Optional[List[slice]] = None, + **kwargs, + ): + ... @overload - def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + def __init__( + self, + data, + shape: Optional[torch.Size] = None, + layout: Optional[List[slice]] = None, + **kwargs, + ): + ... def __init__(self, *args, **kwargs): # Lazy import of sparse tensor backend global SparseTensorData if SparseTensorData is None: import importlib - if BACKEND == 'torchsparse': - SparseTensorData = importlib.import_module('torchsparse').SparseTensor - elif BACKEND == 'spconv': - SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor - + + if BACKEND == "torchsparse": + SparseTensorData = importlib.import_module("torchsparse").SparseTensor + elif BACKEND == "spconv": + SparseTensorData = importlib.import_module( + "spconv.pytorch" + ).SparseConvTensor + method_id = 0 if len(args) != 0: method_id = 0 if isinstance(args[0], torch.Tensor) else 1 else: - method_id = 1 if 'data' in kwargs else 0 + method_id = 1 if "data" in kwargs else 0 if method_id == 0: feats, coords, shape, layout = args + (None,) * (4 - len(args)) - if 'feats' in kwargs: - feats = kwargs['feats'] - del kwargs['feats'] - if 'coords' in kwargs: - coords = kwargs['coords'] - del kwargs['coords'] - if 'shape' in kwargs: - shape = kwargs['shape'] - del kwargs['shape'] - if 'layout' in kwargs: - layout = kwargs['layout'] - del kwargs['layout'] + if "feats" in kwargs: + feats = kwargs["feats"] + del kwargs["feats"] + if "coords" in kwargs: + coords = kwargs["coords"] + del kwargs["coords"] + if "shape" in kwargs: + shape = kwargs["shape"] + del kwargs["shape"] + if "layout" in kwargs: + layout = kwargs["layout"] + del kwargs["layout"] if shape is None: shape = self.__cal_shape(feats, coords) if layout is None: layout = self.__cal_layout(coords, shape[0]) - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": self.data = SparseTensorData(feats, coords, **kwargs) - elif BACKEND == 'spconv': + elif BACKEND == "spconv": spatial_shape = list(coords.max(0)[0] + 1)[1:] - self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data = SparseTensorData( + feats.reshape(feats.shape[0], -1), + coords, + spatial_shape, + shape[0], + **kwargs, + ) self.data._features = feats elif method_id == 1: data, shape, layout = args + (None,) * (3 - len(args)) - if 'data' in kwargs: - data = kwargs['data'] - del kwargs['data'] - if 'shape' in kwargs: - shape = kwargs['shape'] - del kwargs['shape'] - if 'layout' in kwargs: - layout = kwargs['layout'] - del kwargs['layout'] + if "data" in kwargs: + data = kwargs["data"] + del kwargs["data"] + if "shape" in kwargs: + shape = kwargs["shape"] + del kwargs["shape"] + if "layout" in kwargs: + layout = kwargs["layout"] + del kwargs["layout"] self.data = data if shape is None: @@ -96,73 +122,84 @@ class SparseTensor: self._shape = shape self._layout = layout - self._scale = kwargs.get('scale', (1, 1, 1)) - self._spatial_cache = kwargs.get('spatial_cache', {}) + self._scale = kwargs.get("scale", (1, 1, 1)) + self._spatial_cache = kwargs.get("spatial_cache", {}) if DEBUG: try: - assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" - assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" - assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + assert ( + self.feats.shape[0] == self.coords.shape[0] + ), f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape( + self.feats, self.coords + ), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout( + self.coords, self.shape[0] + ), f"Invalid layout: {self.layout}" for i in range(self.shape[0]): - assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + assert torch.all( + self.coords[self.layout[i], 0] == i + ), f"The data of batch {i} is not contiguous" except Exception as e: - print('Debugging information:') + print("Debugging information:") print(f"- Shape: {self.shape}") print(f"- Layout: {self.layout}") print(f"- Scale: {self._scale}") print(f"- Coords: {self.coords}") raise e - + def __cal_shape(self, feats, coords): shape = [] shape.append(coords[:, 0].max().item() + 1) shape.extend([*feats.shape[1:]]) return torch.Size(shape) - + def __cal_layout(self, coords, batch_size): seq_len = torch.bincount(coords[:, 0], minlength=batch_size) - offset = torch.cumsum(seq_len, dim=0) - layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + offset = torch.cumsum(seq_len, dim=0) + layout = [ + slice((offset[i] - seq_len[i]).item(), offset[i].item()) + for i in range(batch_size) + ] return layout - + @property def shape(self) -> torch.Size: return self._shape - + def dim(self) -> int: return len(self.shape) - + @property def layout(self) -> List[slice]: return self._layout @property def feats(self) -> torch.Tensor: - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": return self.data.F - elif BACKEND == 'spconv': + elif BACKEND == "spconv": return self.data.features - + @feats.setter def feats(self, value: torch.Tensor): - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": self.data.F = value - elif BACKEND == 'spconv': + elif BACKEND == "spconv": self.data.features = value @property def coords(self) -> torch.Tensor: - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": return self.data.C - elif BACKEND == 'spconv': + elif BACKEND == "spconv": return self.data.indices - + @coords.setter def coords(self, value: torch.Tensor): - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": self.data.C = value - elif BACKEND == 'spconv': + elif BACKEND == "spconv": self.data.indices = value @property @@ -174,12 +211,18 @@ class SparseTensor: return self.feats.device @overload - def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + def to(self, dtype: torch.dtype) -> "SparseTensor": + ... @overload - def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... - - def to(self, *args, **kwargs) -> 'SparseTensor': + def to( + self, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> "SparseTensor": + ... + + def to(self, *args, **kwargs) -> "SparseTensor": device = None dtype = None if len(args) == 2: @@ -189,13 +232,13 @@ class SparseTensor: dtype = args[0] else: device = args[0] - if 'dtype' in kwargs: + if "dtype" in kwargs: assert dtype is None, "to() received multiple values for argument 'dtype'" - dtype = kwargs['dtype'] - if 'device' in kwargs: + dtype = kwargs["dtype"] + if "device" in kwargs: assert device is None, "to() received multiple values for argument 'device'" - device = kwargs['device'] - + device = kwargs["device"] + new_feats = self.feats.to(device=device, dtype=dtype) new_coords = self.coords.to(device=device) return self.replace(new_feats, new_coords) @@ -204,46 +247,48 @@ class SparseTensor: new_feats = self.feats.type(dtype) return self.replace(new_feats) - def cpu(self) -> 'SparseTensor': + def cpu(self) -> "SparseTensor": new_feats = self.feats.cpu() new_coords = self.coords.cpu() return self.replace(new_feats, new_coords) - - def cuda(self) -> 'SparseTensor': + + def cuda(self) -> "SparseTensor": new_feats = self.feats.cuda() new_coords = self.coords.cuda() return self.replace(new_feats, new_coords) - def half(self) -> 'SparseTensor': + def half(self) -> "SparseTensor": new_feats = self.feats.half() return self.replace(new_feats) - - def float(self) -> 'SparseTensor': + + def float(self) -> "SparseTensor": new_feats = self.feats.float() return self.replace(new_feats) - - def detach(self) -> 'SparseTensor': + + def detach(self) -> "SparseTensor": new_coords = self.coords.detach() new_feats = self.feats.detach() return self.replace(new_feats, new_coords) def dense(self) -> torch.Tensor: - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": return self.data.dense() - elif BACKEND == 'spconv': + elif BACKEND == "spconv": return self.data.dense() - def reshape(self, *shape) -> 'SparseTensor': + def reshape(self, *shape) -> "SparseTensor": new_feats = self.feats.reshape(self.feats.shape[0], *shape) return self.replace(new_feats) - - def unbind(self, dim: int) -> List['SparseTensor']: + + def unbind(self, dim: int) -> List["SparseTensor"]: return sparse_unbind(self, dim) - def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + def replace( + self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None + ) -> "SparseTensor": new_shape = [self.shape[0]] new_shape.extend(feats.shape[1:]) - if BACKEND == 'torchsparse': + if BACKEND == "torchsparse": new_data = SparseTensorData( feats=feats, coords=self.data.coords if coords is None else coords, @@ -251,7 +296,7 @@ class SparseTensor: spatial_range=self.data.spatial_range, ) new_data._caches = self.data._caches - elif BACKEND == 'spconv': + elif BACKEND == "spconv": new_data = SparseTensorData( self.data.features.reshape(self.data.features.shape[0], -1), self.data.indices, @@ -259,7 +304,7 @@ class SparseTensor: self.data.batch_size, self.data.grid, self.data.voxel_num, - self.data.indice_dict + self.data.indice_dict, ) new_data._features = feats new_data.benchmark = self.data.benchmark @@ -270,26 +315,39 @@ class SparseTensor: new_data.int8_scale = self.data.int8_scale if coords is not None: new_data.indices = coords - new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + new_tensor = SparseTensor( + new_data, + shape=torch.Size(new_shape), + layout=self.layout, + scale=self._scale, + spatial_cache=self._spatial_cache, + ) return new_tensor @staticmethod - def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + def full(aabb, dim, value, dtype=torch.float32, device=None) -> "SparseTensor": N, C = dim x = torch.arange(aabb[0], aabb[3] + 1) y = torch.arange(aabb[1], aabb[4] + 1) z = torch.arange(aabb[2], aabb[5] + 1) - coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) - coords = torch.cat([ - torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), - coords.repeat(N, 1), - ], dim=1).to(dtype=torch.int32, device=device) + coords = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1).reshape( + -1, 3 + ) + coords = torch.cat( + [ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], + dim=1, + ).to(dtype=torch.int32, device=device) feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) return SparseTensor(feats=feats, coords=coords) - def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + def __merge_sparse_cache(self, other: "SparseTensor") -> dict: new_cache = {} - for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + for k in set( + list(self._spatial_cache.keys()) + list(other._spatial_cache.keys()) + ): if k in self._spatial_cache: new_cache[k] = self._spatial_cache[k] if k in other._spatial_cache: @@ -299,10 +357,12 @@ class SparseTensor: new_cache[k].update(other._spatial_cache[k]) return new_cache - def __neg__(self) -> 'SparseTensor': + def __neg__(self) -> "SparseTensor": return self.replace(-self.feats) - - def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + + def __elemwise__( + self, other: Union[torch.Tensor, "SparseTensor"], op: callable + ) -> "SparseTensor": if isinstance(other, torch.Tensor): try: other = torch.broadcast_to(other, self.shape) @@ -317,28 +377,44 @@ class SparseTensor: new_tensor._spatial_cache = self.__merge_sparse_cache(other) return new_tensor - def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + def __add__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, torch.add) - def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + def __radd__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, torch.add) - - def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + + def __sub__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, torch.sub) - - def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + + def __rsub__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) - def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + def __mul__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, torch.mul) - def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + def __rmul__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, torch.mul) - def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + def __truediv__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, torch.div) - def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + def __rtruediv__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": return self.__elemwise__(other, lambda x, y: torch.div(y, x)) def __getitem__(self, idx): @@ -348,7 +424,9 @@ class SparseTensor: idx = range(*idx.indices(self.shape[0])) elif isinstance(idx, torch.Tensor): if idx.dtype == torch.bool: - assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + assert idx.shape == ( + self.shape[0], + ), f"Invalid index shape: {idx.shape}" idx = idx.nonzero().squeeze(1) elif idx.dtype in [torch.int32, torch.int64]: assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" @@ -356,7 +434,7 @@ class SparseTensor: raise ValueError(f"Unknown index type: {idx.dtype}") else: raise ValueError(f"Unknown index type: {type(idx)}") - + coords = [] feats = [] for new_idx, old_idx in enumerate(idx): @@ -392,7 +470,7 @@ class SparseTensor: def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: """ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. - + Args: input (torch.Tensor): 1D tensor to broadcast. target (SparseTensor): Sparse tensor to broadcast to. @@ -405,10 +483,12 @@ def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Te return broadcasted -def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: +def sparse_batch_op( + input: SparseTensor, other: torch.Tensor, op: callable = torch.add +) -> SparseTensor: """ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. - + Args: input (torch.Tensor): 1D tensor to broadcast. target (SparseTensor): Sparse tensor to broadcast to. @@ -420,7 +500,7 @@ def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = tor def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: """ Concatenate a list of sparse tensors. - + Args: inputs (List[SparseTensor]): List of sparse tensors to concatenate. """ @@ -447,7 +527,7 @@ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: """ Unbind a sparse tensor along a dimension. - + Args: input (SparseTensor): Sparse tensor to unbind. dim (int): Dimension to unbind. diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py index 340a87126a8de574ee0276feb96b49824a2ce234..a6887f2449757d6a8198c7a17e10513f537cfd26 100755 --- a/trellis/modules/sparse/conv/__init__.py +++ b/trellis/modules/sparse/conv/__init__.py @@ -1,21 +1,26 @@ from .. import BACKEND -SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' +SPCONV_ALGO = "auto" # 'auto', 'implicit_gemm', 'native' + def __from_env(): import os - + global SPCONV_ALGO - env_spconv_algo = os.environ.get('SPCONV_ALGO') - if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + env_spconv_algo = os.environ.get("SPCONV_ALGO") + if env_spconv_algo is not None and env_spconv_algo in [ + "auto", + "implicit_gemm", + "native", + ]: SPCONV_ALGO = env_spconv_algo print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") - + __from_env() -if BACKEND == 'torchsparse': +if BACKEND == "torchsparse": from .conv_torchsparse import * -elif BACKEND == 'spconv': +elif BACKEND == "spconv": from .conv_spconv import * diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py index 524bcd4a845b2d6bd090a5f74bc8859978727528..b6e907980d0c7d282e62cb9be810fb70e76e4a6c 100755 --- a/trellis/modules/sparse/conv/conv_spconv.py +++ b/trellis/modules/sparse/conv/conv_spconv.py @@ -4,21 +4,54 @@ from .. import SparseTensor from .. import DEBUG from . import SPCONV_ALGO + class SparseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + padding=None, + bias=True, + indice_key=None, + ): super(SparseConv3d, self).__init__() - if 'spconv' not in globals(): + if "spconv" not in globals(): import spconv.pytorch as spconv algo = None - if SPCONV_ALGO == 'native': + if SPCONV_ALGO == "native": algo = spconv.ConvAlgo.Native - elif SPCONV_ALGO == 'implicit_gemm': + elif SPCONV_ALGO == "implicit_gemm": algo = spconv.ConvAlgo.MaskImplicitGemm if stride == 1 and (padding is None): - self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + self.conv = spconv.SubMConv3d( + in_channels, + out_channels, + kernel_size, + dilation=dilation, + bias=bias, + indice_key=indice_key, + algo=algo, + ) else: - self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) - self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.conv = spconv.SparseConv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + indice_key=indice_key, + algo=algo, + ) + self.stride = ( + tuple(stride) + if isinstance(stride, (list, tuple)) + else (stride, stride, stride) + ) self.padding = padding def forward(self, x: SparseTensor) -> SparseTensor: @@ -30,42 +63,65 @@ class SparseConv3d(nn.Module): if spatial_changed and (x.shape[0] != 1): # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords fwd = new_data.indices[:, 0].argsort() - bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + bwd = torch.zeros_like(fwd).scatter_( + 0, fwd, torch.arange(fwd.shape[0], device=fwd.device) + ) sorted_feats = new_data.features[fwd] sorted_coords = new_data.indices[fwd] unsorted_data = new_data new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore out = SparseTensor( - new_data, shape=torch.Size(new_shape), layout=new_layout, + new_data, + shape=torch.Size(new_shape), + layout=new_layout, scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), spatial_cache=x._spatial_cache, ) if spatial_changed and (x.shape[0] != 1): - out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) - out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) - + out.register_spatial_cache( + f"conv_{self.stride}_unsorted_data", unsorted_data + ) + out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd) + return out class SparseInverseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + bias=True, + indice_key=None, + ): super(SparseInverseConv3d, self).__init__() - if 'spconv' not in globals(): + if "spconv" not in globals(): import spconv.pytorch as spconv - self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) - self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.conv = spconv.SparseInverseConv3d( + in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key + ) + self.stride = ( + tuple(stride) + if isinstance(stride, (list, tuple)) + else (stride, stride, stride) + ) def forward(self, x: SparseTensor) -> SparseTensor: spatial_changed = any(s != 1 for s in self.stride) if spatial_changed: # recover the original spconv order - data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') - bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data") + bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd") data = data.replace_feature(x.feats[bwd]) if DEBUG: - assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + assert torch.equal( + data.indices, x.coords[bwd] + ), "Recover the original order failed" else: data = x.data @@ -73,7 +129,9 @@ class SparseInverseConv3d(nn.Module): new_shape = [x.shape[0], self.conv.out_channels] new_layout = None if spatial_changed else x.layout out = SparseTensor( - new_data, shape=torch.Size(new_shape), layout=new_layout, + new_data, + shape=torch.Size(new_shape), + layout=new_layout, scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), spatial_cache=x._spatial_cache, ) diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py index 1d612582d4b31f90aca3c00b693bbbc2550dc62c..6ab8798d8ffbedb0655170b49fb8c9806f1a324c 100755 --- a/trellis/modules/sparse/conv/conv_torchsparse.py +++ b/trellis/modules/sparse/conv/conv_torchsparse.py @@ -4,35 +4,73 @@ from .. import SparseTensor class SparseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + bias=True, + indice_key=None, + ): super(SparseConv3d, self).__init__() - if 'torchsparse' not in globals(): + if "torchsparse" not in globals(): import torchsparse - self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + self.conv = torchsparse.nn.Conv3d( + in_channels, out_channels, kernel_size, stride, 0, dilation, bias + ) def forward(self, x: SparseTensor) -> SparseTensor: out = self.conv(x.data) new_shape = [x.shape[0], self.conv.out_channels] - out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out = SparseTensor( + out, + shape=torch.Size(new_shape), + layout=x.layout if all(s == 1 for s in self.conv.stride) else None, + ) out._spatial_cache = x._spatial_cache - out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + out._scale = tuple( + [s * stride for s, stride in zip(x._scale, self.conv.stride)] + ) return out class SparseInverseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + bias=True, + indice_key=None, + ): super(SparseInverseConv3d, self).__init__() - if 'torchsparse' not in globals(): + if "torchsparse" not in globals(): import torchsparse - self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + self.conv = torchsparse.nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride, + 0, + dilation, + bias, + transposed=True, + ) def forward(self, x: SparseTensor) -> SparseTensor: - out = self.conv(x.data) + out = self.conv(x.data) new_shape = [x.shape[0], self.conv.out_channels] - out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out = SparseTensor( + out, + shape=torch.Size(new_shape), + layout=x.layout if all(s == 1 for s in self.conv.stride) else None, + ) out._spatial_cache = x._spatial_cache - out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + out._scale = tuple( + [s // stride for s, stride in zip(x._scale, self.conv.stride)] + ) return out - - - diff --git a/trellis/modules/sparse/linear.py b/trellis/modules/sparse/linear.py index a854e77ce87d1a190b9730d91f363a821ff250bd..70eb81c9d1562c9258ba0db02d848b0f7afa998a 100755 --- a/trellis/modules/sparse/linear.py +++ b/trellis/modules/sparse/linear.py @@ -2,9 +2,7 @@ import torch import torch.nn as nn from . import SparseTensor -__all__ = [ - 'SparseLinear' -] +__all__ = ["SparseLinear"] class SparseLinear(nn.Linear): diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py index f200098dd82011a3aeee1688b9eb17018fa78295..ae2bd56a07c1dd9040999cbf0ca77a092bab1b28 100755 --- a/trellis/modules/sparse/nonlinearity.py +++ b/trellis/modules/sparse/nonlinearity.py @@ -2,18 +2,13 @@ import torch import torch.nn as nn from . import SparseTensor -__all__ = [ - 'SparseReLU', - 'SparseSiLU', - 'SparseGELU', - 'SparseActivation' -] +__all__ = ["SparseReLU", "SparseSiLU", "SparseGELU", "SparseActivation"] class SparseReLU(nn.ReLU): def forward(self, input: SparseTensor) -> SparseTensor: return input.replace(super().forward(input.feats)) - + class SparseSiLU(nn.SiLU): def forward(self, input: SparseTensor) -> SparseTensor: @@ -32,4 +27,3 @@ class SparseActivation(nn.Module): def forward(self, input: SparseTensor) -> SparseTensor: return input.replace(self.activation(input.feats)) - diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py index 6b38a36682c098210000dc31d68ddc31ccd2929d..921e338f6a8ec80088b985623041a0ba0f0d8fd8 100755 --- a/trellis/modules/sparse/norm.py +++ b/trellis/modules/sparse/norm.py @@ -4,10 +4,10 @@ from . import SparseTensor from . import DEBUG __all__ = [ - 'SparseGroupNorm', - 'SparseLayerNorm', - 'SparseGroupNorm32', - 'SparseLayerNorm32', + "SparseGroupNorm", + "SparseLayerNorm", + "SparseGroupNorm32", + "SparseLayerNorm32", ] @@ -19,7 +19,9 @@ class SparseGroupNorm(nn.GroupNorm): nfeats = torch.zeros_like(input.feats) for k in range(input.shape[0]): if DEBUG: - assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + assert ( + input.coords[input.layout[k], 0] == k + ).all(), f"SparseGroupNorm: batch index mismatch" bfeats = input.feats[input.layout[k]] bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) bfeats = super().forward(bfeats) @@ -47,12 +49,15 @@ class SparseGroupNorm32(SparseGroupNorm): """ A GroupNorm layer that converts to float32 before the forward pass. """ + def forward(self, x: SparseTensor) -> SparseTensor: return super().forward(x.float()).type(x.dtype) + class SparseLayerNorm32(SparseLayerNorm): """ A LayerNorm layer that converts to float32 before the forward pass. """ + def forward(self, x: SparseTensor) -> SparseTensor: return super().forward(x.float()).type(x.dtype) diff --git a/trellis/modules/sparse/spatial.py b/trellis/modules/sparse/spatial.py index ad7121473f335b307e2f7ea5f05c964d3aec0440..5f9e7f92daba2725c6b3e12b32a24013e2fcbf15 100755 --- a/trellis/modules/sparse/spatial.py +++ b/trellis/modules/sparse/spatial.py @@ -3,11 +3,7 @@ import torch import torch.nn as nn from . import SparseTensor -__all__ = [ - 'SparseDownsample', - 'SparseUpsample', - 'SparseSubdivide' -] +__all__ = ["SparseDownsample", "SparseUpsample", "SparseSubdivide"] class SparseDownsample(nn.Module): @@ -15,6 +11,7 @@ class SparseDownsample(nn.Module): Downsample a sparse tensor by a factor of `factor`. Implemented as average pooling. """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): super(SparseDownsample, self).__init__() self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor @@ -22,36 +19,47 @@ class SparseDownsample(nn.Module): def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM - assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + assert DIM == len( + factor + ), "Input coordinates must have the same dimension as the downsample factor." coord = list(input.coords.unbind(dim=-1)) for i, f in enumerate(factor): - coord[i+1] = coord[i+1] // f + coord[i + 1] = coord[i + 1] // f - MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + MAX = [coord[i + 1].max().item() + 1 for i in range(DIM)] OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] code = sum([c * o for c, o in zip(coord, OFFSET)]) code, idx = code.unique(return_inverse=True) new_feats = torch.scatter_reduce( - torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), + torch.zeros( + code.shape[0], + input.feats.shape[1], + device=input.feats.device, + dtype=input.feats.dtype, + ), dim=0, index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), src=input.feats, - reduce='mean' + reduce="mean", ) new_coords = torch.stack( - [code // OFFSET[0]] + - [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], - dim=-1 + [code // OFFSET[0]] + + [(code // OFFSET[i + 1]) % MAX[i] for i in range(DIM)], + dim=-1, + ) + out = SparseTensor( + new_feats, + new_coords, + input.shape, ) - out = SparseTensor(new_feats, new_coords, input.shape,) out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) out._spatial_cache = input._spatial_cache - out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) - out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) - out.register_spatial_cache(f'upsample_{factor}_idx', idx) + out.register_spatial_cache(f"upsample_{factor}_coords", input.coords) + out.register_spatial_cache(f"upsample_{factor}_layout", input.layout) + out.register_spatial_cache(f"upsample_{factor}_idx", idx) return out @@ -61,6 +69,7 @@ class SparseUpsample(nn.Module): Upsample a sparse tensor by a factor of `factor`. Implemented as nearest neighbor interpolation. """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): super(SparseUpsample, self).__init__() self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor @@ -68,24 +77,30 @@ class SparseUpsample(nn.Module): def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM - assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + assert DIM == len( + factor + ), "Input coordinates must have the same dimension as the upsample factor." - new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') - new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') - idx = input.get_spatial_cache(f'upsample_{factor}_idx') + new_coords = input.get_spatial_cache(f"upsample_{factor}_coords") + new_layout = input.get_spatial_cache(f"upsample_{factor}_layout") + idx = input.get_spatial_cache(f"upsample_{factor}_idx") if any([x is None for x in [new_coords, new_layout, idx]]): - raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + raise ValueError( + "Upsample cache not found. SparseUpsample must be paired with SparseDownsample." + ) new_feats = input.feats[idx] out = SparseTensor(new_feats, new_coords, input.shape, new_layout) out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) out._spatial_cache = input._spatial_cache return out - + + class SparseSubdivide(nn.Module): """ Upsample a sparse tensor by a factor of `factor`. Implemented as nearest neighbor interpolation. """ + def __init__(self): super(SparseSubdivide, self).__init__() @@ -96,15 +111,20 @@ class SparseSubdivide(nn.Module): n_coords = torch.nonzero(n_cube) n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) factor = n_coords.shape[0] - assert factor == 2 ** DIM + assert factor == 2**DIM # print(n_coords.shape) new_coords = input.coords.clone() new_coords[:, 1:] *= 2 - new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) - - new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) - out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to( + new_coords.dtype + ) + + new_feats = input.feats.unsqueeze(1).expand( + input.feats.shape[0], factor, *input.feats.shape[1:] + ) + out = SparseTensor( + new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape + ) out._scale = input._scale * 2 out._spatial_cache = input._spatial_cache return out - diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py index b08b0d4e5bc24060a2cdc8df75d06dce122972bd..92deb200ea54da8f150354f1e7baff1931c2df1b 100644 --- a/trellis/modules/sparse/transformer/__init__.py +++ b/trellis/modules/sparse/transformer/__init__.py @@ -1,2 +1,2 @@ from .blocks import * -from .modulated import * \ No newline at end of file +from .modulated import * diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py index 9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0..c223eaeb56d4912eb287901ed27c0fd25ecae667 100644 --- a/trellis/modules/sparse/transformer/blocks.py +++ b/trellis/modules/sparse/transformer/blocks.py @@ -25,12 +25,15 @@ class SparseTransformerBlock(nn.Module): """ Sparse Transformer block (MSA + FFN). """ + def __init__( self, channels: int, num_heads: int, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", window_size: Optional[int] = None, shift_sequence: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, @@ -73,7 +76,9 @@ class SparseTransformerBlock(nn.Module): def forward(self, x: SparseTensor) -> SparseTensor: if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, use_reentrant=False + ) else: return self._forward(x) @@ -82,13 +87,16 @@ class SparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN). """ + def __init__( self, channels: int, ctx_channels: int, num_heads: int, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", window_size: Optional[int] = None, shift_sequence: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, @@ -146,6 +154,8 @@ class SparseTransformerCrossBlock(nn.Module): def forward(self, x: SparseTensor, context: torch.Tensor): if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, context, use_reentrant=False + ) else: return self._forward(x, context) diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py index 4a8416559f39acbed9e5996e9891c97f95c80c8f..9e7a482f476be8b4e51bf45c0a17dfe9542e9f7b 100644 --- a/trellis/modules/sparse/transformer/modulated.py +++ b/trellis/modules/sparse/transformer/modulated.py @@ -11,12 +11,15 @@ class ModulatedSparseTransformerBlock(nn.Module): """ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. """ + def __init__( self, channels: int, num_heads: int, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", window_size: Optional[int] = None, shift_sequence: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, @@ -50,15 +53,23 @@ class ModulatedSparseTransformerBlock(nn.Module): ) if not share_mod: self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) ) def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(mod).chunk(6, dim=1) h = x.replace(self.norm1(x.feats)) h = h * (1 + scale_msa) + shift_msa h = self.attn(h) @@ -73,7 +84,9 @@ class ModulatedSparseTransformerBlock(nn.Module): def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, use_reentrant=False + ) else: return self._forward(x, mod) @@ -82,13 +95,16 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. """ + def __init__( self, channels: int, ctx_channels: int, num_heads: int, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", window_size: Optional[int] = None, shift_sequence: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, @@ -99,7 +115,6 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, - ): super().__init__() self.use_checkpoint = use_checkpoint @@ -135,15 +150,25 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): ) if not share_mod: self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) ) - def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + def _forward( + self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor + ) -> SparseTensor: if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(mod).chunk(6, dim=1) h = x.replace(self.norm1(x.feats)) h = h * (1 + scale_msa) + shift_msa h = self.self_attn(h) @@ -159,8 +184,12 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): x = x + h return x - def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + def forward( + self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor + ) -> SparseTensor: if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, context, use_reentrant=False + ) else: return self._forward(x, mod, context) diff --git a/trellis/modules/spatial.py b/trellis/modules/spatial.py index 79e268d36c2ba49b0275744022a1a1e19983dae3..6c7cec71ce4da627d47b803d9bbbfd9f38a36a49 100644 --- a/trellis/modules/spatial.py +++ b/trellis/modules/spatial.py @@ -9,7 +9,7 @@ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: C_ = C // scale_factor**3 x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) - x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + x = x.reshape(B, C_, H * scale_factor, W * scale_factor, D * scale_factor) return x @@ -23,11 +23,18 @@ def patchify(x: torch.Tensor, patch_size: int): """ DIM = x.dim() - 2 for d in range(2, DIM + 2): - assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + assert ( + x.shape[d] % patch_size == 0 + ), f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" - x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) - x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) - x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + x = x.reshape( + *x.shape[:2], + *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []), + ) + x = x.permute( + 0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]) + ) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size**DIM), *(x.shape[-DIM:])) return x @@ -40,9 +47,18 @@ def unpatchify(x: torch.Tensor, patch_size: int): patch_size (int): Patch size """ DIM = x.dim() - 2 - assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + assert ( + x.shape[1] % (patch_size**DIM) == 0 + ), f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" - x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.reshape( + x.shape[0], + x.shape[1] // (patch_size**DIM), + *([patch_size] * DIM), + *(x.shape[-DIM:]), + ) x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) - x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + x = x.reshape( + x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)] + ) return x diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py index b08b0d4e5bc24060a2cdc8df75d06dce122972bd..92deb200ea54da8f150354f1e7baff1931c2df1b 100644 --- a/trellis/modules/transformer/__init__.py +++ b/trellis/modules/transformer/__init__.py @@ -1,2 +1,2 @@ from .blocks import * -from .modulated import * \ No newline at end of file +from .modulated import * diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py index c37eb7ed92f4aacfc9e974a63b247589d95977da..0fc47e8c1eb736e909d75f3e3f070f2bfb372117 100644 --- a/trellis/modules/transformer/blocks.py +++ b/trellis/modules/transformer/blocks.py @@ -9,14 +9,15 @@ class AbsolutePositionEmbedder(nn.Module): """ Embeds spatial positions into vector representations. """ + def __init__(self, channels: int, in_channels: int = 3): super().__init__() self.channels = channels self.in_channels = in_channels self.freq_dim = channels // in_channels // 2 self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim - self.freqs = 1.0 / (10000 ** self.freqs) - + self.freqs = 1.0 / (10000**self.freqs) + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: """ Create sinusoidal position embeddings. @@ -38,11 +39,19 @@ class AbsolutePositionEmbedder(nn.Module): x (torch.Tensor): (N, D) tensor of spatial positions """ N, D = x.shape - assert D == self.in_channels, "Input dimension must match number of input channels" + assert ( + D == self.in_channels + ), "Input dimension must match number of input channels" embed = self._sin_cos_embedding(x.reshape(-1)) embed = embed.reshape(N, -1) if embed.shape[1] < self.channels: - embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + embed = torch.cat( + [ + embed, + torch.zeros(N, self.channels - embed.shape[1], device=embed.device), + ], + dim=-1, + ) return embed @@ -63,6 +72,7 @@ class TransformerBlock(nn.Module): """ Transformer block (MSA + FFN). """ + def __init__( self, channels: int, @@ -107,7 +117,9 @@ class TransformerBlock(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, use_reentrant=False + ) else: return self._forward(x) @@ -116,6 +128,7 @@ class TransformerCrossBlock(nn.Module): """ Transformer cross-attention block (MSA + MCA + FFN). """ + def __init__( self, channels: int, @@ -176,7 +189,8 @@ class TransformerCrossBlock(nn.Module): def forward(self, x: torch.Tensor, context: torch.Tensor): if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, context, use_reentrant=False + ) else: return self._forward(x, context) - \ No newline at end of file diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py index d4aeca0689e68f656b08f7aa822b7be839aa727d..0f1c60e64b77a5c4e9d39692fa688f93e68706f6 100644 --- a/trellis/modules/transformer/modulated.py +++ b/trellis/modules/transformer/modulated.py @@ -10,6 +10,7 @@ class ModulatedTransformerBlock(nn.Module): """ Transformer block (MSA + FFN) with adaptive layer norm conditioning. """ + def __init__( self, channels: int, @@ -45,15 +46,23 @@ class ModulatedTransformerBlock(nn.Module): ) if not share_mod: self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) ) def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(mod).chunk(6, dim=1) h = self.norm1(x) h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) h = self.attn(h) @@ -68,7 +77,9 @@ class ModulatedTransformerBlock(nn.Module): def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, use_reentrant=False + ) else: return self._forward(x, mod) @@ -77,6 +88,7 @@ class ModulatedTransformerCrossBlock(nn.Module): """ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. """ + def __init__( self, channels: int, @@ -125,15 +137,23 @@ class ModulatedTransformerCrossBlock(nn.Module): ) if not share_mod: self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) ) def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(mod).chunk(6, dim=1) h = self.norm1(x) h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) h = self.self_attn(h) @@ -151,7 +171,8 @@ class ModulatedTransformerCrossBlock(nn.Module): def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, context, use_reentrant=False + ) else: return self._forward(x, mod, context) - \ No newline at end of file diff --git a/trellis/modules/utils.py b/trellis/modules/utils.py index f0afb1b6c767aa2ad00bad96649fb30315e696ea..a9b657326fa11ba3fd721d6994c24e8bb64ba483 100755 --- a/trellis/modules/utils.py +++ b/trellis/modules/utils.py @@ -14,6 +14,7 @@ FP16_MODULES = ( sp.SparseLinear, ) + def convert_module_to_f16(l): """ Convert primitive modules to float16. diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py index f9e8548b894aeb3d354c739320ed3288be9c7b0e..0e84651373d542d6e918da6309edf041aebe99c9 100644 --- a/trellis/pipelines/__init__.py +++ b/trellis/pipelines/__init__.py @@ -11,14 +11,16 @@ def from_pretrained(path: str): """ import os import json + is_local = os.path.exists(f"{path}/pipeline.json") if is_local: config_file = f"{path}/pipeline.json" else: from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = json.load(f) - return globals()[config['name']].from_pretrained(path) + return globals()[config["name"]].from_pretrained(path) diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py index 3a9e0df4ec5fb915d57d30189cac854e3f095620..c6cf8a802479ebbe589615285698a0d9307ad491 100644 --- a/trellis/pipelines/base.py +++ b/trellis/pipelines/base.py @@ -8,6 +8,7 @@ class Pipeline: """ A base class for pipelines. """ + def __init__( self, models: dict[str, nn.Module] = None, @@ -25,20 +26,21 @@ class Pipeline: """ import os import json + is_local = os.path.exists(f"{path}/pipeline.json") if is_local: config_file = f"{path}/pipeline.json" else: from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") - with open(config_file, 'r') as f: - args = json.load(f)['args'] + with open(config_file, "r") as f: + args = json.load(f)["args"] _models = { - k: models.from_pretrained(f"{path}/{v}") - for k, v in args['models'].items() + k: models.from_pretrained(f"{path}/{v}") for k, v in args["models"].items() } new_pipeline = Pipeline(_models) @@ -48,10 +50,10 @@ class Pipeline: @property def device(self) -> torch.device: for model in self.models.values(): - if hasattr(model, 'device'): + if hasattr(model, "device"): return model.device for model in self.models.values(): - if hasattr(model, 'parameters'): + if hasattr(model, "parameters"): return next(model.parameters()).device raise RuntimeError("No device found.") diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py index 54d412fc5d8eb662081a92a56ad078243988c2f9..fc0c23732748373aefab09e10aa4fee22210de8f 100755 --- a/trellis/pipelines/samplers/__init__.py +++ b/trellis/pipelines/samplers/__init__.py @@ -1,2 +1,6 @@ from .base import Sampler -from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler \ No newline at end of file +from .flow_euler import ( + FlowEulerSampler, + FlowEulerCfgSampler, + FlowEulerGuidanceIntervalSampler, +) diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py index 1966ce787009a5ee0c1ed06dce491525ff1dbcbf..aba7e29b110d70cfb046b4dfbfdd98f4a85340a7 100644 --- a/trellis/pipelines/samplers/base.py +++ b/trellis/pipelines/samplers/base.py @@ -8,13 +8,8 @@ class Sampler(ABC): """ @abstractmethod - def sample( - self, - model, - **kwargs - ): + def sample(self, model, **kwargs): """ Sample from a model. """ pass - \ No newline at end of file diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py index d79124cf1b07515e8f0b88684e271028b1e3a71d..b439988d35f0a2012a084c1e3797c52f8d722a54 100644 --- a/trellis/pipelines/samplers/flow_euler.py +++ b/trellis/pipelines/samplers/flow_euler.py @@ -15,6 +15,7 @@ class FlowEulerSampler(Sampler): Args: sigma_min: The minimum scale of noise in flow. """ + def __init__( self, sigma_min: float, @@ -32,11 +33,15 @@ class FlowEulerSampler(Sampler): def _v_to_xstart_eps(self, x_t, t, v): assert x_t.shape == v.shape eps = (1 - t) * v + x_t - x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + x_0 = (1 - self.sigma_min) * x_t - ( + self.sigma_min + (1 - self.sigma_min) * t + ) * v return x_0, eps def _inference_model(self, model, x_t, t, cond=None, **kwargs): - t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + t = torch.tensor( + [1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32 + ) return model(x_t, t, cond, **kwargs) def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): @@ -46,17 +51,11 @@ class FlowEulerSampler(Sampler): @torch.no_grad() def sample_once( - self, - model, - x_t, - t: float, - t_prev: float, - cond: Optional[Any] = None, - **kwargs + self, model, x_t, t: float, t_prev: float, cond: Optional[Any] = None, **kwargs ): """ Sample x_{t-1} from the model using Euler method. - + Args: model: The model to sample from. x_t: The [N x C x ...] tensor of noisy inputs at time t. @@ -70,7 +69,9 @@ class FlowEulerSampler(Sampler): - 'pred_x_prev': x_{t-1}. - 'pred_x_0': a prediction of x_0. """ - pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps, pred_v = self._get_model_prediction( + model, x_t, t, cond, **kwargs + ) pred_x_prev = x_t - (t - t_prev) * pred_v return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) @@ -87,7 +88,7 @@ class FlowEulerSampler(Sampler): ): """ Generate samples from the model using Euler method. - + Args: model: The model to sample from. noise: The initial noise tensor. @@ -121,6 +122,7 @@ class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): """ Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. """ + @torch.no_grad() def sample( self, @@ -136,7 +138,7 @@ class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): ): """ Generate samples from the model using Euler method. - + Args: model: The model to sample from. noise: The initial noise tensor. @@ -154,13 +156,24 @@ class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): - 'pred_x_t': a list of prediction of x_t. - 'pred_x_0': a list of prediction of x_0. """ - return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + return super().sample( + model, + noise, + cond, + steps, + rescale_t, + verbose, + neg_cond=neg_cond, + cfg_strength=cfg_strength, + **kwargs + ) class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): """ Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. """ + @torch.no_grad() def sample( self, @@ -177,7 +190,7 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa ): """ Generate samples from the model using Euler method. - + Args: model: The model to sample from. noise: The initial noise tensor. @@ -196,4 +209,15 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa - 'pred_x_t': a list of prediction of x_t. - 'pred_x_0': a list of prediction of x_0. """ - return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + return super().sample( + model, + noise, + cond, + steps, + rescale_t, + verbose, + neg_cond=neg_cond, + cfg_strength=cfg_strength, + cfg_interval=cfg_interval, + **kwargs + ) diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py index 7074a4d5fea20a8f799416aa6571faca4f9eea06..e6eafea078ba0e76d82b0046ffe2e12005f2d7b0 100644 --- a/trellis/pipelines/samplers/guidance_interval_mixin.py +++ b/trellis/pipelines/samplers/guidance_interval_mixin.py @@ -6,7 +6,9 @@ class GuidanceIntervalSamplerMixin: A mixin class for samplers that apply classifier-free guidance with interval. """ - def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + def _inference_model( + self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs + ): if cfg_interval[0] <= t <= cfg_interval[1]: pred = super()._inference_model(model, x_t, t, cond, **kwargs) neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py index f781e3489ab17def756d5cd676b8858b4ba9b156..8307b232bf35b652c0f451f1274e0a88912854da 100644 --- a/trellis/pipelines/trellis_image_to_3d.py +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -26,6 +26,7 @@ class TrellisImageTo3DPipeline(Pipeline): slat_normalization (dict): The normalization parameters for the structured latent. image_cond_model (str): The name of the image conditioning model. """ + def __init__( self, models: dict[str, nn.Module] = None, @@ -53,33 +54,45 @@ class TrellisImageTo3DPipeline(Pipeline): Args: path (str): The path to the model. Can be either local path or a Hugging Face repository. """ - pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) + pipeline = super( + TrellisImageTo3DPipeline, TrellisImageTo3DPipeline + ).from_pretrained(path) new_pipeline = TrellisImageTo3DPipeline() new_pipeline.__dict__ = pipeline.__dict__ args = pipeline._pretrained_args - new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) - new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + new_pipeline.sparse_structure_sampler = getattr( + samplers, args["sparse_structure_sampler"]["name"] + )(**args["sparse_structure_sampler"]["args"]) + new_pipeline.sparse_structure_sampler_params = args["sparse_structure_sampler"][ + "params" + ] - new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) - new_pipeline.slat_sampler_params = args['slat_sampler']['params'] + new_pipeline.slat_sampler = getattr(samplers, args["slat_sampler"]["name"])( + **args["slat_sampler"]["args"] + ) + new_pipeline.slat_sampler_params = args["slat_sampler"]["params"] - new_pipeline.slat_normalization = args['slat_normalization'] + new_pipeline.slat_normalization = args["slat_normalization"] - new_pipeline._init_image_cond_model(args['image_cond_model']) + new_pipeline._init_image_cond_model(args["image_cond_model"]) return new_pipeline - + def _init_image_cond_model(self, name: str): """ Initialize the image conditioning model. """ - dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) + dinov2_model = torch.hub.load("facebookresearch/dinov2", name, pretrained=True) dinov2_model.eval() - self.models['image_cond_model'] = dinov2_model - transform = transforms.Compose([ - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + self.models["image_cond_model"] = dinov2_model + transform = transforms.Compose( + [ + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) self.image_cond_model_transform = transform def preprocess_image(self, input: Image.Image) -> Image.Image: @@ -88,29 +101,42 @@ class TrellisImageTo3DPipeline(Pipeline): """ # if has alpha channel, use it directly; otherwise, remove background has_alpha = False - if input.mode == 'RGBA': + if input.mode == "RGBA": alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True if has_alpha: output = input else: - input = input.convert('RGB') + input = input.convert("RGB") max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: - input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) - if getattr(self, 'rembg_session', None) is None: - self.rembg_session = rembg.new_session('u2net') + input = input.resize( + (int(input.width * scale), int(input.height * scale)), + Image.Resampling.LANCZOS, + ) + if getattr(self, "rembg_session", None) is None: + self.rembg_session = rembg.new_session("u2net") output = rembg.remove(input, session=self.rembg_session) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) - bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + bbox = ( + np.min(bbox[:, 1]), + np.min(bbox[:, 0]), + np.max(bbox[:, 1]), + np.max(bbox[:, 0]), + ) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1.2) - bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + bbox = ( + center[0] - size // 2, + center[1] - size // 2, + center[0] + size // 2, + center[1] + size // 2, + ) output = output.crop(bbox) # type: ignore output = output.resize((518, 518), Image.Resampling.LANCZOS) output = np.array(output).astype(np.float32) / 255 @@ -119,7 +145,9 @@ class TrellisImageTo3DPipeline(Pipeline): return output @torch.no_grad() - def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor: + def encode_image( + self, image: Union[torch.Tensor, list[Image.Image]] + ) -> torch.Tensor: """ Encode the image. @@ -132,19 +160,21 @@ class TrellisImageTo3DPipeline(Pipeline): if isinstance(image, torch.Tensor): assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" elif isinstance(image, list): - assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + assert all( + isinstance(i, Image.Image) for i in image + ), "Image list should be list of PIL images" image = [i.resize((518, 518), Image.LANCZOS) for i in image] - image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [np.array(i.convert("RGB")).astype(np.float32) / 255 for i in image] image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] image = torch.stack(image).to(self.device) else: raise ValueError(f"Unsupported type of image: {type(image)}") - + image = self.image_cond_model_transform(image).to(self.device) - features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] + features = self.models["image_cond_model"](image, is_training=True)["x_prenorm"] patchtokens = F.layer_norm(features, features.shape[-1:]) return patchtokens - + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: """ Get the conditioning information for the model. @@ -158,8 +188,8 @@ class TrellisImageTo3DPipeline(Pipeline): cond = self.encode_image(image) neg_cond = torch.zeros_like(cond) return { - 'cond': cond, - 'neg_cond': neg_cond, + "cond": cond, + "neg_cond": neg_cond, } def sample_sparse_structure( @@ -170,35 +200,33 @@ class TrellisImageTo3DPipeline(Pipeline): ) -> torch.Tensor: """ Sample sparse structures with the given conditioning. - + Args: cond (dict): The conditioning information. num_samples (int): The number of samples to generate. sampler_params (dict): Additional parameters for the sampler. """ # Sample occupancy latent - flow_model = self.models['sparse_structure_flow_model'] + flow_model = self.models["sparse_structure_flow_model"] reso = flow_model.resolution - noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to( + self.device + ) sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} z_s = self.sparse_structure_sampler.sample( - flow_model, - noise, - **cond, - **sampler_params, - verbose=True + flow_model, noise, **cond, **sampler_params, verbose=True ).samples - + # Decode occupancy latent - decoder = self.models['sparse_structure_decoder'] - coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() + decoder = self.models["sparse_structure_decoder"] + coords = torch.argwhere(decoder(z_s) > 0)[:, [0, 2, 3, 4]].int() return coords def decode_slat( self, slat: sp.SparseTensor, - formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + formats: List[str] = ["mesh", "gaussian", "radiance_field"], ) -> dict: """ Decode the structured latent. @@ -211,14 +239,14 @@ class TrellisImageTo3DPipeline(Pipeline): dict: The decoded structured latent. """ ret = {} - if 'mesh' in formats: - ret['mesh'] = self.models['slat_decoder_mesh'](slat) - if 'gaussian' in formats: - ret['gaussian'] = self.models['slat_decoder_gs'](slat) - if 'radiance_field' in formats: - ret['radiance_field'] = self.models['slat_decoder_rf'](slat) + if "mesh" in formats: + ret["mesh"] = self.models["slat_decoder_mesh"](slat) + if "gaussian" in formats: + ret["gaussian"] = self.models["slat_decoder_gs"](slat) + if "radiance_field" in formats: + ret["radiance_field"] = self.models["slat_decoder_rf"](slat) return ret - + def sample_slat( self, cond: dict, @@ -227,31 +255,27 @@ class TrellisImageTo3DPipeline(Pipeline): ) -> sp.SparseTensor: """ Sample structured latent with the given conditioning. - + Args: cond (dict): The conditioning information. coords (torch.Tensor): The coordinates of the sparse structure. sampler_params (dict): Additional parameters for the sampler. """ # Sample structured latent - flow_model = self.models['slat_flow_model'] + flow_model = self.models["slat_flow_model"] noise = sp.SparseTensor( feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), coords=coords, ) sampler_params = {**self.slat_sampler_params, **sampler_params} slat = self.slat_sampler.sample( - flow_model, - noise, - **cond, - **sampler_params, - verbose=True + flow_model, noise, **cond, **sampler_params, verbose=True ).samples - std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) - mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) + std = torch.tensor(self.slat_normalization["std"])[None].to(slat.device) + mean = torch.tensor(self.slat_normalization["mean"])[None].to(slat.device) slat = slat * std + mean - + return slat @torch.no_grad() @@ -262,7 +286,7 @@ class TrellisImageTo3DPipeline(Pipeline): seed: int = 42, sparse_structure_sampler_params: dict = {}, slat_sampler_params: dict = {}, - formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + formats: List[str] = ["mesh", "gaussian", "radiance_field"], preprocess_image: bool = True, ) -> dict: """ @@ -279,7 +303,9 @@ class TrellisImageTo3DPipeline(Pipeline): image = self.preprocess_image(image) cond = self.get_cond([image]) torch.manual_seed(seed) - coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) + coords = self.sample_sparse_structure( + cond, num_samples, sparse_structure_sampler_params + ) slat = self.sample_slat(cond, coords, slat_sampler_params) return self.decode_slat(slat, formats) @@ -289,56 +315,80 @@ class TrellisImageTo3DPipeline(Pipeline): sampler_name: str, num_images: int, num_steps: int, - mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + mode: Literal["stochastic", "multidiffusion"] = "stochastic", ): """ Inject a sampler with multiple images as condition. - + Args: sampler_name (str): The name of the sampler to inject. num_images (int): The number of images to condition on. num_steps (int): The number of steps to run the sampler for. """ sampler = getattr(self, sampler_name) - setattr(sampler, f'_old_inference_model', sampler._inference_model) + setattr(sampler, f"_old_inference_model", sampler._inference_model) - if mode == 'stochastic': + if mode == "stochastic": if num_images > num_steps: - print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " - "This may lead to performance degradation.\033[0m") + print( + f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " + "This may lead to performance degradation.\033[0m" + ) cond_indices = (np.arange(num_steps) % num_images).tolist() + def _new_inference_model(self, model, x_t, t, cond, **kwargs): cond_idx = cond_indices.pop(0) - cond_i = cond[cond_idx:cond_idx+1] + cond_i = cond[cond_idx : cond_idx + 1] return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) - - elif mode =='multidiffusion': + + elif mode == "multidiffusion": from .samplers import FlowEulerSampler - def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + + def _new_inference_model( + self, + model, + x_t, + t, + cond, + neg_cond, + cfg_strength, + cfg_interval, + **kwargs, + ): if cfg_interval[0] <= t <= cfg_interval[1]: preds = [] for i in range(len(cond)): - preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + preds.append( + FlowEulerSampler._inference_model( + self, model, x_t, t, cond[i : i + 1], **kwargs + ) + ) pred = sum(preds) / len(preds) - neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) + neg_pred = FlowEulerSampler._inference_model( + self, model, x_t, t, neg_cond, **kwargs + ) return (1 + cfg_strength) * pred - cfg_strength * neg_pred else: preds = [] for i in range(len(cond)): - preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + preds.append( + FlowEulerSampler._inference_model( + self, model, x_t, t, cond[i : i + 1], **kwargs + ) + ) pred = sum(preds) / len(preds) return pred - + else: raise ValueError(f"Unsupported mode: {mode}") - + sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) yield sampler._inference_model = sampler._old_inference_model - delattr(sampler, f'_old_inference_model') + delattr(sampler, f"_old_inference_model") @torch.no_grad() def run_multi_image( @@ -348,9 +398,9 @@ class TrellisImageTo3DPipeline(Pipeline): seed: int = 42, sparse_structure_sampler_params: dict = {}, slat_sampler_params: dict = {}, - formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + formats: List[str] = ["mesh", "gaussian", "radiance_field"], preprocess_image: bool = True, - mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + mode: Literal["stochastic", "multidiffusion"] = "stochastic", ) -> dict: """ Run the pipeline with multiple images as condition @@ -365,12 +415,21 @@ class TrellisImageTo3DPipeline(Pipeline): if preprocess_image: images = [self.preprocess_image(image) for image in images] cond = self.get_cond(images) - cond['neg_cond'] = cond['neg_cond'][:1] + cond["neg_cond"] = cond["neg_cond"][:1] torch.manual_seed(seed) - ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') - with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): - coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) - slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') - with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode): + ss_steps = { + **self.sparse_structure_sampler_params, + **sparse_structure_sampler_params, + }.get("steps") + with self.inject_sampler_multi_image( + "sparse_structure_sampler", len(images), ss_steps, mode=mode + ): + coords = self.sample_sparse_structure( + cond, num_samples, sparse_structure_sampler_params + ) + slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get("steps") + with self.inject_sampler_multi_image( + "slat_sampler", len(images), slat_steps, mode=mode + ): slat = self.sample_slat(cond, coords, slat_sampler_params) return self.decode_slat(slat, formats) diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py index 0339355c56b8d17f72e926650d140a658452fbe9..36304379f4804ff17c71ede17d142b30fbdbff44 100755 --- a/trellis/renderers/__init__.py +++ b/trellis/renderers/__init__.py @@ -1,15 +1,16 @@ import importlib __attributes = { - 'OctreeRenderer': 'octree_renderer', - 'GaussianRenderer': 'gaussian_render', - 'MeshRenderer': 'mesh_renderer', + "OctreeRenderer": "octree_renderer", + "GaussianRenderer": "gaussian_render", + "MeshRenderer": "mesh_renderer", } __submodules = [] __all__ = list(__attributes.keys()) + __submodules + def __getattr__(name): if name not in globals(): if name in __attributes: @@ -25,7 +26,7 @@ def __getattr__(name): # For Pylance -if __name__ == '__main__': +if __name__ == "__main__": from .octree_renderer import OctreeRenderer from .gaussian_render import GaussianRenderer - from .mesh_renderer import MeshRenderer \ No newline at end of file + from .mesh_renderer import MeshRenderer diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py index 57108e3cccf6aab8e3059431557c461de46aff1a..803a5a14f25a209d51732e767cd5fb525972d9b2 100755 --- a/trellis/renderers/gaussian_render.py +++ b/trellis/renderers/gaussian_render.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -20,10 +20,10 @@ from easydict import EasyDict as edict def intrinsics_to_projection( - intrinsics: torch.Tensor, - near: float, - far: float, - ) -> torch.Tensor: + intrinsics: torch.Tensor, + near: float, + far: float, +) -> torch.Tensor: """ OpenCV intrinsics to OpenGL perspective matrix @@ -40,25 +40,40 @@ def intrinsics_to_projection( ret[0, 0] = 2 * fx ret[1, 1] = 2 * fy ret[0, 2] = 2 * cx - 1 - ret[1, 2] = - 2 * cy + 1 + ret[1, 2] = -2 * cy + 1 ret[2, 2] = far / (far - near) ret[2, 3] = near * far / (near - far) - ret[3, 2] = 1. + ret[3, 2] = 1.0 return ret -def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): +def render( + viewpoint_camera, + pc: Gaussian, + pipe, + bg_color: torch.Tensor, + scaling_modifier=1.0, + override_color=None, +): """ - Render the scene. - + Render the scene. + Background tensor (bg_color) must be on GPU! """ # lazy import - if 'GaussianRasterizer' not in globals(): - from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings - + if "GaussianRasterizer" not in globals(): + from diff_gaussian_rasterization import ( + GaussianRasterizer, + GaussianRasterizationSettings, + ) + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means - screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + screenspace_points = ( + torch.zeros_like( + pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" + ) + + 0 + ) try: screenspace_points.retain_grad() except: @@ -66,9 +81,13 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali # Set up rasterization configuration tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) - + kernel_size = pipe.kernel_size - subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") + subpixel_offset = torch.zeros( + (int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), + dtype=torch.float32, + device="cuda", + ) raster_settings = GaussianRasterizationSettings( image_height=int(viewpoint_camera.image_height), @@ -84,9 +103,9 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali sh_degree=pc.active_sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, - debug=pipe.debug + debug=pipe.debug, ) - + rasterizer = GaussianRasterizer(raster_settings=raster_settings) means3D = pc.get_xyz @@ -110,9 +129,13 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali colors_precomp = None if override_color is None: if pipe.convert_SHs_python: - shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) - dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) - dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + shs_view = pc.get_features.transpose(1, 2).view( + -1, 3, (pc.max_sh_degree + 1) ** 2 + ) + dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( + pc.get_features.shape[0], 1 + ) + dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) else: @@ -120,24 +143,28 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali else: colors_precomp = override_color - # Rasterize visible Gaussians to image, obtain their radii (on screen). + # Rasterize visible Gaussians to image, obtain their radii (on screen). rendered_image, radii = rasterizer( - means3D = means3D, - means2D = means2D, - shs = shs, - colors_precomp = colors_precomp, - opacities = opacity, - scales = scales, - rotations = rotations, - cov3D_precomp = cov3D_precomp + means3D=means3D, + means2D=means2D, + shs=shs, + colors_precomp=colors_precomp, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=cov3D_precomp, ) # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. - return edict({"render": rendered_image, + return edict( + { + "render": rendered_image, "viewspace_points": screenspace_points, - "visibility_filter" : radii > 0, - "radii": radii}) + "visibility_filter": radii > 0, + "radii": radii, + } + ) class GaussianRenderer: @@ -149,30 +176,34 @@ class GaussianRenderer: """ def __init__(self, rendering_options={}) -> None: - self.pipe = edict({ - "kernel_size": 0.1, - "convert_SHs_python": False, - "compute_cov3D_python": False, - "scale_modifier": 1.0, - "debug": False - }) - self.rendering_options = edict({ - "resolution": None, - "near": None, - "far": None, - "ssaa": 1, - "bg_color": 'random', - }) + self.pipe = edict( + { + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False, + } + ) + self.rendering_options = edict( + { + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": "random", + } + ) self.rendering_options.update(rendering_options) self.bg_color = None - + def render( - self, - gausssian: Gaussian, - extrinsics: torch.Tensor, - intrinsics: torch.Tensor, - colors_overwrite: torch.Tensor = None - ) -> edict: + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: """ Render the gausssian. @@ -190,13 +221,15 @@ class GaussianRenderer: near = self.rendering_options["near"] far = self.rendering_options["far"] ssaa = self.rendering_options["ssaa"] - - if self.rendering_options["bg_color"] == 'random': + + if self.rendering_options["bg_color"] == "random": self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") if np.random.rand() < 0.5: self.bg_color += 1 else: - self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], dtype=torch.float32, device="cuda" + ) view = extrinsics perspective = intrinsics_to_projection(intrinsics, near, far) @@ -205,27 +238,40 @@ class GaussianRenderer: focaly = intrinsics[1, 1] fovx = 2 * torch.atan(0.5 / focalx) fovy = 2 * torch.atan(0.5 / focaly) - - camera_dict = edict({ - "image_height": resolution * ssaa, - "image_width": resolution * ssaa, - "FoVx": fovx, - "FoVy": fovy, - "znear": near, - "zfar": far, - "world_view_transform": view.T.contiguous(), - "projection_matrix": perspective.T.contiguous(), - "full_proj_transform": (perspective @ view).T.contiguous(), - "camera_center": camera - }) + + camera_dict = edict( + { + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera, + } + ) # Render - render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier) + render_ret = render( + camera_dict, + gausssian, + self.pipe, + self.bg_color, + override_color=colors_overwrite, + scaling_modifier=self.pipe.scale_modifier, + ) if ssaa > 1: - render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - - ret = edict({ - 'color': render_ret['render'] - }) + render_ret.render = F.interpolate( + render_ret.render[None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + + ret = edict({"color": render_ret["render"]}) return ret diff --git a/trellis/renderers/mesh_renderer.py b/trellis/renderers/mesh_renderer.py index 1c3afa75a18a3f234c8e2848b71af22393f61399..2882b037f3db10425a8e8ece645a7098badf8af8 100644 --- a/trellis/renderers/mesh_renderer.py +++ b/trellis/renderers/mesh_renderer.py @@ -13,10 +13,10 @@ import torch.nn.functional as F def intrinsics_to_projection( - intrinsics: torch.Tensor, - near: float, - far: float, - ) -> torch.Tensor: + intrinsics: torch.Tensor, + near: float, + far: float, +) -> torch.Tensor: """ OpenCV intrinsics to OpenGL perspective matrix @@ -33,10 +33,10 @@ def intrinsics_to_projection( ret[0, 0] = 2 * fx ret[1, 1] = 2 * fy ret[0, 2] = 2 * cx - 1 - ret[1, 2] = - 2 * cy + 1 + ret[1, 2] = -2 * cy + 1 ret[2, 2] = far / (far - near) ret[2, 3] = near * far / (near - far) - ret[3, 2] = 1. + ret[3, 2] = 1.0 return ret @@ -47,25 +47,23 @@ class MeshRenderer: Args: rendering_options (dict): Rendering options. glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. - """ - def __init__(self, rendering_options={}, device='cuda'): - self.rendering_options = edict({ - "resolution": None, - "near": None, - "far": None, - "ssaa": 1 - }) + """ + + def __init__(self, rendering_options={}, device="cuda"): + self.rendering_options = edict( + {"resolution": None, "near": None, "far": None, "ssaa": 1} + ) self.rendering_options.update(rendering_options) self.glctx = dr.RasterizeCudaContext(device=device) - self.device=device - + self.device = device + def render( - self, - mesh : MeshExtractResult, - extrinsics: torch.Tensor, - intrinsics: torch.Tensor, - return_types = ["mask", "normal", "depth"] - ) -> edict: + self, + mesh: MeshExtractResult, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types=["mask", "normal", "depth"], + ) -> edict: """ Render the mesh. @@ -87,51 +85,78 @@ class MeshRenderer: near = self.rendering_options["near"] far = self.rendering_options["far"] ssaa = self.rendering_options["ssaa"] - + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: - default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device) - ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types} + default_img = torch.zeros( + (1, resolution, resolution, 3), dtype=torch.float32, device=self.device + ) + ret_dict = { + k: default_img + if k in ["normal", "normal_map", "color"] + else default_img[..., :1] + for k in return_types + } return ret_dict - + perspective = intrinsics_to_projection(intrinsics, near, far) - + RT = extrinsics.unsqueeze(0) full_proj = (perspective @ extrinsics).unsqueeze(0) - + vertices = mesh.vertices.unsqueeze(0) - vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + vertices_homo = torch.cat( + [vertices, torch.ones_like(vertices[..., :1])], dim=-1 + ) vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2)) vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) faces_int = mesh.faces.int() rast, _ = dr.rasterize( - self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)) - + self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa) + ) + out_dict = edict() for type in return_types: img = None - if type == "mask" : - img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) + if type == "mask": + img = dr.antialias( + (rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int + ) elif type == "depth": - img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0] + img = dr.interpolate( + vertices_camera[..., 2:3].contiguous(), rast, faces_int + )[0] img = dr.antialias(img, rast, vertices_clip, faces_int) - elif type == "normal" : + elif type == "normal": img = dr.interpolate( - mesh.face_normal.reshape(1, -1, 3), rast, - torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3) + mesh.face_normal.reshape(1, -1, 3), + rast, + torch.arange( + mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int + ).reshape(-1, 3), )[0] img = dr.antialias(img, rast, vertices_clip, faces_int) # normalize norm pictures img = (img + 1) / 2 - elif type == "normal_map" : - img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0] + elif type == "normal_map": + img = dr.interpolate( + mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int + )[0] img = dr.antialias(img, rast, vertices_clip, faces_int) - elif type == "color" : - img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0] + elif type == "color": + img = dr.interpolate( + mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int + )[0] img = dr.antialias(img, rast, vertices_clip, faces_int) if ssaa > 1: - img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = F.interpolate( + img.permute(0, 3, 1, 2), + (resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ) img = img.squeeze() else: img = img.permute(0, 3, 1, 2).squeeze() diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py index 136069cdb0645b5759d5d17f7815612a1dfc7bea..829fa15175f5211870efdeb7e72a0245b199e5e1 100755 --- a/trellis/renderers/octree_renderer.py +++ b/trellis/renderers/octree_renderer.py @@ -9,10 +9,10 @@ from ..representations.octree import DfsOctree def intrinsics_to_projection( - intrinsics: torch.Tensor, - near: float, - far: float, - ) -> torch.Tensor: + intrinsics: torch.Tensor, + near: float, + far: float, +) -> torch.Tensor: """ OpenCV intrinsics to OpenGL perspective matrix @@ -29,23 +29,38 @@ def intrinsics_to_projection( ret[0, 0] = 2 * fx ret[1, 1] = 2 * fy ret[0, 2] = 2 * cx - 1 - ret[1, 2] = - 2 * cy + 1 + ret[1, 2] = -2 * cy + 1 ret[2, 2] = far / (far - near) ret[2, 3] = near * far / (near - far) - ret[3, 2] = 1. + ret[3, 2] = 1.0 return ret -def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): +def render( + viewpoint_camera, + octree: DfsOctree, + pipe, + bg_color: torch.Tensor, + scaling_modifier=1.0, + used_rank=None, + colors_overwrite=None, + aux=None, + halton_sampler=None, +): """ - Render the scene. - + Render the scene. + Background tensor (bg_color) must be on GPU! """ # lazy import - if 'OctreeTrivecRasterizer' not in globals(): - from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer - + if "OctreeTrivecRasterizer" not in globals(): + from diffoctreerast import ( + OctreeVoxelRasterizer, + OctreeGaussianRasterizer, + OctreeTrivecRasterizer, + OctreeDecoupolyRasterizer, + ) + # Set up rasterization configuration tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) @@ -96,69 +111,73 @@ def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, if octree.primitive == "voxel": renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) rgb, depth, alpha, distloss = renderer( - positions = positions, - densities = densities, - shs = shs, - colors_precomp = colors_precomp, - depths = depths, - aabb = octree.aabb, - aux = aux, + positions=positions, + densities=densities, + shs=shs, + colors_precomp=colors_precomp, + depths=depths, + aabb=octree.aabb, + aux=aux, ) - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha - ret['distloss'] = distloss + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha + ret["distloss"] = distloss elif octree.primitive == "gaussian": renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) rgb, depth, alpha = renderer( - positions = positions, - opacities = opacities, - shs = shs, - colors_precomp = colors_precomp, - depths = depths, - aabb = octree.aabb, - aux = aux, + positions=positions, + opacities=opacities, + shs=shs, + colors_precomp=colors_precomp, + depths=depths, + aabb=octree.aabb, + aux=aux, ) - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha elif octree.primitive == "trivec": - raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] + raster_settings.used_rank = ( + used_rank if used_rank is not None else trivecs.shape[1] + ) renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) rgb, depth, alpha, percent_depth = renderer( - positions = positions, - trivecs = trivecs, - densities = densities, - shs = shs, - colors_precomp = colors_precomp, - colors_overwrite = colors_overwrite, - depths = depths, - aabb = octree.aabb, - aux = aux, - halton_sampler = halton_sampler, + positions=positions, + trivecs=trivecs, + densities=densities, + shs=shs, + colors_precomp=colors_precomp, + colors_overwrite=colors_overwrite, + depths=depths, + aabb=octree.aabb, + aux=aux, + halton_sampler=halton_sampler, ) - ret['percent_depth'] = percent_depth - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha + ret["percent_depth"] = percent_depth + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha elif octree.primitive == "decoupoly": - raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] + raster_settings.used_rank = ( + used_rank if used_rank is not None else decoupolys_V.shape[1] + ) renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) rgb, depth, alpha = renderer( - positions = positions, - decoupolys_V = decoupolys_V, - decoupolys_g = decoupolys_g, - densities = densities, - shs = shs, - colors_precomp = colors_precomp, - depths = depths, - aabb = octree.aabb, - aux = aux, + positions=positions, + decoupolys_V=decoupolys_V, + decoupolys_g=decoupolys_g, + densities=densities, + shs=shs, + colors_precomp=colors_precomp, + depths=depths, + aabb=octree.aabb, + aux=aux, ) - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha - + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha + return ret @@ -174,37 +193,43 @@ class OctreeRenderer: try: import diffoctreerast except ImportError: - print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") + print( + "\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m" + ) self.unsupported = True else: self.unsupported = False - - self.pipe = edict({ - "with_distloss": False, - "with_aux": False, - "scale_modifier": 1.0, - "used_rank": None, - "jitter": False, - "debug": False, - }) - self.rendering_options = edict({ - "resolution": None, - "near": None, - "far": None, - "ssaa": 1, - "bg_color": 'random', - }) + + self.pipe = edict( + { + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + } + ) + self.rendering_options = edict( + { + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": "random", + } + ) self.halton_sampler = qmc.Halton(2, scramble=False) self.rendering_options.update(rendering_options) self.bg_color = None - + def render( - self, - octree: DfsOctree, - extrinsics: torch.Tensor, - intrinsics: torch.Tensor, - colors_overwrite: torch.Tensor = None, - ) -> edict: + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: """ Render the octree. @@ -227,27 +252,53 @@ class OctreeRenderer: near = self.rendering_options["near"] far = self.rendering_options["far"] ssaa = self.rendering_options["ssaa"] - + if self.unsupported: image = np.zeros((512, 512, 3), dtype=np.uint8) - text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[ + 0 + ] origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 - image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) + image = cv2.putText( + image, + "Unsupported", + origin, + cv2.FONT_HERSHEY_SIMPLEX, + 2, + (255, 255, 255), + 3, + cv2.LINE_AA, + ) return { - 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, + "color": torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) + / 255, } - - if self.rendering_options["bg_color"] == 'random': + + if self.rendering_options["bg_color"] == "random": self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") if np.random.rand() < 0.5: self.bg_color += 1 else: - self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], dtype=torch.float32, device="cuda" + ) if self.pipe["with_aux"]: aux = { - 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, - 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + "grad_color2": torch.zeros( + (octree.num_leaf_nodes, 3), + dtype=torch.float32, + requires_grad=True, + device="cuda", + ) + + 0, + "contributions": torch.zeros( + (octree.num_leaf_nodes, 1), + dtype=torch.float32, + requires_grad=True, + device="cuda", + ) + + 0, } for k in aux.keys(): aux[k].requires_grad_() @@ -262,39 +313,77 @@ class OctreeRenderer: focaly = intrinsics[1, 1] fovx = 2 * torch.atan(0.5 / focalx) fovy = 2 * torch.atan(0.5 / focaly) - - camera_dict = edict({ - "image_height": resolution * ssaa, - "image_width": resolution * ssaa, - "FoVx": fovx, - "FoVy": fovy, - "znear": near, - "zfar": far, - "world_view_transform": view.T.contiguous(), - "projection_matrix": perspective.T.contiguous(), - "full_proj_transform": (perspective @ view).T.contiguous(), - "camera_center": camera - }) + + camera_dict = edict( + { + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera, + } + ) # Render - render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) + render_ret = render( + camera_dict, + octree, + self.pipe, + self.bg_color, + aux=aux, + colors_overwrite=colors_overwrite, + scaling_modifier=self.pipe.scale_modifier, + used_rank=self.pipe.used_rank, + halton_sampler=self.halton_sampler, + ) if ssaa > 1: - render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - if hasattr(render_ret, 'percent_depth'): - render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.rgb = F.interpolate( + render_ret.rgb[None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + render_ret.depth = F.interpolate( + render_ret.depth[None, None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + render_ret.alpha = F.interpolate( + render_ret.alpha[None, None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + if hasattr(render_ret, "percent_depth"): + render_ret.percent_depth = F.interpolate( + render_ret.percent_depth[None, None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() - ret = edict({ - 'color': render_ret.rgb, - 'depth': render_ret.depth, - 'alpha': render_ret.alpha, - }) - if self.pipe["with_distloss"] and 'distloss' in render_ret: - ret['distloss'] = render_ret.distloss + ret = edict( + { + "color": render_ret.rgb, + "depth": render_ret.depth, + "alpha": render_ret.alpha, + } + ) + if self.pipe["with_distloss"] and "distloss" in render_ret: + ret["distloss"] = render_ret.distloss if self.pipe["with_aux"]: - ret['aux'] = aux - if hasattr(render_ret, 'percent_depth'): - ret['percent_depth'] = render_ret.percent_depth + ret["aux"] = aux + if hasattr(render_ret, "percent_depth"): + ret["percent_depth"] = render_ret.percent_depth return ret diff --git a/trellis/renderers/sh_utils.py b/trellis/renderers/sh_utils.py index bbca7d192aa3a7edf8c5b2d24dee535eac765785..d552d61ecfdf048162f8810f41a18f09d3f230f3 100755 --- a/trellis/renderers/sh_utils.py +++ b/trellis/renderers/sh_utils.py @@ -30,7 +30,7 @@ C2 = [ -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, - 0.5462742152960396 + 0.5462742152960396, ] C3 = [ -0.5900435899266435, @@ -39,7 +39,7 @@ C3 = [ 0.3731763325901154, -0.4570457994644658, 1.445305721320277, - -0.5900435899266435 + -0.5900435899266435, ] C4 = [ 2.5033429417967046, @@ -51,7 +51,7 @@ C4 = [ 0.47308734787878004, -1.7701307697799304, 0.6258357354491761, -] +] def eval_sh(deg, sh, dirs): @@ -74,45 +74,55 @@ def eval_sh(deg, sh, dirs): result = C0 * sh[..., 0] if deg > 0: x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - result = (result - - C1 * y * sh[..., 1] + - C1 * z * sh[..., 2] - - C1 * x * sh[..., 3]) + result = ( + result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] + ) if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z - result = (result + - C2[0] * xy * sh[..., 4] + - C2[1] * yz * sh[..., 5] + - C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + - C2[3] * xz * sh[..., 7] + - C2[4] * (xx - yy) * sh[..., 8]) + result = ( + result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8] + ) if deg > 2: - result = (result + - C3[0] * y * (3 * xx - yy) * sh[..., 9] + - C3[1] * xy * z * sh[..., 10] + - C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + - C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + - C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + - C3[5] * z * (xx - yy) * sh[..., 14] + - C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + result = ( + result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15] + ) if deg > 3: - result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + - C4[1] * yz * (3 * xx - yy) * sh[..., 17] + - C4[2] * xy * (7 * zz - 1) * sh[..., 18] + - C4[3] * yz * (7 * zz - 3) * sh[..., 19] + - C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + - C4[5] * xz * (7 * zz - 3) * sh[..., 21] + - C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + - C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + - C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + result = ( + result + + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] + * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) + * sh[..., 24] + ) return result + def RGB2SH(rgb): return (rgb - 0.5) / C0 + def SH2RGB(sh): - return sh * C0 + 0.5 \ No newline at end of file + return sh * C0 + 0.5 diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py index e3de6e180bd732836af876d748255595be2d4d74..cd65d9b8ce925c0db408faf54f451ca83b11f989 100755 --- a/trellis/representations/gaussian/__init__.py +++ b/trellis/representations/gaussian/__init__.py @@ -1 +1 @@ -from .gaussian_model import Gaussian \ No newline at end of file +from .gaussian_model import Gaussian diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py index 54ba16f1550e8edb1728605202cc31b6dd805d90..47d114fe3a0aea5b3962dfb1333d08c109bc48fc 100755 --- a/trellis/representations/gaussian/gaussian_model.py +++ b/trellis/representations/gaussian/gaussian_model.py @@ -7,27 +7,27 @@ import utils3d class Gaussian: def __init__( - self, - aabb : list, - sh_degree : int = 0, - mininum_kernel_size : float = 0.0, - scaling_bias : float = 0.01, - opacity_bias : float = 0.1, - scaling_activation : str = "exp", - device='cuda' - ): + self, + aabb: list, + sh_degree: int = 0, + mininum_kernel_size: float = 0.0, + scaling_bias: float = 0.01, + opacity_bias: float = 0.1, + scaling_activation: str = "exp", + device="cuda", + ): self.init_params = { - 'aabb': aabb, - 'sh_degree': sh_degree, - 'mininum_kernel_size': mininum_kernel_size, - 'scaling_bias': scaling_bias, - 'opacity_bias': opacity_bias, - 'scaling_activation': scaling_activation, + "aabb": aabb, + "sh_degree": sh_degree, + "mininum_kernel_size": mininum_kernel_size, + "scaling_bias": scaling_bias, + "opacity_bias": opacity_bias, + "scaling_activation": scaling_activation, } - + self.sh_degree = sh_degree self.active_sh_degree = sh_degree - self.mininum_kernel_size = mininum_kernel_size + self.mininum_kernel_size = mininum_kernel_size self.scaling_bias = scaling_bias self.opacity_bias = opacity_bias self.scaling_activation_type = scaling_activation @@ -48,7 +48,7 @@ class Gaussian: actual_covariance = L @ L.transpose(1, 2) symm = strip_symmetric(actual_covariance) return symm - + if self.scaling_activation_type == "exp": self.scaling_activation = torch.exp self.inverse_scaling_activation = torch.log @@ -62,74 +62,91 @@ class Gaussian: self.inverse_opacity_activation = inverse_sigmoid self.rotation_activation = torch.nn.functional.normalize - - self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() + + self.scale_bias = self.inverse_scaling_activation( + torch.tensor(self.scaling_bias) + ).cuda() self.rots_bias = torch.zeros((4)).cuda() self.rots_bias[0] = 1 - self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() + self.opacity_bias = self.inverse_opacity_activation( + torch.tensor(self.opacity_bias) + ).cuda() @property def get_scaling(self): scales = self.scaling_activation(self._scaling + self.scale_bias) - scales = torch.square(scales) + self.mininum_kernel_size ** 2 + scales = torch.square(scales) + self.mininum_kernel_size**2 scales = torch.sqrt(scales) return scales - + @property def get_rotation(self): return self.rotation_activation(self._rotation + self.rots_bias[None, :]) - + @property def get_xyz(self): return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] - + @property def get_features(self): - return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc - + return ( + torch.cat((self._features_dc, self._features_rest), dim=2) + if self._features_rest is not None + else self._features_dc + ) + @property def get_opacity(self): return self.opacity_activation(self._opacity + self.opacity_bias) - - def get_covariance(self, scaling_modifier = 1): - return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) - + + def get_covariance(self, scaling_modifier=1): + return self.covariance_activation( + self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :] + ) + def from_scaling(self, scales): - scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size**2) self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias - + def from_rotation(self, rots): self._rotation = rots - self.rots_bias[None, :] - + def from_xyz(self, xyz): self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] - + def from_features(self, features): self._features_dc = features - + def from_opacity(self, opacities): self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias def construct_list_of_attributes(self): - l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + l = ["x", "y", "z", "nx", "ny", "nz"] # All channels except the 3 DC - for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): - l.append('f_dc_{}'.format(i)) - l.append('opacity') + for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): + l.append("f_dc_{}".format(i)) + l.append("opacity") for i in range(self._scaling.shape[1]): - l.append('scale_{}'.format(i)) + l.append("scale_{}".format(i)) for i in range(self._rotation.shape[1]): - l.append('rot_{}'.format(i)) + l.append("rot_{}".format(i)) return l - + def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): xyz = self.get_xyz.detach().cpu().numpy() normals = np.zeros_like(xyz) - f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_dc = ( + self._features_dc.detach() + .transpose(1, 2) + .flatten(start_dim=1) + .contiguous() + .cpu() + .numpy() + ) opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() scale = torch.log(self.get_scaling).detach().cpu().numpy() rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() - + if transform is not None: transform = np.array(transform) xyz = np.matmul(xyz, transform.T) @@ -137,20 +154,29 @@ class Gaussian: rotation = np.matmul(transform, rotation) rotation = utils3d.numpy.matrix_to_quaternion(rotation) - dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + dtype_full = [ + (attribute, "f4") for attribute in self.construct_list_of_attributes() + ] elements = np.empty(xyz.shape[0], dtype=dtype_full) - attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + attributes = np.concatenate( + (xyz, normals, f_dc, opacities, scale, rotation), axis=1 + ) elements[:] = list(map(tuple, attributes)) - el = PlyElement.describe(elements, 'vertex') + el = PlyElement.describe(elements, "vertex") PlyData([el]).write(path) def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): plydata = PlyData.read(path) - xyz = np.stack((np.asarray(plydata.elements[0]["x"]), - np.asarray(plydata.elements[0]["y"]), - np.asarray(plydata.elements[0]["z"])), axis=1) + xyz = np.stack( + ( + np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"]), + ), + axis=1, + ) opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] features_dc = np.zeros((xyz.shape[0], 3, 1)) @@ -159,43 +185,65 @@ class Gaussian: features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) if self.sh_degree > 0: - extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] - extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) - assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 + extra_f_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("f_rest_") + ] + extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) + assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3 features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) - features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + features_extra = features_extra.reshape( + (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1) + ) - scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] - scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scale_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("scale_") + ] + scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) scales = np.zeros((xyz.shape[0], len(scale_names))) for idx, attr_name in enumerate(scale_names): scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) - rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] - rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rot_names = [ + p.name for p in plydata.elements[0].properties if p.name.startswith("rot") + ] + rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) rots = np.zeros((xyz.shape[0], len(rot_names))) for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) - + if transform is not None: transform = np.array(transform) xyz = np.matmul(xyz, transform) rotation = utils3d.numpy.quaternion_to_matrix(rotation) rotation = np.matmul(rotation, transform) rotation = utils3d.numpy.matrix_to_quaternion(rotation) - + # convert to actual gaussian attributes xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) - features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + features_dc = ( + torch.tensor(features_dc, dtype=torch.float, device=self.device) + .transpose(1, 2) + .contiguous() + ) if self.sh_degree > 0: - features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() - opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) + features_extra = ( + torch.tensor(features_extra, dtype=torch.float, device=self.device) + .transpose(1, 2) + .contiguous() + ) + opacities = torch.sigmoid( + torch.tensor(opacities, dtype=torch.float, device=self.device) + ) scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) rots = torch.tensor(rots, dtype=torch.float, device=self.device) - + # convert to _hidden attributes self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] self._features_dc = features_dc @@ -204,6 +252,10 @@ class Gaussian: else: self._features_rest = None self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias - self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias + self._scaling = ( + self.inverse_scaling_activation( + torch.sqrt(torch.square(scales) - self.mininum_kernel_size**2) + ) + - self.scale_bias + ) self._rotation = rots - self.rots_bias[None, :] - \ No newline at end of file diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py index 541c0825229a2d86e84460b765879f86f724a59d..8f00beb02099a5734257bbcf71bda022f758b230 100755 --- a/trellis/representations/gaussian/general_utils.py +++ b/trellis/representations/gaussian/general_utils.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -15,8 +15,10 @@ from datetime import datetime import numpy as np import random + def inverse_sigmoid(x): - return torch.log(x/(1-x)) + return torch.log(x / (1 - x)) + def PILtoTorch(pil_image, resolution): resized_image_PIL = pil_image.resize(resolution) @@ -26,6 +28,7 @@ def PILtoTorch(pil_image, resolution): else: return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + def get_expon_lr_func( lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 ): @@ -61,6 +64,7 @@ def get_expon_lr_func( return helper + def strip_lowerdiag(L): uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") @@ -72,45 +76,52 @@ def strip_lowerdiag(L): uncertainty[:, 5] = L[:, 2, 2] return uncertainty + def strip_symmetric(sym): return strip_lowerdiag(sym) + def build_rotation(r): - norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + norm = torch.sqrt( + r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] + ) q = r / norm[:, None] - R = torch.zeros((q.size(0), 3, 3), device='cuda') + R = torch.zeros((q.size(0), 3, 3), device="cuda") r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] - R[:, 0, 0] = 1 - 2 * (y*y + z*z) - R[:, 0, 1] = 2 * (x*y - r*z) - R[:, 0, 2] = 2 * (x*z + r*y) - R[:, 1, 0] = 2 * (x*y + r*z) - R[:, 1, 1] = 1 - 2 * (x*x + z*z) - R[:, 1, 2] = 2 * (y*z - r*x) - R[:, 2, 0] = 2 * (x*z - r*y) - R[:, 2, 1] = 2 * (y*z + r*x) - R[:, 2, 2] = 1 - 2 * (x*x + y*y) + R[:, 0, 0] = 1 - 2 * (y * y + z * z) + R[:, 0, 1] = 2 * (x * y - r * z) + R[:, 0, 2] = 2 * (x * z + r * y) + R[:, 1, 0] = 2 * (x * y + r * z) + R[:, 1, 1] = 1 - 2 * (x * x + z * z) + R[:, 1, 2] = 2 * (y * z - r * x) + R[:, 2, 0] = 2 * (x * z - r * y) + R[:, 2, 1] = 2 * (y * z + r * x) + R[:, 2, 2] = 1 - 2 * (x * x + y * y) return R + def build_scaling_rotation(s, r): L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") R = build_rotation(r) - L[:,0,0] = s[:,0] - L[:,1,1] = s[:,1] - L[:,2,2] = s[:,2] + L[:, 0, 0] = s[:, 0] + L[:, 1, 1] = s[:, 1] + L[:, 2, 2] = s[:, 2] L = R @ L return L + def safe_state(silent): old_f = sys.stdout + class F: def __init__(self, silent): self.silent = silent @@ -118,7 +129,14 @@ def safe_state(silent): def write(self, x): if not self.silent: if x.endswith("\n"): - old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + old_f.write( + x.replace( + "\n", + " [{}]\n".format( + str(datetime.now().strftime("%d/%m %H:%M:%S")) + ), + ) + ) else: old_f.write(x) diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py index 6d3d200d9cc56769bffba50a5d718f25706ad416..9d96d08f270f1908ae67f13b703d7a8bde6c192f 100644 --- a/trellis/representations/mesh/cube2mesh.py +++ b/trellis/representations/mesh/cube2mesh.py @@ -13,24 +13,19 @@ from .flexicube import FlexiCubes class MeshExtractResult: - def __init__(self, - vertices, - faces, - vertex_attrs=None, - res=64 - ): + def __init__(self, vertices, faces, vertex_attrs=None, res=64): self.vertices = vertices self.faces = faces.long() self.vertex_attrs = vertex_attrs self.face_normal = self.comput_face_normals(vertices, faces) self.res = res - self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) + self.success = vertices.shape[0] != 0 and faces.shape[0] != 0 # training only self.tsdf_v = None self.tsdf_s = None self.reg_loss = None - + def comput_face_normals(self, verts, faces): i0 = faces[..., 0].long() i1 = faces[..., 1].long() @@ -43,7 +38,7 @@ class MeshExtractResult: face_normals = torch.nn.functional.normalize(face_normals, dim=1) # print(face_normals.min(), face_normals.max(), face_normals.shape) return face_normals[:, None, :].repeat(1, 3, 1) - + def comput_v_normals(self, verts, faces): i0 = faces[..., 0].long() i1 = faces[..., 1].long() @@ -59,16 +54,16 @@ class MeshExtractResult: v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) v_normals = torch.nn.functional.normalize(v_normals, dim=1) - return v_normals + return v_normals class SparseFeatures2Mesh: def __init__(self, device="cuda", res=64, use_color=True): - ''' + """ a model to generate a mesh from sparse features structures using flexicube - ''' + """ super().__init__() - self.device=device + self.device = device self.res = res self.mesh_extractor = FlexiCubes(device=device) self.sdf_bias = -1.0 / res @@ -77,57 +72,74 @@ class SparseFeatures2Mesh: self.reg_v = verts.to(self.device) self.use_color = use_color self._calc_layout() - + def _calc_layout(self): LAYOUTS = { - 'sdf': {'shape': (8, 1), 'size': 8}, - 'deform': {'shape': (8, 3), 'size': 8 * 3}, - 'weights': {'shape': (21,), 'size': 21} + "sdf": {"shape": (8, 1), "size": 8}, + "deform": {"shape": (8, 3), "size": 8 * 3}, + "weights": {"shape": (21,), "size": 21}, } if self.use_color: - ''' + """ 6 channel color including normal map - ''' - LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + """ + LAYOUTS["color"] = { + "shape": ( + 8, + 6, + ), + "size": 8 * 6, + } self.layouts = edict(LAYOUTS) start = 0 for k, v in self.layouts.items(): - v['range'] = (start, start + v['size']) - start += v['size'] + v["range"] = (start, start + v["size"]) + start += v["size"] self.feats_channels = start - - def get_layout(self, feats : torch.Tensor, name : str): + + def get_layout(self, feats: torch.Tensor, name: str): if name not in self.layouts: return None - return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) - - def __call__(self, cubefeats : SparseTensor, training=False): + return feats[ + :, self.layouts[name]["range"][0] : self.layouts[name]["range"][1] + ].reshape(-1, *self.layouts[name]["shape"]) + + def __call__(self, cubefeats: SparseTensor, training=False): """ Generates a mesh based on the specified sparse voxel structures. Args: cube_attrs [Nx21] : Sparse Tensor attrs about cube weights - verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal Returns: - return the success tag and ni you loss, + return the success tag and ni you loss, """ # add sdf bias to verts_attrs coords = cubefeats.coords[:, 1:] feats = cubefeats.feats - - sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] + + sdf, deform, color, weights = [ + self.get_layout(feats, name) + for name in ["sdf", "deform", "color", "weights"] + ] sdf += self.sdf_bias v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] - v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) - v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) + v_pos, v_attrs, reg_loss = sparse_cube2verts( + coords, torch.cat(v_attrs, dim=-1), training=training + ) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res + 1, sdf_init=True) weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) if self.use_color: - sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] + sdf_d, deform_d, colors_d = ( + v_attrs_d[..., 0], + v_attrs_d[..., 1:4], + v_attrs_d[..., 4:], + ) else: sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] colors_d = None - + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) - + vertices, faces, L_dev, colors = self.mesh_extractor( voxelgrid_vertices=x_nx3, scalar_field=sdf_d, @@ -137,13 +149,16 @@ class SparseFeatures2Mesh: alpha=weights_d[:, 12:20], gamma_f=weights_d[:, 20], voxelgrid_colors=colors_d, - training=training) - - mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) + training=training, + ) + + mesh = MeshExtractResult( + vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res + ) if training: if mesh.success: reg_loss += L_dev.mean() * 0.5 - reg_loss += (weights[:,:20]).abs().mean() * 0.2 + reg_loss += (weights[:, :20]).abs().mean() * 0.2 mesh.reg_loss = reg_loss mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) mesh.tsdf_s = v_attrs[:, 0] diff --git a/trellis/representations/mesh/flexicube.py b/trellis/representations/mesh/flexicube.py index 6f1fc167fb066175c84c8bd93ec9831b0c2a0e32..2f314be2bcfc79d1a5e12c8291170fa7271aa67a 100644 --- a/trellis/representations/mesh/flexicube.py +++ b/trellis/representations/mesh/flexicube.py @@ -9,56 +9,105 @@ import torch from .tables import * -__all__ = [ - 'FlexiCubes' -] +__all__ = ["FlexiCubes"] class FlexiCubes: def __init__(self, device="cuda"): self.device = device - self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) - self.num_vd_table = torch.tensor(num_vd_table, - dtype=torch.long, device=device, requires_grad=False) + self.dmc_table = torch.tensor( + dmc_table, dtype=torch.long, device=device, requires_grad=False + ) + self.num_vd_table = torch.tensor( + num_vd_table, dtype=torch.long, device=device, requires_grad=False + ) self.check_table = torch.tensor( - check_table, - dtype=torch.long, device=device, requires_grad=False) + check_table, dtype=torch.long, device=device, requires_grad=False + ) - self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) - self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.tet_table = torch.tensor( + tet_table, dtype=torch.long, device=device, requires_grad=False + ) + self.quad_split_1 = torch.tensor( + [0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False + ) + self.quad_split_2 = torch.tensor( + [0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False + ) self.quad_split_train = torch.tensor( - [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + [0, 1, 1, 2, 2, 3, 3, 0], + dtype=torch.long, + device=device, + requires_grad=False, + ) - self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ - 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners = torch.tensor( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + dtype=torch.float, + device=device, + ) self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) - self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, - 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) - - self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], - dtype=torch.long, device=device) - self.dir_faces_table = torch.tensor([ - [[5, 4], [3, 2], [4, 5], [2, 3]], - [[5, 4], [1, 0], [4, 5], [0, 1]], - [[3, 2], [1, 0], [2, 3], [0, 1]] - ], dtype=torch.long, device=device) - self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) - - def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, - weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): + self.cube_edges = torch.tensor( + [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], + dtype=torch.long, + device=device, + requires_grad=False, + ) + + self.edge_dir_table = torch.tensor( + [0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=torch.long, device=device + ) + self.dir_faces_table = torch.tensor( + [ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]], + ], + dtype=torch.long, + device=device, + ) + self.adj_pairs = torch.tensor( + [0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device + ) + + def __call__( + self, + voxelgrid_vertices, + scalar_field, + cube_idx, + resolution, + qef_reg_scale=1e-3, + weight_scale=0.99, + beta=None, + alpha=None, + gamma_f=None, + voxelgrid_colors=None, + training=False, + ): surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) if surf_cubes.sum() == 0: return ( torch.zeros((0, 3), device=self.device), torch.zeros((0, 3), dtype=torch.long, device=self.device), torch.zeros((0), device=self.device), - torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None + torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) + if voxelgrid_colors is not None + else None, ) beta, alpha, gamma_f = self._normalize_weights( - beta, alpha, gamma_f, surf_cubes, weight_scale) - + beta, alpha, gamma_f, surf_cubes, weight_scale + ) + if voxelgrid_colors is not None: voxelgrid_colors = torch.sigmoid(voxelgrid_colors) @@ -69,21 +118,46 @@ class FlexiCubes: ) vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( - voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, - case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) + voxelgrid_vertices, + cube_idx[surf_cubes], + surf_edges, + scalar_field, + case_ids, + beta, + alpha, + gamma_f, + idx_map, + qef_reg_scale, + voxelgrid_colors, + ) vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( - scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, - vd_idx_map, surf_edges_mask, training, vd_color) + scalar_field, + surf_edges, + vd, + vd_gamma, + edge_counts, + idx_map, + vd_idx_map, + surf_edges_mask, + training, + vd_color, + ) return vertices, faces, L_dev, vertices_color def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): """ Regularizer L_dev as in Equation 8 """ - dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + dist = torch.norm( + ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1 + ) mean_l2 = torch.zeros_like(vd[:, 0]) - mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() - mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + mean_l2 = (mean_l2).index_add_( + 0, edge_group_to_vd, dist + ) / vd_num_edges.squeeze(1).float() + mad = ( + dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0) + ).abs() return mad def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): @@ -93,12 +167,12 @@ class FlexiCubes: n_cubes = surf_cubes.shape[0] if beta is not None: - beta = (torch.tanh(beta) * weight_scale + 1) + beta = torch.tanh(beta) * weight_scale + 1 else: beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) if alpha is not None: - alpha = (torch.tanh(alpha) * weight_scale + 1) + alpha = torch.tanh(alpha) * weight_scale + 1 else: alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) @@ -112,11 +186,13 @@ class FlexiCubes: @torch.no_grad() def _get_case_id(self, occ_fx8, surf_cubes, res): """ - Obtains the ID of topology cases based on cell corner occupancy. This function resolves the - ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the supplementary material. It should be noted that this function assumes a regular grid. """ - case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + case_ids = ( + occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0) + ).sum(-1) problem_config = self.check_table.to(self.device)[case_ids] to_check = problem_config[..., 0] == 1 @@ -127,41 +203,53 @@ class FlexiCubes: # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). # This allows efficient checking on adjacent cubes. - problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + problem_config_full = torch.zeros( + list(res) + [5], device=self.device, dtype=torch.long + ) vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 vol_idx_problem = vol_idx[surf_cubes][to_check] - problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + problem_config_full[ + vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2] + ] = problem_config vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] within_range = ( - vol_idx_problem_adj[..., 0] >= 0) & ( - vol_idx_problem_adj[..., 0] < res[0]) & ( - vol_idx_problem_adj[..., 1] >= 0) & ( - vol_idx_problem_adj[..., 1] < res[1]) & ( - vol_idx_problem_adj[..., 2] >= 0) & ( - vol_idx_problem_adj[..., 2] < res[2]) + (vol_idx_problem_adj[..., 0] >= 0) + & (vol_idx_problem_adj[..., 0] < res[0]) + & (vol_idx_problem_adj[..., 1] >= 0) + & (vol_idx_problem_adj[..., 1] < res[1]) + & (vol_idx_problem_adj[..., 2] >= 0) + & (vol_idx_problem_adj[..., 2] < res[2]) + ) vol_idx_problem = vol_idx_problem[within_range] vol_idx_problem_adj = vol_idx_problem_adj[within_range] problem_config = problem_config[within_range] - problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], - vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + problem_config_adj = problem_config_full[ + vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], + vol_idx_problem_adj[..., 2], + ] # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. - to_invert = (problem_config_adj[..., 0] == 1) - idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + to_invert = problem_config_adj[..., 0] == 1 + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][ + within_range + ][to_invert] case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) return case_ids @torch.no_grad() def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): """ - Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge - can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge and marks the cube edges with this index. """ occ_n = scalar_field < 0 all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges, _idx_map, counts = torch.unique( + all_edges, dim=0, return_inverse=True, return_counts=True + ) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 @@ -169,7 +257,12 @@ class FlexiCubes: surf_edges_mask = mask_edges[_idx_map] counts = counts[_idx_map] - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 + mapping = ( + torch.ones( + (unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device + ) + * -1 + ) mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. @@ -180,7 +273,7 @@ class FlexiCubes: @torch.no_grad() def _identify_surf_cubes(self, scalar_field, cube_idx): """ - Identifies grid cubes that intersect with the underlying surface by checking if the signs at + Identifies grid cubes that intersect with the underlying surface by checking if the signs at all corners are not identical. """ occ_n = scalar_field < 0 @@ -195,9 +288,21 @@ class FlexiCubes: """ edge_dim = edges_weight.dim() - 2 assert edges_weight.shape[edge_dim] == 2 - edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - - torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] - , edge_dim) + edges_weight = torch.cat( + [ + torch.index_select( + input=edges_weight, + index=torch.tensor(1, device=self.device), + dim=edge_dim, + ), + -torch.index_select( + input=edges_weight, + index=torch.tensor(0, device=self.device), + dim=edge_dim, + ), + ], + edge_dim, + ) denominator = edges_weight.sum(edge_dim) ue = (edges_x * edges_weight).sum(edge_dim) / denominator return ue @@ -209,56 +314,112 @@ class FlexiCubes: A = norm_bxnx3 B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) - A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + A_reg = ( + (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale) + .unsqueeze(0) + .repeat(p_bxnx3.shape[0], 1, 1) + ) B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) A = torch.cat([A, A_reg], 1) B = torch.cat([B, B_reg], 1) dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) return dual_verts - def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, - case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): + def _compute_vd( + self, + voxelgrid_vertices, + surf_cubes_fx8, + surf_edges, + scalar_field, + case_ids, + beta, + alpha, + gamma_f, + idx_map, + qef_reg_scale, + voxelgrid_colors, + ): """ Computes the location of dual vertices as described in Section 4.2 """ - alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) - surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) - surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + alpha_nx12x2 = torch.index_select( + input=alpha, index=self.cube_edges, dim=1 + ).reshape(-1, 12, 2) + surf_edges_x = torch.index_select( + input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0 + ).reshape(-1, 2, 3) + surf_edges_s = torch.index_select( + input=scalar_field, index=surf_edges.reshape(-1), dim=0 + ).reshape(-1, 2, 1) zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) - + if voxelgrid_colors is not None: C = voxelgrid_colors.shape[-1] - surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) + surf_edges_c = torch.index_select( + input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0 + ).reshape(-1, 2, C) idx_map = idx_map.reshape(-1, 12) num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) - edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] - + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = ( + [], + [], + [], + [], + [], + ) + # if color is not None: # vd_color = [] total_num_vd = 0 - vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + vd_idx_map = torch.zeros( + (case_ids.shape[0], 12), + dtype=torch.long, + device=self.device, + requires_grad=False, + ) for num in torch.unique(num_vd): - cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + cur_cubes = ( + num_vd == num + ) # consider cubes with the same numbers of vd emitted (for batching) curr_num_vd = cur_cubes.sum() * num - curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) - curr_edge_group_to_vd = torch.arange( - curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape( + -1, num * 7 + ) + curr_edge_group_to_vd = ( + torch.arange(curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + + total_num_vd + ) total_num_vd += curr_num_vd - curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ - cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + curr_edge_group_to_cube = ( + torch.arange(idx_map.shape[0], device=self.device)[cur_cubes] + .unsqueeze(-1) + .repeat(1, num * 7) + .reshape_as(curr_edge_group) + ) - curr_mask = (curr_edge_group != -1) + curr_mask = curr_edge_group != -1 edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) - edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) - edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + edge_group_to_vd.append( + torch.masked_select( + curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask + ) + ) + edge_group_to_cube.append( + torch.masked_select(curr_edge_group_to_cube, curr_mask) + ) vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) - vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + vd_gamma.append( + torch.masked_select(gamma_f, cur_cubes) + .unsqueeze(-1) + .repeat(1, num) + .reshape(-1) + ) # if color is not None: # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) - + edge_group = torch.cat(edge_group) edge_group_to_vd = torch.cat(edge_group_to_vd) edge_group_to_cube = torch.cat(edge_group_to_cube) @@ -272,88 +433,149 @@ class FlexiCubes: vd = torch.zeros((total_num_vd, 3), device=self.device) beta_sum = torch.zeros((total_num_vd, 1), device=self.device) - idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + idx_group = torch.gather( + input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group + ) - x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) - s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) - + x_group = torch.index_select( + input=surf_edges_x, index=idx_group.reshape(-1), dim=0 + ).reshape(-1, 2, 3) + s_group = torch.index_select( + input=surf_edges_s, index=idx_group.reshape(-1), dim=0 + ).reshape(-1, 2, 1) zero_crossing_group = torch.index_select( - input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) - - alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + input=zero_crossing, index=idx_group.reshape(-1), dim=0 + ).reshape(-1, 3) + + alpha_group = torch.index_select( + input=alpha_nx12x2.reshape(-1, 2), + dim=0, + index=edge_group_to_cube * 12 + edge_group, + ).reshape(-1, 2, 1) ue_group = self._linear_interp(s_group * alpha_group, x_group) - beta_group = torch.gather(input=beta.reshape(-1), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_group = torch.gather( + input=beta.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group + ).reshape(-1, 1) beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) - vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum - - ''' + vd = ( + vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) + / beta_sum + ) + + """ interpolate colors use the same method as dual vertices - ''' + """ if voxelgrid_colors is not None: vd_color = torch.zeros((total_num_vd, C), device=self.device) - c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) + c_group = torch.index_select( + input=surf_edges_c, index=idx_group.reshape(-1), dim=0 + ).reshape(-1, 2, C) uc_group = self._linear_interp(s_group * alpha_group, c_group) - vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum + vd_color = ( + vd_color.index_add_( + 0, index=edge_group_to_vd, source=uc_group * beta_group + ) + / beta_sum + ) else: vd_color = None - - L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + L_dev = self._compute_reg_loss( + vd, zero_crossing_group, edge_group_to_vd, vd_num_edges + ) v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd - vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * - 12 + edge_group, src=v_idx[edge_group_to_vd]) + vd_idx_map = (vd_idx_map.reshape(-1)).scatter( + dim=0, + index=edge_group_to_cube * 12 + edge_group, + src=v_idx[edge_group_to_vd], + ) return vd, L_dev, vd_gamma, vd_idx_map, vd_color - def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): + def _triangulate( + self, + scalar_field, + surf_edges, + vd, + vd_gamma, + edge_counts, + idx_map, + vd_idx_map, + surf_edges_mask, + training, + vd_color, + ): """ - Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into triangles based on the gamma parameter, as described in Section 4.3. """ with torch.no_grad(): - group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group_mask = ( + edge_counts == 4 + ) & surf_edges_mask # surface edges shared by 4 cubes. group = idx_map.reshape(-1)[group_mask] vd_idx = vd_idx_map[group_mask] edge_indices, indices = torch.sort(group, stable=True) quad_vd_idx = vd_idx[indices].reshape(-1, 4) # Ensure all face directions point towards the positive SDF to maintain consistent winding. - s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + s_edges = scalar_field[ + surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1) + ].reshape(-1, 2) flip_mask = s_edges[:, 0] > 0 - quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], - quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + quad_vd_idx = torch.cat( + ( + quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]], + ) + ) - quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + quad_gamma = torch.index_select( + input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0 + ).reshape(-1, 4) gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] if not training: - mask = (gamma_02 > gamma_13) - faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + mask = gamma_02 > gamma_13 + faces = torch.zeros( + (quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device + ) faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] faces = faces.reshape(-1, 3) else: - vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_quad = torch.index_select( + input=vd, index=quad_vd_idx.reshape(-1), dim=0 + ).reshape(-1, 4, 3) vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 weight_sum = (gamma_02 + gamma_13) + 1e-8 - vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) - + vd_center = ( + vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1) + ) / weight_sum.unsqueeze(-1) + if vd_color is not None: - color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) + color_quad = torch.index_select( + input=vd_color, index=quad_vd_idx.reshape(-1), dim=0 + ).reshape(-1, 4, vd_color.shape[-1]) color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 - color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + color_center = ( + color_02 * gamma_02.unsqueeze(-1) + + color_13 * gamma_13.unsqueeze(-1) + ) / weight_sum.unsqueeze(-1) vd_color = torch.cat([vd_color, color_center]) - - - vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + + vd_center_idx = ( + torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + ) vd = torch.cat([vd, vd_center]) faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) - faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + faces = torch.cat( + [faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1 + ).reshape(-1, 3) return vd, faces, s_edges, edge_indices, vd_color diff --git a/trellis/representations/mesh/tables.py b/trellis/representations/mesh/tables.py index 7c02dd7f4133aef487f623c02b11e3075cab0916..fc62621fecfce2573395847d3eac9a5d139bb86e 100644 --- a/trellis/representations/mesh/tables.py +++ b/trellis/representations/mesh/tables.py @@ -6,786 +6,2314 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. dmc_table = [ -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] + [ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 7, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 5, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 5, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 7, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 7, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 7, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 8, 11, -1, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 5, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 5, 8, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 8, 9, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 7, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 7, 8, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 7, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 9, 10, -1, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 7, 9, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 5, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 5, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 8, 9, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 7, 9, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 7, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [8, 9, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 7, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 9, 10, 11, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 8, 10, 11, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 5, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 8, 9, -1, -1, -1], + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 7, 9, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 7, 8, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [4, 6, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 6, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 5, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 5, 8, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 6, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 6, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 6, 8, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 6, 7, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 6, 7, 8, 9, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 6, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [2, 3, 4, 6, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 6, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 6, 7, 8, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 5, -1, -1, -1], + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 5, 6, 7, 8], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 6, 8, 9, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 6, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 2, 3, 5, 6, 8], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 10, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 9, 10, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 8, 9, 10, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 8, 11, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 6, 11, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 9, 10, -1, -1, -1], + [4, 6, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 6, 9, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [4, 5, 9, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 5, 10, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 5, 8, 10, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 6, 8, 9, 11, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 6, 9, 11, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 6, 8, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 6, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 6, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 6, 7, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 6, 7, 9, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [6, 7, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 6, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 6, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 6, 8, 9, 10], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [1, 3, 6, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 6, 7, 8, 10, -1], + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 5, 6, 7, 10], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 6, 7, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 6, 8, 9, 10], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 6, 9, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 7, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 7, 9, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [4, 6, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 6, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 6, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [6, 7, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 6, 7, 9, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 6, 7, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 6, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 11, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 8, 11, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 8, 9, 11, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 7, 11, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 7, 9, 11, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 9, 10, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 8, 11, -1, -1, -1], + [4, 6, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 6, 10, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 6, 8, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [6, 7, 8, 9, 10, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 6, 7, 9, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 6, 7, 8, 10, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 6, 7, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 6, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 6, 8, 9, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 7, -1, -1, -1], + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 6, 9, -1, -1], + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 5, 6, 7, 9], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 6, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [1, 2, 4, 6, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 6, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 6, 7, 8, 9, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 2, 3, 6, 7, 9], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 6, 7, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 6, 8, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 6, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 6, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [1, 3, 5, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 5, 6, 7, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 6, 9, 11, -1], + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 6, 7, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 6, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 6, 8, 9, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 6, 7, 8, 9, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 6, 7, 8, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 5, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [4, 5, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 5, 9, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [4, 7, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 7, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 7, 8, 10, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [8, 9, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 7, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [2, 3, 5, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 7, 8, 9, 10], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 5, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 5, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [2, 3, 4, 5, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 5, 9, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 7, 9, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 7, 8, 9, 10], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 2, 3, 4, 7, 10], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 2, 3, 8, 10, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [1, 2, 5, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 5, 7, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 7, 8, 9, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 5, 8, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 2, 3, 4, 5, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 5, 8, 9, 11], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 4, 7, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [1, 2, 4, 7, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 4, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 4, 7, 8, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 2, 3, 9, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 5, 7, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 5, 7, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 5, 8, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 5, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 5, 8, 9, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 4, 7, 9, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 4, 7, 8, 9, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], +] +num_vd_table = [ + 0, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 3, + 1, + 2, + 2, + 2, + 1, + 2, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 1, + 2, + 3, + 1, + 1, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 2, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 2, + 3, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 1, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 3, + 2, + 2, + 2, + 2, + 2, + 1, + 3, + 4, + 2, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 3, + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 3, + 2, + 3, + 2, + 4, + 2, + 2, + 2, + 2, + 1, + 2, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, ] -num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, -2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, -1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, -1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, -2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, -3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, -2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, -1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, -1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] check_table = [ -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 194], -[1, -1, 0, 0, 193], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 164], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 161], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 152], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 145], -[1, 0, 0, 1, 144], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 137], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 133], -[1, 0, 1, 0, 132], -[1, 1, 0, 0, 131], -[1, 1, 0, 0, 130], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 100], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 98], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 96], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 88], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 82], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 74], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 72], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 70], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 67], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 65], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 56], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 52], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 44], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 40], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 38], -[1, 0, -1, 0, 37], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 33], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 28], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 26], -[1, 0, 0, -1, 25], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 20], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 18], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 9], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 6], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0] + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 194], + [1, -1, 0, 0, 193], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 164], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 161], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 152], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 145], + [1, 0, 0, 1, 144], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 137], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 133], + [1, 0, 1, 0, 132], + [1, 1, 0, 0, 131], + [1, 1, 0, 0, 130], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 100], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 98], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 96], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 88], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 82], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 74], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 72], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 70], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 67], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 65], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 56], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 52], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 44], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 40], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 38], + [1, 0, -1, 0, 37], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 33], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 28], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 26], + [1, 0, 0, -1, 25], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 20], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 18], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 9], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 6], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], ] tet_table = [ -[-1, -1, -1, -1, -1, -1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, -1], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, -1], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, -1, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, -1, 2, 4, 4, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, 5, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, -1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[-1, 1, 1, 4, 4, 1], -[0, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[8, 8, 8, 8, 8, 8], -[1, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 4, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 5, 5, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[6, 6, 6, 6, 6, 6], -[6, -1, 0, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 4, -1, 6, 4, 6], -[6, 4, 0, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 2, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 1, 1, 6, -1, 6], -[6, 1, 1, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 4], -[2, 2, 2, 2, 2, 2], -[6, 1, 1, 6, 4, 6], -[6, 1, 1, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 5, 0, 5, 0, 5], -[5, 5, 5, 5, 5, 5], -[5, 5, 5, 5, 5, 5], -[0, 5, 0, 5, 0, 5], -[-1, 5, 0, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[4, 5, -1, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[4, 5, 0, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 6, 6, 6, 6, 6], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, -1, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[2, 5, 2, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 4], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 6, 2, 6, 6, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 1, 4, 1], -[0, 1, 1, 1, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 0, 0, 6, 0, 6], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[5, 5, 5, 5, 5, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 4, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[4, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[8, 8, 8, 8, 8, 8], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 1, 1, 4, 4, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 4, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[12, 12, 12, 12, 12, 12] -] \ No newline at end of file + [-1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [4, 4, 4, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, -1], + [1, 1, 1, 1, 1, 1], + [4, 4, 4, 4, 4, 4], + [0, 4, 0, 4, 4, -1], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [5, 5, 5, 5, 5, 5], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, -1, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, -1, 2, 4, 4, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 4, 4, 2], + [1, 1, 1, 1, 1, 1], + [2, 4, 2, 4, 4, 2], + [0, 4, 0, 4, 4, 0], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 5, 2, 5, 5, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, -1, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [4, 1, 1, 4, 4, 1], + [0, 1, 1, 0, 0, 1], + [4, 0, 0, 4, 4, 0], + [2, 2, 2, 2, 2, 2], + [-1, 1, 1, 4, 4, 1], + [0, 1, 1, 4, 4, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [5, 1, 1, 5, 5, 1], + [0, 1, 1, 0, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [8, 8, 8, 8, 8, 8], + [1, 1, 1, 4, 4, 1], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 4, 4, 1], + [0, 4, 0, 4, 4, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 5, 5, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [5, 5, 5, 5, 5, 5], + [6, 6, 6, 6, 6, 6], + [6, -1, 0, 6, 0, 6], + [6, 0, 0, 6, 0, 6], + [6, 1, 1, 6, 1, 6], + [4, 4, 4, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [6, 4, -1, 6, 4, 6], + [6, 4, 0, 6, 4, 6], + [6, 0, 0, 6, 0, 6], + [6, 1, 1, 6, 1, 6], + [5, 5, 5, 5, 5, 5], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 2, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [2, 4, 2, 2, 4, 2], + [0, 4, 0, 4, 4, 0], + [2, 0, 2, 2, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [6, 1, 1, 6, -1, 6], + [6, 1, 1, 6, 0, 6], + [6, 0, 0, 6, 0, 6], + [6, 2, 2, 6, 2, 6], + [4, 1, 1, 4, 4, 1], + [0, 1, 1, 0, 0, 1], + [4, 0, 0, 4, 4, 4], + [2, 2, 2, 2, 2, 2], + [6, 1, 1, 6, 4, 6], + [6, 1, 1, 6, 4, 6], + [6, 0, 0, 6, 0, 6], + [6, 2, 2, 6, 2, 6], + [5, 1, 1, 5, 5, 1], + [0, 1, 1, 0, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [6, 6, 6, 6, 6, 6], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 4, 1], + [0, 4, 0, 4, 4, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 5, 0, 5, 0, 5], + [5, 5, 5, 5, 5, 5], + [5, 5, 5, 5, 5, 5], + [0, 5, 0, 5, 0, 5], + [-1, 5, 0, 5, 0, 5], + [1, 5, 1, 5, 1, 5], + [4, 5, -1, 5, 4, 5], + [0, 5, 0, 5, 0, 5], + [4, 5, 0, 5, 4, 5], + [1, 5, 1, 5, 1, 5], + [4, 4, 4, 4, 4, 4], + [0, 4, 0, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [6, 6, 6, 6, 6, 6], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 5, 2, 5, -1, 5], + [0, 5, 0, 5, 0, 5], + [2, 5, 2, 5, 0, 5], + [1, 5, 1, 5, 1, 5], + [2, 5, 2, 5, 4, 5], + [0, 5, 0, 5, 0, 5], + [2, 5, 2, 5, 4, 5], + [1, 5, 1, 5, 1, 5], + [2, 4, 2, 4, 4, 2], + [0, 4, 0, 4, 4, 4], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 6, 2, 6, 6, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [4, 1, 1, 1, 4, 1], + [0, 1, 1, 1, 0, 1], + [4, 0, 0, 4, 4, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 4, 1], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [6, 0, 0, 6, 0, 6], + [0, 0, 0, 0, 0, 0], + [6, 6, 6, 6, 6, 6], + [5, 5, 5, 5, 5, 5], + [5, 5, 0, 5, 0, 5], + [5, 5, 0, 5, 0, 5], + [5, 5, 1, 5, 1, 5], + [4, 4, 4, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [4, 4, 0, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [4, 4, 4, 4, 4, 4], + [4, 4, 0, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [8, 8, 8, 8, 8, 8], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 1, 1, 4, 4, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 4, 2, 4, 4, 2], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [12, 12, 12, 12, 12, 12], +] diff --git a/trellis/representations/mesh/utils_cube.py b/trellis/representations/mesh/utils_cube.py index 23913c97bb2d57dfa0384667c69f9860ea0a4155..cf1b96899175e90b4172cefca930211353f4b988 100644 --- a/trellis/representations/mesh/utils_cube.py +++ b/trellis/representations/mesh/utils_cube.py @@ -1,18 +1,40 @@ import torch -cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ - 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) -cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) -cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, - 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) - -def construct_dense_grid(res, device='cuda'): - '''construct a dense grid based on resolution''' + +cube_corners = torch.tensor( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + dtype=torch.int, +) +cube_neighbor = torch.tensor( + [[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]] +) +cube_edges = torch.tensor( + [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], + dtype=torch.long, + requires_grad=False, +) + + +def construct_dense_grid(res, device="cuda"): + """construct a dense grid based on resolution""" res_v = res + 1 - vertsid = torch.arange(res_v ** 3, device=device) + vertsid = torch.arange(res_v**3, device=device) coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() - cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] - cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) - verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) + cube_corners_bias = ( + cube_corners[:, 0] * res_v + cube_corners[:, 1] + ) * res_v + cube_corners[:, 2] + cube_fx8 = coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device) + verts = torch.stack( + [vertsid // (res_v**2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1 + ) return verts, cube_fx8 @@ -23,7 +45,7 @@ def construct_voxel_grid(coords): return verts_unique, cubes -def cubes_to_verts(num_verts, cubes, value, reduce='mean'): +def cubes_to_verts(num_verts, cubes, value, reduce="mean"): """ Args: cubes [Vx8] verts index for each cube @@ -31,12 +53,18 @@ def cubes_to_verts(num_verts, cubes, value, reduce='mean'): Operation: reduced[cubes[i][j]][k] += value[i][k] """ - M = value.shape[2] # number of channels + M = value.shape[2] # number of channels reduced = torch.zeros(num_verts, M, device=cubes.device) - return torch.scatter_reduce(reduced, 0, - cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), - value.flatten(0, 1), reduce=reduce, include_self=False) - + return torch.scatter_reduce( + reduced, + 0, + cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), + value.flatten(0, 1), + reduce=reduce, + include_self=False, + ) + + def sparse_cube2verts(coords, feats, training=True): new_coords, cubes = construct_voxel_grid(coords) new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) @@ -45,17 +73,16 @@ def sparse_cube2verts(coords, feats, training=True): else: con_loss = 0.0 return new_coords, new_feats, con_loss - -def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): + +def get_dense_attrs(coords: torch.Tensor, feats: torch.Tensor, res: int, sdf_init=True): F = feats.shape[-1] dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) if sdf_init: - dense_attrs[..., 0] = 1 # initial outside sdf value + dense_attrs[..., 0] = 1 # initial outside sdf value dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats return dense_attrs.reshape(-1, F) -def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): +def get_defomed_verts(v_pos: torch.Tensor, deform: torch.Tensor, res): return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) - \ No newline at end of file diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py index f66a39a5a7498e2e99fe9d94d663796b3bc157b5..56b122b66d469d749169525ea1a619b3e2ac5b2e 100755 --- a/trellis/representations/octree/__init__.py +++ b/trellis/representations/octree/__init__.py @@ -1 +1 @@ -from .octree_dfs import DfsOctree \ No newline at end of file +from .octree_dfs import DfsOctree diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py index 9d1f7898f30414f304953cfb2d51d00511ec8325..85d9b4121890f67c2d10250b5f366238bf0caf63 100755 --- a/trellis/representations/octree/octree_dfs.py +++ b/trellis/representations/octree/octree_dfs.py @@ -4,17 +4,17 @@ import torch.nn.functional as F DEFAULT_TRIVEC_CONFIG = { - 'dim': 8, - 'rank': 8, + "dim": 8, + "rank": 8, } DEFAULT_VOXEL_CONFIG = { - 'solid': False, + "solid": False, } DEFAULT_DECOPOLY_CONFIG = { - 'degree': 8, - 'rank': 16, + "degree": 8, + "rank": 16, } @@ -51,14 +51,14 @@ class DfsOctree: """ def __init__( - self, - depth, - aabb=[0,0,0,1,1,1], - sh_degree=2, - primitive='voxel', - primitive_config={}, - device='cuda', - ): + self, + depth, + aabb=[0, 0, 0, 1, 1, 1], + sh_degree=2, + primitive="voxel", + primitive_config={}, + device="cuda", + ): self.max_depth = depth self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) self.device = device @@ -67,54 +67,124 @@ class DfsOctree: self.primitive = primitive self.primitive_config = primitive_config - self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) + self.structure = torch.tensor( + [[8, 1, 0]], dtype=torch.int32, device=self.device + ) self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) - self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) - self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) - self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) + self.position[:, 0] = torch.tensor( + [0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device + ) + self.position[:, 1] = torch.tensor( + [0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device + ) + self.position[:, 2] = torch.tensor( + [0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device + ) self.depth[:, 0] = 1 - self.data = ['position', 'depth'] + self.data = ["position", "depth"] self.param_names = [] - if primitive == 'voxel': - self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) - self.data += ['features_dc', 'features_ac'] - self.param_names += ['features_dc', 'features_ac'] - if not primitive_config.get('solid', False): - self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) - self.data.append('density') - self.param_names.append('density') - elif primitive == 'gaussian': - self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + if primitive == "voxel": + self.features_dc = torch.zeros( + (8, 1, 3), dtype=torch.float32, device=self.device + ) + self.features_ac = torch.zeros( + (8, (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.data += ["features_dc", "features_ac"] + self.param_names += ["features_dc", "features_ac"] + if not primitive_config.get("solid", False): + self.density = torch.zeros( + (8, 1), dtype=torch.float32, device=self.device + ) + self.data.append("density") + self.param_names.append("density") + elif primitive == "gaussian": + self.features_dc = torch.zeros( + (8, 1, 3), dtype=torch.float32, device=self.device + ) + self.features_ac = torch.zeros( + (8, (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) - self.data += ['features_dc', 'features_ac', 'opacity'] - self.param_names += ['features_dc', 'features_ac', 'opacity'] - elif primitive == 'trivec': - self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) - self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) - self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ["features_dc", "features_ac", "opacity"] + self.param_names += ["features_dc", "features_ac", "opacity"] + elif primitive == "trivec": + self.trivec = torch.zeros( + (8, primitive_config["rank"], 3, primitive_config["dim"]), + dtype=torch.float32, + device=self.device, + ) + self.density = torch.zeros( + (8, primitive_config["rank"]), dtype=torch.float32, device=self.device + ) + self.features_dc = torch.zeros( + (8, primitive_config["rank"], 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.features_ac = torch.zeros( + (8, primitive_config["rank"], (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) self.density_shift = 0 - self.data += ['trivec', 'density', 'features_dc', 'features_ac'] - self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] - elif primitive == 'decoupoly': - self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) - self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) - self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) - self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ["trivec", "density", "features_dc", "features_ac"] + self.param_names += ["trivec", "density", "features_dc", "features_ac"] + elif primitive == "decoupoly": + self.decoupoly_V = torch.zeros( + (8, primitive_config["rank"], 3), + dtype=torch.float32, + device=self.device, + ) + self.decoupoly_g = torch.zeros( + (8, primitive_config["rank"], primitive_config["degree"]), + dtype=torch.float32, + device=self.device, + ) + self.density = torch.zeros( + (8, primitive_config["rank"]), dtype=torch.float32, device=self.device + ) + self.features_dc = torch.zeros( + (8, primitive_config["rank"], 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.features_ac = torch.zeros( + (8, primitive_config["rank"], (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) self.density_shift = 0 - self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] - self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + self.data += [ + "decoupoly_V", + "decoupoly_g", + "density", + "features_dc", + "features_ac", + ] + self.param_names += [ + "decoupoly_V", + "decoupoly_g", + "density", + "features_dc", + "features_ac", + ] self.setup_functions() def setup_functions(self): - self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) + self.density_activation = ( + (lambda x: torch.exp(x - 2)) + if self.primitive != "trivec" + else (lambda x: x) + ) self.opacity_activation = lambda x: torch.sigmoid(x - 6) self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 self.color_activation = lambda x: torch.sigmoid(x) @@ -122,7 +192,7 @@ class DfsOctree: @property def num_non_leaf_nodes(self): return self.structure.shape[0] - + @property def num_leaf_nodes(self): return self.depth.shape[0] @@ -130,11 +200,11 @@ class DfsOctree: @property def cur_depth(self): return self.depth.max().item() - + @property def occupancy(self): - return self.num_leaf_nodes / 8 ** self.cur_depth - + return self.num_leaf_nodes / 8**self.cur_depth + @property def get_xyz(self): return self.position @@ -145,10 +215,15 @@ class DfsOctree: @property def get_density(self): - if self.primitive == 'voxel' and self.voxel_config['solid']: - return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device) + if self.primitive == "voxel" and self.voxel_config["solid"]: + return torch.full( + (self.position.shape[0], 1), + 1000, + dtype=torch.float32, + device=self.device, + ) return self.density_activation(self.density) - + @property def get_opacity(self): return self.opacity_activation(self.density) @@ -172,9 +247,18 @@ class DfsOctree: return torch.cat([self.features_dc, self.features_ac], dim=-2) def state_dict(self): - ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive} - if hasattr(self, 'density_shift'): - ret['density_shift'] = self.density_shift + ret = { + "structure": self.structure, + "position": self.position, + "depth": self.depth, + "sh_degree": self.sh_degree, + "active_sh_degree": self.active_sh_degree, + "trivec_config": self.trivec_config, + "voxel_config": self.voxel_config, + "primitive": self.primitive, + } + if hasattr(self, "density_shift"): + ret["density_shift"] = self.density_shift for data in set(self.data + self.param_names): if not isinstance(getattr(self, data), nn.Module): ret[data] = getattr(self, data) @@ -183,7 +267,14 @@ class DfsOctree: return ret def load_state_dict(self, state_dict): - keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) + keys = list( + set( + self.data + + self.param_names + + list(state_dict.keys()) + + ["structure", "position", "depth"] + ) + ) for key in keys: if key not in state_dict: print(f"Warning: key {key} not found in the state_dict.") @@ -206,12 +297,14 @@ class DfsOctree: """ leaf_cnt = self.structure[:, 0] leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] - ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) + ret = torch.zeros( + (self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device + ) for i in range(8): if leaf_cnt_masks[i].sum() == 0: continue start = self.structure[leaf_cnt_masks[i], 2] - for j in range(i+1): + for j in range(i + 1): ret[leaf_cnt_masks[i]] += data[start + j] return ret @@ -229,7 +322,7 @@ class DfsOctree: if non_leaf_cnt_masks[i].sum() == 0: continue start = self.structure[non_leaf_cnt_masks[i], 1] - for j in range(i+1): + for j in range(i + 1): ret[non_leaf_cnt_masks[i]] += data[start + j] return ret @@ -241,67 +334,121 @@ class DfsOctree: mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. """ # Dont subdivide when the depth is the maximum. - mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max( + mask[self.depth.squeeze() == self.max_depth], 0 + ) # Dont merge when the depth is the minimum. - mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) + mask[self.depth.squeeze() == 1] = torch.clamp_min( + mask[self.depth.squeeze() == 1], 0 + ) # Gather control mask structre_ctrl = self.gather_from_leaf_children(mask) - structre_ctrl[structre_ctrl==-8] = -1 + structre_ctrl[structre_ctrl == -8] = -1 new_leaf_num = self.structure[:, 0].clone() # Modify the leaf num. structre_valid = structre_ctrl >= 0 - new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. + new_leaf_num[structre_valid] -= structre_ctrl[ + structre_valid + ] # Add the new nodes. structre_delete = structre_ctrl < 0 merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) - new_leaf_num += merged_nodes # Delete the merged nodes. + new_leaf_num += merged_nodes # Delete the merged nodes. # Update the structure array to allocate new nodes. - mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) - mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. - mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. - new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + mem_offset = torch.zeros( + (self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device + ) + mem_offset.index_add_( + 0, self.structure[structre_valid, 1], structre_ctrl[structre_valid] + ) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange( + 0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device + ) + mem_offset.cumsum(0) new_structure_length = new_structre_idx[-1].item() new_structre_idx = new_structre_idx[:-1] - new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) - new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] + new_structure = torch.empty( + (new_structure_length, 3), dtype=torch.int32, device=self.device + ) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[ + structre_valid + ] # Initialize the new nodes. - new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) + new_node_mask = torch.ones( + (new_structure_length,), dtype=torch.bool, device=self.device + ) new_node_mask[new_structre_idx[structre_valid]] = False - new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. new_node_num = new_node_mask.sum().item() # Rebuild child ptr. non_leaf_cnt = 8 - new_structure[:, 0] - new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) + new_child_ptr = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + non_leaf_cnt.cumsum(0)[:-1], + ] + ) new_structure[:, 1] = new_child_ptr + 1 # Rebuild data ptr with old data. - leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) + leaf_cnt = torch.zeros( + (new_structure_length,), dtype=torch.int32, device=self.device + ) leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) - old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + old_data_ptr = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + leaf_cnt.cumsum(0)[:-1], + ] + ) # Update the data array subdivide_mask = mask == 1 merge_mask = mask == -1 data_valid = ~(subdivide_mask | merge_mask) - mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) - mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes - mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes - mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes - mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes - new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + mem_offset = torch.zeros( + (self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device + ) + mem_offset.index_add_( + 0, + old_data_ptr[new_node_mask], + torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device), + ) # Add data array for new nodes + mem_offset[ + :-1 + ] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_( + 0, self.structure[structre_valid, 2], merged_nodes[structre_valid] + ) # Add data elements for merge nodes + new_data_idx = torch.arange( + 0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device + ) + mem_offset.cumsum(0) new_data_length = new_data_idx[-1].item() new_data_idx = new_data_idx[:-1] - new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} + new_data = { + data: torch.empty( + (new_data_length,) + getattr(self, data).shape[1:], + dtype=getattr(self, data).dtype, + device=self.device, + ) + for data in self.data + } for data in self.data: new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] # Rebuild data ptr leaf_cnt = new_structure[:, 0] - new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + new_data_ptr = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + leaf_cnt.cumsum(0)[:-1], + ] + ) new_structure[:, 2] = new_data_ptr # Initialize the new data array @@ -310,41 +457,112 @@ class DfsOctree: subdivide_data_ptr = new_structure[new_node_mask, 2] for data in self.data: for i in range(8): - if data == 'position': - offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 + if data == "position": + offset = ( + torch.tensor( + [i // 4, (i // 2) % 2, i % 2], + dtype=torch.float32, + device=self.device, + ) + - 0.5 + ) scale = 2 ** (-1.0 - self.depth[subdivide_mask]) - new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale - elif data == 'depth': - new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 - elif data == 'opacity': - new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) - elif data == 'trivec': - offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 - coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) - axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) - coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 - new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) + new_data["position"][subdivide_data_ptr + i] = ( + self.position[subdivide_mask] + offset * scale + ) + elif data == "depth": + new_data["depth"][subdivide_data_ptr + i] = ( + self.depth[subdivide_mask] + 1 + ) + elif data == "opacity": + new_data["opacity"][ + subdivide_data_ptr + i + ] = self.inverse_opacity_activation( + torch.sqrt( + self.opacity_activation(self.opacity[subdivide_mask]) + ) + ) + elif data == "trivec": + offset = ( + torch.tensor( + [i // 4, (i // 2) % 2, i % 2], + dtype=torch.float32, + device=self.device, + ) + * 0.5 + ) + coord = ( + torch.linspace( + 0, + 0.5, + self.trivec.shape[-1], + dtype=torch.float32, + device=self.device, + )[None] + + offset[:, None] + ).reshape(1, 3, self.trivec.shape[-1], 1) + axis = ( + torch.linspace( + 0, 1, 3, dtype=torch.float32, device=self.device + ) + .reshape(1, 3, 1, 1) + .repeat(1, 1, self.trivec.shape[-1], 1) + ) + coord = ( + torch.stack([coord, axis], dim=3) + .reshape(1, 3, self.trivec.shape[-1], 2) + .expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) + * 2 + - 1 + ) + new_data["trivec"][subdivide_data_ptr + i] = F.grid_sample( + self.trivec[subdivide_mask], coord, align_corners=True + ) else: - new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[ + subdivide_mask + ] ## For merge nodes if merge_mask.sum() > 0: - merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) - merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) + merge_data_ptr = torch.empty( + (merged_nodes.sum().item(),), dtype=torch.int32, device=self.device + ) + merge_nodes_cumsum = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + merged_nodes.cumsum(0)[:-1], + ] + ) for i in range(8): - merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = ( + new_structure[new_structre_idx[merged_nodes > i], 2] + i + ) old_merge_data_ptr = self.structure[structre_delete, 2] for data in self.data: - if data == 'position': + if data == "position": scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) - new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 - elif data == 'depth': - new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 - elif data == 'opacity': - new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) - elif data == 'trivec': - new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] + new_data["position"][merge_data_ptr] = ( + ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() + * scale + + 0.5 * scale + - 0.5 + ) + elif data == "depth": + new_data["depth"][merge_data_ptr] = ( + self.depth[old_merge_data_ptr] - 1 + ) + elif data == "opacity": + new_data["opacity"][ + subdivide_data_ptr + i + ] = self.inverse_opacity_activation( + self.opacity_activation(self.opacity[subdivide_mask]) ** 2 + ) + elif data == "trivec": + new_data["trivec"][merge_data_ptr] = self.trivec[old_merge_data_ptr] else: - new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] + new_data[data][merge_data_ptr] = getattr(self, data)[ + old_merge_data_ptr + ] # Update the structure and data array self.structure = new_structure @@ -353,10 +571,10 @@ class DfsOctree: # Save data array control temp variables self.data_rearrange_buffer = { - 'subdivide_mask': subdivide_mask, - 'merge_mask': merge_mask, - 'data_valid': data_valid, - 'new_data_idx': new_data_idx, - 'new_data_length': new_data_length, - 'new_data': new_data - } + "subdivide_mask": subdivide_mask, + "merge_mask": merge_mask, + "data_valid": data_valid, + "new_data_idx": new_data_idx, + "new_data_length": new_data_length, + "new_data": new_data, + } diff --git a/trellis/representations/radiance_field/__init__.py b/trellis/representations/radiance_field/__init__.py index b72a1b7e76b509ee5a5e6979858eb17b4158a151..8f9f6b32ec0212eff46039f3121079be2bac176d 100755 --- a/trellis/representations/radiance_field/__init__.py +++ b/trellis/representations/radiance_field/__init__.py @@ -1 +1 @@ -from .strivec import Strivec \ No newline at end of file +from .strivec import Strivec diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py index 3b454d9c75521e33466055fe37c3fc1e37180a79..a848ef8ac91c38d956beb7ec46896d7e7810ea9d 100755 --- a/trellis/utils/general_utils.py +++ b/trellis/utils/general_utils.py @@ -4,20 +4,24 @@ import torch # Dictionary utils -def _dict_merge(dicta, dictb, prefix=''): +def _dict_merge(dicta, dictb, prefix=""): """ Merge two dictionaries. """ - assert isinstance(dicta, dict), 'input must be a dictionary' - assert isinstance(dictb, dict), 'input must be a dictionary' + assert isinstance(dicta, dict), "input must be a dictionary" + assert isinstance(dictb, dict), "input must be a dictionary" dict_ = {} all_keys = set(dicta.keys()).union(set(dictb.keys())) for key in all_keys: if key in dicta.keys() and key in dictb.keys(): if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): - dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + dict_[key] = _dict_merge( + dicta[key], dictb[key], prefix=f"{prefix}.{key}" + ) else: - raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + raise ValueError( + f"Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}" + ) elif key in dicta.keys(): dict_[key] = dicta[key] else: @@ -29,14 +33,14 @@ def dict_merge(dicta, dictb): """ Merge two dictionaries. """ - return _dict_merge(dicta, dictb, prefix='') + return _dict_merge(dicta, dictb, prefix="") def dict_foreach(dic, func, special_func={}): """ Recursively apply a function to all non-dictionary leaf values in a dictionary. """ - assert isinstance(dic, dict), 'input must be a dictionary' + assert isinstance(dic, dict), "input must be a dictionary" for key in dic.keys(): if isinstance(dic[key], dict): dic[key] = dict_foreach(dic[key], func) @@ -52,9 +56,11 @@ def dict_reduce(dicts, func, special_func={}): """ Reduce a list of dictionaries. Leaf values must be scalars. """ - assert isinstance(dicts, list), 'input must be a list of dictionaries' - assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' - assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + assert isinstance(dicts, list), "input must be a list of dictionaries" + assert all( + [isinstance(d, dict) for d in dicts] + ), "input must be a list of dictionaries" + assert len(dicts) > 0, "input must be a non-empty list of dictionaries" all_keys = set([key for dict_ in dicts for key in dict_.keys()]) reduced_dict = {} for key in all_keys: @@ -73,7 +79,7 @@ def dict_any(dic, func): """ Recursively apply a function to all non-dictionary leaf values in a dictionary. """ - assert isinstance(dic, dict), 'input must be a dictionary' + assert isinstance(dic, dict), "input must be a dictionary" for key in dic.keys(): if isinstance(dic[key], dict): if dict_any(dic[key], func): @@ -88,7 +94,7 @@ def dict_all(dic, func): """ Recursively apply a function to all non-dictionary leaf values in a dictionary. """ - assert isinstance(dic, dict), 'input must be a dictionary' + assert isinstance(dic, dict), "input must be a dictionary" for key in dic.keys(): if isinstance(dic[key], dict): if not dict_all(dic[key], func): @@ -99,11 +105,11 @@ def dict_all(dic, func): return True -def dict_flatten(dic, sep='.'): +def dict_flatten(dic, sep="."): """ Flatten a nested dictionary into a dictionary with no nested dictionaries. """ - assert isinstance(dic, dict), 'input must be a dictionary' + assert isinstance(dic, dict), "input must be a dictionary" flat_dict = {} for key in dic.keys(): if isinstance(dic[key], dict): @@ -128,21 +134,37 @@ def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): elif nrow is not None and ncol is None: ncol = (num_images + nrow - 1) // nrow else: - assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' - - grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + assert ( + nrow * ncol >= num_images + ), "nrow * ncol must be greater than or equal to the number of images" + + grid = np.zeros( + (nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), + dtype=images[0].dtype, + ) for i, img in enumerate(images): row = i // ncol col = i % ncol - grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + grid[ + row * img.shape[0] : (row + 1) * img.shape[0], + col * img.shape[1] : (col + 1) * img.shape[1], + ] = img return grid def notes_on_image(img, notes=None): - img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), "constant", constant_values=0) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if notes is not None: - img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.putText( + img, + notes, + (0, img.shape[0] - 4), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (255, 255, 255), + 1, + ) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img @@ -161,6 +183,7 @@ def save_image_with_notes(img, path, notes=None): # debug utils + def atol(x, y): """ Absolute tolerance. @@ -172,7 +195,9 @@ def rtol(x, y): """ Relative tolerance. """ - return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + return torch.abs(x - y) / torch.clamp_min( + torch.maximum(torch.abs(x), torch.abs(y)), 1e-12 + ) # print utils @@ -180,8 +205,7 @@ def indent(s, n=4): """ Indent a string. """ - lines = s.split('\n') + lines = s.split("\n") for i in range(1, len(lines)): - lines[i] = ' ' * n + lines[i] - return '\n'.join(lines) - + lines[i] = " " * n + lines[i] + return "\n".join(lines) diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py index 0a8d9fb79ad73effecc9ebcbfe241a12f8022e0f..6248ccb96ed6a9e50d905e9ead457571b6324b05 100644 --- a/trellis/utils/postprocessing_utils.py +++ b/trellis/utils/postprocessing_utils.py @@ -27,7 +27,7 @@ def _fill_holes( resolution=128, num_views=500, debug=False, - verbose=False + verbose=False, ): """ Rasterize a mesh from multiple views and remove invisible faces. @@ -57,109 +57,178 @@ def _fill_holes( projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) views = [] for (yaw, pitch) in zip(yaws, pitchs): - orig = torch.tensor([ - torch.sin(yaw) * torch.cos(pitch), - torch.cos(yaw) * torch.cos(pitch), - torch.sin(pitch), - ]).cuda().float() * radius - view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + orig = ( + torch.tensor( + [ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ] + ) + .cuda() + .float() + * radius + ) + view = utils3d.torch.view_look_at( + orig, + torch.tensor([0, 0, 0]).float().cuda(), + torch.tensor([0, 0, 1]).float().cuda(), + ) views.append(view) views = torch.stack(views, dim=0) # Rasterize visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) - rastctx = utils3d.torch.RastContext(backend='cuda') - for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + rastctx = utils3d.torch.RastContext(backend="cuda") + for i in tqdm( + range(views.shape[0]), + total=views.shape[0], + disable=not verbose, + desc="Rasterizing", + ): view = views[i] buffers = utils3d.torch.rasterize_triangle_faces( - rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection + rastctx, + verts[None], + faces, + resolution, + resolution, + view=view, + projection=projection, ) - face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1 face_id = torch.unique(face_id).long() visblity[face_id] += 1 visblity = visblity.float() / num_views - + # Mincut ## construct outer faces edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) - connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) - outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + connected_components = utils3d.torch.compute_connected_components( + faces, edges, face2edge + ) + outer_face_indices = torch.zeros( + faces.shape[0], dtype=torch.bool, device=faces.device + ) for i in range(len(connected_components)): - outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices[connected_components[i]] = visblity[ + connected_components[i] + ] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) outer_face_indices = outer_face_indices.nonzero().reshape(-1) - + ## construct inner faces inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) if verbose: - tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + tqdm.write(f"Found {inner_face_indices.shape[0]} invisible faces") if inner_face_indices.shape[0] == 0: return verts, faces - + ## Construct dual graph (faces as nodes, edges as edges) dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) dual_edge2edge = edges[dual_edge2edge] - dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + dual_edges_weights = torch.norm( + verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1 + ) if verbose: - tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + tqdm.write(f"Dual graph: {dual_edges.shape[0]} edges") ## solve mincut problem ### construct main graph g = igraph.Graph() g.add_vertices(faces.shape[0]) g.add_edges(dual_edges.cpu().numpy()) - g.es['weight'] = dual_edges_weights.cpu().numpy() - + g.es["weight"] = dual_edges_weights.cpu().numpy() + ### source and target - g.add_vertex('s') - g.add_vertex('t') - + g.add_vertex("s") + g.add_vertex("t") + ### connect invisible faces to source - g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) - + g.add_edges( + [(f, "s") for f in inner_face_indices], + attributes={ + "weight": torch.ones(inner_face_indices.shape[0], dtype=torch.float32) + .cpu() + .numpy() + }, + ) + ### connect outer faces to target - g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) - + g.add_edges( + [(f, "t") for f in outer_face_indices], + attributes={ + "weight": torch.ones(outer_face_indices.shape[0], dtype=torch.float32) + .cpu() + .numpy() + }, + ) + ### solve mincut - cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) - remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist()) + remove_face_indices = torch.tensor( + [v for v in cut.partition[0] if v < faces.shape[0]], + dtype=torch.long, + device=faces.device, + ) if verbose: - tqdm.write(f'Mincut solved, start checking the cut') - + tqdm.write(f"Mincut solved, start checking the cut") + ### check if the cut is valid with each connected component - to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + to_remove_cc = utils3d.torch.compute_connected_components( + faces[remove_face_indices] + ) if debug: - tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + tqdm.write(f"Number of connected components of the cut: {len(to_remove_cc)}") valid_remove_cc = [] cutting_edges = [] for cc in to_remove_cc: #### check if the connected component has low visibility visblity_median = visblity[remove_face_indices[cc]].median() if debug: - tqdm.write(f'visblity_median: {visblity_median}') + tqdm.write(f"visblity_median: {visblity_median}") if visblity_median > 0.25: continue - + #### check if the cuting loop is small enough - cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_edge_indices, cc_edges_degree = torch.unique( + face2edge[remove_face_indices[cc]], return_counts=True + ) cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] - cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[ + ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices) + ] if len(cc_new_boundary_edge_indices) > 0: - cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) - cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components( + edges[cc_new_boundary_edge_indices] + ) + cc_new_boundary_edges_cc_center = [ + verts[edges[cc_new_boundary_edge_indices[edge_cc]]] + .mean(dim=1) + .mean(dim=0) + for edge_cc in cc_new_boundary_edge_cc + ] cc_new_boundary_edges_cc_area = [] for i, edge_cc in enumerate(cc_new_boundary_edge_cc): - _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] - _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] - cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + _e1 = ( + verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] + - cc_new_boundary_edges_cc_center[i] + ) + _e2 = ( + verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] + - cc_new_boundary_edges_cc_center[i] + ) + cc_new_boundary_edges_cc_area.append( + torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5 + ) if debug: cutting_edges.append(cc_new_boundary_edge_indices) - tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + tqdm.write(f"Area of the cutting loop: {cc_new_boundary_edges_cc_area}") if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): continue - + valid_remove_cc.append(cc) - + if debug: face_v = verts[faces].mean(dim=1).cpu().numpy() vis_dual_edges = dual_edges.cpu().numpy() @@ -168,14 +237,17 @@ def _fill_holes( vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] if len(valid_remove_cc) > 0: - vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] - utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) - + vis_colors[ + remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy() + ] = [255, 0, 0] + utils3d.io.write_ply( + "dbg_dual.ply", face_v, edges=vis_dual_edges, vertex_colors=vis_colors + ) + vis_verts = verts.cpu().numpy() vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() - utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) - - + utils3d.io.write_ply("dbg_cut.ply", vis_verts, edges=vis_edges) + if len(valid_remove_cc) > 0: remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) @@ -183,16 +255,18 @@ def _fill_holes( faces = faces[mask] faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) if verbose: - tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + tqdm.write(f"Removed {(~mask).sum()} faces by mincut") else: if verbose: - tqdm.write(f'Removed 0 faces by mincut') - + tqdm.write(f"Removed 0 faces by mincut") + mesh = _meshfix.PyTMesh() mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) verts, faces = mesh.return_arrays() - verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + verts, faces = torch.tensor( + verts, device="cuda", dtype=torch.float32 + ), torch.tensor(faces, device="cuda", dtype=torch.int32) return verts, faces @@ -227,21 +301,31 @@ def postprocess_mesh( """ if verbose: - tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + tqdm.write( + f"Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces" + ) # Simplify if simplify and simplify_ratio > 0: - mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = pv.PolyData( + vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1) + ) mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] if verbose: - tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + tqdm.write( + f"After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces" + ) # Remove invisible faces if fill_holes: - vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = ( + torch.tensor(vertices).cuda(), + torch.tensor(faces.astype(np.int32)).cuda(), + ) vertices, faces = _fill_holes( - vertices, faces, + vertices, + faces, max_hole_size=fill_holes_max_hole_size, max_hole_nbe=fill_holes_max_hole_nbe, resolution=fill_holes_resolution, @@ -251,7 +335,9 @@ def postprocess_mesh( ) vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() if verbose: - tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + tqdm.write( + f"After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces" + ) return vertices, faces @@ -284,7 +370,7 @@ def bake_texture( texture_size: int = 2048, near: float = 0.1, far: float = 10.0, - mode: Literal['fast', 'opt'] = 'opt', + mode: Literal["fast", "opt"] = "opt", lambda_tv: float = 1e-2, verbose: bool = False, ): @@ -310,71 +396,131 @@ def bake_texture( faces = torch.tensor(faces.astype(np.int32)).cuda() uvs = torch.tensor(uvs).cuda() observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] - masks = [torch.tensor(m>0).bool().cuda() for m in masks] - views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] - projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] - - if mode == 'fast': - texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() - texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() - rastctx = utils3d.torch.RastContext(backend='cuda') - for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): + masks = [torch.tensor(m > 0).bool().cuda() for m in masks] + views = [ + utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) + for extr in extrinsics + ] + projections = [ + utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) + for intr in intrinsics + ] + + if mode == "fast": + texture = torch.zeros( + (texture_size * texture_size, 3), dtype=torch.float32 + ).cuda() + texture_weights = torch.zeros( + (texture_size * texture_size), dtype=torch.float32 + ).cuda() + rastctx = utils3d.torch.RastContext(backend="cuda") + for observation, view, projection in tqdm( + zip(observations, views, projections), + total=len(observations), + disable=not verbose, + desc="Texture baking (fast)", + ): with torch.no_grad(): rast = utils3d.torch.rasterize_triangle_faces( - rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + rastctx, + vertices[None], + faces, + observation.shape[1], + observation.shape[0], + uv=uvs[None], + view=view, + projection=projection, ) - uv_map = rast['uv'][0].detach().flip(0) - mask = rast['mask'][0].detach().bool() & masks[0] - + uv_map = rast["uv"][0].detach().flip(0) + mask = rast["mask"][0].detach().bool() & masks[0] + # nearest neighbor interpolation uv_map = (uv_map * texture_size).floor().long() obs = observation[mask] uv_map = uv_map[mask] idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) - texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) + texture_weights = texture_weights.scatter_add( + 0, + idx, + torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device), + ) mask = texture_weights > 0 texture[mask] /= texture_weights[mask][:, None] - texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) + texture = np.clip( + texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255 + ).astype(np.uint8) # inpaint - mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) + mask = ( + (texture_weights == 0) + .cpu() + .numpy() + .astype(np.uint8) + .reshape(texture_size, texture_size) + ) texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) - elif mode == 'opt': - rastctx = utils3d.torch.RastContext(backend='cuda') + elif mode == "opt": + rastctx = utils3d.torch.RastContext(backend="cuda") observations = [observations.flip(0) for observations in observations] masks = [m.flip(0) for m in masks] _uv = [] _uv_dr = [] - for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): + for observation, view, projection in tqdm( + zip(observations, views, projections), + total=len(views), + disable=not verbose, + desc="Texture baking (opt): UV", + ): with torch.no_grad(): rast = utils3d.torch.rasterize_triangle_faces( - rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + rastctx, + vertices[None], + faces, + observation.shape[1], + observation.shape[0], + uv=uvs[None], + view=view, + projection=projection, ) - _uv.append(rast['uv'].detach()) - _uv_dr.append(rast['uv_dr'].detach()) + _uv.append(rast["uv"].detach()) + _uv_dr.append(rast["uv_dr"].detach()) - texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) + texture = torch.nn.Parameter( + torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda() + ) optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): return start_lr * (end_lr / start_lr) ** (step / total_steps) def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): - return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) - + return end_lr + 0.5 * (start_lr - end_lr) * ( + 1 + np.cos(np.pi * step / total_steps) + ) + def tv_loss(texture): - return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ - torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) - + return torch.nn.functional.l1_loss( + texture[:, :-1, :, :], texture[:, 1:, :, :] + ) + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + total_steps = 2500 - with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: + with tqdm( + total=total_steps, + disable=not verbose, + desc="Texture baking (opt): optimizing", + ) as pbar: for step in range(total_steps): optimizer.zero_grad() selected = np.random.randint(0, len(views)) - uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] + uv, uv_dr, observation, mask = ( + _uv[selected], + _uv_dr[selected], + observations[selected], + masks[selected], + ) render = dr.texture(texture, uv, uv_dr)[0] loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) if lambda_tv > 0: @@ -382,16 +528,20 @@ def bake_texture( loss.backward() optimizer.step() # annealing - optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) - pbar.set_postfix({'loss': loss.item()}) + optimizer.param_groups[0]["lr"] = cosine_anealing( + optimizer, step, total_steps, 1e-2, 1e-5 + ) + pbar.set_postfix({"loss": loss.item()}) pbar.update() - texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) + texture = np.clip( + texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 + ).astype(np.uint8) mask = 1 - utils3d.torch.rasterize_triangle_faces( rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size - )['mask'][0].detach().cpu().numpy().astype(np.uint8) + )["mask"][0].detach().cpu().numpy().astype(np.uint8) texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) else: - raise ValueError(f'Unknown mode: {mode}') + raise ValueError(f"Unknown mode: {mode}") return texture @@ -421,15 +571,16 @@ def to_glb( """ vertices = mesh.vertices.cpu().numpy() faces = mesh.faces.cpu().numpy() - + # mesh postprocess vertices, faces = postprocess_mesh( - vertices, faces, + vertices, + faces, simplify=simplify > 0, simplify_ratio=simplify, fill_holes=fill_holes, fill_holes_max_hole_size=fill_holes_max_size, - fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), + fill_holes_max_hole_nbe=int(250 * np.sqrt(1 - simplify)), fill_holes_resolution=1024, fill_holes_num_views=1000, debug=debug, @@ -440,16 +591,24 @@ def to_glb( vertices, faces, uvs = parametrize_mesh(vertices, faces) # bake texture - observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) + observations, extrinsics, intrinsics = render_multiview( + app_rep, resolution=1024, nviews=100 + ) masks = [np.any(observation > 0, axis=-1) for observation in observations] extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] texture = bake_texture( - vertices, faces, uvs, - observations, masks, extrinsics, intrinsics, - texture_size=texture_size, mode='opt', + vertices, + faces, + uvs, + observations, + masks, + extrinsics, + intrinsics, + texture_size=texture_size, + mode="opt", lambda_tv=0.01, - verbose=verbose + verbose=verbose, ) texture = Image.fromarray(texture) @@ -458,9 +617,11 @@ def to_glb( material = trimesh.visual.material.PBRMaterial( roughnessFactor=1.0, baseColorTexture=texture, - baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), + ) + mesh = trimesh.Trimesh( + vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) ) - mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) return mesh @@ -472,56 +633,70 @@ def simplify_gs( """ Simplify 3D Gaussians NOTE: this function is not used in the current implementation for the unsatisfactory performance. - + Args: gs (Gaussian): 3D Gaussian. simplify (float): Ratio of Gaussians to remove in simplification. """ if simplify <= 0: return gs - + # simplify - observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100) - observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] - + observations, extrinsics, intrinsics = render_multiview( + gs, resolution=1024, nviews=100 + ) + observations = [ + torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) + for obs in observations + ] + # Following https://arxiv.org/pdf/2411.06019 - renderer = GaussianRenderer({ + renderer = GaussianRenderer( + { "resolution": 1024, "near": 0.8, "far": 1.6, "ssaa": 1, - "bg_color": (0,0,0), - }) + "bg_color": (0, 0, 0), + } + ) new_gs = Gaussian(**gs.init_params) new_gs._features_dc = gs._features_dc.clone() - new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None + new_gs._features_rest = ( + gs._features_rest.clone() if gs._features_rest is not None else None + ) new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) - + start_lr = [1e-4, 1e-3, 5e-3, 0.025] end_lr = [1e-6, 1e-5, 5e-5, 0.00025] - optimizer = torch.optim.Adam([ - {"params": new_gs._xyz, "lr": start_lr[0]}, - {"params": new_gs._rotation, "lr": start_lr[1]}, - {"params": new_gs._scaling, "lr": start_lr[2]}, - {"params": new_gs._opacity, "lr": start_lr[3]}, - ], lr=start_lr[0]) - + optimizer = torch.optim.Adam( + [ + {"params": new_gs._xyz, "lr": start_lr[0]}, + {"params": new_gs._rotation, "lr": start_lr[1]}, + {"params": new_gs._scaling, "lr": start_lr[2]}, + {"params": new_gs._opacity, "lr": start_lr[3]}, + ], + lr=start_lr[0], + ) + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): - return start_lr * (end_lr / start_lr) ** (step / total_steps) + return start_lr * (end_lr / start_lr) ** (step / total_steps) def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): - return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) - + return end_lr + 0.5 * (start_lr - end_lr) * ( + 1 + np.cos(np.pi * step / total_steps) + ) + _zeta = new_gs.get_opacity.clone().detach().squeeze() _lambda = torch.zeros_like(_zeta) _delta = 1e-7 _interval = 10 num_target = int((1 - simplify) * _zeta.shape[0]) - - with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: + + with tqdm(total=2500, disable=not verbose, desc="Simplifying Gaussian") as pbar: for i in range(2500): # prune if i % 100 == 0: @@ -532,21 +707,28 @@ def simplify_gs( new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) new_gs._features_dc = new_gs._features_dc[mask] - new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None + new_gs._features_rest = ( + new_gs._features_rest[mask] + if new_gs._features_rest is not None + else None + ) _zeta = _zeta[mask] _lambda = _lambda[mask] # update optimizer state - for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): - stored_state = optimizer.state[param_group['params'][0]] - if 'exp_avg' in stored_state: - stored_state['exp_avg'] = stored_state['exp_avg'][mask] - stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] - del optimizer.state[param_group['params'][0]] - param_group['params'][0] = new_param - optimizer.state[param_group['params'][0]] = stored_state + for param_group, new_param in zip( + optimizer.param_groups, + [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity], + ): + stored_state = optimizer.state[param_group["params"][0]] + if "exp_avg" in stored_state: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + del optimizer.state[param_group["params"][0]] + param_group["params"][0] = new_param + optimizer.state[param_group["params"][0]] = stored_state opacity = new_gs.get_opacity.squeeze() - + # sparisfy if i % _interval == 0: _zeta = _lambda + opacity.detach() @@ -556,32 +738,41 @@ def simplify_gs( _m[index] = 0 _zeta[_m] = 0 _lambda = _lambda + opacity.detach() - _zeta - + # sample a random view view_idx = np.random.randint(len(observations)) observation = observations[view_idx] extrinsic = extrinsics[view_idx] intrinsic = intrinsics[view_idx] - - color = renderer.render(new_gs, extrinsic, intrinsic)['color'] + + color = renderer.render(new_gs, extrinsic, intrinsic)["color"] rgb_loss = torch.nn.functional.l1_loss(color, observation) - loss = rgb_loss + \ - _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) - + loss = rgb_loss + _delta * torch.sum( + torch.pow(_lambda + opacity - _zeta, 2) + ) + optimizer.zero_grad() loss.backward() optimizer.step() - + # update lr for j in range(len(optimizer.param_groups)): - optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) - - pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) + optimizer.param_groups[j]["lr"] = cosine_anealing( + optimizer, i, 2500, start_lr[j], end_lr[j] + ) + + pbar.set_postfix( + { + "loss": rgb_loss.item(), + "num": opacity.shape[0], + "lambda": _lambda.mean().item(), + } + ) pbar.update() - + new_gs._xyz = new_gs._xyz.data new_gs._rotation = new_gs._rotation.data new_gs._scaling = new_gs._scaling.data new_gs._opacity = new_gs._opacity.data - + return new_gs diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py index 5b668c277b51f4930991912a80573adc79364028..51f0af9e6b57e3385bc033882f7c602799891729 100644 --- a/trellis/utils/random_utils.py +++ b/trellis/utils/random_utils.py @@ -2,6 +2,7 @@ import numpy as np PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + def radical_inverse(base, n): val = 0 inv_base = 1.0 / base @@ -13,12 +14,15 @@ def radical_inverse(base, n): inv_base_n *= inv_base return val + def halton_sequence(dim, n): return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + def hammersley_sequence(dim, n, num_samples): return [n / num_samples] + halton_sequence(dim - 1, n) + def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): u, v = hammersley_sequence(2, n, num_samples) u += offset[0] / num_samples @@ -27,4 +31,4 @@ def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 theta = np.arccos(1 - 2 * u) - np.pi / 2 phi = v * 2 * np.pi - return [phi, theta] \ No newline at end of file + return [phi, theta] diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py index 8187c84f305d51540e88ae5b634a484a74c16e95..cf462c9c82839ef703012312f164a8a57a52777b 100644 --- a/trellis/utils/render_utils.py +++ b/trellis/utils/render_utils.py @@ -25,12 +25,21 @@ def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): fov = torch.deg2rad(torch.tensor(float(fov))).cuda() yaw = torch.tensor(float(yaw)).cuda() pitch = torch.tensor(float(pitch)).cuda() - orig = torch.tensor([ - torch.sin(yaw) * torch.cos(pitch), - torch.cos(yaw) * torch.cos(pitch), - torch.sin(pitch), - ]).cuda() * r - extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + orig = ( + torch.tensor( + [ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ] + ).cuda() + * r + ) + extr = utils3d.torch.extrinsics_look_at( + orig, + torch.tensor([0, 0, 0]).float().cuda(), + torch.tensor([0, 0, 1]).float().cuda(), + ) intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) extrinsics.append(extr) intrinsics.append(intr) @@ -40,60 +49,93 @@ def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): return extrinsics, intrinsics -def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): +def render_frames( + sample, + extrinsics, + intrinsics, + options={}, + colors_overwrite=None, + verbose=True, + **kwargs, +): if isinstance(sample, Octree): renderer = OctreeRenderer() - renderer.rendering_options.resolution = options.get('resolution', 512) - renderer.rendering_options.near = options.get('near', 0.8) - renderer.rendering_options.far = options.get('far', 1.6) - renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) - renderer.rendering_options.ssaa = options.get('ssaa', 4) + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 0.8) + renderer.rendering_options.far = options.get("far", 1.6) + renderer.rendering_options.bg_color = options.get("bg_color", (0, 0, 0)) + renderer.rendering_options.ssaa = options.get("ssaa", 4) renderer.pipe.primitive = sample.primitive elif isinstance(sample, Gaussian): renderer = GaussianRenderer() - renderer.rendering_options.resolution = options.get('resolution', 512) - renderer.rendering_options.near = options.get('near', 0.8) - renderer.rendering_options.far = options.get('far', 1.6) - renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) - renderer.rendering_options.ssaa = options.get('ssaa', 1) - renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 0.8) + renderer.rendering_options.far = options.get("far", 1.6) + renderer.rendering_options.bg_color = options.get("bg_color", (0, 0, 0)) + renderer.rendering_options.ssaa = options.get("ssaa", 1) + renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1) renderer.pipe.use_mip_gaussian = True elif isinstance(sample, MeshExtractResult): renderer = MeshRenderer() - renderer.rendering_options.resolution = options.get('resolution', 512) - renderer.rendering_options.near = options.get('near', 1) - renderer.rendering_options.far = options.get('far', 100) - renderer.rendering_options.ssaa = options.get('ssaa', 4) + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 1) + renderer.rendering_options.far = options.get("far", 100) + renderer.rendering_options.ssaa = options.get("ssaa", 4) else: - raise ValueError(f'Unsupported sample type: {type(sample)}') - + raise ValueError(f"Unsupported sample type: {type(sample)}") + rets = {} - for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): + for j, (extr, intr) in tqdm( + enumerate(zip(extrinsics, intrinsics)), desc="Rendering", disable=not verbose + ): if not isinstance(sample, MeshExtractResult): res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) - if 'color' not in rets: rets['color'] = [] - if 'depth' not in rets: rets['depth'] = [] - rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) - if 'percent_depth' in res: - rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) - elif 'depth' in res: - rets['depth'].append(res['depth'].detach().cpu().numpy()) + if "color" not in rets: + rets["color"] = [] + if "depth" not in rets: + rets["depth"] = [] + rets["color"].append( + np.clip( + res["color"].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8) + ) + if "percent_depth" in res: + rets["depth"].append(res["percent_depth"].detach().cpu().numpy()) + elif "depth" in res: + rets["depth"].append(res["depth"].detach().cpu().numpy()) else: - rets['depth'].append(None) + rets["depth"].append(None) else: res = renderer.render(sample, extr, intr) - if 'normal' not in rets: rets['normal'] = [] - rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + if "normal" not in rets: + rets["normal"] = [] + rets["normal"].append( + np.clip( + res["normal"].detach().cpu().numpy().transpose(1, 2, 0) * 255, + 0, + 255, + ).astype(np.uint8) + ) return rets -def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): +def render_video( + sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs +): yaws = torch.linspace(0, 2 * 3.1415, num_frames) pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) yaws = yaws.tolist() pitch = pitch.tolist() - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) - return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaws, pitch, r, fov + ) + return render_frames( + sample, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": bg_color}, + **kwargs, + ) def render_multiview(sample, resolution=512, nviews=30): @@ -102,15 +144,38 @@ def render_multiview(sample, resolution=512, nviews=30): cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] yaws = [cam[0] for cam in cams] pitchs = [cam[1] for cam in cams] - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) - res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) - return res['color'], extrinsics, intrinsics + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaws, pitchs, r, fov + ) + res = render_frames( + sample, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": (0, 0, 0)}, + ) + return res["color"], extrinsics, intrinsics -def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): - yaw = [0, np.pi/2, np.pi, 3*np.pi/2] +def render_snapshot( + samples, + resolution=512, + bg_color=(0, 0, 0), + offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), + r=10, + fov=8, + **kwargs, +): + yaw = [0, np.pi / 2, np.pi, 3 * np.pi / 2] yaw_offset = offset[0] yaw = [y + yaw_offset for y in yaw] pitch = [offset[1] for _ in range(4)] - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) - return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaw, pitch, r, fov + ) + return render_frames( + samples, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": bg_color}, + **kwargs, + )